mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-19 06:19:54 +00:00
Compare commits
19 Commits
mdm_integr
...
fix/browse
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d6759ce93c | ||
|
|
8d9580e491 | ||
|
|
5bd7c6c7ea | ||
|
|
8ae2cd0a08 | ||
|
|
e4397d4d46 | ||
|
|
6fbc90b4d3 | ||
|
|
5095e17cc5 | ||
|
|
6df0175607 | ||
|
|
3c23700e56 | ||
|
|
38ad2b67e8 | ||
|
|
01aa49433e | ||
|
|
08a2b63675 | ||
|
|
b3f9e6588a | ||
|
|
967e2d6864 | ||
|
|
e7c1d364c3 | ||
|
|
a44198fd77 | ||
|
|
b57f714350 | ||
|
|
f893abc41d | ||
|
|
60067619a1 |
55
.github/workflows/release.yml
vendored
55
.github/workflows/release.yml
vendored
@@ -9,10 +9,13 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.1.5"
|
||||
GORELEASER_VER: "v2.14.3"
|
||||
SIGN_PIPE_VER: "v0.1.6"
|
||||
GORELEASER_VER: "v2.16.0"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
flags: ""
|
||||
SKIP_PUBLISH: "true"
|
||||
SKIP_DOCKER_PUSH: "false"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
@@ -130,8 +133,6 @@ jobs:
|
||||
windows_packages_artifact_url: ${{ steps.upload_windows_packages.outputs.artifact-url }}
|
||||
macos_packages_artifact_url: ${{ steps.upload_macos_packages.outputs.artifact-url }}
|
||||
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
|
||||
env:
|
||||
flags: ""
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
@@ -143,8 +144,27 @@ jobs:
|
||||
id: semver_parser
|
||||
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Set snapshot flag
|
||||
if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: |
|
||||
echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
|
||||
- name: Set build vars
|
||||
if: ${{ startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: |
|
||||
if [[ "x-${{ steps.semver_parser.outputs.prerelease }}" == "x-" && "x-${{ github.repository }}" == "x-netbirdio/netbird" ]]; then
|
||||
echo "x-${{ github.repository }}"
|
||||
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
|
||||
echo "SKIP_PUBLISH=false" >> $GITHUB_ENV
|
||||
else
|
||||
echo "x-${{ github.repository }}"
|
||||
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
|
||||
fi
|
||||
|
||||
if [[ "x-${{ github.repository }}" != "x-netbirdio/netbird" ]]; then
|
||||
echo "SKIP_DOCKER_PUSH=true" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
@@ -161,6 +181,8 @@ jobs:
|
||||
${{ runner.os }}-go-releaser-
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
- name: run openapi generator
|
||||
run: bash shared/management/http/api/generate.sh
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Set up QEMU
|
||||
@@ -210,6 +232,8 @@ jobs:
|
||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
|
||||
NFPM_NETBIRD_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
|
||||
SKIP_PUBLISH: ${{ env.SKIP_PUBLISH }}
|
||||
SKIP_DOCKER_PUSH: ${{ env.SKIP_DOCKER_PUSH }}
|
||||
- name: Verify RPM signatures
|
||||
run: |
|
||||
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
|
||||
@@ -332,8 +356,22 @@ jobs:
|
||||
id: semver_parser
|
||||
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Set snapshot flag
|
||||
if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: |
|
||||
echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
|
||||
- name: Set build vars
|
||||
if: ${{ startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: |
|
||||
if [[ "x-${{ steps.semver_parser.outputs.prerelease }}" == "x-" && "x-${{ github.repository }}" == "x-netbirdio/netbird" ]]; then
|
||||
echo "x-${{ github.repository }}"
|
||||
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
|
||||
echo "SKIP_PUBLISH=false" >> $GITHUB_ENV
|
||||
else
|
||||
echo "x-${{ github.repository }}"
|
||||
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
|
||||
fi
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
@@ -393,6 +431,7 @@ jobs:
|
||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
|
||||
NFPM_NETBIRD_UI_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
|
||||
SKIP_PUBLISH: ${{ env.SKIP_PUBLISH }}
|
||||
- name: Verify RPM signatures
|
||||
run: |
|
||||
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
|
||||
|
||||
862
.goreleaser.yaml
862
.goreleaser.yaml
@@ -1,5 +1,7 @@
|
||||
version: 2
|
||||
|
||||
env:
|
||||
- SKIP_PUBLISH={{ if index .Env "SKIP_PUBLISH" }}{{ .Env.SKIP_PUBLISH }}{{ else }}true{{ end }}
|
||||
- SKIP_DOCKER_PUSH={{ if index .Env "SKIP_DOCKER_PUSH" }}{{ .Env.SKIP_DOCKER_PUSH }}{{ else }}false{{ end }}
|
||||
project_name: netbird
|
||||
builds:
|
||||
- id: netbird-wasm
|
||||
@@ -74,6 +76,8 @@ builds:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
goarm:
|
||||
- 7
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
@@ -88,6 +92,8 @@ builds:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
goarm:
|
||||
- 7
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
@@ -102,6 +108,8 @@ builds:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
goarm:
|
||||
- 7
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
@@ -122,6 +130,8 @@ builds:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
goarm:
|
||||
- 7
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
@@ -136,6 +146,8 @@ builds:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
goarm:
|
||||
- 7
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
@@ -150,6 +162,8 @@ builds:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
goarm:
|
||||
- 7
|
||||
ldflags:
|
||||
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
@@ -170,6 +184,8 @@ builds:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
goarm:
|
||||
- 7
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
@@ -222,670 +238,192 @@ nfpms:
|
||||
rpm:
|
||||
signature:
|
||||
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
|
||||
dockers:
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile
|
||||
extra_files:
|
||||
- client/netbird-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile
|
||||
extra_files:
|
||||
- client/netbird-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile
|
||||
extra_files:
|
||||
- client/netbird-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
ids:
|
||||
- netbird
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile-rootless
|
||||
extra_files:
|
||||
- client/netbird-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
ids:
|
||||
- netbird
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile-rootless
|
||||
extra_files:
|
||||
- client/netbird-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
ids:
|
||||
- netbird
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile-rootless
|
||||
extra_files:
|
||||
- client/netbird-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
|
||||
- image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-relay
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: relay/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-relay
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: relay/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-relay
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: relay/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/signal:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-signal
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: signal/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/signal:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-signal
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: signal/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/signal:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-signal
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: signal/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: management/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: management/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: management/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-debug-amd64
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: management/Dockerfile.debug
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: management/Dockerfile.debug
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-debug-arm
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: management/Dockerfile.debug
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-upload
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: upload-server/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-upload
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: upload-server/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-upload
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: upload-server/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-server
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: combined/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-server
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: combined/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-server
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: combined/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-proxy
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: proxy/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-proxy
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: proxy/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-proxy
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: proxy/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
docker_manifests:
|
||||
- name_template: netbirdio/netbird:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
- netbirdio/netbird:{{ .Version }}-arm
|
||||
- netbirdio/netbird:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/netbird:latest
|
||||
image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
- netbirdio/netbird:{{ .Version }}-arm
|
||||
- netbirdio/netbird:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/netbird:{{ .Version }}-rootless
|
||||
image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
|
||||
- name_template: netbirdio/netbird:rootless-latest
|
||||
image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
|
||||
- name_template: netbirdio/relay:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-arm64v8
|
||||
- netbirdio/relay:{{ .Version }}-arm
|
||||
- netbirdio/relay:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/relay:latest
|
||||
image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-arm64v8
|
||||
- netbirdio/relay:{{ .Version }}-arm
|
||||
- netbirdio/relay:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/signal:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/signal:{{ .Version }}-arm64v8
|
||||
- netbirdio/signal:{{ .Version }}-arm
|
||||
- netbirdio/signal:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/signal:latest
|
||||
image_templates:
|
||||
- netbirdio/signal:{{ .Version }}-arm64v8
|
||||
- netbirdio/signal:{{ .Version }}-arm
|
||||
- netbirdio/signal:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/management:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/management:{{ .Version }}-arm64v8
|
||||
- netbirdio/management:{{ .Version }}-arm
|
||||
- netbirdio/management:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/management:latest
|
||||
image_templates:
|
||||
- netbirdio/management:{{ .Version }}-arm64v8
|
||||
- netbirdio/management:{{ .Version }}-arm
|
||||
- netbirdio/management:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/management:debug-latest
|
||||
image_templates:
|
||||
- netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||
- netbirdio/management:{{ .Version }}-debug-arm
|
||||
- netbirdio/management:{{ .Version }}-debug-amd64
|
||||
- name_template: netbirdio/upload:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- netbirdio/upload:{{ .Version }}-arm
|
||||
- netbirdio/upload:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/upload:latest
|
||||
image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- netbirdio/upload:{{ .Version }}-arm
|
||||
- netbirdio/upload:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/netbird-server:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/netbird-server:latest
|
||||
image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}-rootless
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird:rootless-latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/relay:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/relay:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/signal:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/signal:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/management:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/management:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/management:debug-latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/upload:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/upload:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird-server:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird-server:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/reverse-proxy:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/reverse-proxy:latest
|
||||
image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/reverse-proxy:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/reverse-proxy:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
dockers_v2:
|
||||
- id: netbird
|
||||
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
|
||||
ids:
|
||||
- netbird
|
||||
images:
|
||||
- netbirdio/netbird
|
||||
- ghcr.io/netbirdio/netbird
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: client/Dockerfile
|
||||
extra_files:
|
||||
- client/netbird-entrypoint.sh
|
||||
platforms:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
- linux/arm/6
|
||||
annotations:
|
||||
"org.opencontainers.image.created": "{{.Date}}"
|
||||
"org.opencontainers.image.title": "{{.ProjectName}}"
|
||||
"org.opencontainers.image.version": "{{.Version}}"
|
||||
"org.opencontainers.image.revision": "{{.FullCommit}}"
|
||||
"org.opencontainers.image.source": "{{.GitURL}}"
|
||||
"maintainer": "dev@netbird.io"
|
||||
- id: netbird-rootless
|
||||
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
|
||||
ids:
|
||||
- netbird
|
||||
images:
|
||||
- netbirdio/netbird
|
||||
- ghcr.io/netbirdio/netbird
|
||||
tags:
|
||||
- "v{{ .Version }}-rootless"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: client/Dockerfile-rootless
|
||||
extra_files:
|
||||
- client/netbird-entrypoint.sh
|
||||
platforms:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
- linux/arm/6
|
||||
annotations:
|
||||
"org.opencontainers.image.created": "{{.Date}}"
|
||||
"org.opencontainers.image.title": "{{.ProjectName}}"
|
||||
"org.opencontainers.image.version": "{{.Version}}"
|
||||
"org.opencontainers.image.revision": "{{.FullCommit}}"
|
||||
"org.opencontainers.image.source": "{{.GitURL}}"
|
||||
"maintainer": "dev@netbird.io"
|
||||
- id: relay
|
||||
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
|
||||
ids:
|
||||
- netbird-relay
|
||||
images:
|
||||
- netbirdio/relay
|
||||
- ghcr.io/netbirdio/relay
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: relay/Dockerfile
|
||||
platforms:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
- linux/arm
|
||||
annotations:
|
||||
"org.opencontainers.image.created": "{{.Date}}"
|
||||
"org.opencontainers.image.title": "{{.ProjectName}}"
|
||||
"org.opencontainers.image.version": "{{.Version}}"
|
||||
"org.opencontainers.image.revision": "{{.FullCommit}}"
|
||||
"org.opencontainers.image.source": "{{.GitURL}}"
|
||||
"maintainer": "dev@netbird.io"
|
||||
- id: signal
|
||||
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
|
||||
ids:
|
||||
- netbird-signal
|
||||
images:
|
||||
- netbirdio/signal
|
||||
- ghcr.io/netbirdio/signal
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: signal/Dockerfile
|
||||
platforms:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
- linux/arm
|
||||
annotations:
|
||||
"org.opencontainers.image.created": "{{.Date}}"
|
||||
"org.opencontainers.image.title": "{{.ProjectName}}"
|
||||
"org.opencontainers.image.version": "{{.Version}}"
|
||||
"org.opencontainers.image.revision": "{{.FullCommit}}"
|
||||
"org.opencontainers.image.source": "{{.GitURL}}"
|
||||
"maintainer": "dev@netbird.io"
|
||||
- id: management
|
||||
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
images:
|
||||
- netbirdio/management
|
||||
- ghcr.io/netbirdio/management
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: management/Dockerfile
|
||||
platforms:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
- linux/arm
|
||||
annotations:
|
||||
"org.opencontainers.image.created": "{{.Date}}"
|
||||
"org.opencontainers.image.title": "{{.ProjectName}}"
|
||||
"org.opencontainers.image.version": "{{.Version}}"
|
||||
"org.opencontainers.image.revision": "{{.FullCommit}}"
|
||||
"org.opencontainers.image.source": "{{.GitURL}}"
|
||||
"maintainer": "dev@netbird.io"
|
||||
- id: upload
|
||||
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
|
||||
ids:
|
||||
- netbird-upload
|
||||
images:
|
||||
- netbirdio/upload
|
||||
- ghcr.io/netbirdio/upload
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: upload-server/Dockerfile
|
||||
platforms:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
- linux/arm
|
||||
annotations:
|
||||
"org.opencontainers.image.created": "{{.Date}}"
|
||||
"org.opencontainers.image.title": "{{.ProjectName}}"
|
||||
"org.opencontainers.image.version": "{{.Version}}"
|
||||
"org.opencontainers.image.revision": "{{.FullCommit}}"
|
||||
"org.opencontainers.image.source": "{{.GitURL}}"
|
||||
"maintainer": "dev@netbird.io"
|
||||
- id: netbird-server
|
||||
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
|
||||
ids:
|
||||
- netbird-server
|
||||
images:
|
||||
- netbirdio/netbird-server
|
||||
- ghcr.io/netbirdio/netbird-server
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: combined/Dockerfile
|
||||
platforms:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
- linux/arm
|
||||
annotations:
|
||||
"org.opencontainers.image.created": "{{.Date}}"
|
||||
"org.opencontainers.image.title": "{{.ProjectName}}"
|
||||
"org.opencontainers.image.version": "{{.Version}}"
|
||||
"org.opencontainers.image.revision": "{{.FullCommit}}"
|
||||
"org.opencontainers.image.source": "{{.GitURL}}"
|
||||
"maintainer": "dev@netbird.io"
|
||||
- id: netbird-proxy
|
||||
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
|
||||
ids:
|
||||
- netbird-proxy
|
||||
images:
|
||||
- netbirdio/reverse-proxy
|
||||
- ghcr.io/netbirdio/reverse-proxy
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: proxy/Dockerfile
|
||||
platforms:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
- linux/arm
|
||||
annotations:
|
||||
"org.opencontainers.image.created": "{{.Date}}"
|
||||
"org.opencontainers.image.title": "{{.ProjectName}}"
|
||||
"org.opencontainers.image.version": "{{.Version}}"
|
||||
"org.opencontainers.image.revision": "{{.FullCommit}}"
|
||||
"org.opencontainers.image.source": "{{.GitURL}}"
|
||||
"maintainer": "dev@netbird.io"
|
||||
|
||||
brews:
|
||||
- ids:
|
||||
- default
|
||||
skip_upload: "{{ .Env.SKIP_PUBLISH }}"
|
||||
repository:
|
||||
owner: netbirdio
|
||||
name: homebrew-tap
|
||||
@@ -902,6 +440,7 @@ brews:
|
||||
|
||||
uploads:
|
||||
- name: debian
|
||||
skip: "{{ .Env.SKIP_PUBLISH }}"
|
||||
ids:
|
||||
- netbird_deb
|
||||
mode: archive
|
||||
@@ -910,6 +449,7 @@ uploads:
|
||||
method: PUT
|
||||
|
||||
- name: yum
|
||||
skip: "{{ .Env.SKIP_PUBLISH }}"
|
||||
ids:
|
||||
- netbird_rpm
|
||||
mode: archive
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
version: 2
|
||||
|
||||
env:
|
||||
- SKIP_PUBLISH={{ if index .Env "SKIP_PUBLISH" }}{{ .Env.SKIP_PUBLISH }}{{ else }}true{{ end }}
|
||||
project_name: netbird-ui
|
||||
builds:
|
||||
- id: netbird-ui
|
||||
@@ -101,6 +102,7 @@ nfpms:
|
||||
|
||||
uploads:
|
||||
- name: debian
|
||||
skip: "{{ .Env.SKIP_PUBLISH }}"
|
||||
ids:
|
||||
- netbird_ui_deb
|
||||
mode: archive
|
||||
@@ -109,6 +111,7 @@ uploads:
|
||||
method: PUT
|
||||
|
||||
- name: yum
|
||||
skip: "{{ .Env.SKIP_PUBLISH }}"
|
||||
ids:
|
||||
- netbird_ui_rpm
|
||||
mode: archive
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
||||
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
||||
|
||||
FROM alpine:3.23.3
|
||||
FROM alpine:3.24
|
||||
# iproute2: busybox doesn't display ip rules properly
|
||||
RUN apk add --no-cache \
|
||||
bash \
|
||||
@@ -21,7 +21,7 @@ ENV \
|
||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
|
||||
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||
|
||||
ARG NETBIRD_BINARY=netbird
|
||||
ARG TARGETPLATFORM
|
||||
ARG NETBIRD_BINARY=$TARGETPLATFORM/netbird
|
||||
COPY client/netbird-entrypoint.sh /usr/local/bin/netbird-entrypoint.sh
|
||||
COPY "${NETBIRD_BINARY}" /usr/local/bin/netbird
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
||||
# podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
||||
|
||||
FROM alpine:3.22.0
|
||||
FROM alpine:3.24
|
||||
|
||||
RUN apk add --no-cache \
|
||||
bash \
|
||||
@@ -27,7 +27,7 @@ ENV \
|
||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
|
||||
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||
|
||||
ARG NETBIRD_BINARY=netbird
|
||||
ARG TARGETPLATFORM
|
||||
ARG NETBIRD_BINARY=$TARGETPLATFORM/netbird
|
||||
COPY client/netbird-entrypoint.sh /usr/local/bin/netbird-entrypoint.sh
|
||||
COPY "${NETBIRD_BINARY}" /usr/local/bin/netbird
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
@@ -76,13 +75,6 @@ type Client struct {
|
||||
connectClient *internal.ConnectClient
|
||||
config *profilemanager.Config
|
||||
cacheDir string
|
||||
|
||||
// mdmLoader holds the per-Client MDM policy source. Set by
|
||||
// SetMDMPolicyFetcher (called from the Kotlin side). Each Run
|
||||
// passes this loader to the resolved Config so applyMDMPolicy
|
||||
// picks up the active overlay. Nil means "MDM enforcement off
|
||||
// for this Client".
|
||||
mdmLoader *mdm.Loader
|
||||
}
|
||||
|
||||
func (c *Client) setState(cfg *profilemanager.Config, cacheDir string, cc *internal.ConnectClient) {
|
||||
@@ -137,7 +129,6 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.applyMDMOverlay(cfg)
|
||||
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
|
||||
c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive)
|
||||
|
||||
@@ -182,7 +173,6 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.applyMDMOverlay(cfg)
|
||||
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
|
||||
c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive)
|
||||
|
||||
@@ -240,7 +230,6 @@ func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (strin
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
c.applyMDMOverlay(cfg)
|
||||
cacheDir = platformFiles.CacheDir()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
)
|
||||
|
||||
// PolicyFetcher is the mobile-side bridge for the MDM managed-config
|
||||
// snapshot. The native layer (Kotlin) implements this and registers
|
||||
// the instance per Client via Client.SetMDMPolicyFetcher. Every
|
||||
// invocation of fetchJSON must read the current RestrictionsManager
|
||||
// state and return the result as a JSON-encoded map[string]any string.
|
||||
//
|
||||
// JSON is used because gomobile does not support map[string]any
|
||||
// crossing the JNI boundary — the adapter on the Go side parses the
|
||||
// string back into the map[string]any expected by mdm.Loader.
|
||||
//
|
||||
// Return value contract:
|
||||
// - "" (empty) : interpreted as "no MDM source / no managed keys"
|
||||
// - "{}" : managed config explicitly empty
|
||||
// - "{...}" : JSON object with key/value pairs
|
||||
// - malformed JSON : logged and treated as empty
|
||||
type PolicyFetcher interface {
|
||||
FetchJSON() string
|
||||
}
|
||||
|
||||
// jsonFetcherAdapter wraps a gomobile-exposed PolicyFetcher into the
|
||||
// internal mdm.PolicyFetcher interface, taking care of JSON decoding
|
||||
// on every Fetch.
|
||||
type jsonFetcherAdapter struct {
|
||||
inner PolicyFetcher
|
||||
}
|
||||
|
||||
func (a *jsonFetcherAdapter) Fetch() map[string]any {
|
||||
raw := a.inner.FetchJSON()
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal([]byte(raw), &out); err != nil {
|
||||
log.Warnf("MDM mobile fetcher: invalid JSON payload from native: %v", err)
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// SetMDMPolicyFetcher registers the native-provided MDM policy fetcher
|
||||
// on this Client. Call once from the gomobile-init code (Kotlin
|
||||
// Application.onCreate or Service onCreate) before invoking Run /
|
||||
// RunWithoutLogin. Passing nil disables MDM enforcement on this
|
||||
// Client.
|
||||
//
|
||||
// The fetcher is held as a *mdm.Loader instance on the Client (no
|
||||
// package-level state) — multiple Clients in the same process get
|
||||
// independent Loaders, and tests can inject fakes per Client.
|
||||
func (c *Client) SetMDMPolicyFetcher(p PolicyFetcher) {
|
||||
if p == nil {
|
||||
c.mdmLoader = mdm.NewLoader(nil)
|
||||
return
|
||||
}
|
||||
c.mdmLoader = mdm.NewLoader(&jsonFetcherAdapter{inner: p})
|
||||
}
|
||||
|
||||
// applyMDMOverlay applies the Client-held MDM Loader's current policy
|
||||
// on top of the just-read Config. Called immediately after every
|
||||
// UpdateOrCreateConfig — profilemanager's apply() initialises the
|
||||
// policy to empty and leaves overlay responsibility to the lifecycle
|
||||
// owner. No-op when no fetcher was registered.
|
||||
func (c *Client) applyMDMOverlay(cfg *profilemanager.Config) {
|
||||
if cfg == nil || c.mdmLoader == nil {
|
||||
return
|
||||
}
|
||||
cfg.ApplyMDMPolicy(c.mdmLoader.Load())
|
||||
}
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
@@ -249,11 +248,6 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
||||
if err != nil {
|
||||
return fmt.Errorf("read config file %s: %v", configFilePath, err)
|
||||
}
|
||||
// CLI standalone login: profilemanager no longer auto-applies MDM,
|
||||
// so layer in the OS-native policy here. Desktop builds construct
|
||||
// a Loader with no fetcher — the build-tagged loadPlatform reads
|
||||
// the registry/plist directly.
|
||||
config.ApplyMDMPolicy(mdm.NewLoader(nil).Load())
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
|
||||
if err != nil {
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -188,10 +187,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
}
|
||||
// CLI foreground path runs without the daemon Server: layer in the
|
||||
// active MDM policy explicitly so a forced ManagementURL / PSK /
|
||||
// other managed key actually takes effect on this run.
|
||||
config.ApplyMDMPolicy(mdm.NewLoader(nil).Load())
|
||||
|
||||
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -216,10 +215,6 @@ func New(opts Options) (*Client, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create config: %w", err)
|
||||
}
|
||||
// Embedded path runs without the daemon Server: apply the active
|
||||
// MDM policy explicitly so a forced ManagementURL / PSK / other
|
||||
// managed key takes effect on this embedded engine instance.
|
||||
config.ApplyMDMPolicy(mdm.NewLoader(nil).Load())
|
||||
|
||||
if opts.PrivateKey != "" {
|
||||
config.PrivateKey = opts.PrivateKey
|
||||
|
||||
@@ -41,7 +41,6 @@ type ICEBind struct {
|
||||
*wgConn.StdNetBind
|
||||
|
||||
transportNet transport.Net
|
||||
filterFn udpmux.FilterFn
|
||||
address wgaddr.Address
|
||||
mtu uint16
|
||||
|
||||
@@ -61,12 +60,11 @@ type ICEBind struct {
|
||||
ipv6Conn *net.UDPConn
|
||||
}
|
||||
|
||||
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
||||
func NewICEBind(transportNet transport.Net, address wgaddr.Address, mtu uint16) *ICEBind {
|
||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||
ib := &ICEBind{
|
||||
StdNetBind: b,
|
||||
transportNet: transportNet,
|
||||
filterFn: filterFn,
|
||||
address: address,
|
||||
mtu: mtu,
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
@@ -265,7 +263,6 @@ func (s *ICEBind) createOrUpdateMux() {
|
||||
udpmux.UniversalUDPMuxParams{
|
||||
UDPConn: muxConn,
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
WGAddress: s.address,
|
||||
MTU: s.mtu,
|
||||
},
|
||||
|
||||
@@ -289,7 +289,7 @@ func setupICEBind(t *testing.T) *ICEBind {
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/10"),
|
||||
}
|
||||
return NewICEBind(transportNet, nil, address, 1280)
|
||||
return NewICEBind(transportNet, address, 1280)
|
||||
}
|
||||
|
||||
func createDualStackConns(t *testing.T) (*net.UDPConn, *net.UDPConn) {
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
@@ -41,10 +44,13 @@ type PacketCapture interface {
|
||||
type FilteredDevice struct {
|
||||
tun.Device
|
||||
|
||||
filter PacketFilter
|
||||
capture atomic.Pointer[PacketCapture]
|
||||
mutex sync.RWMutex
|
||||
closeOnce sync.Once
|
||||
filter PacketFilter
|
||||
capture atomic.Pointer[PacketCapture]
|
||||
// panicHandler is invoked after a panic in the underlying device is
|
||||
// recovered in Read or Write.
|
||||
panicHandler atomic.Pointer[func()]
|
||||
mutex sync.RWMutex
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
// newDeviceFilter constructor function
|
||||
@@ -70,7 +76,7 @@ func (d *FilteredDevice) Close() error {
|
||||
|
||||
// Read wraps read method with filtering feature
|
||||
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
|
||||
if n, err = d.deviceRead(bufs, sizes, offset); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@@ -112,7 +118,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
d.mutex.RUnlock()
|
||||
|
||||
if filter == nil {
|
||||
return d.Device.Write(bufs, offset)
|
||||
return d.deviceWrite(bufs, offset)
|
||||
}
|
||||
|
||||
filteredBufs := make([][]byte, 0, len(bufs))
|
||||
@@ -125,9 +131,44 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
n, err := d.Device.Write(filteredBufs, offset)
|
||||
n += dropped
|
||||
return n, err
|
||||
n, err := d.deviceWrite(filteredBufs, offset)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
return n + dropped, nil
|
||||
}
|
||||
|
||||
// deviceRead calls the underlying device Read, recovering from panics in the
|
||||
// wintun read path and converting them into errors.
|
||||
func (d *FilteredDevice) deviceRead(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
defer d.recoverFromPanic("read", &n, &err)
|
||||
return d.Device.Read(bufs, sizes, offset)
|
||||
}
|
||||
|
||||
// deviceWrite calls the underlying device Write, recovering from panics in the
|
||||
// wintun write path and converting them into errors.
|
||||
func (d *FilteredDevice) deviceWrite(bufs [][]byte, offset int) (n int, err error) {
|
||||
defer d.recoverFromPanic("write", &n, &err)
|
||||
return d.Device.Write(bufs, offset)
|
||||
}
|
||||
|
||||
// recoverFromPanic converts a panic in the underlying device into a regular
|
||||
// error and invokes the registered panic handler. The wintun read path is
|
||||
// known to panic on zero-length packets that third-party filter drivers can
|
||||
// place in the ring.
|
||||
func (d *FilteredDevice) recoverFromPanic(op string, n *int, err *error) {
|
||||
r := recover()
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Errorf("recovered panic in tun device %s: %v\n%s", op, r, debug.Stack())
|
||||
*n = 0
|
||||
*err = fmt.Errorf("tun device %s panic: %v", op, r)
|
||||
|
||||
if handler := d.panicHandler.Load(); handler != nil {
|
||||
(*handler)()
|
||||
}
|
||||
}
|
||||
|
||||
// SetFilter sets packet filter to device
|
||||
@@ -137,6 +178,17 @@ func (d *FilteredDevice) SetFilter(filter PacketFilter) {
|
||||
d.mutex.Unlock()
|
||||
}
|
||||
|
||||
// SetPanicHandler registers a handler invoked after a recovered panic in Read
|
||||
// or Write. The device is unusable after such a panic; the handler should
|
||||
// trigger recreation of the interface. Pass nil to remove.
|
||||
func (d *FilteredDevice) SetPanicHandler(handler func()) {
|
||||
if handler == nil {
|
||||
d.panicHandler.Store(nil)
|
||||
return
|
||||
}
|
||||
d.panicHandler.Store(&handler)
|
||||
}
|
||||
|
||||
// SetCapture sets or clears the packet capture sink. Pass nil to disable.
|
||||
// Uses atomic store so the hot path (Read/Write) is a single pointer load
|
||||
// with no locking overhead when capture is off.
|
||||
|
||||
@@ -221,3 +221,60 @@ func TestDeviceWrapperRead(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeviceWrapperReadPanic(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
tun := mocks.NewMockDevice(ctrl)
|
||||
tun.EXPECT().Read(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||
// Reproduce the wintun zero-length packet panic (index out of range).
|
||||
packet := make([]byte, 0)
|
||||
return int(packet[0]), nil
|
||||
})
|
||||
|
||||
wrapped := newDeviceFilter(tun)
|
||||
|
||||
handlerCalled := false
|
||||
wrapped.SetPanicHandler(func() { handlerCalled = true })
|
||||
|
||||
n, err := wrapped.Read([][]byte{{}}, []int{0}, 0)
|
||||
if err == nil {
|
||||
t.Errorf("expected error from recovered panic, got nil")
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("expected n=0, got %d", n)
|
||||
}
|
||||
if !handlerCalled {
|
||||
t.Errorf("expected panic handler to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceWrapperWritePanic(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
tun := mocks.NewMockDevice(ctrl)
|
||||
tun.EXPECT().Write(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(bufs [][]byte, offset int) (int, error) {
|
||||
packet := make([]byte, 0)
|
||||
return int(packet[0]), nil
|
||||
})
|
||||
|
||||
wrapped := newDeviceFilter(tun)
|
||||
|
||||
handlerCalled := false
|
||||
wrapped.SetPanicHandler(func() { handlerCalled = true })
|
||||
|
||||
n, err := wrapped.Write([][]byte{{0x45, 0x00}}, 0)
|
||||
if err == nil {
|
||||
t.Errorf("expected error from recovered panic, got nil")
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("expected n=0, got %d", n)
|
||||
}
|
||||
if !handlerCalled {
|
||||
t.Errorf("expected panic handler to be called")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,8 +32,6 @@ type TunKernelDevice struct {
|
||||
link *wgLink
|
||||
udpMuxConn net.PacketConn
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
|
||||
filterFn udpmux.FilterFn
|
||||
}
|
||||
|
||||
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice {
|
||||
@@ -104,7 +102,6 @@ func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
bindParams := udpmux.UniversalUDPMuxParams{
|
||||
UDPConn: nbnet.WrapPacketConn(rawSock),
|
||||
Net: t.transportNet,
|
||||
FilterFn: t.filterFn,
|
||||
WGAddress: t.address,
|
||||
MTU: t.mtu,
|
||||
}
|
||||
|
||||
@@ -63,7 +63,6 @@ type WGIFaceOpts struct {
|
||||
MTU uint16
|
||||
MobileArgs *device.MobileIFaceArguments
|
||||
TransportNet transport.Net
|
||||
FilterFn udpmux.FilterFn
|
||||
DisableDNS bool
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
|
||||
|
||||
var tun WGTunDevice
|
||||
if netstack.IsEnabled() {
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
|
||||
|
||||
if netstack.IsEnabled() {
|
||||
wgIFace := &WGIface{
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
|
||||
|
||||
wgIFace := &WGIface{
|
||||
tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
if netstack.IsEnabled() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
|
||||
return &WGIface{
|
||||
tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
|
||||
userspaceBind: true,
|
||||
@@ -30,7 +30,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
}
|
||||
|
||||
if device.ModuleTunIsLoaded() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
|
||||
return &WGIface{
|
||||
tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind),
|
||||
userspaceBind: true,
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -22,10 +20,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// FilterFn is a function that filters out candidates based on the address.
|
||||
// If it returns true, the address is to be filtered. It also returns the prefix of matching route.
|
||||
type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
|
||||
|
||||
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
|
||||
// It then passes packets to the UDPMux that does the actual connection muxing.
|
||||
type UniversalUDPMuxDefault struct {
|
||||
@@ -43,7 +37,6 @@ type UniversalUDPMuxParams struct {
|
||||
UDPConn net.PacketConn
|
||||
XORMappedAddrCacheTTL time.Duration
|
||||
Net transport.Net
|
||||
FilterFn FilterFn
|
||||
WGAddress wgaddr.Address
|
||||
MTU uint16
|
||||
}
|
||||
@@ -68,7 +61,6 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
||||
PacketConn: params.UDPConn,
|
||||
mux: m,
|
||||
logger: params.Logger,
|
||||
filterFn: params.FilterFn,
|
||||
address: params.WGAddress,
|
||||
}
|
||||
|
||||
@@ -115,15 +107,12 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// UDPConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
|
||||
// UDPConn is a wrapper around UDPMux conn that overrides WriteTo to drop packets destined for the overlay subnet.
|
||||
type UDPConn struct {
|
||||
net.PacketConn
|
||||
mux *UniversalUDPMuxDefault
|
||||
logger logging.LeveledLogger
|
||||
filterFn FilterFn
|
||||
// TODO: reset cache on route changes
|
||||
addrCache sync.Map
|
||||
address wgaddr.Address
|
||||
mux *UniversalUDPMuxDefault
|
||||
logger logging.LeveledLogger
|
||||
address wgaddr.Address
|
||||
}
|
||||
|
||||
// GetPacketConn returns the underlying PacketConn
|
||||
@@ -132,65 +121,16 @@ func (u *UDPConn) GetPacketConn() net.PacketConn {
|
||||
}
|
||||
|
||||
func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
if u.filterFn == nil {
|
||||
udpAddr, ok := addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
if isRouted, found := u.addrCache.Load(addr.String()); found {
|
||||
return u.handleCachedAddress(isRouted.(bool), b, addr)
|
||||
}
|
||||
|
||||
return u.handleUncachedAddress(b, addr)
|
||||
}
|
||||
|
||||
func (u *UDPConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
|
||||
if isRouted {
|
||||
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
|
||||
}
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
func (u *UDPConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
|
||||
if err := u.performFilterCheck(addr); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
||||
host, err := getHostFromAddr(addr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get host from address %s: %v", addr, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
a, err := netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse address %s: %v", addr, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if u.address.Network.Contains(a) {
|
||||
dst := udpAddr.AddrPort().Addr().Unmap()
|
||||
if (u.address.Network.IsValid() && u.address.Network.Contains(dst)) || (u.address.IPv6Net.IsValid() && u.address.IPv6Net.Contains(dst)) {
|
||||
log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
return 0, fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
}
|
||||
|
||||
if isRouted, prefix, err := u.filterFn(a); err != nil {
|
||||
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
|
||||
} else {
|
||||
u.addrCache.Store(addr.String(), isRouted)
|
||||
if isRouted {
|
||||
// Extra log, as the error only shows up with ICE logging enabled
|
||||
log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix)
|
||||
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getHostFromAddr(addr net.Addr) (string, error) {
|
||||
host, _, err := net.SplitHostPort(addr.String())
|
||||
return host, err
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
// GetSharedConn returns the shared udp conn
|
||||
@@ -225,6 +165,13 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A
|
||||
return nil
|
||||
}
|
||||
|
||||
src := udpAddr.AddrPort().Addr().Unmap()
|
||||
wg := m.params.WGAddress
|
||||
if (wg.Network.IsValid() && wg.Network.Contains(src)) || (wg.IPv6Net.IsValid() && wg.IPv6Net.Contains(src)) {
|
||||
log.Debugf("dropping STUN message from overlay source %s", udpAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
if m.isXORMappedResponse(msg, udpAddr.String()) {
|
||||
err := m.handleXORMappedResponse(udpAddr, msg)
|
||||
if err != nil {
|
||||
|
||||
@@ -66,7 +66,7 @@ func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
|
||||
iceBind := bind.NewICEBind(nil, wgAddress, 1280)
|
||||
endpointAddress := &net.UDPAddr{
|
||||
IP: net.IPv4(10, 0, 0, 1),
|
||||
Port: 1234,
|
||||
|
||||
@@ -22,7 +22,7 @@ func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
|
||||
iceBind := bind.NewICEBind(nil, wgAddress, 1280)
|
||||
endpointAddress := &net.UDPAddr{
|
||||
IP: net.IPv4(10, 0, 0, 1),
|
||||
Port: 1234,
|
||||
|
||||
@@ -118,6 +118,8 @@ func (c *ConnectClient) RunOniOS(
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsManager dns.IosDnsManager,
|
||||
stateFilePath string,
|
||||
cacheDir string,
|
||||
logFilePath string,
|
||||
) error {
|
||||
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
||||
debug.SetGCPercent(5)
|
||||
@@ -127,8 +129,9 @@ func (c *ConnectClient) RunOniOS(
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
DnsManager: dnsManager,
|
||||
StateFilePath: stateFilePath,
|
||||
TempDir: cacheDir,
|
||||
}
|
||||
return c.run(mobileDependency, nil, "")
|
||||
return c.run(mobileDependency, nil, logFilePath)
|
||||
}
|
||||
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
|
||||
|
||||
@@ -250,6 +250,7 @@ type BundleGenerator struct {
|
||||
syncResponse *mgmProto.SyncResponse
|
||||
logPath string
|
||||
tempDir string
|
||||
statePath string
|
||||
cpuProfile []byte
|
||||
capturePath string
|
||||
refreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
@@ -276,6 +277,7 @@ type GeneratorDependencies struct {
|
||||
SyncResponse *mgmProto.SyncResponse
|
||||
LogPath string
|
||||
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
|
||||
StatePath string // Path to the state file. If empty, the ServiceManager default path is used.
|
||||
CPUProfile []byte
|
||||
CapturePath string
|
||||
RefreshStatus func()
|
||||
@@ -299,6 +301,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
||||
syncResponse: deps.SyncResponse,
|
||||
logPath: deps.LogPath,
|
||||
tempDir: deps.TempDir,
|
||||
statePath: deps.StatePath,
|
||||
cpuProfile: deps.CPUProfile,
|
||||
capturePath: deps.CapturePath,
|
||||
refreshStatus: deps.RefreshStatus,
|
||||
@@ -850,8 +853,11 @@ func (g *BundleGenerator) maskSecrets() {
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addStateFile() error {
|
||||
sm := profilemanager.NewServiceManager("")
|
||||
path := sm.GetStatePath()
|
||||
path := g.statePath
|
||||
if path == "" {
|
||||
sm := profilemanager.NewServiceManager("")
|
||||
path = sm.GetStatePath()
|
||||
}
|
||||
if path == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
36
client/internal/debug/debug_ios.go
Normal file
36
client/internal/debug/debug_ios.go
Normal file
@@ -0,0 +1,36 @@
|
||||
//go:build ios
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// swiftLogFile is the Swift app log written by the iOS app into the same log
|
||||
// directory as the Go client log, so it can be collected into the bundle.
|
||||
const swiftLogFile = "swift-log.log"
|
||||
|
||||
// addPlatformLog collects logs for the iOS debug bundle. iOS has no logcat or
|
||||
// systemd journal, so we rely on file-based logs. addLogfile handles the Go
|
||||
// client log (logPath) with rotation, the stderr/stdout companions and
|
||||
// anonymization. The iOS app writes its own Swift log into the same directory,
|
||||
// so we add it alongside the Go log.
|
||||
func (g *BundleGenerator) addPlatformLog() error {
|
||||
if err := g.addLogfile(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if g.logPath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
swiftLogPath := filepath.Join(filepath.Dir(g.logPath), swiftLogFile)
|
||||
if err := g.addSingleLogfile(swiftLogPath, swiftLogFile); err != nil {
|
||||
// The Swift log is best-effort: the app may not have written it yet.
|
||||
log.Warnf("failed to add %s to debug bundle: %v", swiftLogFile, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build !android && !ios
|
||||
|
||||
package debug
|
||||
|
||||
|
||||
@@ -53,7 +53,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/syncstore"
|
||||
"github.com/netbirdio/netbird/client/internal/updater"
|
||||
@@ -240,7 +239,7 @@ type Engine struct {
|
||||
syncStore syncstore.Store
|
||||
syncStoreDir string
|
||||
|
||||
flowManager nftypes.FlowManager
|
||||
flowManager nftypes.FlowManager
|
||||
|
||||
// auto-update
|
||||
updateManager *updater.Manager
|
||||
@@ -531,6 +530,10 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return fmt.Errorf("create wg interface: %w", err)
|
||||
}
|
||||
|
||||
if filteredDevice := e.wgInterface.GetDevice(); filteredDevice != nil {
|
||||
filteredDevice.SetPanicHandler(e.triggerClientRestart)
|
||||
}
|
||||
|
||||
if err := e.createFirewall(); err != nil {
|
||||
e.close()
|
||||
return err
|
||||
@@ -1711,6 +1714,13 @@ func (e *Engine) receiveSignalEvents() {
|
||||
return e.ctx.Err()
|
||||
}
|
||||
|
||||
// Self-addressed heartbeat: the signal client's receive watchdog
|
||||
// round-trips this through the server to confirm the receive stream
|
||||
// is delivering. Liveness is already recorded before this handler.
|
||||
if msg.GetBody().GetType() == sProto.Body_HEARTBEAT {
|
||||
return nil
|
||||
}
|
||||
|
||||
conn, ok := e.peerStore.PeerConn(msg.Key)
|
||||
if !ok {
|
||||
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
||||
@@ -1909,7 +1919,6 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
|
||||
WGPrivKey: e.config.WgPrivateKey.String(),
|
||||
MTU: e.config.MTU,
|
||||
TransportNet: transportNet,
|
||||
FilterFn: e.addrViaRoutes,
|
||||
DisableDNS: e.config.DisableDNS,
|
||||
}
|
||||
|
||||
@@ -2157,21 +2166,6 @@ func (e *Engine) startNetworkMonitor() {
|
||||
}()
|
||||
}
|
||||
|
||||
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
|
||||
var vpnRoutes []netip.Prefix
|
||||
for _, routes := range e.routeManager.GetClientRoutes() {
|
||||
if len(routes) > 0 && routes[0] != nil {
|
||||
vpnRoutes = append(vpnRoutes, routes[0].Network)
|
||||
}
|
||||
}
|
||||
|
||||
if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn {
|
||||
return true, prefix, nil
|
||||
}
|
||||
|
||||
return false, netip.Prefix{}, nil
|
||||
}
|
||||
|
||||
func (e *Engine) stopDNSServer() {
|
||||
if e.dnsServer == nil {
|
||||
return
|
||||
|
||||
@@ -1024,14 +1024,17 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
return d.relayStates
|
||||
}
|
||||
|
||||
// extend the list of stun, turn servers with relay address
|
||||
// extend the list of stun, turn servers with the relay server connections
|
||||
relayStates := slices.Clone(d.relayStates)
|
||||
|
||||
// if the server connection is not established then we will use the general address
|
||||
// in case of connection we will use the instance specific address
|
||||
instanceAddr, _, err := d.relayMgr.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
// TODO add their status
|
||||
states := d.relayMgr.RelayStates()
|
||||
if len(states) == 0 {
|
||||
// no relay connection tracked yet; surface configured servers as
|
||||
// unavailable with the real reconnect error when known
|
||||
err := relayClient.ErrRelayClientNotConnected
|
||||
if connErr := d.relayMgr.RelayConnectError(); connErr != nil {
|
||||
err = connErr
|
||||
}
|
||||
for _, r := range d.relayMgr.ServerURLs() {
|
||||
relayStates = append(relayStates, relay.ProbeResult{
|
||||
URI: r,
|
||||
@@ -1041,10 +1044,14 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
return relayStates
|
||||
}
|
||||
|
||||
relayState := relay.ProbeResult{
|
||||
URI: instanceAddr,
|
||||
for _, rs := range states {
|
||||
relayStates = append(relayStates, relay.ProbeResult{
|
||||
URI: rs.URL,
|
||||
Err: rs.Err,
|
||||
Transport: rs.Transport,
|
||||
})
|
||||
}
|
||||
return append(relayStates, relayState)
|
||||
return relayStates
|
||||
}
|
||||
|
||||
func (d *Status) ForwardingRules() []firewall.ForwardRule {
|
||||
@@ -1405,6 +1412,7 @@ func (fs FullStatus) ToProto() *proto.FullStatus {
|
||||
pbRelayState := &proto.RelayState{
|
||||
URI: relayState.URI,
|
||||
Available: relayState.Err == nil,
|
||||
Transport: relayState.Transport,
|
||||
}
|
||||
if err := relayState.Err; err != nil {
|
||||
pbRelayState.Error = err.Error()
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -165,10 +164,6 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA
|
||||
return
|
||||
}
|
||||
|
||||
if candidateViaRoutes(candidate, haRoutes) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := w.agent.AddRemoteCandidate(candidate); err != nil {
|
||||
w.log.Errorf("error while handling remote candidate")
|
||||
return
|
||||
@@ -589,34 +584,6 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
|
||||
return ec, nil
|
||||
}
|
||||
|
||||
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
|
||||
addr, err := netip.ParseAddr(candidate.Address())
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
var routePrefixes []netip.Prefix
|
||||
for _, routes := range clientRoutes {
|
||||
if len(routes) > 0 && routes[0] != nil {
|
||||
routePrefixes = append(routePrefixes, routes[0].Network)
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range routePrefixes {
|
||||
// default route is handled by route exclusion / ip rules
|
||||
if prefix.Bits() == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if prefix.Contains(addr) {
|
||||
log.Debugf("Ignoring candidate [%s], its address is part of routed network %s", candidate.String(), prefix)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isRelayCandidate(candidate ice.Candidate) bool {
|
||||
return candidate.Type() == ice.CandidateTypeRelay
|
||||
}
|
||||
|
||||
@@ -58,6 +58,10 @@ var DefaultInterfaceBlacklist = []string{
|
||||
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
|
||||
}
|
||||
|
||||
// loadMDMPolicy is the package-level indirection used by apply() to read the
|
||||
// active MDM policy. Tests override this to inject a fake policy.
|
||||
var loadMDMPolicy = mdm.LoadPolicy
|
||||
|
||||
// ConfigInput carries configuration changes to the client
|
||||
type ConfigInput struct {
|
||||
ManagementURL string
|
||||
@@ -176,27 +180,14 @@ type Config struct {
|
||||
|
||||
MTU uint16
|
||||
|
||||
// policy is the MDM policy that produced the currently-set values
|
||||
// for any MDM-enforced fields. Set by ApplyMDMPolicy on every
|
||||
// invocation. Never persisted to disk. Callers query enforcement
|
||||
// state via Policy() and the mdm.Policy API (HasKey, ManagedKeys,
|
||||
// IsEmpty).
|
||||
// policy is the MDM policy that produced the currently-set values for
|
||||
// any MDM-enforced fields. Set by applyMDMPolicy at the tail of apply()
|
||||
// and reset on every apply() invocation. Never persisted to disk.
|
||||
// Callers query enforcement state via Policy() and the mdm.Policy API
|
||||
// (HasKey, ManagedKeys, IsEmpty).
|
||||
policy *mdm.Policy `json:"-"`
|
||||
}
|
||||
|
||||
// ApplyMDMPolicy overlays the supplied MDM Policy on top of the
|
||||
// currently resolved Config values. Idempotent — pass an empty Policy
|
||||
// to clear any prior overlay. The lifecycle owner (Server.getConfig
|
||||
// on desktop, the Client.Run path on mobile) calls this with
|
||||
// loader.Load() once the per-process Loader is known; the Config
|
||||
// itself holds no reference to the Loader.
|
||||
func (config *Config) ApplyMDMPolicy(policy *mdm.Policy) {
|
||||
if config == nil {
|
||||
return
|
||||
}
|
||||
config.applyMDMPolicy(policy)
|
||||
}
|
||||
|
||||
// Policy returns the MDM policy applied to this Config. Returns a non-nil
|
||||
// empty Policy when MDM enforcement is inactive; callers can always invoke
|
||||
// HasKey / ManagedKeys / IsEmpty without a nil check.
|
||||
@@ -643,11 +634,9 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
// Initialise the MDM overlay to "no enforcement" so Config.Policy()
|
||||
// never returns a stale or nil policy on a freshly applied Config.
|
||||
// Lifecycle owners that want to enforce a real MDM policy invoke
|
||||
// Config.ApplyMDMPolicy(loader.Load()) after this returns.
|
||||
config.applyMDMPolicy(mdm.NewPolicy(nil))
|
||||
// MDM is the last override layer: any key present in the policy
|
||||
// supersedes defaults, on-disk config, env vars and CLI input.
|
||||
config.applyMDMPolicy(loadMDMPolicy())
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
@@ -10,58 +10,24 @@ import (
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
)
|
||||
|
||||
// fakeFetcher implements mdm.PolicyFetcher returning a pre-set policy
|
||||
// map. Test helper used to construct a Loader without touching the OS
|
||||
// or any package-level state.
|
||||
type fakeFetcher struct{ values map[string]any }
|
||||
|
||||
func (f *fakeFetcher) Fetch() map[string]any { return f.values }
|
||||
|
||||
// loaderFor builds an mdm.Loader whose loadPlatform returns the
|
||||
// supplied Policy's underlying values.
|
||||
func loaderFor(policy *mdm.Policy) *mdm.Loader {
|
||||
if policy == nil || policy.IsEmpty() {
|
||||
return mdm.NewLoader(&fakeFetcher{values: nil})
|
||||
}
|
||||
values := make(map[string]any)
|
||||
for _, k := range policy.ManagedKeys() {
|
||||
if v, ok := policy.GetString(k); ok {
|
||||
values[k] = v
|
||||
continue
|
||||
}
|
||||
if v, ok := policy.GetBool(k); ok {
|
||||
values[k] = v
|
||||
continue
|
||||
}
|
||||
if v, ok := policy.GetInt(k); ok {
|
||||
values[k] = v
|
||||
continue
|
||||
}
|
||||
if v, ok := policy.GetStringSlice(k); ok {
|
||||
values[k] = v
|
||||
}
|
||||
}
|
||||
return mdm.NewLoader(&fakeFetcher{values: values})
|
||||
}
|
||||
|
||||
// configWithMDM is the test convenience that builds a Config via
|
||||
// UpdateOrCreateConfig and overlays the supplied MDM policy on top —
|
||||
// mirrors the production pattern (Server.getConfig / Client.applyMDMOverlay)
|
||||
// where the Loader lives outside Config and the apply step is driven
|
||||
// by the lifecycle owner.
|
||||
func configWithMDM(t *testing.T, input ConfigInput, policy *mdm.Policy) *Config {
|
||||
// withMDMPolicy temporarily overrides the package-level loadMDMPolicy hook so
|
||||
// apply() observes the supplied Policy. The original loader is restored at
|
||||
// test cleanup.
|
||||
func withMDMPolicy(t *testing.T, policy *mdm.Policy) {
|
||||
t.Helper()
|
||||
cfg, err := UpdateOrCreateConfig(input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
cfg.ApplyMDMPolicy(loaderFor(policy).Load())
|
||||
return cfg
|
||||
prev := loadMDMPolicy
|
||||
loadMDMPolicy = func() *mdm.Policy { return policy }
|
||||
t.Cleanup(func() { loadMDMPolicy = prev })
|
||||
}
|
||||
|
||||
func TestApply_MDMEmpty_NoEnforcement(t *testing.T) {
|
||||
cfg := configWithMDM(t, ConfigInput{
|
||||
withMDMPolicy(t, mdm.NewPolicy(nil))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
}, mdm.NewPolicy(nil))
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.True(t, cfg.Policy().IsEmpty(), "no MDM source ⇒ empty Policy")
|
||||
assert.False(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
|
||||
@@ -73,15 +39,18 @@ func TestApply_MDMEmpty_NoEnforcement(t *testing.T) {
|
||||
|
||||
func TestApply_MDMOnly_OverridesDefaults(t *testing.T) {
|
||||
const mdmURL = "https://corp.mdm.example.com:443"
|
||||
|
||||
cfg := configWithMDM(t, ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
}, mdm.NewPolicy(map[string]any{
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: mdmURL,
|
||||
mdm.KeyDisableClientRoutes: true,
|
||||
mdm.KeyBlockInbound: true,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.Equal(t, mdmURL, cfg.ManagementURL.String())
|
||||
assert.True(t, cfg.DisableClientRoutes)
|
||||
assert.True(t, cfg.BlockInbound)
|
||||
@@ -96,25 +65,33 @@ func TestApply_MDMBeatsCLIInput(t *testing.T) {
|
||||
const mdmURL = "https://mdm.example.com:443"
|
||||
const cliURL = "https://cli.example.com:443"
|
||||
|
||||
cfg := configWithMDM(t, ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
ManagementURL: cliURL,
|
||||
}, mdm.NewPolicy(map[string]any{
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: mdmURL,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
ManagementURL: cliURL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// MDM wins over CLI-supplied management URL.
|
||||
assert.Equal(t, mdmURL, cfg.ManagementURL.String())
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
|
||||
}
|
||||
|
||||
func TestApply_MDMInvalidURL_KeepsPreviousValue(t *testing.T) {
|
||||
cfg := configWithMDM(t, ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
}, mdm.NewPolicy(map[string]any{
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "not-a-url",
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// Invalid MDM URL is logged and skipped: default URL stays in place
|
||||
// to keep the client functional.
|
||||
assert.Equal(t, DefaultManagementURL, cfg.ManagementURL.String())
|
||||
@@ -129,20 +106,24 @@ func TestApply_MDMBoolKeysOverrideOnDiskValue(t *testing.T) {
|
||||
tmp := filepath.Join(t.TempDir(), "config.json")
|
||||
|
||||
// Seed without MDM.
|
||||
configWithMDM(t, ConfigInput{
|
||||
withMDMPolicy(t, mdm.NewPolicy(nil))
|
||||
_, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: tmp,
|
||||
DisableClientRoutes: boolPtr(false),
|
||||
RosenpassEnabled: boolPtr(false),
|
||||
}, mdm.NewPolicy(nil))
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now enable MDM enforcement for these keys.
|
||||
cfg := configWithMDM(t, ConfigInput{
|
||||
ConfigPath: tmp,
|
||||
}, mdm.NewPolicy(map[string]any{
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyDisableClientRoutes: true,
|
||||
mdm.KeyRosenpassEnabled: true,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{ConfigPath: tmp})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.True(t, cfg.DisableClientRoutes, "MDM override should flip on-disk false to true")
|
||||
assert.True(t, cfg.RosenpassEnabled)
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyDisableClientRoutes))
|
||||
@@ -152,12 +133,16 @@ func TestApply_MDMBoolKeysOverrideOnDiskValue(t *testing.T) {
|
||||
func TestApply_MDMPreSharedKeyRedactionSentinelRejected(t *testing.T) {
|
||||
const maskSentinel = "**********"
|
||||
|
||||
cfg := configWithMDM(t, ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
}, mdm.NewPolicy(map[string]any{
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyPreSharedKey: maskSentinel,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// Mask sentinel must not be persisted as the actual PSK.
|
||||
assert.NotEqual(t, maskSentinel, cfg.PreSharedKey)
|
||||
// Key still marked managed so user writes are still rejected.
|
||||
|
||||
@@ -32,6 +32,9 @@ type ProbeResult struct {
|
||||
URI string
|
||||
Err error
|
||||
Addr string
|
||||
// Transport is the negotiated relay transport, empty
|
||||
// for stun/turn probes or when not connected.
|
||||
Transport string
|
||||
}
|
||||
|
||||
type StunTurnProbe struct {
|
||||
|
||||
@@ -22,14 +22,14 @@ type removePeerCall struct {
|
||||
}
|
||||
|
||||
type mockServer struct {
|
||||
mu sync.Mutex
|
||||
addCalls []addPeerCall
|
||||
removed []removePeerCall
|
||||
nextID rp.PeerID
|
||||
addErr error
|
||||
removeErr error
|
||||
closed bool
|
||||
ran bool
|
||||
mu sync.Mutex
|
||||
addCalls []addPeerCall
|
||||
removed []removePeerCall
|
||||
nextID rp.PeerID
|
||||
addErr error
|
||||
removeErr error
|
||||
closed bool
|
||||
ran bool
|
||||
}
|
||||
|
||||
func (m *mockServer) AddPeer(cfg rp.PeerConfig) (rp.PeerID, error) {
|
||||
@@ -51,7 +51,7 @@ func (m *mockServer) RemovePeer(id rp.PeerID) error {
|
||||
return m.removeErr
|
||||
}
|
||||
|
||||
func (m *mockServer) Run() error { m.ran = true; return nil }
|
||||
func (m *mockServer) Run() error { m.ran = true; return nil }
|
||||
func (m *mockServer) Close() error { m.closed = true; return nil }
|
||||
|
||||
type setPSKCall struct {
|
||||
|
||||
@@ -41,4 +41,3 @@ func TestDeterministicSeedKey_TooShortKey_ReturnsError(t *testing.T) {
|
||||
_, err = DeterministicSeedKey(long, short)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"net/url"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -332,6 +333,8 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
||||
}
|
||||
}
|
||||
|
||||
m.notifier.Close()
|
||||
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
m.clientRoutes = nil
|
||||
@@ -700,6 +703,8 @@ func resolveURLsToIPs(urls []string) []net.IP {
|
||||
|
||||
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
|
||||
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
|
||||
m.mirrorV6ExitPairSelections(clientRoutes)
|
||||
|
||||
// An explicit user "deselect all" must not be overridden by management auto-apply.
|
||||
// Auto-applying an exit node here would call SelectRoutes, which clears the
|
||||
// deselect-all flag and re-enables every route the user turned off.
|
||||
@@ -716,6 +721,24 @@ func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HA
|
||||
m.logExitNodeUpdate(exitNodeInfo)
|
||||
}
|
||||
|
||||
// mirrorV6ExitPairSelections keeps every synthesized "-v6" exit route's selection
|
||||
// consistent with its v4 base. The v4/v6 exit pair is a single toggle, so the v6
|
||||
// entry always follows the base: deselecting the v4 exit node also drops its ::/0
|
||||
// pair, and any stale (orphaned) explicit selection on the v6 entry is reset. This
|
||||
// runs before selection is read so both collectExitNodeInfo and FilterSelectedExitNodes
|
||||
// see consistent state, including pairs loaded from persisted selector state.
|
||||
func (m *DefaultManager) mirrorV6ExitPairSelections(clientRoutes route.HAMap) {
|
||||
routesByNetID := make(map[route.NetID][]*route.Route, len(clientRoutes))
|
||||
for haID, routes := range clientRoutes {
|
||||
routesByNetID[haID.NetID()] = routes
|
||||
}
|
||||
|
||||
for v6ID := range route.V6ExitMergeSet(routesByNetID) {
|
||||
baseID := route.NetID(strings.TrimSuffix(string(v6ID), route.V6ExitSuffix))
|
||||
m.routeSelector.SyncPairedSelection(baseID, v6ID)
|
||||
}
|
||||
}
|
||||
|
||||
type exitNodeInfo struct {
|
||||
allIDs []route.NetID
|
||||
selectedByManagement []route.NetID
|
||||
|
||||
47
client/internal/routemanager/manager_v6exit_test.go
Normal file
47
client/internal/routemanager/manager_v6exit_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// TestUpdateRouteSelectorFromManagement_MirrorsV6ExitPair reproduces the bug seen
|
||||
// in netbird-engine.log: persisted selector state has the v4 exit node deselected
|
||||
// but its synthesized "-v6" pair explicitly selected (orphaned), so the ::/0 route
|
||||
// leaked onto the tunnel. The management update must mirror the v4 deselect onto the
|
||||
// v6 pair so FilterSelectedExitNodes drops it.
|
||||
func TestUpdateRouteSelectorFromManagement_MirrorsV6ExitPair(t *testing.T) {
|
||||
const (
|
||||
v4ID = route.NetID("Exit Node (raspberrypi)")
|
||||
v6ID = route.NetID("Exit Node (raspberrypi)-v6")
|
||||
)
|
||||
all := []route.NetID{v4ID, v6ID}
|
||||
|
||||
rs := routeselector.NewRouteSelector()
|
||||
// Orphan the v6 selection: select the pair, then deselect only the v4 base.
|
||||
require.NoError(t, rs.SelectRoutes([]route.NetID{v4ID, v6ID}, true, all))
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{v4ID}, all))
|
||||
require.True(t, rs.IsSelected(v6ID), "precondition: orphaned v6 selection survives v4 deselect")
|
||||
|
||||
m := &DefaultManager{routeSelector: rs}
|
||||
|
||||
v4Route := &route.Route{NetID: v4ID, Network: netip.MustParsePrefix("0.0.0.0/0")}
|
||||
v6Route := &route.Route{NetID: v6ID, Network: netip.MustParsePrefix("::/0")}
|
||||
clientRoutes := route.HAMap{
|
||||
"Exit Node (raspberrypi)|0.0.0.0/0": {v4Route},
|
||||
"Exit Node (raspberrypi)-v6|::/0": {v6Route},
|
||||
}
|
||||
|
||||
m.updateRouteSelectorFromManagement(clientRoutes)
|
||||
|
||||
assert.False(t, rs.IsSelected(v6ID), "v6 pair must follow the v4 base deselect after the management update")
|
||||
|
||||
filtered := rs.FilterSelectedExitNodes(clientRoutes)
|
||||
assert.Empty(t, filtered, "deselected v4 exit node must not leak its ::/0 pair onto the tunnel")
|
||||
}
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
type Notifier struct {
|
||||
initialRoutes []*route.Route
|
||||
currentRoutes []*route.Route
|
||||
fakeIPRoutes []*route.Route
|
||||
fakeIPRoutes []*route.Route
|
||||
|
||||
listener listener.NetworkChangeListener
|
||||
listenerMux sync.Mutex
|
||||
@@ -119,3 +119,7 @@ func (n *Notifier) GetInitialRouteRanges() []string {
|
||||
sort.Strings(initialStrings)
|
||||
return initialStrings
|
||||
}
|
||||
|
||||
func (n *Notifier) Close() {
|
||||
// unused
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
@@ -14,19 +15,26 @@ import (
|
||||
)
|
||||
|
||||
type Notifier struct {
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
currentPrefixes []string
|
||||
|
||||
listener listener.NetworkChangeListener
|
||||
listenerMux sync.Mutex
|
||||
listener listener.NetworkChangeListener
|
||||
queue *list.List
|
||||
closed bool
|
||||
}
|
||||
|
||||
func NewNotifier() *Notifier {
|
||||
return &Notifier{}
|
||||
n := &Notifier{
|
||||
queue: list.New(),
|
||||
}
|
||||
n.cond = sync.NewCond(&n.mu)
|
||||
go n.deliverLoop()
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
n.mu.Lock()
|
||||
defer n.mu.Unlock()
|
||||
n.listener = listener
|
||||
}
|
||||
|
||||
@@ -43,32 +51,52 @@ func (n *Notifier) OnNewRoutes(route.HAMap) {
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||
newNets := make([]string, 0)
|
||||
newNets := make([]string, 0, len(prefixes))
|
||||
for _, prefix := range prefixes {
|
||||
newNets = append(newNets, prefix.String())
|
||||
}
|
||||
|
||||
sort.Strings(newNets)
|
||||
|
||||
n.mu.Lock()
|
||||
if slices.Equal(n.currentPrefixes, newNets) {
|
||||
n.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
n.currentPrefixes = newNets
|
||||
n.notify()
|
||||
routes := strings.Join(n.currentPrefixes, ",")
|
||||
n.queue.PushBack(routes)
|
||||
n.cond.Signal()
|
||||
n.mu.Unlock()
|
||||
}
|
||||
func (n *Notifier) notify() {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
if n.listener == nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func(l listener.NetworkChangeListener) {
|
||||
l.OnNetworkChanged(strings.Join(n.currentPrefixes, ","))
|
||||
}(n.listener)
|
||||
func (n *Notifier) Close() {
|
||||
n.mu.Lock()
|
||||
n.closed = true
|
||||
n.cond.Signal()
|
||||
n.mu.Unlock()
|
||||
}
|
||||
|
||||
func (n *Notifier) GetInitialRouteRanges() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *Notifier) deliverLoop() {
|
||||
for {
|
||||
n.mu.Lock()
|
||||
for n.queue.Len() == 0 && !n.closed {
|
||||
n.cond.Wait()
|
||||
}
|
||||
if n.closed && n.queue.Len() == 0 {
|
||||
n.mu.Unlock()
|
||||
return
|
||||
}
|
||||
routes := n.queue.Remove(n.queue.Front()).(string)
|
||||
l := n.listener
|
||||
n.mu.Unlock()
|
||||
|
||||
if l != nil {
|
||||
l.OnNetworkChanged(routes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,3 +38,7 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||
func (n *Notifier) GetInitialRouteRanges() []string {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
func (n *Notifier) Close() {
|
||||
// unused
|
||||
}
|
||||
|
||||
@@ -121,9 +121,12 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
|
||||
return Nexthop{}, vars.ErrRouteNotAllowed
|
||||
}
|
||||
|
||||
// Check if the prefix is part of any local subnets
|
||||
if isLocal, subnet := r.isPrefixInLocalSubnets(prefix); isLocal {
|
||||
return Nexthop{}, fmt.Errorf("prefix %s is part of local subnet %s: %w", prefix, subnet, vars.ErrRouteNotAllowed)
|
||||
// BSDs blackhole a /32 added inside a directly-connected subnet; Linux/Windows need it to beat the wt0 route.
|
||||
switch runtime.GOOS {
|
||||
case "darwin", "freebsd", "netbsd", "openbsd", "dragonfly":
|
||||
if isLocal, subnet := r.isPrefixInLocalSubnets(prefix); isLocal {
|
||||
return Nexthop{}, fmt.Errorf("prefix %s is part of local subnet %s: %w", prefix, subnet, vars.ErrRouteNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
@@ -132,6 +131,33 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||
return rs.isSelectedLocked(routeID)
|
||||
}
|
||||
|
||||
// SyncPairedSelection forces pairedID's explicit selection state to match baseID's,
|
||||
// so a synthesized "-v6" exit route always follows its v4 base: selecting or
|
||||
// deselecting the v4 exit node governs the ::/0 pair, and any stale (orphaned)
|
||||
// explicit state on the v6 entry is reset. The v4/v6 exit pair is treated as a single
|
||||
// toggle, so the v6 entry carries no independent selection of its own.
|
||||
func (rs *RouteSelector) SyncPairedSelection(baseID, pairedID route.NetID) {
|
||||
rs.mu.Lock()
|
||||
defer rs.mu.Unlock()
|
||||
|
||||
if rs.deselectAll {
|
||||
return
|
||||
}
|
||||
|
||||
_, baseSelected := rs.selectedRoutes[baseID]
|
||||
_, baseDeselected := rs.deselectedRoutes[baseID]
|
||||
|
||||
delete(rs.selectedRoutes, pairedID)
|
||||
delete(rs.deselectedRoutes, pairedID)
|
||||
|
||||
switch {
|
||||
case baseSelected:
|
||||
rs.selectedRoutes[pairedID] = struct{}{}
|
||||
case baseDeselected:
|
||||
rs.deselectedRoutes[pairedID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// FilterSelected removes unselected routes from the provided map.
|
||||
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
||||
rs.mu.RLock()
|
||||
@@ -151,14 +177,13 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
||||
}
|
||||
|
||||
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this route.
|
||||
// Intended for exit-node code paths: a v6 exit-node pair (e.g. "MyExit-v6") with no explicit state of
|
||||
// its own inherits its v4 base's state, so legacy persisted selections that predate v6 pairing
|
||||
// transparently apply to the synthesized v6 entry.
|
||||
// The lookup is literal; v4/v6 exit pairs are kept consistent at write time via SyncPairedSelection,
|
||||
// so a synthesized "-v6" entry carries the same explicit state as its v4 base.
|
||||
func (rs *RouteSelector) HasUserSelectionForRoute(routeID route.NetID) bool {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
return rs.hasUserSelectionForRouteLocked(rs.effectiveNetID(routeID))
|
||||
return rs.hasUserSelectionForRouteLocked(routeID)
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap {
|
||||
@@ -187,83 +212,6 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap
|
||||
return filtered
|
||||
}
|
||||
|
||||
// effectiveNetID returns the v4 base for a "-v6" exit pair entry that has no explicit
|
||||
// state of its own, so selections made on the v4 entry govern the v6 entry automatically.
|
||||
// Only call this from exit-node-specific code paths: applying it to a non-exit "-v6" route
|
||||
// would make it inherit unrelated v4 state. Must be called with rs.mu held.
|
||||
func (rs *RouteSelector) effectiveNetID(id route.NetID) route.NetID {
|
||||
name := string(id)
|
||||
if !strings.HasSuffix(name, route.V6ExitSuffix) {
|
||||
return id
|
||||
}
|
||||
if _, ok := rs.selectedRoutes[id]; ok {
|
||||
return id
|
||||
}
|
||||
if _, ok := rs.deselectedRoutes[id]; ok {
|
||||
return id
|
||||
}
|
||||
return route.NetID(strings.TrimSuffix(name, route.V6ExitSuffix))
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool {
|
||||
if rs.deselectAll {
|
||||
return false
|
||||
}
|
||||
_, deselected := rs.deselectedRoutes[routeID]
|
||||
return !deselected
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool {
|
||||
if rs.deselectAll {
|
||||
return true
|
||||
}
|
||||
_, deselected := rs.deselectedRoutes[netID]
|
||||
return deselected
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool {
|
||||
_, selected := rs.selectedRoutes[routeID]
|
||||
_, deselected := rs.deselectedRoutes[routeID]
|
||||
return selected || deselected
|
||||
}
|
||||
|
||||
func isExitNode(rt []*route.Route) bool {
|
||||
return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network))
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) applyExitNodeFilter(
|
||||
id route.HAUniqueID,
|
||||
netID route.NetID,
|
||||
rt []*route.Route,
|
||||
out route.HAMap,
|
||||
) {
|
||||
// Exit-node path: apply the v4/v6 pair mirror so a deselect on the v4 base also
|
||||
// drops the synthesized v6 entry that lacks its own explicit state.
|
||||
effective := rs.effectiveNetID(netID)
|
||||
if rs.hasUserSelectionForRouteLocked(effective) {
|
||||
if rs.isSelectedLocked(effective) {
|
||||
out[id] = rt
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// no explicit selection for this route: defer to management's SkipAutoApply flag
|
||||
sel := collectSelected(rt)
|
||||
if len(sel) > 0 {
|
||||
out[id] = sel
|
||||
}
|
||||
}
|
||||
|
||||
func collectSelected(rt []*route.Route) []*route.Route {
|
||||
var sel []*route.Route
|
||||
for _, r := range rt {
|
||||
if !r.SkipAutoApply {
|
||||
sel = append(sel, r)
|
||||
}
|
||||
}
|
||||
return sel
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface
|
||||
func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
|
||||
rs.mu.RLock()
|
||||
@@ -317,3 +265,59 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool {
|
||||
if rs.deselectAll {
|
||||
return false
|
||||
}
|
||||
_, deselected := rs.deselectedRoutes[routeID]
|
||||
return !deselected
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool {
|
||||
if rs.deselectAll {
|
||||
return true
|
||||
}
|
||||
_, deselected := rs.deselectedRoutes[netID]
|
||||
return deselected
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool {
|
||||
_, selected := rs.selectedRoutes[routeID]
|
||||
_, deselected := rs.deselectedRoutes[routeID]
|
||||
return selected || deselected
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) applyExitNodeFilter(
|
||||
id route.HAUniqueID,
|
||||
netID route.NetID,
|
||||
rt []*route.Route,
|
||||
out route.HAMap,
|
||||
) {
|
||||
if rs.hasUserSelectionForRouteLocked(netID) {
|
||||
if rs.isSelectedLocked(netID) {
|
||||
out[id] = rt
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// no explicit selection for this route: defer to management's SkipAutoApply flag
|
||||
sel := collectSelected(rt)
|
||||
if len(sel) > 0 {
|
||||
out[id] = sel
|
||||
}
|
||||
}
|
||||
|
||||
func isExitNode(rt []*route.Route) bool {
|
||||
return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network))
|
||||
}
|
||||
|
||||
func collectSelected(rt []*route.Route) []*route.Route {
|
||||
var sel []*route.Route
|
||||
for _, r := range rt {
|
||||
if !r.SkipAutoApply {
|
||||
sel = append(sel, r)
|
||||
}
|
||||
}
|
||||
return sel
|
||||
}
|
||||
|
||||
@@ -330,39 +330,73 @@ func TestRouteSelector_FilterSelectedExitNodes(t *testing.T) {
|
||||
assert.Len(t, filtered, 0) // No routes should be selected
|
||||
}
|
||||
|
||||
// TestRouteSelector_V6ExitPairInherits covers the v4/v6 exit-node pair selection
|
||||
// mirror. The mirror is scoped to exit-node code paths: HasUserSelectionForRoute
|
||||
// and FilterSelectedExitNodes resolve a "-v6" entry without explicit state to its
|
||||
// v4 base, so legacy persisted selections that predate v6 pairing transparently
|
||||
// apply to the synthesized v6 entry. General lookups (IsSelected, FilterSelected)
|
||||
// stay literal so unrelated routes named "*-v6" don't inherit unrelated state.
|
||||
func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
|
||||
// TestRouteSelector_V6ExitPairSync covers SyncPairedSelection, which keeps a v4
|
||||
// exit node and its synthesized "-v6" counterpart consistent. The selector itself
|
||||
// is literal and never infers a v6 entry's state from its v4 base; callers that know
|
||||
// the pairing (exit-node code paths) call SyncPairedSelection to force the v6 entry
|
||||
// to follow the base, treating the pair as a single toggle.
|
||||
func TestRouteSelector_V6ExitPairSync(t *testing.T) {
|
||||
all := []route.NetID{"exit1", "exit1-v6", "exit2", "exit2-v6", "corp", "corp-v6"}
|
||||
|
||||
t.Run("HasUserSelectionForRoute mirrors deselected v4 base", func(t *testing.T) {
|
||||
t.Run("selector lookups stay literal without sync", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
|
||||
|
||||
assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 pair sees v4 base's user selection")
|
||||
// The selector does not pair-resolve: the v6 entry is independent until synced.
|
||||
assert.False(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 entry has no state of its own")
|
||||
assert.True(t, rs.IsSelected("exit1-v6"), "unsynced v6 entry stays selected by default")
|
||||
|
||||
// unrelated v6 with no v4 base touched is unaffected
|
||||
assert.False(t, rs.HasUserSelectionForRoute("exit2-v6"))
|
||||
// A route literally named "exit1-something" must never pair-resolve either.
|
||||
assert.False(t, rs.HasUserSelectionForRoute("exit1-something"))
|
||||
})
|
||||
|
||||
t.Run("IsSelected stays literal for non-exit lookups", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
|
||||
|
||||
// A non-exit route literally named "corp-v6" must not inherit "corp"'s state
|
||||
// via the mirror; the mirror only applies in exit-node code paths.
|
||||
assert.False(t, rs.IsSelected("corp"))
|
||||
assert.True(t, rs.IsSelected("corp-v6"), "non-exit *-v6 routes must not inherit unrelated v4 state")
|
||||
})
|
||||
|
||||
t.Run("explicit v6 state overrides v4 base in filter", func(t *testing.T) {
|
||||
t.Run("sync mirrors deselected v4 base onto v6", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
|
||||
|
||||
rs.SyncPairedSelection("exit1", "exit1-v6")
|
||||
|
||||
assert.False(t, rs.IsSelected("exit1"))
|
||||
assert.False(t, rs.IsSelected("exit1-v6"), "v6 pair follows v4 base deselect")
|
||||
assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 carries explicit deselect after sync")
|
||||
})
|
||||
|
||||
t.Run("sync mirrors selected v4 base onto v6", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1"}, false, all))
|
||||
|
||||
rs.SyncPairedSelection("exit1", "exit1-v6")
|
||||
|
||||
assert.True(t, rs.IsSelected("exit1"))
|
||||
assert.True(t, rs.IsSelected("exit1-v6"), "v6 pair follows v4 base select")
|
||||
})
|
||||
|
||||
t.Run("sync clears v6 state when base has no explicit selection", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1-v6"}, true, all))
|
||||
require.True(t, rs.HasUserSelectionForRoute("exit1-v6"))
|
||||
|
||||
rs.SyncPairedSelection("exit1", "exit1-v6")
|
||||
|
||||
assert.False(t, rs.HasUserSelectionForRoute("exit1-v6"),
|
||||
"v6 explicit state is cleared so it follows management like its base")
|
||||
})
|
||||
|
||||
// Regression for the observed bug (see netbird-engine.log): persisted state has
|
||||
// the v4 base deselected but the v6 sibling explicitly selected (orphaned). The
|
||||
// sync must reset the orphan so the ::/0 route does not leak onto the tunnel.
|
||||
t.Run("sync clears orphaned explicit v6 selection on deselected base", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
|
||||
// Prior state: both explicitly selected, then only the v4 base deselected,
|
||||
// leaving the v6 entry as a stale explicit selection.
|
||||
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1", "exit1-v6"}, true, all))
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
|
||||
require.True(t, rs.IsSelected("exit1-v6"), "precondition: orphaned v6 selection")
|
||||
|
||||
rs.SyncPairedSelection("exit1", "exit1-v6")
|
||||
|
||||
assert.False(t, rs.IsSelected("exit1-v6"), "orphaned v6 selection reset to follow v4 deselect")
|
||||
|
||||
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
|
||||
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
|
||||
@@ -370,23 +404,14 @@ func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
|
||||
"exit1|0.0.0.0/0": {v4Route},
|
||||
"exit1-v6|::/0": {v6Route},
|
||||
}
|
||||
|
||||
filtered := rs.FilterSelectedExitNodes(routes)
|
||||
assert.NotContains(t, filtered, route.HAUniqueID("exit1|0.0.0.0/0"))
|
||||
assert.Contains(t, filtered, route.HAUniqueID("exit1-v6|::/0"), "explicit v6 select wins over v4 base")
|
||||
assert.Empty(t, filtered, "deselecting v4 base must drop the v6 pair even if it was explicitly selected before")
|
||||
})
|
||||
|
||||
t.Run("non-v6-suffix routes unaffected", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
|
||||
|
||||
// A route literally named "exit1-something" must not pair-resolve.
|
||||
assert.False(t, rs.HasUserSelectionForRoute("exit1-something"))
|
||||
})
|
||||
|
||||
t.Run("filter v6 paired with deselected v4 base", func(t *testing.T) {
|
||||
t.Run("filter drops synced v6 pair of deselected v4 base", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
|
||||
rs.SyncPairedSelection("exit1", "exit1-v6")
|
||||
|
||||
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
|
||||
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
|
||||
@@ -399,6 +424,15 @@ func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
|
||||
assert.Empty(t, filtered, "deselecting v4 base must also drop the v6 pair")
|
||||
})
|
||||
|
||||
t.Run("deselectAll makes sync a no-op", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
rs.DeselectAllRoutes()
|
||||
|
||||
rs.SyncPairedSelection("exit1", "exit1-v6")
|
||||
|
||||
assert.False(t, rs.HasUserSelectionForRoute("exit1-v6"), "sync must not write explicit state under deselectAll")
|
||||
})
|
||||
|
||||
t.Run("non-exit *-v6 routes pass through FilterSelectedExitNodes", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
@@ -25,6 +26,7 @@ import (
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
types "github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
// ConnectionListener export internal Listener for mobile
|
||||
@@ -54,6 +56,7 @@ type selectRoute struct {
|
||||
Network netip.Prefix
|
||||
Domains domain.List
|
||||
Selected bool
|
||||
Status string
|
||||
extraNetworks []netip.Prefix
|
||||
}
|
||||
|
||||
@@ -65,6 +68,8 @@ func init() {
|
||||
type Client struct {
|
||||
cfgFile string
|
||||
stateFile string
|
||||
cacheDir string
|
||||
logFilePath string
|
||||
recorder *peer.Status
|
||||
ctxCancel context.CancelFunc
|
||||
ctxCancelLock *sync.Mutex
|
||||
@@ -75,16 +80,21 @@ type Client struct {
|
||||
onHostDnsFn func([]string)
|
||||
dnsManager dns.IosDnsManager
|
||||
loginComplete bool
|
||||
connectClient *internal.ConnectClient
|
||||
// preloadedConfig holds config loaded from JSON (used on tvOS where file writes are blocked)
|
||||
preloadedConfig *profilemanager.Config
|
||||
|
||||
stateMu sync.RWMutex
|
||||
connectClient *internal.ConnectClient
|
||||
config *profilemanager.Config
|
||||
}
|
||||
|
||||
// NewClient instantiate a new Client
|
||||
func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client {
|
||||
func NewClient(cfgFile, stateFile, cacheDir, logFilePath, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client {
|
||||
return &Client{
|
||||
cfgFile: cfgFile,
|
||||
stateFile: stateFile,
|
||||
cacheDir: cacheDir,
|
||||
logFilePath: logFilePath,
|
||||
deviceName: deviceName,
|
||||
osName: osName,
|
||||
osVersion: osVersion,
|
||||
@@ -161,8 +171,13 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
||||
c.onHostDnsFn = func([]string) {}
|
||||
cfg.WgIface = interfaceName
|
||||
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
c.setState(cfg, connectClient)
|
||||
// Persist the latest sync response so DebugBundle can include the network
|
||||
// map. On iOS this is backed by disk to keep it out of the constrained
|
||||
// process memory (see the syncstore package).
|
||||
connectClient.SetSyncResponsePersistence(true)
|
||||
return connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile, c.cacheDir, c.logFilePath)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
@@ -174,6 +189,84 @@ func (c *Client) Stop() {
|
||||
}
|
||||
|
||||
c.ctxCancel()
|
||||
c.setState(nil, nil)
|
||||
}
|
||||
|
||||
// DebugBundle generates a debug bundle, uploads it and returns the upload key.
|
||||
// It works with or without a running engine: when the engine is up it reuses
|
||||
// the live config, sync response and client metrics; otherwise it loads the
|
||||
// config from disk (or the preloaded tvOS config).
|
||||
func (c *Client) DebugBundle(anonymize bool) (string, error) {
|
||||
cfg, cc := c.stateSnapshot()
|
||||
|
||||
// If the engine hasn't been started, load config so we can reach management.
|
||||
if cfg == nil {
|
||||
if c.preloadedConfig != nil {
|
||||
cfg = c.preloadedConfig
|
||||
} else {
|
||||
var err error
|
||||
// Use DirectUpdateOrCreateConfig to avoid atomic file operations
|
||||
// (temp file + rename) blocked by the tvOS sandbox.
|
||||
cfg, err = profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: c.cfgFile,
|
||||
StateFilePath: c.stateFile,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
deps := debug.GeneratorDependencies{
|
||||
InternalConfig: cfg,
|
||||
StatusRecorder: c.recorder,
|
||||
TempDir: c.cacheDir,
|
||||
StatePath: c.stateFile,
|
||||
LogPath: c.logFilePath,
|
||||
}
|
||||
|
||||
if cc != nil {
|
||||
resp, err := cc.GetLatestSyncResponse()
|
||||
if err != nil {
|
||||
log.Warnf("get latest sync response: %v", err)
|
||||
}
|
||||
deps.SyncResponse = resp
|
||||
|
||||
if e := cc.Engine(); e != nil {
|
||||
if cm := e.GetClientMetrics(); cm != nil {
|
||||
deps.ClientMetrics = cm
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bundleGenerator := debug.NewBundleGenerator(
|
||||
deps,
|
||||
debug.BundleConfig{
|
||||
Anonymize: anonymize,
|
||||
IncludeSystemInfo: true,
|
||||
},
|
||||
)
|
||||
|
||||
path, err := bundleGenerator.Generate()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate debug bundle: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Remove(path); err != nil {
|
||||
log.Errorf("failed to remove debug bundle file: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("upload debug bundle: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("debug bundle uploaded with key %s", key)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// SetTraceLogLevel configure the logger to trace level
|
||||
@@ -227,6 +320,16 @@ func (c *Client) RemoveConnectionListener() {
|
||||
c.recorder.RemoveConnectionListener()
|
||||
}
|
||||
|
||||
// IsLoginRequiredCached reports whether the LAST observed management error was an
|
||||
// auth failure (PermissionDenied/InvalidArgument), using the in-memory status
|
||||
// recorder. Unlike IsLoginRequired() it performs NO network call, so it is safe to
|
||||
// call from the connection listener during teardown (e.g. onDisconnected) without
|
||||
// blocking on a slow or unavailable network. Returns false while connected to
|
||||
// management or when the last error was not auth-related.
|
||||
func (c *Client) IsLoginRequiredCached() bool {
|
||||
return c.recorder.IsLoginRequired()
|
||||
}
|
||||
|
||||
func (c *Client) IsLoginRequired() bool {
|
||||
var ctx context.Context
|
||||
//nolint
|
||||
@@ -354,11 +457,12 @@ func (c *Client) ClearLoginComplete() {
|
||||
}
|
||||
|
||||
func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
|
||||
if c.connectClient == nil {
|
||||
_, connectClient := c.stateSnapshot()
|
||||
if connectClient == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
engine := c.connectClient.Engine()
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
@@ -377,9 +481,57 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
|
||||
routes := buildSelectRoutes(routesMap, routeSelector.IsSelected, v6ExitMerged)
|
||||
resolvedDomains := c.recorder.GetResolvedDomainsStates()
|
||||
|
||||
// Compute each route's connection status in the core (mirroring the Android
|
||||
// bridge), so the UI doesn't have to infer it by string-matching the joined
|
||||
// Network value against peer routes. For a merged exit node the status reflects
|
||||
// whichever of the v4/v6 prefixes is served by a connected peer; for dynamic
|
||||
// (DNS) routes the peer route key is the domain pattern (see dynamic.Route.String).
|
||||
connectedRoutes := c.connectedRouteSet()
|
||||
for _, r := range routes {
|
||||
r.Status = routeStatus(r, connectedRoutes)
|
||||
}
|
||||
|
||||
return prepareRouteSelectionDetails(routes, resolvedDomains), nil
|
||||
}
|
||||
|
||||
// connectedRouteSet returns the set of route keys (as strings) currently served by a
|
||||
// connected peer, gathered across all connected peers' route tables. The keys match
|
||||
// what the route manager records: a prefix string for static routes (e.g. "0.0.0.0/0")
|
||||
// and the domain pattern for dynamic routes (e.g. "*.example.com").
|
||||
func (c *Client) connectedRouteSet() map[string]struct{} {
|
||||
connected := map[string]struct{}{}
|
||||
for _, p := range c.recorder.GetFullStatus().Peers {
|
||||
if p.ConnStatus != peer.StatusConnected {
|
||||
continue
|
||||
}
|
||||
for r := range p.GetRoutes() {
|
||||
connected[r] = struct{}{}
|
||||
}
|
||||
}
|
||||
return connected
|
||||
}
|
||||
|
||||
// routeStatus reports "Connected" if any of the route's keys is served by a connected
|
||||
// peer: the primary Network prefix, an extra v6 network of a merged exit node, or the
|
||||
// domain pattern for a dynamic DNS route. Otherwise "Idle".
|
||||
func routeStatus(r *selectRoute, connectedRoutes map[string]struct{}) string {
|
||||
keys := make([]string, 0, 1+len(r.extraNetworks))
|
||||
if len(r.Domains) > 0 {
|
||||
keys = append(keys, r.Domains.SafeString())
|
||||
} else {
|
||||
keys = append(keys, r.Network.String())
|
||||
}
|
||||
for _, extra := range r.extraNetworks {
|
||||
keys = append(keys, extra.String())
|
||||
}
|
||||
for _, k := range keys {
|
||||
if _, ok := connectedRoutes[k]; ok {
|
||||
return peer.StatusConnected.String()
|
||||
}
|
||||
}
|
||||
return peer.StatusIdle.String()
|
||||
}
|
||||
|
||||
func buildSelectRoutes(routesMap map[route.NetID][]*route.Route, isSelected func(route.NetID) bool, v6Merged map[route.NetID]struct{}) []*selectRoute {
|
||||
var routes []*selectRoute
|
||||
for id, rt := range routesMap {
|
||||
@@ -462,6 +614,7 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom
|
||||
Network: netStr,
|
||||
Domains: &domainDetails,
|
||||
Selected: r.Selected,
|
||||
Status: r.Status,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -470,11 +623,12 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom
|
||||
}
|
||||
|
||||
func (c *Client) SelectRoute(id string) error {
|
||||
if c.connectClient == nil {
|
||||
_, connectClient := c.stateSnapshot()
|
||||
if connectClient == nil {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
engine := c.connectClient.Engine()
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
@@ -500,10 +654,11 @@ func (c *Client) SelectRoute(id string) error {
|
||||
}
|
||||
|
||||
func (c *Client) DeselectRoute(id string) error {
|
||||
if c.connectClient == nil {
|
||||
_, connectClient := c.stateSnapshot()
|
||||
if connectClient == nil {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
engine := c.connectClient.Engine()
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
@@ -527,6 +682,22 @@ func (c *Client) DeselectRoute(id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setState stores the running engine state so DebugBundle can reuse the live
|
||||
// config and ConnectClient. It is cleared on Stop.
|
||||
func (c *Client) setState(cfg *profilemanager.Config, cc *internal.ConnectClient) {
|
||||
c.stateMu.Lock()
|
||||
defer c.stateMu.Unlock()
|
||||
c.config = cfg
|
||||
c.connectClient = cc
|
||||
}
|
||||
|
||||
// stateSnapshot returns the current config and ConnectClient under the lock.
|
||||
func (c *Client) stateSnapshot() (*profilemanager.Config, *internal.ConnectClient) {
|
||||
c.stateMu.RLock()
|
||||
defer c.stateMu.RUnlock()
|
||||
return c.config, c.connectClient
|
||||
}
|
||||
|
||||
func formatDuration(d time.Duration) string {
|
||||
ds := d.String()
|
||||
dotIndex := strings.Index(ds, ".")
|
||||
|
||||
@@ -20,6 +20,7 @@ type RoutesSelectionInfo struct {
|
||||
Network string
|
||||
Domains *DomainDetails
|
||||
Selected bool
|
||||
Status string
|
||||
}
|
||||
|
||||
type DomainCollection interface {
|
||||
|
||||
@@ -84,48 +84,16 @@ func NewPolicy(values map[string]any) *Policy {
|
||||
return &Policy{values: values}
|
||||
}
|
||||
|
||||
// PolicyFetcher is implemented by mobile platforms (Android / iOS) that
|
||||
// push the OS-managed configuration into the Go runtime instead of
|
||||
// having Go read an on-disk source directly. Desktop platforms ignore
|
||||
// this interface — Loader.loadPlatform on windows/darwin reads the
|
||||
// registry / plist on its own. A Loader constructed with a non-nil
|
||||
// fetcher delegates to it on mobile; passing nil disables MDM
|
||||
// enforcement (loadPlatform returns nil values).
|
||||
type PolicyFetcher interface {
|
||||
Fetch() map[string]any
|
||||
}
|
||||
|
||||
// Loader is the DI-friendly entry point for reading the active MDM
|
||||
// policy. Construct one at the daemon's lifecycle owner (Server on
|
||||
// desktop, gomobile-exposed bridge on mobile) and pass it to anything
|
||||
// that needs to read MDM state (the reload ticker, profilemanager's
|
||||
// Config). Each callsite has the Loader handed in instead of looking
|
||||
// up package-level state.
|
||||
type Loader struct {
|
||||
fetcher PolicyFetcher
|
||||
}
|
||||
|
||||
// NewLoader constructs a Loader. The fetcher is consulted only on
|
||||
// mobile builds (ios || android); on desktop it is unused but accepted
|
||||
// to keep a single constructor signature across platforms — pass nil
|
||||
// on desktop.
|
||||
func NewLoader(f PolicyFetcher) *Loader {
|
||||
return &Loader{fetcher: f}
|
||||
}
|
||||
|
||||
// Load reads the platform-native MDM configuration and returns a
|
||||
// Policy. Returns an empty (but non-nil) Policy when no source is
|
||||
// present, the source is empty, or the platform is unsupported.
|
||||
// LoadPolicy reads the platform-native MDM configuration. Returns an
|
||||
// empty (but non-nil) Policy when no source is present, the source is
|
||||
// empty, or the platform is unsupported.
|
||||
//
|
||||
// Diagnostic logging differentiates the three states:
|
||||
// - source absent / unsupported platform: trace log only
|
||||
// - source present, zero keys: info "MDM enrolled (no managed keys)"
|
||||
// - source present, N keys: info "MDM enrolled with N managed keys: [...]"
|
||||
func (l *Loader) Load() *Policy {
|
||||
if l == nil {
|
||||
return &Policy{values: map[string]any{}}
|
||||
}
|
||||
values, err := l.loadPlatform()
|
||||
func LoadPolicy() *Policy {
|
||||
values, err := loadPlatformPolicy()
|
||||
if err != nil {
|
||||
log.Tracef("MDM policy load: %v", err)
|
||||
return &Policy{values: map[string]any{}}
|
||||
|
||||
@@ -25,10 +25,8 @@ import (
|
||||
// writable plist, as a defense against tampered installs.
|
||||
const policyPlistPath = "/Library/Managed Preferences/io.netbird.client.plist"
|
||||
|
||||
// loadPlatform reads the MDM-managed configuration from the macOS
|
||||
// managed-preferences plist at policyPlistPath. The Loader's fetcher
|
||||
// field is unused on this platform — the plist is the authoritative
|
||||
// source. Returns:
|
||||
// loadPlatformPolicy reads the MDM-managed configuration from the macOS
|
||||
// managed-preferences plist at policyPlistPath. Returns:
|
||||
// - (nil, nil) when the plist is absent (device not MDM-enrolled for
|
||||
// NetBird, or admin has not yet pushed a payload)
|
||||
// - (map, nil) with N entries when N managed values are present
|
||||
@@ -41,13 +39,7 @@ const policyPlistPath = "/Library/Managed Preferences/io.netbird.client.plist"
|
||||
// skipped so a stray entry in the payload does not block startup.
|
||||
// Native plist value types map naturally onto the Policy accessor
|
||||
// expectations (GetString / GetBool / GetInt / GetStringSlice).
|
||||
func (l *Loader) loadPlatform() (map[string]any, error) {
|
||||
// Honour the injected fetcher when present so tests (and any
|
||||
// future non-macOS MDM channel) can short-circuit the plist read
|
||||
// with a scripted policy.
|
||||
if l != nil && l.fetcher != nil {
|
||||
return l.fetcher.Fetch(), nil
|
||||
}
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
f, err := os.Open(policyPlistPath)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
|
||||
@@ -2,14 +2,13 @@
|
||||
|
||||
package mdm
|
||||
|
||||
// loadPlatform reads the OS-managed configuration via the native
|
||||
// PolicyFetcher injected at Loader construction. Returns
|
||||
// (nil, nil) — the platform-absent sentinel that Loader.Load treats as
|
||||
// "no MDM source present" — when no fetcher was provided.
|
||||
func (l *Loader) loadPlatform() (map[string]any, error) {
|
||||
if l == nil || l.fetcher == nil {
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see Loader.Load.
|
||||
return nil, nil
|
||||
}
|
||||
return l.fetcher.Fetch(), nil
|
||||
// loadPlatformPolicy is unused on mobile: the native layer (Swift on iOS,
|
||||
// Kotlin/Java on Android) reads the OS managed-config store and pushes the
|
||||
// resulting dictionary in-process via a gomobile entry point that lands in
|
||||
// Phase 5 / Phase 6. The stub keeps the package compilable for mobile
|
||||
// builds and returns (nil, nil) — the platform-absent sentinel that
|
||||
// LoadPolicy in policy.go treats as "no MDM source present".
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -2,17 +2,13 @@
|
||||
|
||||
package mdm
|
||||
|
||||
// loadPlatform reads the MDM policy on platforms without a native MDM
|
||||
// channel (Linux, FreeBSD). When no fetcher was injected the policy is
|
||||
// (nil, nil) — the platform-absent sentinel that Loader.Load treats as
|
||||
// "MDM enforcement disabled". A non-nil fetcher takes precedence: it
|
||||
// is the test-seam used by unit tests to inject a scripted policy
|
||||
// without touching the OS, and the same hook supports any future
|
||||
// non-mobile OS that grows an out-of-band MDM channel.
|
||||
func (l *Loader) loadPlatform() (map[string]any, error) {
|
||||
if l != nil && l.fetcher != nil {
|
||||
return l.fetcher.Fetch(), nil
|
||||
}
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see Loader.Load.
|
||||
// loadPlatformPolicy returns no policy on platforms without an MDM channel
|
||||
// (Linux, FreeBSD). MDM enforcement is off and the client behaves as if
|
||||
// the feature did not exist. Returns (nil, nil) — the platform-absent
|
||||
// sentinel the caller (LoadPolicy in policy.go) treats as "no MDM
|
||||
// source present"; an error here would just translate to the same
|
||||
// outcome with an extra log line.
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -150,12 +150,10 @@ func TestPolicy_GetStringSlice(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoader_NilFetcherReturnsEmpty(t *testing.T) {
|
||||
// Loader.Load with no fetcher (desktop construction) must degrade
|
||||
// gracefully and never return nil; on linux loadPlatform is a stub
|
||||
// returning (nil, nil), and Load is expected to translate that
|
||||
// into a non-nil empty Policy.
|
||||
p := NewLoader(nil).Load()
|
||||
func TestLoadPolicy_PlatformStubReturnsEmpty(t *testing.T) {
|
||||
// loadPlatformPolicy is a stub on every OS for Phase 1. LoadPolicy must
|
||||
// degrade gracefully and never return nil.
|
||||
p := LoadPolicy()
|
||||
require.NotNil(t, p)
|
||||
assert.True(t, p.IsEmpty())
|
||||
assert.Empty(t, p.ManagedKeys())
|
||||
|
||||
@@ -61,10 +61,8 @@ func readRegistryValue(k registry.Key, name, canonical string, out map[string]an
|
||||
}
|
||||
}
|
||||
|
||||
// loadPlatform reads the MDM-managed configuration from the Windows
|
||||
// registry under HKLM\Software\Policies\NetBird. The Loader's fetcher
|
||||
// field is unused on this platform — the registry is the
|
||||
// authoritative source. Returns:
|
||||
// loadPlatformPolicy reads the MDM-managed configuration from the
|
||||
// Windows registry under HKLM\Software\Policies\NetBird. Returns:
|
||||
// - (nil, nil) when the key is absent (device not MDM-enrolled for NetBird)
|
||||
// - (map, nil) with N entries when N managed values are set (N may be 0)
|
||||
// - (nil, err) on open / enumerate registry errors
|
||||
@@ -72,13 +70,7 @@ func readRegistryValue(k registry.Key, name, canonical string, out map[string]an
|
||||
// Per-value type coercion + skip-on-error is delegated to
|
||||
// readRegistryValue. Unknown value names are logged and skipped so a
|
||||
// malformed deployment does not block startup.
|
||||
func (l *Loader) loadPlatform() (map[string]any, error) {
|
||||
// Honour the injected fetcher when present so tests (and any
|
||||
// future non-Windows MDM channel) can short-circuit the registry
|
||||
// read with a scripted policy.
|
||||
if l != nil && l.fetcher != nil {
|
||||
return l.fetcher.Fetch(), nil
|
||||
}
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, policyRegistryPath, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
if errors.Is(err, registry.ErrNotExist) {
|
||||
|
||||
@@ -15,33 +15,33 @@ import (
|
||||
// instead, hence anticipating the ticker mechanism entirely.
|
||||
const DefaultReloadInterval = 1 * time.Minute
|
||||
|
||||
// Ticker periodically re-reads the OS-native MDM policy via the
|
||||
// injected Loader and invokes the onChange callback (supplied to Run)
|
||||
// whenever the observed Policy diverges from the last observation
|
||||
// (added / removed / changed keys). Launch with Run from a goroutine;
|
||||
// cancel the supplied context to stop.
|
||||
// policyLoader is the indirection through which the ticker reads the
|
||||
// OS-native policy, both for the initial observation and on every tick.
|
||||
// Production points it at LoadPolicy; tests in this package override it to
|
||||
// feed a scripted sequence of policies without touching the real OS store.
|
||||
var policyLoader = LoadPolicy
|
||||
|
||||
// Ticker periodically re-reads the OS-native MDM policy via LoadPolicy and
|
||||
// invokes the onChange callback (supplied to Run) whenever the observed
|
||||
// Policy diverges from the last observation (added / removed / changed
|
||||
// keys). Launch with Run from a goroutine; cancel the supplied context
|
||||
// to stop.
|
||||
type Ticker struct {
|
||||
interval time.Duration
|
||||
loader *Loader
|
||||
prev *Policy
|
||||
}
|
||||
|
||||
// NewTicker constructs a Ticker that will re-read the OS-native policy
|
||||
// every reloadInterval once Run is called. The Loader is injected so
|
||||
// the ticker doesn't depend on any package-level state — production
|
||||
// passes the daemon-owned Loader, tests pass a fake Loader (built with
|
||||
// a fake PolicyFetcher).
|
||||
//
|
||||
// The initial snapshot is populated by calling loader.Load() at
|
||||
// every reloadInterval once Run is called.
|
||||
// The initial snapshot is populated by calling policyLoader at
|
||||
// construction time so the first tick only fires
|
||||
// onChange when the policy actually changed since boot — without
|
||||
// this baseline the first tick would report every currently-managed
|
||||
// key as "added" and trigger a spurious engine restart.
|
||||
func NewTicker(reloadInterval time.Duration, loader *Loader) *Ticker {
|
||||
func NewTicker(reloadInterval time.Duration) *Ticker {
|
||||
return &Ticker{
|
||||
interval: reloadInterval,
|
||||
loader: loader,
|
||||
prev: loader.Load(),
|
||||
prev: policyLoader(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ func (t *Ticker) Run(ctx context.Context, onChange func(prev, curr *Policy) erro
|
||||
log.Info("MDM policy reload ticker stopped")
|
||||
return
|
||||
case <-tk.C:
|
||||
curr := t.loader.Load()
|
||||
curr := policyLoader()
|
||||
if policiesEqual(t.prev, curr) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -13,40 +13,28 @@ import (
|
||||
// testReloadInterval for speeding up the ticker cadence under `go test`
|
||||
const testReloadInterval = 1 * time.Second
|
||||
|
||||
// fakePolicyFetcher implements PolicyFetcher returning a scripted
|
||||
// policy map. Goroutine-safe so the test can mutate the script while
|
||||
// the ticker is observing it.
|
||||
type fakePolicyFetcher struct {
|
||||
mu sync.Mutex
|
||||
values map[string]any
|
||||
}
|
||||
|
||||
func (f *fakePolicyFetcher) Fetch() map[string]any {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
if f.values == nil {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(f.values))
|
||||
for k, v := range f.values {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (f *fakePolicyFetcher) set(values map[string]any) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.values = values
|
||||
// withPolicyLoader overrides the package-level policyLoader for the duration
|
||||
// of the test so the ticker observes a scripted policy instead of the real
|
||||
// OS-native store. The original loader is restored on cleanup.
|
||||
func withPolicyLoader(t *testing.T, fn func() *Policy) {
|
||||
t.Helper()
|
||||
prev := policyLoader
|
||||
policyLoader = fn
|
||||
t.Cleanup(func() { policyLoader = prev })
|
||||
}
|
||||
|
||||
func TestTicker_FiresOnChangeWithDelta(t *testing.T) {
|
||||
fetcher := &fakePolicyFetcher{} // initial observation: empty (no enforcement)
|
||||
loader := NewLoader(fetcher)
|
||||
var mu sync.Mutex
|
||||
current := NewPolicy(nil) // initial observation: empty (no enforcement)
|
||||
withPolicyLoader(t, func() *Policy {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return current
|
||||
})
|
||||
|
||||
type change struct{ prev, curr *Policy }
|
||||
changes := make(chan change, 1)
|
||||
tk := NewTicker(testReloadInterval, loader)
|
||||
tk := NewTicker(testReloadInterval)
|
||||
require.Equal(t, testReloadInterval, tk.interval)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -61,13 +49,15 @@ func TestTicker_FiresOnChangeWithDelta(t *testing.T) {
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
// Stop Run and wait for it to exit before returning, so the test
|
||||
// goroutine doesn't race the still-running ticker.
|
||||
// Stop Run and wait for it to exit before returning, so the policyLoader
|
||||
// restore in t.Cleanup can't race the ticker goroutine still reading it.
|
||||
defer func() { cancel(); <-done }()
|
||||
|
||||
// Flip the OS-observed policy from empty to one managed key. The
|
||||
// next tick must detect the diff and invoke onChange.
|
||||
fetcher.set(map[string]any{KeyManagementURL: "https://mdm.example.com:443"})
|
||||
// Flip the OS-observed policy from empty to one managed key. The next
|
||||
// tick must detect the diff and invoke onChange.
|
||||
mu.Lock()
|
||||
current = NewPolicy(map[string]any{KeyManagementURL: "https://mdm.example.com:443"})
|
||||
mu.Unlock()
|
||||
|
||||
select {
|
||||
case c := <-changes:
|
||||
@@ -79,11 +69,12 @@ func TestTicker_FiresOnChangeWithDelta(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTicker_NoCallbackWhenPolicyUnchanged(t *testing.T) {
|
||||
fetcher := &fakePolicyFetcher{values: map[string]any{KeyBlockInbound: true}}
|
||||
loader := NewLoader(fetcher)
|
||||
withPolicyLoader(t, func() *Policy {
|
||||
return NewPolicy(map[string]any{KeyBlockInbound: true})
|
||||
})
|
||||
|
||||
fired := make(chan struct{}, 1)
|
||||
tk := NewTicker(testReloadInterval, loader)
|
||||
tk := NewTicker(testReloadInterval)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
@@ -99,8 +90,8 @@ func TestTicker_NoCallbackWhenPolicyUnchanged(t *testing.T) {
|
||||
}()
|
||||
defer func() { cancel(); <-done }()
|
||||
|
||||
// Over ~2 ticks at the 1s test cadence the policy never changes,
|
||||
// so the diff guard must suppress the callback entirely.
|
||||
// Over ~2 ticks at the 1s test cadence the policy never changes, so the
|
||||
// diff guard must suppress the callback entirely.
|
||||
select {
|
||||
case <-fired:
|
||||
t.Fatal("onChange fired despite an unchanged policy")
|
||||
|
||||
@@ -1849,10 +1849,13 @@ func (x *ManagementState) GetError() string {
|
||||
|
||||
// RelayState contains the latest state of the relay
|
||||
type RelayState struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
URI string `protobuf:"bytes,1,opt,name=URI,proto3" json:"URI,omitempty"`
|
||||
Available bool `protobuf:"varint,2,opt,name=available,proto3" json:"available,omitempty"`
|
||||
Error string `protobuf:"bytes,3,opt,name=error,proto3" json:"error,omitempty"`
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
URI string `protobuf:"bytes,1,opt,name=URI,proto3" json:"URI,omitempty"`
|
||||
Available bool `protobuf:"varint,2,opt,name=available,proto3" json:"available,omitempty"`
|
||||
Error string `protobuf:"bytes,3,opt,name=error,proto3" json:"error,omitempty"`
|
||||
// transport is the negotiated relay transport (e.g. "ws", "quic"),
|
||||
// empty for stun/turn probes or when not connected.
|
||||
Transport string `protobuf:"bytes,4,opt,name=transport,proto3" json:"transport,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -1908,6 +1911,13 @@ func (x *RelayState) GetError() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *RelayState) GetTransport() string {
|
||||
if x != nil {
|
||||
return x.Transport
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type NSGroupState struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Servers []string `protobuf:"bytes,1,rep,name=servers,proto3" json:"servers,omitempty"`
|
||||
@@ -6486,12 +6496,13 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\x0fManagementState\x12\x10\n" +
|
||||
"\x03URL\x18\x01 \x01(\tR\x03URL\x12\x1c\n" +
|
||||
"\tconnected\x18\x02 \x01(\bR\tconnected\x12\x14\n" +
|
||||
"\x05error\x18\x03 \x01(\tR\x05error\"R\n" +
|
||||
"\x05error\x18\x03 \x01(\tR\x05error\"p\n" +
|
||||
"\n" +
|
||||
"RelayState\x12\x10\n" +
|
||||
"\x03URI\x18\x01 \x01(\tR\x03URI\x12\x1c\n" +
|
||||
"\tavailable\x18\x02 \x01(\bR\tavailable\x12\x14\n" +
|
||||
"\x05error\x18\x03 \x01(\tR\x05error\"r\n" +
|
||||
"\x05error\x18\x03 \x01(\tR\x05error\x12\x1c\n" +
|
||||
"\ttransport\x18\x04 \x01(\tR\ttransport\"r\n" +
|
||||
"\fNSGroupState\x12\x18\n" +
|
||||
"\aservers\x18\x01 \x03(\tR\aservers\x12\x18\n" +
|
||||
"\adomains\x18\x02 \x03(\tR\adomains\x12\x18\n" +
|
||||
|
||||
@@ -378,6 +378,9 @@ message RelayState {
|
||||
string URI = 1;
|
||||
bool available = 2;
|
||||
string error = 3;
|
||||
// transport is the negotiated relay transport (e.g. "ws", "quic"),
|
||||
// empty for stun/turn probes or when not connected.
|
||||
string transport = 4;
|
||||
}
|
||||
|
||||
message NSGroupState {
|
||||
|
||||
@@ -20,6 +20,10 @@ import (
|
||||
// a no-op echo, never as a conflict with the policy.
|
||||
const preSharedKeyRedactedSentinel = "**********"
|
||||
|
||||
// loadMDMPolicy is the indirection used by server handlers to read the
|
||||
// active MDM policy. Tests override this to inject a fake policy.
|
||||
var loadMDMPolicy = mdm.LoadPolicy
|
||||
|
||||
// conflictCheck is a value-aware comparison between a single field in
|
||||
// the incoming request and the corresponding MDM-enforced value. It
|
||||
// runs only when the field was actually set in the request (presence
|
||||
|
||||
@@ -110,15 +110,6 @@ type Server struct {
|
||||
// stopped by the rootCtx cancellation.
|
||||
mdmTicker *mdm.Ticker
|
||||
|
||||
// mdmLoader is the daemon-owned source of the active MDM policy.
|
||||
// Constructed once during Server.Start (with a nil PolicyFetcher on
|
||||
// desktop — the build-tagged Loader.loadPlatform reads the OS
|
||||
// registry / plist directly) and injected into every consumer:
|
||||
// mdmTicker for its periodic reload, the SetConfig / Login MDM
|
||||
// gates for conflict detection, and every Config produced via
|
||||
// getConfig() so its apply() picks up the same overlay.
|
||||
mdmLoader *mdm.Loader
|
||||
|
||||
updateManager *updater.Manager
|
||||
|
||||
jwtCache *jwtCache
|
||||
@@ -182,14 +173,8 @@ func (s *Server) Start() error {
|
||||
// Runs re-resolves Config (re-running profilemanager.Config.apply which
|
||||
// applies the freshly-read MDM policy as the last layer) and brings
|
||||
// the engine back with the new values.
|
||||
if s.mdmLoader == nil {
|
||||
// Desktop builds pass a nil PolicyFetcher: the Loader's
|
||||
// build-tagged loadPlatform reads the OS source directly
|
||||
// (registry on Windows, plist on macOS, no-op elsewhere).
|
||||
s.mdmLoader = mdm.NewLoader(nil)
|
||||
}
|
||||
if s.mdmTicker == nil {
|
||||
s.mdmTicker = mdm.NewTicker(mdm.DefaultReloadInterval, s.mdmLoader)
|
||||
s.mdmTicker = mdm.NewTicker(mdm.DefaultReloadInterval)
|
||||
go s.mdmTicker.Run(s.rootCtx, s.onMDMPolicyChange)
|
||||
}
|
||||
|
||||
@@ -385,7 +370,7 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
// by the active MDM policy. The error carries an MDMManagedFields-
|
||||
// Violation detail listing the offending key names. Non-conflicting
|
||||
// fields in the same request are not applied either.
|
||||
policy := s.mdmLoader.Load()
|
||||
policy := loadMDMPolicy()
|
||||
if err := rejectMDMManagedFieldConflicts(mdmManagedFieldConflicts(msg, policy)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -511,7 +496,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
if s.checkUpdateSettingsDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
|
||||
}
|
||||
policy := s.mdmLoader.Load()
|
||||
policy := loadMDMPolicy()
|
||||
if err := rejectMDMManagedFieldConflicts(loginRequestMDMConflicts(msg, policy)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1103,12 +1088,6 @@ func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*prof
|
||||
return nil, false, fmt.Errorf("failed to get config: %w", err)
|
||||
}
|
||||
|
||||
// Apply the daemon-owned MDM policy on top of the just-resolved
|
||||
// Config. profilemanager's apply() initialises the policy to
|
||||
// empty — the Loader lives outside Config, so this overlay step
|
||||
// is driven externally here.
|
||||
config.ApplyMDMPolicy(s.mdmLoader.Load())
|
||||
|
||||
return config, configExisted, nil
|
||||
}
|
||||
|
||||
@@ -1163,9 +1142,6 @@ func (s *Server) logoutFromProfile(ctx context.Context, profileName, username st
|
||||
if err != nil {
|
||||
return fmt.Errorf("profile '%s' not found", profileName)
|
||||
}
|
||||
// Honour any MDM-enforced ManagementURL when issuing the logout
|
||||
// RPC: the user-stored value may have been overridden by policy.
|
||||
config.ApplyMDMPolicy(s.mdmLoader.Load())
|
||||
|
||||
return s.sendLogoutRequestWithConfig(ctx, config)
|
||||
}
|
||||
@@ -1598,11 +1574,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
log.Errorf("failed to get active profile config: %v", err)
|
||||
return nil, fmt.Errorf("failed to get active profile config: %w", err)
|
||||
}
|
||||
// Overlay the active MDM policy so the response's MDMManagedFields
|
||||
// list reflects what the GUI / CLI must render as read-only.
|
||||
// profilemanager.GetConfig itself returns a Config without the
|
||||
// overlay (Loader lives outside profilemanager).
|
||||
cfg.ApplyMDMPolicy(s.mdmLoader.Load())
|
||||
managementURL := cfg.ManagementURL
|
||||
adminURL := cfg.AdminURL
|
||||
|
||||
|
||||
@@ -16,40 +16,14 @@ import (
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// fakeMDMFetcher implements mdm.PolicyFetcher returning a pre-set
|
||||
// policy map. Tests build one per Server instance to inject a
|
||||
// scripted MDM overlay via a Loader rather than via package-level state.
|
||||
type fakeMDMFetcher struct{ values map[string]any }
|
||||
|
||||
func (f *fakeMDMFetcher) Fetch() map[string]any { return f.values }
|
||||
|
||||
// withMDMPolicy installs an mdm.Loader on the given Server whose
|
||||
// loadPlatform returns the supplied Policy's underlying values. Use
|
||||
// after setupServerWithProfile to inject the scripted policy the
|
||||
// SetConfig / Login MDM gates will observe.
|
||||
func withMDMPolicy(t *testing.T, s *Server, policy *mdm.Policy) {
|
||||
// withMDMPolicy temporarily overrides the server-package loadMDMPolicy hook
|
||||
// so SetConfig observes the supplied Policy. Restores the original loader
|
||||
// at test cleanup.
|
||||
func withMDMPolicy(t *testing.T, policy *mdm.Policy) {
|
||||
t.Helper()
|
||||
values := map[string]any{}
|
||||
if policy != nil {
|
||||
for _, k := range policy.ManagedKeys() {
|
||||
if v, ok := policy.GetString(k); ok {
|
||||
values[k] = v
|
||||
continue
|
||||
}
|
||||
if v, ok := policy.GetBool(k); ok {
|
||||
values[k] = v
|
||||
continue
|
||||
}
|
||||
if v, ok := policy.GetInt(k); ok {
|
||||
values[k] = v
|
||||
continue
|
||||
}
|
||||
if v, ok := policy.GetStringSlice(k); ok {
|
||||
values[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
s.mdmLoader = mdm.NewLoader(&fakeMDMFetcher{values: values})
|
||||
prev := loadMDMPolicy
|
||||
loadMDMPolicy = func() *mdm.Policy { return policy }
|
||||
t.Cleanup(func() { loadMDMPolicy = prev })
|
||||
}
|
||||
|
||||
// setupServerWithProfile mirrors the boilerplate of TestSetConfig_AllFieldsSaved:
|
||||
@@ -115,11 +89,12 @@ func extractViolation(t *testing.T, err error) *proto.MDMManagedFieldsViolation
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMReject_SingleField(t *testing.T) {
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
withMDMPolicy(t, s, mdm.NewPolicy(map[string]any{
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
|
||||
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: username,
|
||||
@@ -131,13 +106,14 @@ func TestSetConfig_MDMReject_SingleField(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMReject_MultipleFields(t *testing.T) {
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
withMDMPolicy(t, s, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
mdm.KeyBlockInbound: true,
|
||||
mdm.KeyRosenpassEnabled: true,
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
mdm.KeyBlockInbound: true,
|
||||
mdm.KeyRosenpassEnabled: true,
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
|
||||
blockInbound := false
|
||||
rosenpassEnabled := false
|
||||
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
@@ -161,11 +137,12 @@ func TestSetConfig_MDMReject_AllOrNothing(t *testing.T) {
|
||||
// enforced field AND a non-enforced field (RosenpassEnabled).
|
||||
// The whole request must be rejected — non-conflicting fields are not
|
||||
// applied either.
|
||||
s, ctx, profName, username, cfgPath := setupServerWithProfile(t)
|
||||
withMDMPolicy(t, s, mdm.NewPolicy(map[string]any{
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, cfgPath := setupServerWithProfile(t)
|
||||
|
||||
rosenpassEnabled := true
|
||||
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
@@ -187,11 +164,12 @@ func TestSetConfig_MDMReject_AllOrNothing(t *testing.T) {
|
||||
func TestSetConfig_MDMAllow_NonManagedFields(t *testing.T) {
|
||||
// MDM enforces ManagementURL but the user only writes RosenpassEnabled.
|
||||
// Request must succeed.
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
withMDMPolicy(t, s, mdm.NewPolicy(map[string]any{
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
|
||||
rosenpassEnabled := true
|
||||
resp, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
@@ -205,8 +183,9 @@ func TestSetConfig_MDMAllow_NonManagedFields(t *testing.T) {
|
||||
|
||||
func TestSetConfig_MDMEmpty_NoEnforcement(t *testing.T) {
|
||||
// No MDM policy active: any field can be written.
|
||||
withMDMPolicy(t, mdm.NewPolicy(nil))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
withMDMPolicy(t, s, mdm.NewPolicy(nil))
|
||||
|
||||
resp, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
|
||||
@@ -98,6 +98,7 @@ type RelayStateOutputDetail struct {
|
||||
URI string `json:"uri" yaml:"uri"`
|
||||
Available bool `json:"available" yaml:"available"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
Transport string `json:"transport,omitempty" yaml:"transport,omitempty"`
|
||||
}
|
||||
|
||||
type RelayStateOutput struct {
|
||||
@@ -219,7 +220,8 @@ func mapRelays(relays []*proto.RelayState) RelayStateOutput {
|
||||
RelayStateOutputDetail{
|
||||
URI: relay.URI,
|
||||
Available: available,
|
||||
Error: relay.GetError(),
|
||||
Error: relayErrorString(relay.GetError()),
|
||||
Transport: relay.GetTransport(),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -235,6 +237,12 @@ func mapRelays(relays []*proto.RelayState) RelayStateOutput {
|
||||
}
|
||||
}
|
||||
|
||||
// relayErrorString flattens a newline-joined aggregated relay error onto a
|
||||
// single line for status output.
|
||||
func relayErrorString(s string) string {
|
||||
return strings.ReplaceAll(s, "\n", "; ")
|
||||
}
|
||||
|
||||
func mapNSGroups(servers []*proto.NSGroupState) []NsServerGroupStateOutput {
|
||||
mappedNSGroups := make([]NsServerGroupStateOutput, 0, len(servers))
|
||||
for _, pbNsGroupServer := range servers {
|
||||
@@ -441,6 +449,8 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
available = "Unavailable"
|
||||
reason = fmt.Sprintf(", reason: %s", relay.Error)
|
||||
}
|
||||
} else if relay.Transport != "" {
|
||||
available = fmt.Sprintf("%s via %s", available, relay.Transport)
|
||||
}
|
||||
|
||||
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
|
||||
|
||||
@@ -647,3 +647,13 @@ func TestTimeAgo(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapRelaysTransport(t *testing.T) {
|
||||
out := mapRelays([]*proto.RelayState{
|
||||
{URI: "rels://relay.example:443", Available: true, Transport: "quic"},
|
||||
{URI: "rels://relay2.example:443", Available: true, Transport: "ws"},
|
||||
})
|
||||
require.Len(t, out.Details, 2)
|
||||
assert.Equal(t, "quic", out.Details[0].Transport)
|
||||
assert.Equal(t, "ws", out.Details[1].Transport)
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
||||
nbwebsocket "github.com/netbirdio/netbird/client/wasm/internal/websocket"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -30,6 +31,7 @@ const (
|
||||
pingTimeout = 10 * time.Second
|
||||
defaultLogLevel = "warn"
|
||||
defaultSSHDetectionTimeout = 20 * time.Second
|
||||
dialWebSocketTimeout = 30 * time.Second
|
||||
|
||||
icmpEchoRequest = 8
|
||||
icmpCodeEcho = 0
|
||||
@@ -677,6 +679,7 @@ func createClientObject(client *netbird.Client) js.Value {
|
||||
obj["createSSHConnection"] = createSSHMethod(client)
|
||||
obj["proxyRequest"] = createProxyRequestMethod(client)
|
||||
obj["createRDPProxy"] = createRDPProxyMethod(client)
|
||||
obj["dialWebSocket"] = createDialWebSocketMethod(client)
|
||||
obj["status"] = createStatusMethod(client)
|
||||
obj["statusSummary"] = createStatusSummaryMethod(client)
|
||||
obj["statusDetail"] = createStatusDetailMethod(client)
|
||||
@@ -691,6 +694,74 @@ func createClientObject(client *netbird.Client) js.Value {
|
||||
return js.ValueOf(obj)
|
||||
}
|
||||
|
||||
func createDialWebSocketMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
url, protocols, timeout, errVal := parseDialWebSocketArgs(args)
|
||||
if !errVal.IsUndefined() {
|
||||
return errVal
|
||||
}
|
||||
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := nbwebsocket.Dial(ctx, client, url, protocols)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(fmt.Sprintf("dial websocket: %v", err)))
|
||||
return
|
||||
}
|
||||
|
||||
resolve.Invoke(nbwebsocket.NewJSInterface(conn))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func parseDialWebSocketArgs(args []js.Value) (url string, protocols []string, timeout time.Duration, errVal js.Value) {
|
||||
if len(args) < 1 || args[0].Type() != js.TypeString {
|
||||
return "", nil, 0, js.ValueOf("error: dialWebSocket requires a URL string argument")
|
||||
}
|
||||
url = args[0].String()
|
||||
|
||||
if len(args) >= 2 && !args[1].IsNull() && !args[1].IsUndefined() {
|
||||
arr, err := jsStringArray(args[1])
|
||||
if err != nil {
|
||||
return "", nil, 0, js.ValueOf(fmt.Sprintf("error: protocols: %v", err))
|
||||
}
|
||||
protocols = arr
|
||||
}
|
||||
|
||||
timeout = dialWebSocketTimeout
|
||||
if len(args) >= 3 && !args[2].IsNull() && !args[2].IsUndefined() {
|
||||
if args[2].Type() != js.TypeNumber {
|
||||
return "", nil, 0, js.ValueOf("error: timeoutMs must be a number")
|
||||
}
|
||||
timeoutMs := args[2].Int()
|
||||
if timeoutMs <= 0 {
|
||||
return "", nil, 0, js.ValueOf("error: timeout must be positive")
|
||||
}
|
||||
timeout = time.Duration(timeoutMs) * time.Millisecond
|
||||
}
|
||||
|
||||
return url, protocols, timeout, js.Undefined()
|
||||
}
|
||||
|
||||
// jsStringArray converts a JS array of strings to a Go []string.
|
||||
func jsStringArray(v js.Value) ([]string, error) {
|
||||
if !v.InstanceOf(js.Global().Get("Array")) {
|
||||
return nil, fmt.Errorf("expected array")
|
||||
}
|
||||
n := v.Length()
|
||||
out := make([]string, n)
|
||||
for i := 0; i < n; i++ {
|
||||
el := v.Index(i)
|
||||
if el.Type() != js.TypeString {
|
||||
return nil, fmt.Errorf("element %d is not a string", i)
|
||||
}
|
||||
out[i] = el.String()
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// netBirdClientConstructor acts as a JavaScript constructor function
|
||||
func netBirdClientConstructor(_ js.Value, args []js.Value) any {
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
||||
|
||||
304
client/wasm/internal/websocket/websocket.go
Normal file
304
client/wasm/internal/websocket/websocket.go
Normal file
@@ -0,0 +1,304 @@
|
||||
//go:build js
|
||||
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall/js"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
netbird "github.com/netbirdio/netbird/client/embed"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type closeError struct {
|
||||
code uint16
|
||||
reason string
|
||||
}
|
||||
|
||||
func (e *closeError) Error() string {
|
||||
return fmt.Sprintf("websocket closed: %d %s", e.code, e.reason)
|
||||
}
|
||||
|
||||
// bufferedConn fronts a net.Conn with a reader that serves any bytes buffered
|
||||
// during the WebSocket handshake before falling through to the raw conn.
|
||||
type bufferedConn struct {
|
||||
net.Conn
|
||||
r io.Reader
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Read(p []byte) (int, error) { return c.r.Read(p) }
|
||||
|
||||
// Conn wraps a WebSocket connection over a NetBird TCP connection.
|
||||
type Conn struct {
|
||||
conn net.Conn
|
||||
mu sync.Mutex
|
||||
closed chan struct{}
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
}
|
||||
|
||||
// Dial establishes a WebSocket connection to the given URL through the NetBird network.
|
||||
// Optional protocols are sent via the Sec-WebSocket-Protocol header.
|
||||
func Dial(ctx context.Context, client *netbird.Client, rawURL string, protocols []string) (*Conn, error) {
|
||||
d := ws.Dialer{
|
||||
NetDial: client.Dial,
|
||||
Protocols: protocols,
|
||||
}
|
||||
|
||||
conn, br, _, err := d.Dial(ctx, rawURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("websocket dial: %w", err)
|
||||
}
|
||||
|
||||
// br is non-nil when the server pushed frames alongside the handshake
|
||||
// response; those bytes live in the bufio.Reader and must be drained
|
||||
// before reading from conn, otherwise we'd skip the first frames.
|
||||
if br != nil {
|
||||
if br.Buffered() > 0 {
|
||||
conn = &bufferedConn{Conn: conn, r: io.MultiReader(br, conn)}
|
||||
} else {
|
||||
ws.PutReader(br)
|
||||
}
|
||||
}
|
||||
|
||||
return &Conn{
|
||||
conn: conn,
|
||||
closed: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ReadMessage reads the next WebSocket message, handling control frames automatically.
|
||||
func (c *Conn) ReadMessage() (ws.OpCode, []byte, error) {
|
||||
for {
|
||||
msgs, err := wsutil.ReadServerMessage(c.conn, nil)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
for _, msg := range msgs {
|
||||
if msg.OpCode.IsControl() {
|
||||
if err := c.handleControl(msg); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
return msg.OpCode, msg.Payload, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) handleControl(msg wsutil.Message) error {
|
||||
switch msg.OpCode {
|
||||
case ws.OpPing:
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return wsutil.WriteClientMessage(c.conn, ws.OpPong, msg.Payload)
|
||||
case ws.OpClose:
|
||||
code, reason := parseClosePayload(msg.Payload)
|
||||
return &closeError{code: code, reason: reason}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WriteText sends a text WebSocket message.
|
||||
func (c *Conn) WriteText(data []byte) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return wsutil.WriteClientMessage(c.conn, ws.OpText, data)
|
||||
}
|
||||
|
||||
// WriteBinary sends a binary WebSocket message.
|
||||
func (c *Conn) WriteBinary(data []byte) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return wsutil.WriteClientMessage(c.conn, ws.OpBinary, data)
|
||||
}
|
||||
|
||||
// Close sends a close frame with StatusNormalClosure and closes the underlying connection.
|
||||
func (c *Conn) Close() error {
|
||||
return c.closeWith(ws.StatusNormalClosure, "")
|
||||
}
|
||||
|
||||
// closeWith sends a close frame with the given code/reason and closes the underlying connection.
|
||||
// Used to echo the server's code when responding to a server-initiated close per RFC 6455 §5.5.1.
|
||||
func (c *Conn) closeWith(code ws.StatusCode, reason string) error {
|
||||
var first bool
|
||||
c.closeOnce.Do(func() {
|
||||
first = true
|
||||
close(c.closed)
|
||||
|
||||
c.mu.Lock()
|
||||
_ = wsutil.WriteClientMessage(c.conn, ws.OpClose, ws.NewCloseFrameBody(code, reason))
|
||||
c.mu.Unlock()
|
||||
|
||||
c.closeErr = c.conn.Close()
|
||||
})
|
||||
|
||||
if !first {
|
||||
return net.ErrClosed
|
||||
}
|
||||
return c.closeErr
|
||||
}
|
||||
|
||||
// NewJSInterface creates a JavaScript object wrapping the WebSocket connection.
|
||||
// It exposes: send(string|Uint8Array), close(), and callback properties
|
||||
// onmessage, onclose, onerror.
|
||||
//
|
||||
// Callback properties may be set from the JS thread while the read loop
|
||||
// goroutine reads them. In WASM this is safe because Go and JS share a
|
||||
// single thread, but the design would need synchronization on
|
||||
// multi-threaded runtimes.
|
||||
func NewJSInterface(conn *Conn) js.Value {
|
||||
obj := js.Global().Get("Object").Call("create", js.Null())
|
||||
|
||||
sendFunc := js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
log.Errorf("websocket send requires a data argument")
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
|
||||
data := args[0]
|
||||
switch data.Type() {
|
||||
case js.TypeString:
|
||||
if err := conn.WriteText([]byte(data.String())); err != nil {
|
||||
log.Errorf("failed to send websocket text: %v", err)
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
default:
|
||||
buf, err := jsToBytes(data)
|
||||
if err != nil {
|
||||
log.Errorf("failed to convert js value to bytes: %v", err)
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
if err := conn.WriteBinary(buf); err != nil {
|
||||
log.Errorf("failed to send websocket binary: %v", err)
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
}
|
||||
return js.ValueOf(true)
|
||||
})
|
||||
obj.Set("send", sendFunc)
|
||||
|
||||
closeFunc := js.FuncOf(func(_ js.Value, _ []js.Value) any {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("failed to close websocket: %v", err)
|
||||
}
|
||||
return js.Undefined()
|
||||
})
|
||||
obj.Set("close", closeFunc)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
log.Debugf("close websocket on readLoop exit: %v", err)
|
||||
}
|
||||
}()
|
||||
readLoop(conn, obj)
|
||||
// Undefining before Release turns post-close JS calls into TypeError
|
||||
// instead of a silent "call to released function".
|
||||
obj.Set("send", js.Undefined())
|
||||
obj.Set("close", js.Undefined())
|
||||
sendFunc.Release()
|
||||
closeFunc.Release()
|
||||
}()
|
||||
|
||||
return obj
|
||||
}
|
||||
|
||||
func jsToBytes(data js.Value) ([]byte, error) {
|
||||
var uint8Array js.Value
|
||||
switch {
|
||||
case data.InstanceOf(js.Global().Get("Uint8Array")):
|
||||
uint8Array = data
|
||||
case data.InstanceOf(js.Global().Get("ArrayBuffer")):
|
||||
uint8Array = js.Global().Get("Uint8Array").New(data)
|
||||
default:
|
||||
return nil, fmt.Errorf("send: unsupported data type, use string, Uint8Array, or ArrayBuffer")
|
||||
}
|
||||
|
||||
buf := make([]byte, uint8Array.Get("length").Int())
|
||||
js.CopyBytesToGo(buf, uint8Array)
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func readLoop(conn *Conn, obj js.Value) {
|
||||
var ce *closeError
|
||||
defer func() { invokeOnClose(obj, ce) }()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-conn.closed:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
op, payload, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
ce = handleReadError(conn, obj, err)
|
||||
return
|
||||
}
|
||||
|
||||
dispatchMessage(obj, op, payload)
|
||||
}
|
||||
}
|
||||
|
||||
func handleReadError(conn *Conn, obj js.Value, err error) *closeError {
|
||||
var ce *closeError
|
||||
if errors.As(err, &ce) {
|
||||
if cerr := conn.closeWith(ws.StatusCode(ce.code), ce.reason); cerr != nil {
|
||||
log.Debugf("failed to close websocket after server close frame: %v", cerr)
|
||||
}
|
||||
return ce
|
||||
}
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
if onerror := obj.Get("onerror"); onerror.Truthy() {
|
||||
onerror.Invoke(js.ValueOf(err.Error()))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func invokeOnClose(obj js.Value, ce *closeError) {
|
||||
onclose := obj.Get("onclose")
|
||||
if !onclose.Truthy() {
|
||||
return
|
||||
}
|
||||
if ce != nil {
|
||||
onclose.Invoke(js.ValueOf(int(ce.code)), js.ValueOf(ce.reason))
|
||||
return
|
||||
}
|
||||
onclose.Invoke()
|
||||
}
|
||||
|
||||
func dispatchMessage(obj js.Value, op ws.OpCode, payload []byte) {
|
||||
onmessage := obj.Get("onmessage")
|
||||
if !onmessage.Truthy() {
|
||||
return
|
||||
}
|
||||
switch op {
|
||||
case ws.OpText:
|
||||
onmessage.Invoke(js.ValueOf(string(payload)))
|
||||
case ws.OpBinary:
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(payload))
|
||||
js.CopyBytesToJS(uint8Array, payload)
|
||||
onmessage.Invoke(uint8Array)
|
||||
}
|
||||
}
|
||||
|
||||
func parseClosePayload(payload []byte) (uint16, string) {
|
||||
if len(payload) < 2 {
|
||||
return 1005, "" // RFC 6455: No Status Rcvd
|
||||
}
|
||||
code := binary.BigEndian.Uint16(payload[:2])
|
||||
return code, string(payload[2:])
|
||||
}
|
||||
@@ -2,4 +2,5 @@ FROM ubuntu:24.04
|
||||
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
|
||||
ENTRYPOINT [ "/go/bin/netbird-server" ]
|
||||
CMD ["--config", "/etc/netbird/config.yaml"]
|
||||
COPY netbird-server /go/bin/netbird-server
|
||||
ARG TARGETPLATFORM
|
||||
COPY ${TARGETPLATFORM}/netbird-server /go/bin/netbird-server
|
||||
|
||||
3
go.mod
3
go.mod
@@ -56,6 +56,7 @@ require (
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/gliderlabs/ssh v0.3.8
|
||||
github.com/go-jose/go-jose/v4 v4.1.4
|
||||
github.com/gobwas/ws v1.4.0
|
||||
github.com/goccy/go-yaml v1.18.0
|
||||
github.com/godbus/dbus/v5 v5.1.0
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1
|
||||
@@ -215,6 +216,8 @@ require (
|
||||
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
|
||||
github.com/go-webauthn/webauthn v0.16.4 // indirect
|
||||
github.com/go-webauthn/x v0.2.3 // indirect
|
||||
github.com/gobwas/httphead v0.1.0 // indirect
|
||||
github.com/gobwas/pool v0.2.1 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
|
||||
github.com/google/btree v1.1.3 // indirect
|
||||
|
||||
7
go.sum
7
go.sum
@@ -249,6 +249,12 @@ github.com/go-webauthn/webauthn v0.16.4 h1:R9jqR/cYZa7hRquFF7Za/8qoH/K/TIs1/Q/4C
|
||||
github.com/go-webauthn/webauthn v0.16.4/go.mod h1:SU2ljAgToTV/YLPI0C05QS4qn+e04WpB5g1RMfcZfS4=
|
||||
github.com/go-webauthn/x v0.2.3 h1:8oArS+Rc1SWFLXhE17KZNx258Z4kUSyaDgsSncCO5RA=
|
||||
github.com/go-webauthn/x v0.2.3/go.mod h1:tM04GF3V6VYq79AZMl7vbj4q6pz9r7L2criWRzbWhPk=
|
||||
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
|
||||
github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
|
||||
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
|
||||
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
|
||||
github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs=
|
||||
github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc=
|
||||
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
|
||||
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
|
||||
@@ -845,6 +851,7 @@ golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
||||
@@ -2,4 +2,5 @@ FROM ubuntu:24.04
|
||||
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
|
||||
ENTRYPOINT [ "/go/bin/netbird-mgmt","management"]
|
||||
CMD ["--log-file", "console"]
|
||||
COPY netbird-mgmt /go/bin/netbird-mgmt
|
||||
ARG TARGETPLATFORM
|
||||
COPY ${TARGETPLATFORM}/netbird-mgmt /go/bin/netbird-mgmt
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
FROM ubuntu:24.04
|
||||
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
|
||||
ENTRYPOINT [ "/go/bin/netbird-mgmt","management","--log-level","debug"]
|
||||
CMD ["--log-file", "console"]
|
||||
COPY netbird-mgmt /go/bin/netbird-mgmt
|
||||
@@ -45,7 +45,7 @@ type Controller struct {
|
||||
EphemeralPeersManager ephemeral.Manager
|
||||
|
||||
accountUpdateLocks sync.Map
|
||||
sendAccountUpdateLocks sync.Map
|
||||
affectedPeerUpdateLocks sync.Map
|
||||
updateAccountPeersBufferInterval atomic.Int64
|
||||
// dnsDomain is used for peer resolution. This is appended to the peer's name
|
||||
dnsDomain string
|
||||
@@ -64,6 +64,13 @@ type bufferUpdate struct {
|
||||
update atomic.Bool
|
||||
}
|
||||
|
||||
type bufferAffectedUpdate struct {
|
||||
sendMu sync.Mutex
|
||||
dataMu sync.Mutex
|
||||
next *time.Timer
|
||||
peerIDs map[string]struct{}
|
||||
}
|
||||
|
||||
var _ network_map.Controller = (*Controller)(nil)
|
||||
|
||||
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, ephemeralPeersManager ephemeral.Manager, config *config.Config) *Controller {
|
||||
@@ -201,7 +208,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
|
||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
@@ -226,44 +233,6 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
||||
log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName())
|
||||
|
||||
if c.accountManagerMetrics != nil {
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
|
||||
bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
|
||||
b := bufUpd.(*bufferUpdate)
|
||||
|
||||
if !b.mu.TryLock() {
|
||||
b.update.Store(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
if b.next != nil {
|
||||
b.next.Stop()
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer b.mu.Unlock()
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||
if !b.update.Load() {
|
||||
return
|
||||
}
|
||||
b.update.Store(false)
|
||||
if b.next == nil {
|
||||
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||
})
|
||||
return
|
||||
}
|
||||
b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load()))
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
||||
@@ -273,6 +242,143 @@ func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, r
|
||||
return c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// UpdateAffectedPeers updates only the specified peers that belong to an account.
|
||||
func (c *Controller) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
if len(peerIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return c.sendUpdateForAffectedPeers(ctx, accountID, peerIDs)
|
||||
}
|
||||
|
||||
func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: account %s, %d affected peers: %v (caller: %s)", accountID, len(peerIDs), peerIDs, util.GetCallerName())
|
||||
|
||||
if !c.hasConnectedPeers(peerIDs) {
|
||||
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no connected peers among %v, skipping", peerIDs)
|
||||
return nil
|
||||
}
|
||||
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account: %v", err)
|
||||
}
|
||||
|
||||
globalStart := time.Now()
|
||||
|
||||
peersToUpdate := c.filterConnectedAffectedPeers(account, peerIDs)
|
||||
if len(peersToUpdate) == 0 {
|
||||
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no peers to update (affected peers not found in account or no channels)")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: sending network map to %d connected peers", len(peersToUpdate))
|
||||
|
||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get validate peers: %v", err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, 10)
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return fmt.Errorf("failed to get proxy network maps: %v", err)
|
||||
}
|
||||
|
||||
extraSetting, err := c.settingsManager.GetExtraSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get flow enabled status: %v", err)
|
||||
}
|
||||
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
|
||||
return fmt.Errorf("failed to get account zones: %v", err)
|
||||
}
|
||||
|
||||
for _, peer := range peersToUpdate {
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{}
|
||||
go func(p *nbpeer.Peer) {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
start := time.Now()
|
||||
|
||||
postureChecks, err := c.getPeerPostureChecks(account, p.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", p.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
|
||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
peerGroups := account.GetPeerGroups(p.ID)
|
||||
start = time.Now()
|
||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
||||
|
||||
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
|
||||
Update: update,
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
})
|
||||
}(peer)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
if c.accountManagerMetrics != nil {
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersDuration(time.Since(globalStart))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) hasConnectedPeers(peerIDs []string) bool {
|
||||
for _, id := range peerIDs {
|
||||
if c.peersUpdateManager.HasChannel(id) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Controller) filterConnectedAffectedPeers(account *types.Account, peerIDs []string) []*nbpeer.Peer {
|
||||
affected := make(map[string]struct{}, len(peerIDs))
|
||||
for _, id := range peerIDs {
|
||||
affected[id] = struct{}{}
|
||||
}
|
||||
|
||||
var result []*nbpeer.Peer
|
||||
for _, peer := range account.Peers {
|
||||
if _, ok := affected[peer.ID]; ok && c.peersUpdateManager.HasChannel(peer.ID) {
|
||||
result = append(result, peer)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
|
||||
if !c.peersUpdateManager.HasChannel(peerId) {
|
||||
return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId)
|
||||
@@ -381,66 +487,164 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
// BufferUpdateAffectedPeers accumulates peer IDs and flushes them after the buffer interval.
|
||||
func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error {
|
||||
if len(peerIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.accountManagerMetrics != nil {
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName())
|
||||
|
||||
bufUpd, _ := c.affectedPeerUpdateLocks.LoadOrStore(accountID, &bufferAffectedUpdate{
|
||||
peerIDs: make(map[string]struct{}),
|
||||
})
|
||||
b := bufUpd.(*bufferAffectedUpdate)
|
||||
|
||||
b.addPeerIDs(peerIDs)
|
||||
|
||||
if !b.sendMu.TryLock() {
|
||||
// Another goroutine is already sending; it will pick up our IDs on its next drain.
|
||||
return nil
|
||||
}
|
||||
|
||||
b.stopTimer()
|
||||
|
||||
// The send and the debounced timer outlive the calling request, so detach from
|
||||
// its context to avoid sending with a cancelled context once the handler returns.
|
||||
bgCtx := context.WithoutCancel(ctx)
|
||||
|
||||
collected := b.drainPeerIDs()
|
||||
go func() {
|
||||
defer b.sendMu.Unlock()
|
||||
_ = c.sendUpdateForAffectedPeers(bgCtx, accountID, collected)
|
||||
|
||||
// Check if more peer IDs accumulated while we were sending.
|
||||
if !b.hasPending() {
|
||||
return
|
||||
}
|
||||
|
||||
// Schedule a debounced flush for the newly accumulated IDs.
|
||||
b.setTimer(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||
ids := b.drainPeerIDs()
|
||||
if len(ids) > 0 {
|
||||
_ = c.sendUpdateForAffectedPeers(bgCtx, accountID, ids)
|
||||
}
|
||||
})
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *bufferAffectedUpdate) addPeerIDs(ids []string) {
|
||||
b.dataMu.Lock()
|
||||
for _, id := range ids {
|
||||
b.peerIDs[id] = struct{}{}
|
||||
}
|
||||
b.dataMu.Unlock()
|
||||
}
|
||||
|
||||
func (b *bufferAffectedUpdate) drainPeerIDs() []string {
|
||||
b.dataMu.Lock()
|
||||
defer b.dataMu.Unlock()
|
||||
if len(b.peerIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
ids := make([]string, 0, len(b.peerIDs))
|
||||
for id := range b.peerIDs {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
b.peerIDs = make(map[string]struct{})
|
||||
return ids
|
||||
}
|
||||
|
||||
func (b *bufferAffectedUpdate) hasPending() bool {
|
||||
b.dataMu.Lock()
|
||||
defer b.dataMu.Unlock()
|
||||
return len(b.peerIDs) > 0
|
||||
}
|
||||
|
||||
func (b *bufferAffectedUpdate) stopTimer() {
|
||||
b.dataMu.Lock()
|
||||
defer b.dataMu.Unlock()
|
||||
if b.next != nil {
|
||||
b.next.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *bufferAffectedUpdate) setTimer(d time.Duration, f func()) {
|
||||
b.dataMu.Lock()
|
||||
defer b.dataMu.Unlock()
|
||||
if b.next == nil {
|
||||
b.next = time.AfterFunc(d, f)
|
||||
return
|
||||
}
|
||||
b.next.Reset(d)
|
||||
}
|
||||
|
||||
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
if isRequiresApproval {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
emptyMap := &types.NetworkMap{
|
||||
Network: network.Copy(),
|
||||
}
|
||||
return peer, emptyMap, nil, 0, nil
|
||||
return emptyMap, nil, 0, nil
|
||||
}
|
||||
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
|
||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
startPosture := time.Now()
|
||||
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
|
||||
postureChecks, err := c.getPeerPostureChecks(account, peerID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
|
||||
return nil, nil, nil, 0, err
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return nil, nil, nil, 0, err
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peerID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peerID]
|
||||
if ok {
|
||||
networkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
return peer, networkMap, postureChecks, dnsFwdPort, nil
|
||||
return networkMap, postureChecks, dnsFwdPort, nil
|
||||
}
|
||||
|
||||
// GetDNSDomain returns the configured dnsDomain
|
||||
@@ -578,21 +782,24 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
err := c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
|
||||
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
if len(affectedPeerIDs) == 0 {
|
||||
log.WithContext(ctx).Tracef("no affected peers for peer update in account %s, skipping", accountID)
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationCreate})
|
||||
if len(affectedPeerIDs) == 0 {
|
||||
log.WithContext(ctx).Tracef("no affected peers for peer add in account %s, skipping", accountID)
|
||||
return nil
|
||||
}
|
||||
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationCreate})
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -625,7 +832,11 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
|
||||
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||
}
|
||||
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
|
||||
if len(affectedPeerIDs) == 0 {
|
||||
log.WithContext(ctx).Tracef("no affected peers for peer delete in account %s, skipping", accountID)
|
||||
return nil
|
||||
}
|
||||
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
|
||||
}
|
||||
|
||||
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
|
||||
|
||||
@@ -19,17 +19,19 @@ const (
|
||||
|
||||
type Controller interface {
|
||||
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
|
||||
UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error
|
||||
BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error
|
||||
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
|
||||
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error)
|
||||
GetDNSDomain(settings *types.Settings) string
|
||||
StartWarmup(context.Context)
|
||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||
CountStreams() int
|
||||
|
||||
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error
|
||||
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error
|
||||
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error
|
||||
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string, affectedPeerIDs []string) error
|
||||
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error
|
||||
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error
|
||||
DisconnectPeers(ctx context.Context, accountId string, peerIDs []string)
|
||||
OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error)
|
||||
OnPeerDisconnected(ctx context.Context, accountID string, peerID string)
|
||||
|
||||
@@ -57,6 +57,20 @@ func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, r
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// BufferUpdateAffectedPeers mocks base method.
|
||||
func (m *MockController) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BufferUpdateAffectedPeers", ctx, accountID, peerIDs, reason)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// BufferUpdateAffectedPeers indicates an expected call of BufferUpdateAffectedPeers.
|
||||
func (mr *MockControllerMockRecorder) BufferUpdateAffectedPeers(ctx, accountID, peerIDs, reason any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAffectedPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAffectedPeers), ctx, accountID, peerIDs, reason)
|
||||
}
|
||||
|
||||
// CountStreams mocks base method.
|
||||
func (m *MockController) CountStreams() int {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -113,21 +127,20 @@ func (mr *MockControllerMockRecorder) GetNetworkMap(ctx, peerID any) *gomock.Cal
|
||||
}
|
||||
|
||||
// GetValidatedPeerWithMap mocks base method.
|
||||
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].(*types.NetworkMap)
|
||||
ret2, _ := ret[2].([]*posture.Checks)
|
||||
ret3, _ := ret[3].(int64)
|
||||
ret4, _ := ret[4].(error)
|
||||
return ret0, ret1, ret2, ret3, ret4
|
||||
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, peerID)
|
||||
ret0, _ := ret[0].(*types.NetworkMap)
|
||||
ret1, _ := ret[1].([]*posture.Checks)
|
||||
ret2, _ := ret[2].(int64)
|
||||
ret3, _ := ret[3].(error)
|
||||
return ret0, ret1, ret2, ret3
|
||||
}
|
||||
|
||||
// GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap.
|
||||
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
|
||||
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peerID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, peerID)
|
||||
}
|
||||
|
||||
// OnPeerConnected mocks base method.
|
||||
@@ -158,45 +171,45 @@ func (mr *MockControllerMockRecorder) OnPeerDisconnected(ctx, accountID, peerID
|
||||
}
|
||||
|
||||
// OnPeersAdded mocks base method.
|
||||
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs)
|
||||
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs, affectedPeerIDs)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// OnPeersAdded indicates an expected call of OnPeersAdded.
|
||||
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs any) *gomock.Call {
|
||||
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs, affectedPeerIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs, affectedPeerIDs)
|
||||
}
|
||||
|
||||
// OnPeersDeleted mocks base method.
|
||||
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs)
|
||||
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs, affectedPeerIDs)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// OnPeersDeleted indicates an expected call of OnPeersDeleted.
|
||||
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs any) *gomock.Call {
|
||||
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs, affectedPeerIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs, affectedPeerIDs)
|
||||
}
|
||||
|
||||
// OnPeersUpdated mocks base method.
|
||||
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error {
|
||||
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs)
|
||||
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs, affectedPeerIDs)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// OnPeersUpdated indicates an expected call of OnPeersUpdated.
|
||||
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs any) *gomock.Call {
|
||||
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs, affectedPeerIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs, affectedPeerIDs)
|
||||
}
|
||||
|
||||
// StartWarmup mocks base method.
|
||||
@@ -250,3 +263,17 @@ func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID, reason
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// UpdateAffectedPeers mocks base method.
|
||||
func (m *MockController) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateAffectedPeers", ctx, accountID, peerIDs)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateAffectedPeers indicates an expected call of UpdateAffectedPeers.
|
||||
func (mr *MockControllerMockRecorder) UpdateAffectedPeers(ctx, accountID, peerIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAffectedPeers", reflect.TypeOf((*MockController)(nil).UpdateAffectedPeers), ctx, accountID, peerIDs)
|
||||
}
|
||||
|
||||
@@ -242,7 +242,7 @@ func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, pee
|
||||
},
|
||||
}
|
||||
|
||||
_, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, true)
|
||||
_, _, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create proxy peer: %w", err)
|
||||
}
|
||||
|
||||
@@ -918,6 +918,10 @@ func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID strin
|
||||
}
|
||||
|
||||
for _, svc := range services {
|
||||
if err = transaction.DeleteServiceTargets(ctx, accountID, svc.ID); err != nil {
|
||||
return fmt.Errorf("failed to delete service targets: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.DeleteService(ctx, accountID, svc.ID); err != nil {
|
||||
return fmt.Errorf("failed to delete service: %w", err)
|
||||
}
|
||||
@@ -1270,6 +1274,10 @@ func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serv
|
||||
return status.Errorf(status.PermissionDenied, "cannot delete service exposed by another peer")
|
||||
}
|
||||
|
||||
if err = transaction.DeleteServiceTargets(ctx, accountID, serviceID); err != nil {
|
||||
return fmt.Errorf("delete service targets: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
|
||||
return fmt.Errorf("delete service: %w", err)
|
||||
}
|
||||
@@ -1319,6 +1327,10 @@ func (m *Manager) deleteExpiredPeerService(ctx context.Context, accountID, peerI
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = transaction.DeleteServiceTargets(ctx, accountID, serviceID); err != nil {
|
||||
return fmt.Errorf("delete service targets: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
|
||||
return fmt.Errorf("delete service: %w", err)
|
||||
}
|
||||
|
||||
@@ -458,6 +458,9 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
txMock.EXPECT().
|
||||
GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID).
|
||||
Return(newEphemeralService(), nil)
|
||||
txMock.EXPECT().
|
||||
DeleteServiceTargets(ctx, accountID, serviceID).
|
||||
Return(nil)
|
||||
txMock.EXPECT().
|
||||
DeleteService(ctx, accountID, serviceID).
|
||||
Return(nil)
|
||||
@@ -560,6 +563,9 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
txMock.EXPECT().
|
||||
GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID).
|
||||
Return(newEphemeralService(), nil)
|
||||
txMock.EXPECT().
|
||||
DeleteServiceTargets(ctx, accountID, serviceID).
|
||||
Return(nil)
|
||||
txMock.EXPECT().
|
||||
DeleteService(ctx, accountID, serviceID).
|
||||
Return(nil)
|
||||
@@ -604,6 +610,9 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
txMock.EXPECT().
|
||||
GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID).
|
||||
Return(newEphemeralService(), nil)
|
||||
txMock.EXPECT().
|
||||
DeleteServiceTargets(ctx, accountID, serviceID).
|
||||
Return(nil)
|
||||
txMock.EXPECT().
|
||||
DeleteService(ctx, accountID, serviceID).
|
||||
Return(nil)
|
||||
@@ -1192,6 +1201,67 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
assert.Len(t, targets, 0, "All targets should be deleted when service is deleted")
|
||||
}
|
||||
|
||||
func TestDeleteExpiredPeerService_DeletesTargets(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Mode: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
|
||||
|
||||
targets, err := testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, targets, 1, "ephemeral peer-exposed service should have exactly one persisted target before reaping")
|
||||
|
||||
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
|
||||
err = mgr.deleteExpiredPeerService(ctx, testAccountID, testPeerID, svcID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||
require.Error(t, err, "expired peer-exposed service should be deleted")
|
||||
s, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, status.NotFound, s.Type())
|
||||
|
||||
targets, err = testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, targets, 0, "orphaned target rows must be deleted when an expired peer-exposed service is reaped")
|
||||
}
|
||||
|
||||
func TestDeleteServiceFromPeer_DeletesTargets(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Mode: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
|
||||
|
||||
targets, err := testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, targets, 1, "ephemeral peer-exposed service should have exactly one persisted target before stopping")
|
||||
|
||||
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||
require.Error(t, err, "stopped peer-exposed service should be deleted")
|
||||
s, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, status.NotFound, s.Type())
|
||||
|
||||
targets, err = testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, targets, 0, "orphaned target rows must be deleted when a peer stops its exposed service")
|
||||
}
|
||||
|
||||
func TestValidateProtocolChange(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -778,7 +778,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
sshKey = loginReq.GetPeerKeys().GetSshPubKey()
|
||||
}
|
||||
|
||||
peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{
|
||||
peer, network, postureChecks, enableSSH, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{
|
||||
WireGuardPubKey: peerKey.String(),
|
||||
SSHKey: string(sshKey),
|
||||
Meta: peerMeta,
|
||||
@@ -792,7 +792,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
|
||||
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
|
||||
loginResp, err := s.prepareLoginResponse(ctx, peer, network, postureChecks, enableSSH)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err)
|
||||
return nil, status.Errorf(codes.Internal, "failed logging in peer")
|
||||
@@ -895,7 +895,7 @@ func (s *Server) ExtendAuthSession(ctx context.Context, req *proto.EncryptedMess
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
|
||||
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, network *types.Network, postureChecks []*posture.Checks, enableSSH bool) (*proto.LoginResponse, error) {
|
||||
var relayToken *Token
|
||||
var err error
|
||||
if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 {
|
||||
@@ -914,7 +914,7 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
|
||||
// if peer has reached this point then it has logged in
|
||||
loginResp := &proto.LoginResponse{
|
||||
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
|
||||
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, netMap.EnableSSH),
|
||||
PeerConfig: toPeerConfig(peer, network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, enableSSH),
|
||||
Checks: toProtocolChecks(ctx, postureChecks),
|
||||
}
|
||||
|
||||
|
||||
@@ -1894,7 +1894,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
|
||||
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
||||
}
|
||||
|
||||
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano()); err != nil {
|
||||
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano(), netMap); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
@@ -2577,7 +2577,9 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, []string{peerID})
|
||||
changedPeerIDs := []string{peerID}
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, changedPeerIDs, affectedPeerIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
}
|
||||
@@ -2668,7 +2670,9 @@ func (am *DefaultAccountManager) UpdatePeerIPv6(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
if updateNetworkMap {
|
||||
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peerID}); err != nil {
|
||||
changedPeerIDs := []string{peerID}
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
return fmt.Errorf("notify network map controller: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
@@ -61,7 +62,7 @@ type Manager interface {
|
||||
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
|
||||
MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
|
||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
@@ -69,7 +70,7 @@ type Manager interface {
|
||||
UpdatePeerIPv6(ctx context.Context, accountID, userID, peerID string, newIPv6 netip.Addr) error
|
||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
|
||||
AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error)
|
||||
CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
|
||||
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
||||
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
|
||||
@@ -108,8 +109,8 @@ type Manager interface {
|
||||
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
||||
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
||||
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||
ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) // used by peer gRPC API for ExtendAuthSession
|
||||
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) // used by peer gRPC API
|
||||
ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) // used by peer gRPC API for ExtendAuthSession
|
||||
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API
|
||||
GetExternalCacheManager() ExternalCacheManager
|
||||
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||
@@ -128,6 +129,7 @@ type Manager interface {
|
||||
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
|
||||
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
|
||||
ExpandAndUpdateAffected(ctx context.Context, accountID string, snap *affectedpeers.Snapshot, change affectedpeers.Change)
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
|
||||
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
dns "github.com/netbirdio/netbird/dns"
|
||||
service "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
activity "github.com/netbirdio/netbird/management/server/activity"
|
||||
affectedpeers "github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
idp "github.com/netbirdio/netbird/management/server/idp"
|
||||
peer "github.com/netbirdio/netbird/management/server/peer"
|
||||
posture "github.com/netbirdio/netbird/management/server/posture"
|
||||
@@ -79,14 +80,15 @@ func (mr *MockManagerMockRecorder) AccountExists(ctx, accountID interface{}) *go
|
||||
}
|
||||
|
||||
// AddPeer mocks base method.
|
||||
func (m *MockManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, p *peer.Peer, temporary bool) (*peer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
func (m *MockManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, p *peer.Peer, temporary bool) (*peer.Peer, *types.Network, []*posture.Checks, bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddPeer", ctx, accountID, setupKey, userID, p, temporary)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].(*types.NetworkMap)
|
||||
ret1, _ := ret[1].(*types.Network)
|
||||
ret2, _ := ret[2].([]*posture.Checks)
|
||||
ret3, _ := ret[3].(error)
|
||||
return ret0, ret1, ret2, ret3
|
||||
ret3, _ := ret[3].(bool)
|
||||
ret4, _ := ret[4].(error)
|
||||
return ret0, ret1, ret2, ret3, ret4
|
||||
}
|
||||
|
||||
// AddPeer indicates an expected call of AddPeer.
|
||||
@@ -1288,14 +1290,15 @@ func (mr *MockManagerMockRecorder) ListUsers(ctx, accountID interface{}) *gomock
|
||||
}
|
||||
|
||||
// LoginPeer mocks base method.
|
||||
func (m *MockManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*peer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
func (m *MockManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*peer.Peer, *types.Network, []*posture.Checks, bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LoginPeer", ctx, login)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].(*types.NetworkMap)
|
||||
ret1, _ := ret[1].(*types.Network)
|
||||
ret2, _ := ret[2].([]*posture.Checks)
|
||||
ret3, _ := ret[3].(error)
|
||||
return ret0, ret1, ret2, ret3
|
||||
ret3, _ := ret[3].(bool)
|
||||
ret4, _ := ret[4].(error)
|
||||
return ret0, ret1, ret2, ret3, ret4
|
||||
}
|
||||
|
||||
// LoginPeer indicates an expected call of LoginPeer.
|
||||
@@ -1320,17 +1323,17 @@ func (mr *MockManagerMockRecorder) ExtendPeerSession(ctx, peerPubKey, userID int
|
||||
}
|
||||
|
||||
// MarkPeerConnected mocks base method.
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt)
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// MarkPeerConnected indicates an expected call of MarkPeerConnected.
|
||||
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
|
||||
}
|
||||
|
||||
// MarkPeerDisconnected mocks base method.
|
||||
@@ -1637,6 +1640,18 @@ func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID, reason int
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// ExpandAndUpdateAffected mocks base method.
|
||||
func (m *MockManager) ExpandAndUpdateAffected(ctx context.Context, accountID string, snap *affectedpeers.Snapshot, change affectedpeers.Change) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "ExpandAndUpdateAffected", ctx, accountID, snap, change)
|
||||
}
|
||||
|
||||
// ExpandAndUpdateAffected indicates an expected call of ExpandAndUpdateAffected.
|
||||
func (mr *MockManagerMockRecorder) ExpandAndUpdateAffected(ctx, accountID, snap, change interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExpandAndUpdateAffected", reflect.TypeOf((*MockManager)(nil).ExpandAndUpdateAffected), ctx, accountID, snap, change)
|
||||
}
|
||||
|
||||
// UpdateAccountSettings mocks base method.
|
||||
func (m *MockManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -84,7 +84,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager nbAccount.Manager, account
|
||||
setupKey = key.Key
|
||||
}
|
||||
|
||||
_, _, _, err := manager.AddPeer(context.Background(), "", setupKey, userID, peer, false)
|
||||
_, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey, userID, peer, false)
|
||||
if err != nil {
|
||||
t.Error("expected to add new peer successfully after creating new account, but failed", err)
|
||||
}
|
||||
@@ -1092,7 +1092,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
||||
}
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
peer, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
}, false)
|
||||
@@ -1156,7 +1156,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
expectedUserID := userID
|
||||
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
peer, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
}, false)
|
||||
@@ -1504,7 +1504,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
|
||||
peerKey := key.PublicKey().String()
|
||||
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
peer, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: peerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: peerKey},
|
||||
}, false)
|
||||
@@ -1826,7 +1826,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
peer, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
LoginExpirationEnabled: true,
|
||||
@@ -1836,7 +1836,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
@@ -1882,7 +1882,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
LoginExpirationEnabled: true,
|
||||
@@ -1907,7 +1907,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1927,7 +1927,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peerPubKey := key.PublicKey().String()
|
||||
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: peerPubKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
}, false)
|
||||
@@ -1935,7 +1935,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
|
||||
t.Run("disconnect peer when session token matches", func(t *testing.T) {
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -1956,7 +1956,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) {
|
||||
// Newer stream wins on connect (sets SessionStartedAt = now ns).
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -1980,7 +1980,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
|
||||
t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) {
|
||||
node2SyncTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano())
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "node 2 should connect peer")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -1990,7 +1990,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
"SessionStartedAt should equal node2SyncTime token")
|
||||
|
||||
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano())
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "stale connect should not return error")
|
||||
|
||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -2017,7 +2017,7 @@ func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peerPubKey := key.PublicKey().String()
|
||||
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: peerPubKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "race-peer"},
|
||||
}, false)
|
||||
@@ -2052,7 +2052,7 @@ func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
|
||||
defer done.Done()
|
||||
ready.Done()
|
||||
start.Wait()
|
||||
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token)
|
||||
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token, nil)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -2080,7 +2080,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
LoginExpirationEnabled: true,
|
||||
@@ -2093,7 +2093,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
@@ -3276,7 +3276,7 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
|
||||
}
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
peer, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
Status: &nbpeer.PeerStatus{
|
||||
@@ -3305,6 +3305,19 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
|
||||
// when the channel delivers.
|
||||
const peerUpdateTimeout = 5 * time.Second
|
||||
|
||||
func drainPeerUpdates(ch <-chan *network_map.UpdateMessage) {
|
||||
for {
|
||||
select {
|
||||
case _, ok := <-ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
|
||||
t.Helper()
|
||||
select {
|
||||
@@ -3431,7 +3444,7 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
|
||||
_, _, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
|
||||
WireGuardPubKey: account.Peers["peer-1"].Key,
|
||||
SSHKey: "someKey",
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
|
||||
@@ -3500,7 +3513,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
|
||||
_, _, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
|
||||
WireGuardPubKey: "some-new-key" + strconv.Itoa(i),
|
||||
SSHKey: "someKey",
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
|
||||
@@ -3895,13 +3908,13 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
|
||||
key2, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
|
||||
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
peer1, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key1.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer1")
|
||||
|
||||
peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
peer2, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key2.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
|
||||
}, false)
|
||||
|
||||
117
management/server/affected_peers_coverage_test.go
Normal file
117
management/server/affected_peers_coverage_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// TestAffectedPeers_DependencyCoverageMatrix enumerates each network-map
|
||||
// dependency crossed with the change-type that can alter it, asserting the
|
||||
// resolver folds in exactly the peers whose map changes. A new dependency that
|
||||
// the resolver fails to walk should fail one of these rows; a new change-type
|
||||
// without a row is a coverage gap to add here.
|
||||
func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
|
||||
type row struct {
|
||||
name string
|
||||
build func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string)
|
||||
}
|
||||
|
||||
rows := []row{
|
||||
{
|
||||
name: "policy-groups/source-group-change refreshes source+routing, excludes unrelated",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{ChangedGroupIDs: []string{s.sourceGroupID}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "resource-routing-bridge/router-peer-change refreshes policy sources",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{ChangedPeerIDs: []string{s.routerPeerID}},
|
||||
[]string{s.sourcePeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "policy-change/explicit-policy refreshes source+routing",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
return affectedpeers.Change{Policies: []*types.Policy{policy}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "policy-destinationresource/explicit-policy bridges to routing peer",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
policy := peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID)
|
||||
return affectedpeers.Change{Policies: []*types.Policy{policy}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "resource-change refreshes source+routing on its network",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{Resources: []*resourceTypes.NetworkResource{
|
||||
{ID: s.resourceID, NetworkID: s.networkID, GroupIDs: []string{s.resourceGroupID}},
|
||||
}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "network-change refreshes source+routing on that network",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{Networks: []*networkTypes.Network{{ID: s.networkID}}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "posture-check-change refreshes source+routing of gated policy",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
check, err := s.manager.SavePostureChecks(ctx, s.accountID, userID, &posture.Checks{
|
||||
Name: "cov-min-version",
|
||||
Checks: posture.ChecksDefinition{NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"}},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
policy.SourcePostureChecks = []string{check.ID}
|
||||
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{PostureCheckIDs: []string{check.ID}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, r := range rows {
|
||||
t.Run(r.name, func(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
change, mustContain, mustExclude := r.build(t, s, ctx)
|
||||
affected := resolveAffected(t, s.manager.Store, s.accountID, change)
|
||||
|
||||
for _, id := range mustContain {
|
||||
assert.Contains(t, affected, id, "expected peer to be affected")
|
||||
}
|
||||
for _, id := range mustExclude {
|
||||
assert.NotContains(t, affected, id, "peer must not be affected")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
143
management/server/affected_peers_oldstate_test.go
Normal file
143
management/server/affected_peers_oldstate_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// An update spans an old and a new state. The affected set must be the UNION of
|
||||
// peers reachable before and after the change; resolving only against the final
|
||||
// state drops peers that were reachable but no longer are. These tests pin the
|
||||
// two paths where the old state is reachable only by the changed object's
|
||||
// previous references: detaching a resource group, and re-pointing a router peer.
|
||||
|
||||
// TestAffectedPeers_E2E_UpdateResource_DetachGroup_RefreshesOldGroupSources:
|
||||
// a resource is reachable by a source group via two destination resource groups;
|
||||
// detaching one of them must still refresh that group's policy source peers, even
|
||||
// though the post-update resource no longer maps to it.
|
||||
func TestAffectedPeers_E2E_UpdateResource_DetachGroup_RefreshesOldGroupSources(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
// A second resource group + a second source group/peer that reaches the
|
||||
// resource only through that second group.
|
||||
const detachGroupID = "rs-detach-grp"
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{ID: detachGroupID, Name: "rs-detach"}))
|
||||
|
||||
const secondSourceGroupID = "rs-source-grp-2"
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-detach-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
secondSourcePeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
|
||||
ID: secondSourceGroupID, Name: "rs-source-2", Peers: []string{secondSourcePeer.ID},
|
||||
}))
|
||||
|
||||
resourcesManager, _, _ := s.managers()
|
||||
|
||||
// Attach the resource to the detach group as well: now in [resourceGroup, detachGroup].
|
||||
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
ID: s.resourceID,
|
||||
AccountID: s.accountID,
|
||||
NetworkID: s.networkID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/24",
|
||||
GroupIDs: []string{s.resourceGroupID, detachGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Policy granting the second source group access via the detach group.
|
||||
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(secondSourceGroupID, detachGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
secondSrcCh := s.updateManager.CreateChannel(ctx, secondSourcePeer.ID)
|
||||
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, secondSourcePeer.ID) })
|
||||
settleAffectedUpdates(secondSrcCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
// Detaching the resource from detachGroup removes the second source's
|
||||
// access; that source peer must be refreshed even though the post-update
|
||||
// resource no longer maps to detachGroup.
|
||||
peerShouldReceiveUpdate(t, secondSrcCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
ID: s.resourceID,
|
||||
AccountID: s.accountID,
|
||||
NetworkID: s.networkID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/24",
|
||||
GroupIDs: []string{s.resourceGroupID}, // detached detachGroup
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: detaching a resource group did not refresh the old group's policy source peer")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAffectedPeers_E2E_UpdateRouter_RepointPeer_RefreshesOldRoutingPeer:
|
||||
// changing router.Peer within the same network must still refresh the OLD routing
|
||||
// peer, which loses its routing role.
|
||||
func TestAffectedPeers_E2E_UpdateRouter_RepointPeer_RefreshesOldRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, routersManager, _ := s.managers()
|
||||
|
||||
routers, err := s.manager.Store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, s.accountID, s.networkID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, routers, 1)
|
||||
router := routers[0]
|
||||
oldRoutingPeer := router.Peer
|
||||
require.NotEmpty(t, oldRoutingPeer)
|
||||
|
||||
// A new peer to become the routing peer in place of the old one.
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-newrouter-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
newRoutingPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
|
||||
oldCh := s.updateManager.CreateChannel(ctx, oldRoutingPeer)
|
||||
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, oldRoutingPeer) })
|
||||
settleAffectedUpdates(oldCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
// The old routing peer stops serving the resource and must be refreshed.
|
||||
peerShouldReceiveUpdate(t, oldCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = routersManager.UpdateRouter(ctx, userID, &routerTypes.NetworkRouter{
|
||||
ID: router.ID,
|
||||
NetworkID: s.networkID,
|
||||
AccountID: s.accountID,
|
||||
Peer: newRoutingPeer.ID, // repoint within the same network
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: re-pointing the router peer did not refresh the old routing peer")
|
||||
}
|
||||
}
|
||||
255
management/server/affected_peers_property_test.go
Normal file
255
management/server/affected_peers_property_test.go
Normal file
@@ -0,0 +1,255 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// allPeerMaps computes the serialized per-peer network map for every peer in the
|
||||
// account, mirroring the controller's compute path so the property test compares
|
||||
// against real output.
|
||||
func allPeerMaps(t *testing.T, manager *DefaultAccountManager, accountID string) map[string]string {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
account, err := manager.Store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
|
||||
validated := make(map[string]struct{}, len(account.Peers))
|
||||
for id := range account.Peers {
|
||||
validated[id] = struct{}{}
|
||||
}
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
out := make(map[string]string, len(account.Peers))
|
||||
for peerID := range account.Peers {
|
||||
nm := account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validated, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
// Network.Serial is an account-global counter bumped on every change; it
|
||||
// is not a per-peer dependency, so normalize it out of the comparison.
|
||||
if nm.Network != nil {
|
||||
nm.Network.Serial = 0
|
||||
}
|
||||
out[peerID] = canonicalJSON(t, nm)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// canonicalJSON marshals v and returns an order-insensitive string form: every
|
||||
// JSON array is sorted by the canonical form of its elements. The network map's
|
||||
// Peers/Routes/FirewallRules/SourceRanges slices have nondeterministic order, so
|
||||
// a raw JSON compare would report spurious changes.
|
||||
func canonicalJSON(t *testing.T, v interface{}) string {
|
||||
t.Helper()
|
||||
b, err := json.Marshal(v)
|
||||
require.NoError(t, err)
|
||||
var parsed interface{}
|
||||
require.NoError(t, json.Unmarshal(b, &parsed))
|
||||
canonicalized, err := json.Marshal(sortAny(parsed))
|
||||
require.NoError(t, err)
|
||||
return string(canonicalized)
|
||||
}
|
||||
|
||||
func sortAny(v interface{}) interface{} {
|
||||
switch val := v.(type) {
|
||||
case []interface{}:
|
||||
for i := range val {
|
||||
val[i] = sortAny(val[i])
|
||||
}
|
||||
sort.Slice(val, func(i, j int) bool {
|
||||
bi, _ := json.Marshal(val[i])
|
||||
bj, _ := json.Marshal(val[j])
|
||||
return string(bi) < string(bj)
|
||||
})
|
||||
return val
|
||||
case map[string]interface{}:
|
||||
for k := range val {
|
||||
val[k] = sortAny(val[k])
|
||||
}
|
||||
return val
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// changedPeers returns the peer IDs whose serialized map differs between before
|
||||
// and after.
|
||||
func changedPeers(before, after map[string]string) []string {
|
||||
var changed []string
|
||||
for id, b := range before {
|
||||
a, ok := after[id]
|
||||
if !ok || a != b {
|
||||
changed = append(changed, id)
|
||||
}
|
||||
}
|
||||
for id := range after {
|
||||
if _, ok := before[id]; !ok {
|
||||
changed = append(changed, id)
|
||||
}
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
// TestAffectedPeers_Property_ResolverSupersetsRealChanges builds a topology,
|
||||
// applies random changes, and asserts that the resolver's affected set is a
|
||||
// superset of the peers whose real network map actually changed. If the resolver
|
||||
// ever misses a dependency, a change will alter a peer's map without that peer
|
||||
// appearing in the affected set, failing here.
|
||||
func TestAffectedPeers_Property_ResolverSupersetsRealChanges(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
// A pre-existing peer->resource policy so the resource/router bridge is live.
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Extra peers and groups to give mutations room to move membership around.
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "prop-key", types.SetupKeyReusable, 0, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
extraPeers := make([]string, 0, 4)
|
||||
for i := 0; i < 4; i++ {
|
||||
p := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
extraPeers = append(extraPeers, p.ID)
|
||||
}
|
||||
extraGroups := []string{"prop-grp-0", "prop-grp-1"}
|
||||
for _, g := range extraGroups {
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{ID: g, Name: g}))
|
||||
}
|
||||
|
||||
rng := rand.New(rand.NewSource(1))
|
||||
allGroups := append([]string{s.sourceGroupID, s.resourceGroupID, s.routerPeerGroupID}, extraGroups...)
|
||||
allPeers := append([]string{s.sourcePeerID, s.routerPeerID, s.routerGroupPeerID, s.unrelatedPeerID}, extraPeers...)
|
||||
|
||||
for iter := 0; iter < 60; iter++ {
|
||||
change, apply := s.randomMutation(t, rng, allGroups, allPeers)
|
||||
if apply == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
before := allPeerMaps(t, s.manager, s.accountID)
|
||||
|
||||
resolvedSet := make(map[string]struct{})
|
||||
resolve := func() {
|
||||
require.NoError(t, s.manager.Store.ExecuteInTransaction(ctx, func(tx store.Store) error {
|
||||
snap, err := affectedpeers.Load(ctx, tx, s.accountID, change)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, id := range snap.Expand(ctx, s.accountID, change) {
|
||||
resolvedSet[id] = struct{}{}
|
||||
}
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
// Resolve on both sides of the mutation and union: removals are visible
|
||||
// only pre-apply (the leaving peer is still a member), additions only
|
||||
// post-apply (the joining peer is now a member). Production captures both
|
||||
// via per-path handling (e.g. UpdateGroup passes peersToRemove); the union
|
||||
// models that without coupling the test to each path's ordering.
|
||||
resolve()
|
||||
changedIDs := change.ChangedPeerIDs
|
||||
apply()
|
||||
resolve()
|
||||
|
||||
after := allPeerMaps(t, s.manager, s.accountID)
|
||||
|
||||
// The explicitly-changed peer's own map refresh is the caller's
|
||||
// responsibility (the resolver returns the peers to propagate to), so it
|
||||
// is allowed to be absent from the resolved set.
|
||||
changedExplicitly := make(map[string]struct{}, len(changedIDs))
|
||||
for _, id := range changedIDs {
|
||||
changedExplicitly[id] = struct{}{}
|
||||
}
|
||||
|
||||
for _, id := range changedPeers(before, after) {
|
||||
if _, stillExists := after[id]; !stillExists {
|
||||
continue
|
||||
}
|
||||
if _, isExplicit := changedExplicitly[id]; isExplicit {
|
||||
continue
|
||||
}
|
||||
_, ok := resolvedSet[id]
|
||||
require.Truef(t, ok,
|
||||
"iter %d: peer %s network map changed but was not in the resolver's affected set %v (change=%+v)",
|
||||
iter, id, maps.Keys(resolvedSet), change)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// randomMutation picks a random change, returns the Change to resolve and a
|
||||
// function that applies the underlying store mutation. apply is nil when the
|
||||
// drawn mutation is a no-op for the current state.
|
||||
func (s *routerScenario) randomMutation(t *testing.T, rng *rand.Rand, allGroups, allPeers []string) (affectedpeers.Change, func()) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
switch rng.Intn(3) {
|
||||
case 0:
|
||||
groupID := allGroups[rng.Intn(len(allGroups))]
|
||||
peerID := allPeers[rng.Intn(len(allPeers))]
|
||||
grp, err := s.manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, s.accountID, groupID)
|
||||
require.NoError(t, err)
|
||||
if slicesContains(grp.Peers, peerID) {
|
||||
return affectedpeers.Change{}, nil
|
||||
}
|
||||
return affectedpeers.Change{ChangedGroupIDs: []string{groupID}, ChangedPeerIDs: []string{peerID}},
|
||||
func() {
|
||||
require.NoError(t, s.manager.GroupAddPeer(ctx, s.accountID, groupID, peerID))
|
||||
}
|
||||
case 1:
|
||||
groupID := allGroups[rng.Intn(len(allGroups))]
|
||||
grp, err := s.manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, s.accountID, groupID)
|
||||
require.NoError(t, err)
|
||||
if len(grp.Peers) == 0 {
|
||||
return affectedpeers.Change{}, nil
|
||||
}
|
||||
peerID := grp.Peers[rng.Intn(len(grp.Peers))]
|
||||
return affectedpeers.Change{ChangedGroupIDs: []string{groupID}, ChangedPeerIDs: []string{peerID}},
|
||||
func() {
|
||||
require.NoError(t, s.manager.GroupDeletePeer(ctx, s.accountID, groupID, peerID))
|
||||
}
|
||||
default:
|
||||
src := allGroups[rng.Intn(len(allGroups))]
|
||||
dst := allGroups[rng.Intn(len(allGroups))]
|
||||
policy := &types.Policy{
|
||||
Enabled: true,
|
||||
Name: fmt.Sprintf("prop-policy-%d", rng.Int()),
|
||||
Rules: []*types.PolicyRule{{
|
||||
Enabled: true,
|
||||
Sources: []string{src},
|
||||
Destinations: []string{dst},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
}},
|
||||
}
|
||||
return affectedpeers.Change{Policies: []*types.Policy{policy}},
|
||||
func() {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, policy, true)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func slicesContains(s []string, v string) bool {
|
||||
for _, x := range s {
|
||||
if x == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
164
management/server/affected_peers_querycount_test.go
Normal file
164
management/server/affected_peers_querycount_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// countingStore wraps a real store and counts the per-account collection loads
|
||||
// the resolver performs, so a test can assert each is read at most once and that
|
||||
// irrelevant collections are skipped entirely.
|
||||
type countingStore struct {
|
||||
store.Store
|
||||
mu sync.Mutex
|
||||
counts map[string]int
|
||||
}
|
||||
|
||||
func newCountingStore(s store.Store) *countingStore {
|
||||
return &countingStore{Store: s, counts: map[string]int{}}
|
||||
}
|
||||
|
||||
func (c *countingStore) bump(name string) {
|
||||
c.mu.Lock()
|
||||
c.counts[name]++
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *countingStore) count(name string) int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.counts[name]
|
||||
}
|
||||
|
||||
func (c *countingStore) total() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
n := 0
|
||||
for _, v := range c.counts {
|
||||
n += v
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (c *countingStore) GetAccountPolicies(ctx context.Context, ls store.LockingStrength, accountID string) ([]*types.Policy, error) {
|
||||
c.bump("policies")
|
||||
return c.Store.GetAccountPolicies(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetAccountRoutes(ctx context.Context, ls store.LockingStrength, accountID string) ([]*route.Route, error) {
|
||||
c.bump("routes")
|
||||
return c.Store.GetAccountRoutes(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetAccountNameServerGroups(ctx context.Context, ls store.LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
|
||||
c.bump("nameservers")
|
||||
return c.Store.GetAccountNameServerGroups(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetAccountDNSSettings(ctx context.Context, ls store.LockingStrength, accountID string) (*types.DNSSettings, error) {
|
||||
c.bump("dnssettings")
|
||||
return c.Store.GetAccountDNSSettings(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetNetworkRoutersByAccountID(ctx context.Context, ls store.LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) {
|
||||
c.bump("routers")
|
||||
return c.Store.GetNetworkRoutersByAccountID(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetNetworkResourcesByAccountID(ctx context.Context, ls store.LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) {
|
||||
c.bump("resources")
|
||||
return c.Store.GetNetworkResourcesByAccountID(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetAccountServices(ctx context.Context, ls store.LockingStrength, accountID string) ([]*rpservice.Service, error) {
|
||||
c.bump("services")
|
||||
return c.Store.GetAccountServices(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
// TestAffectedPeers_QueryCount_NoRedundantFullTableLoads asserts the resolver
|
||||
// loads each per-account collection at most once per Resolve (memoization) even
|
||||
// on a change that drives every bridge, and skips the services table when the
|
||||
// account has no embedded proxy peers.
|
||||
func TestAffectedPeers_QueryCount_NoRedundantFullTableLoads(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
cs := newCountingStore(s.manager.Store)
|
||||
|
||||
// A group change that exercises policies, routers, resources and the bridge.
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{s.sourceGroupID}}
|
||||
snap, err := affectedpeers.Load(ctx, cs, s.accountID, change)
|
||||
require.NoError(t, err)
|
||||
affected := snap.Expand(ctx, s.accountID, change)
|
||||
assert.Contains(t, affected, s.routerPeerID, "bridge must still resolve the routing peer")
|
||||
|
||||
for _, name := range []string{"policies", "routes", "nameservers", "dnssettings", "routers", "resources"} {
|
||||
assert.LessOrEqualf(t, cs.count(name), 1,
|
||||
"%s must be loaded at most once per Resolve, got %d", name, cs.count(name))
|
||||
}
|
||||
assert.Equal(t, 0, cs.count("services"),
|
||||
"services must not be loaded when the account has no embedded proxy peers")
|
||||
}
|
||||
|
||||
// TestAffectedPeers_QueryCount_NarrowChangeSkipsLoads asserts that a change with
|
||||
// no group/peer signal touches no per-account collections beyond what its inputs
|
||||
// require.
|
||||
func TestAffectedPeers_QueryCount_NarrowChangeSkipsLoads(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
cs := newCountingStore(s.manager.Store)
|
||||
|
||||
// A bare network change drives only the router->source bridge: routers and
|
||||
// resources are needed, but routes/nameservers/dnssettings/services are not.
|
||||
_, err := affectedpeers.Load(ctx, cs, s.accountID, affectedpeers.Change{Networks: []*networkTypes.Network{{ID: s.networkID}}})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 0, cs.count("routes"), "routes must not be loaded for a network-only change")
|
||||
assert.Equal(t, 0, cs.count("nameservers"), "nameservers must not be loaded for a network-only change")
|
||||
assert.Equal(t, 0, cs.count("dnssettings"), "dnssettings must not be loaded for a network-only change")
|
||||
assert.Equal(t, 0, cs.count("services"), "services must not be loaded for a network-only change")
|
||||
}
|
||||
|
||||
// TestAffectedPeers_QueryCount_ExpandReadsNothing is the core invariant of the
|
||||
// Load/Expand split: Load (run inside the transaction) does all store reads;
|
||||
// Expand (run after commit) must touch the store ZERO times, so it never holds
|
||||
// the write lock and never reads post-commit state.
|
||||
func TestAffectedPeers_QueryCount_ExpandReadsNothing(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{s.sourceGroupID}}
|
||||
|
||||
cs := newCountingStore(s.manager.Store)
|
||||
snap, err := affectedpeers.Load(ctx, cs, s.accountID, change)
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, cs.total(), 0, "Load must read the store")
|
||||
|
||||
// Any store access during Expand would increment the same counter. Expand
|
||||
// operates purely on the snapshot, so the count must not move.
|
||||
readsAfterLoad := cs.total()
|
||||
affected := snap.Expand(ctx, s.accountID, change)
|
||||
assert.Contains(t, affected, s.routerPeerID, "Expand must still produce the affected peers from the snapshot")
|
||||
assert.Equal(t, readsAfterLoad, cs.total(), "Expand must perform zero store reads — it operates purely on the loaded snapshot")
|
||||
}
|
||||
333
management/server/affected_peers_router_paths_test.go
Normal file
333
management/server/affected_peers_router_paths_test.go
Normal file
@@ -0,0 +1,333 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func (s *routerScenario) resolveGroupChangeAffected(ctx context.Context, changedGroupIDs []string) []string {
|
||||
change := affectedpeers.Change{ChangedGroupIDs: changedGroupIDs}
|
||||
snap, err := affectedpeers.Load(ctx, s.manager.Store, s.accountID, change)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return snap.Expand(ctx, s.accountID, change)
|
||||
}
|
||||
|
||||
func (s *routerScenario) resolvePeerChangeAffected(ctx context.Context, changedPeerIDs []string) []string {
|
||||
change := affectedpeers.Change{ChangedPeerIDs: changedPeerIDs}
|
||||
snap, err := affectedpeers.Load(ctx, s.manager.Store, s.accountID, change)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return snap.Expand(ctx, s.accountID, change)
|
||||
}
|
||||
|
||||
func TestAffectedPeers_GroupChange_SourceGroupMembership_RefreshesRoutingPeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{s.sourceGroupID})
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source group member must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"changing the source group of a peer->resource policy must refresh the resource's routing peer")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_GroupChange_SourceGroupMembership_RefreshesRoutingPeer_RouterPeerGroups(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{s.sourceGroupID})
|
||||
|
||||
assert.Contains(t, affected, s.routerGroupPeerID,
|
||||
"changing the source group must refresh the routing peer defined via router.PeerGroups")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_GroupChange_RouterPeerGroupMembership_RefreshesPolicySources(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{s.routerPeerGroupID})
|
||||
|
||||
assert.Contains(t, affected, s.routerGroupPeerID, "the routing peer itself must be affected")
|
||||
assert.Contains(t, affected, s.sourcePeerID,
|
||||
"changing the router's PeerGroups must refresh the source peers of policies serving the resource")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PeerChange_SourcePeer_RefreshesRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolvePeerChangeAffected(ctx, []string{s.sourcePeerID})
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"a status change on a source peer must refresh the resource's routing peer that serves it")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PeerChange_SourcePeer_ByDestinationResource_RefreshesRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolvePeerChangeAffected(ctx, []string{s.sourcePeerID})
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"DestinationResource-targeted policy must still bridge a source-peer change to the routing peer")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_DeleteGroup_ResolvesAffectedPeers(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
const memberOnlyGroupID = "rs-memberonly-grp"
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
|
||||
ID: memberOnlyGroupID, Name: "rs-memberonly", Peers: []string{s.sourcePeerID},
|
||||
}))
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{memberOnlyGroupID})
|
||||
assert.Empty(t, affected, "an unlinked group has no network-map impact, so no peer is affected")
|
||||
|
||||
require.NoError(t, s.manager.DeleteGroup(ctx, s.accountID, userID, memberOnlyGroupID))
|
||||
}
|
||||
|
||||
func TestAffectedPeers_GroupAddResource_RefreshesRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
const extraResourceGroupID = "rs-resource-grp-extra"
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
|
||||
ID: extraResourceGroupID, Name: "rs-resource-extra",
|
||||
}))
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, extraResourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.manager.GroupAddResource(ctx, s.accountID, extraResourceGroupID, types.Resource{
|
||||
ID: s.resourceID,
|
||||
Type: types.ResourceTypeHost,
|
||||
}))
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{extraResourceGroupID})
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"attaching a resource to a policy destination group must refresh the resource's routing peer")
|
||||
assert.Contains(t, affected, s.sourcePeerID, "policy source peers must refresh")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func (s *routerScenario) createPostureCheckGatedPolicy(t *testing.T, ctx context.Context) string {
|
||||
t.Helper()
|
||||
|
||||
check, err := s.manager.SavePostureChecks(ctx, s.accountID, userID, &posture.Checks{
|
||||
Name: "rs-min-version",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
policy.SourcePostureChecks = []string{check.ID}
|
||||
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
return check.ID
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_SavePostureCheck_RefreshesRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
checkID := s.createPostureCheckGatedPolicy(t, ctx)
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
unrelatedCh := s.updateManager.CreateChannel(ctx, s.unrelatedPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.unrelatedPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh, unrelatedCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
peerShouldNotReceiveUpdate(t, unrelatedCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := s.manager.SavePostureChecks(ctx, s.accountID, userID, &posture.Checks{
|
||||
ID: checkID,
|
||||
Name: "rs-min-version",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.31.0"},
|
||||
},
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: editing a posture check did not refresh source + routing peers")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_UpdateResource_DestinationResourcePolicy_RefreshesSourcePeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
resourcesManager, _, _ := s.managers()
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
unrelatedCh := s.updateManager.CreateChannel(ctx, s.unrelatedPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.unrelatedPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh, unrelatedCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
peerShouldNotReceiveUpdate(t, unrelatedCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
ID: s.resourceID,
|
||||
AccountID: s.accountID,
|
||||
NetworkID: s.networkID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/25",
|
||||
GroupIDs: []string{s.resourceGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: updating a DestinationResource-targeted resource did not refresh its policy source peer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouter_StillBridged(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
resourcesManager, routersManager, _ := s.managers()
|
||||
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-disabled", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
disabledRouterPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
_, err = routersManager.CreateRouter(ctx, userID, &routerTypes.NetworkRouter{
|
||||
NetworkID: s.networkID,
|
||||
AccountID: s.accountID,
|
||||
Peer: disabledRouterPeer.ID,
|
||||
Masquerade: true,
|
||||
Metric: 9000,
|
||||
Enabled: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
disabledCh := s.updateManager.CreateChannel(ctx, disabledRouterPeer.ID)
|
||||
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, disabledRouterPeer.ID) })
|
||||
|
||||
settleAffectedUpdates(disabledCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, disabledCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
ID: s.resourceID,
|
||||
AccountID: s.accountID,
|
||||
NetworkID: s.networkID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/25",
|
||||
GroupIDs: []string{s.resourceGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: resource update did not refresh the disabled sibling router's peer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_GroupChange_RouterInOtherNetworkNotAffected(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
second := s.addSecondTopology(t, "groupiso")
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{s.sourceGroupID})
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID, "network A's routing peer must be affected")
|
||||
assert.NotContains(t, affected, second.routerPeerID,
|
||||
"a router in an unrelated network must not be affected by a source-group change for another resource")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PeerChange_RouterInOtherNetworkNotAffected(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
second := s.addSecondTopology(t, "peeriso")
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolvePeerChangeAffected(ctx, []string{s.sourcePeerID})
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID, "network A's routing peer must be affected")
|
||||
assert.NotContains(t, affected, second.routerPeerID,
|
||||
"a router in an unrelated network must not be affected by a source-peer change for another resource")
|
||||
}
|
||||
771
management/server/affected_peers_router_test.go
Normal file
771
management/server/affected_peers_router_test.go
Normal file
@@ -0,0 +1,771 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// routerScenario captures the topology from the bug report:
|
||||
//
|
||||
// network ── router (routing peer) ── resource (in resourceGroup)
|
||||
// independent peer ──(policy: source -> resource)──> resource
|
||||
//
|
||||
// The routing peer must be refreshed when a policy grants a source peer access
|
||||
// to the resource, because the network map connects the source peer to the
|
||||
// routing peer at compute time (Account.GetPoliciesForNetworkResource +
|
||||
// addNetworksRoutingPeers). The routing peer is NOT a member of the resource
|
||||
// group, so static group/peer resolution alone cannot find it.
|
||||
type routerScenario struct {
|
||||
manager *DefaultAccountManager
|
||||
updateManager *update_channel.PeersUpdateManager
|
||||
accountID string
|
||||
networkID string
|
||||
|
||||
sourcePeerID string // independent peer that the policy grants access from
|
||||
sourceGroupID string // group containing the source peer
|
||||
|
||||
routerPeerID string // peer acting as the routing peer (direct router.Peer)
|
||||
routerGroupPeerID string // peer that is a member of routerPeerGroup
|
||||
routerPeerGroupID string // group used for router.PeerGroups
|
||||
|
||||
resourceID string // network resource
|
||||
resourceGroupID string // group whose member is the resource (no peers)
|
||||
|
||||
unrelatedPeerID string // peer in no relevant entity
|
||||
}
|
||||
|
||||
// setupRouterScenario builds the topology above with the default policy removed
|
||||
// and channels NOT yet created, so callers control exactly when updates can flow.
|
||||
func setupRouterScenario(t *testing.T, directRouterPeer bool) *routerScenario {
|
||||
t.Helper()
|
||||
|
||||
manager, updateManager, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
account, err := createAccount(manager, "router_scenario", userID, "")
|
||||
require.NoError(t, err)
|
||||
accountID := account.Id
|
||||
|
||||
// Remove the default policy so AddPeer/CreateGroup don't schedule unrelated updates.
|
||||
policies, err := manager.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
for _, p := range policies {
|
||||
require.NoError(t, manager.Store.DeletePolicy(ctx, accountID, p.ID))
|
||||
}
|
||||
|
||||
setupKey, err := manager.CreateSetupKey(ctx, accountID, "rs-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
sourcePeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
|
||||
routerPeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
|
||||
routerGroupPeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
|
||||
unrelatedPeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
|
||||
|
||||
const (
|
||||
sourceGroupID = "rs-source-grp"
|
||||
routerPeerGroupID = "rs-router-grp"
|
||||
resourceGroupID = "rs-resource-grp"
|
||||
)
|
||||
|
||||
for _, g := range []*types.Group{
|
||||
{ID: sourceGroupID, Name: "rs-source", Peers: []string{sourcePeer.ID}},
|
||||
{ID: routerPeerGroupID, Name: "rs-router", Peers: []string{routerGroupPeer.ID}},
|
||||
{ID: resourceGroupID, Name: "rs-resource"}, // intentionally peerless; the resource is its only member
|
||||
} {
|
||||
require.NoError(t, manager.CreateGroup(ctx, accountID, userID, g))
|
||||
}
|
||||
|
||||
permissionsManager := permissions.NewManager(manager.Store)
|
||||
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
|
||||
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
|
||||
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
|
||||
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
|
||||
|
||||
network, err := networksManager.CreateNetwork(ctx, userID, &networkTypes.Network{
|
||||
ID: "rs-network",
|
||||
AccountID: accountID,
|
||||
Name: "rs-network",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resource, err := resourcesManager.CreateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
AccountID: accountID,
|
||||
NetworkID: network.ID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/24",
|
||||
GroupIDs: []string{resourceGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
router := &routerTypes.NetworkRouter{
|
||||
ID: "rs-router",
|
||||
NetworkID: network.ID,
|
||||
AccountID: accountID,
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
}
|
||||
if directRouterPeer {
|
||||
router.Peer = routerPeer.ID
|
||||
} else {
|
||||
router.PeerGroups = []string{routerPeerGroupID}
|
||||
}
|
||||
_, err = routersManager.CreateRouter(ctx, userID, router)
|
||||
require.NoError(t, err)
|
||||
|
||||
return &routerScenario{
|
||||
manager: manager,
|
||||
updateManager: updateManager,
|
||||
accountID: accountID,
|
||||
networkID: network.ID,
|
||||
sourcePeerID: sourcePeer.ID,
|
||||
sourceGroupID: sourceGroupID,
|
||||
routerPeerID: routerPeer.ID,
|
||||
routerGroupPeerID: routerGroupPeer.ID,
|
||||
routerPeerGroupID: routerPeerGroupID,
|
||||
resourceID: resource.ID,
|
||||
resourceGroupID: resourceGroupID,
|
||||
unrelatedPeerID: unrelatedPeer.ID,
|
||||
}
|
||||
}
|
||||
|
||||
// peerToResourcePolicy builds a policy granting the source group access to the
|
||||
// resource, referencing the resource by its group in the rule destination.
|
||||
func peerToResourcePolicyByGroup(sourceGroupID, resourceGroupID string) *types.Policy {
|
||||
return &types.Policy{
|
||||
Enabled: true,
|
||||
Name: "peer-to-resource-by-group",
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{sourceGroupID},
|
||||
Destinations: []string{resourceGroupID},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// peerToResourcePolicyByResource builds a policy referencing the resource
|
||||
// directly via DestinationResource rather than its group.
|
||||
func peerToResourcePolicyByResource(sourceGroupID, resourceID string) *types.Policy {
|
||||
return &types.Policy{
|
||||
Enabled: true,
|
||||
Name: "peer-to-resource-by-resource",
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{sourceGroupID},
|
||||
DestinationResource: types.Resource{ID: resourceID, Type: types.ResourceTypeHost},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// resolvePolicyAffected mirrors SavePolicy's resolution: resolve the affected
|
||||
// peers for the given policy.
|
||||
func (s *routerScenario) resolvePolicyAffected(ctx context.Context, policy *types.Policy) []string {
|
||||
change := affectedpeers.Change{Policies: []*types.Policy{policy}}
|
||||
snap, err := affectedpeers.Load(ctx, s.manager.Store, s.accountID, change)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return snap.Expand(ctx, s.accountID, change)
|
||||
}
|
||||
|
||||
func TestAffectedPeers_SourcePeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_RoutingPeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
// BUG: the direct routing peer serves the resource's subnet to the source
|
||||
// peer, so it must be refreshed when the policy is created. The policy path
|
||||
// only resolves the literal rule groups (source group + resource group);
|
||||
// the resource group has no peer members and the router peer is reachable
|
||||
// only through the network, so it is dropped.
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"routing peer (router.Peer) serving the resource must be affected by a policy granting access to it")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_RoutingPeer_RouterPeerGroups(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
// Router defined via PeerGroups instead of a direct peer.
|
||||
assert.Contains(t, affected, s.routerGroupPeerID,
|
||||
"routing peer (router.PeerGroups member) serving the resource must be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_DestResource_RoutingPeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
// When the resource is referenced via DestinationResource, RuleGroups()
|
||||
// returns only the source group and the resource ID is not a peer, so
|
||||
// collectPolicyAffectedGroupsAndPeers yields nothing for the destination at
|
||||
// all. The routing peer is dropped here too.
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"routing peer must be affected when the resource is referenced via DestinationResource")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_DestResource_RoutingPeer_RouterPeerGroups(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
assert.Contains(t, affected, s.routerGroupPeerID,
|
||||
"routing peer (PeerGroups) must be affected when the resource is referenced via DestinationResource")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_SourceResourcePeer_RoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
// Source expressed as a direct peer (SourceResource), destination as resource group.
|
||||
policy := &types.Policy{
|
||||
Enabled: true,
|
||||
Name: "sourceResource-peer-to-resource",
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
SourceResource: types.Resource{ID: s.sourcePeerID, Type: types.ResourceTypePeer},
|
||||
Destinations: []string{s.resourceGroupID},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
// The direct source peer IS picked up (collectPolicyAffectedGroupsAndPeers
|
||||
// handles SourceResource peers), but the routing peer is still missing.
|
||||
assert.Contains(t, affected, s.sourcePeerID, "direct source peer must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID, "routing peer must be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PolicyToResource_UnrelatedPeerNotAffected(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
// Guard against an over-broad fix: a peer in no relevant entity must never
|
||||
// be pulled in.
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_ResourceSideBridgesToRoutingPeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
// A pre-existing policy grants the source group access to the resource.
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Drive an update through the resource manager and assert the routing peer
|
||||
// is among the affected set by observing the channel. This path walks
|
||||
// policies whose destinations reference the resource's groups, folds in the
|
||||
// source groups, and loads the network's routers, so it reaches both the
|
||||
// source peer and the routing peer.
|
||||
permissionsManager := permissions.NewManager(s.manager.Store)
|
||||
groupsManager := groups.NewManager(s.manager.Store, permissionsManager, s.manager)
|
||||
rm := resources.NewManager(s.manager.Store, permissionsManager, groupsManager, s.manager, s.manager.serviceManager)
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = rm.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
ID: s.resourceID,
|
||||
AccountID: s.accountID,
|
||||
NetworkID: s.networkID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/24",
|
||||
GroupIDs: []string{s.resourceGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: resource update did not refresh source peer + routing peer")
|
||||
}
|
||||
}
|
||||
|
||||
// settleAffectedUpdates waits for in-flight async updates to arrive, then drains
|
||||
// every given channel so subsequent assertions start from a clean slate.
|
||||
//
|
||||
// Setup (CreateNetwork/CreateResource/CreateRouter) fires async UpdateAffectedPeers
|
||||
// goroutines; draining first means the assertion only observes updates from the
|
||||
// action under test, not setup stragglers.
|
||||
func settleAffectedUpdates(chans ...<-chan *network_map.UpdateMessage) {
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
for _, ch := range chans {
|
||||
drainPeerUpdates(ch)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_CreatePolicy_RoutingPeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
unrelatedCh := s.updateManager.CreateChannel(ctx, s.unrelatedPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.unrelatedPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh, unrelatedCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
peerShouldNotReceiveUpdate(t, unrelatedCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: creating peer->resource policy did not refresh the routing peer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_CreatePolicy_RoutingPeer_RouterPeerGroups(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerGroupPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerGroupPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: routing peer (PeerGroups) not refreshed on policy create")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_DestResource_RoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: routing peer not refreshed when policy targets DestinationResource")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_DeletePolicy_RoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
require.NoError(t, s.manager.DeletePolicy(ctx, s.accountID, policy.ID, userID))
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: deleting peer->resource policy did not refresh the routing peer")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *routerScenario) managers() (resources.Manager, routers.Manager, networks.Manager) {
|
||||
permissionsManager := permissions.NewManager(s.manager.Store)
|
||||
groupsManager := groups.NewManager(s.manager.Store, permissionsManager, s.manager)
|
||||
resourcesManager := resources.NewManager(s.manager.Store, permissionsManager, groupsManager, s.manager, s.manager.serviceManager)
|
||||
routersManager := routers.NewManager(s.manager.Store, permissionsManager, s.manager)
|
||||
networksManager := networks.NewManager(s.manager.Store, permissionsManager, resourcesManager, routersManager, s.manager)
|
||||
return resourcesManager, routersManager, networksManager
|
||||
}
|
||||
|
||||
type secondTopology struct {
|
||||
networkID string
|
||||
resourceID string
|
||||
resourceGroupID string
|
||||
routerPeerID string
|
||||
}
|
||||
|
||||
func (s *routerScenario) addSecondTopology(t *testing.T, suffix string) secondTopology {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
resourcesManager, routersManager, networksManager := s.managers()
|
||||
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-"+suffix, types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
routerPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
|
||||
resourceGroupID := "rs-resource-grp-" + suffix
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
|
||||
ID: resourceGroupID, Name: "rs-resource-" + suffix,
|
||||
}))
|
||||
|
||||
network, err := networksManager.CreateNetwork(ctx, userID, &networkTypes.Network{
|
||||
ID: "rs-network-" + suffix,
|
||||
AccountID: s.accountID,
|
||||
Name: "rs-network-" + suffix,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resource, err := resourcesManager.CreateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
AccountID: s.accountID,
|
||||
NetworkID: network.ID,
|
||||
Name: "rs-resource-host-" + suffix,
|
||||
Address: "10.40.50.0/24",
|
||||
GroupIDs: []string{resourceGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = routersManager.CreateRouter(ctx, userID, &routerTypes.NetworkRouter{
|
||||
NetworkID: network.ID,
|
||||
AccountID: s.accountID,
|
||||
Peer: routerPeer.ID,
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return secondTopology{
|
||||
networkID: network.ID,
|
||||
resourceID: resource.ID,
|
||||
resourceGroupID: resourceGroupID,
|
||||
routerPeerID: routerPeer.ID,
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_UpdatePolicy_BothRoutingPeers(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
second := s.addSecondTopology(t, "b")
|
||||
ctx := context.Background()
|
||||
|
||||
policy, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerACh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
routerBCh := s.updateManager.CreateChannel(ctx, second.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
s.updateManager.CloseChannel(ctx, second.routerPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerACh, routerBCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerACh)
|
||||
peerShouldReceiveUpdate(t, routerBCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
policy.Rules[0].Destinations = []string{second.resourceGroupID}
|
||||
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: re-pointing the policy destination did not refresh both routing peers")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_UpdatePolicy_AddSource(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
const secondSourceGroupID = "rs-source-grp-2"
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-2", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
secondSourcePeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
|
||||
ID: secondSourceGroupID, Name: "rs-source-2", Peers: []string{secondSourcePeer.ID},
|
||||
}))
|
||||
|
||||
policy, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
newSrcCh := s.updateManager.CreateChannel(ctx, secondSourcePeer.ID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, secondSourcePeer.ID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(newSrcCh, routerCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, newSrcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
policy.Rules[0].Sources = []string{s.sourceGroupID, secondSourceGroupID}
|
||||
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: adding a source group did not refresh the new source peer + routing peer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_DestResource_RouterPeerGroups(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerGroupPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerGroupPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: DestinationResource policy with PeerGroups router did not refresh the routing peer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_AllRoutingPeers_Network(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, routersManager, _ := s.managers()
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-r2", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
secondRouterPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
_, err = routersManager.CreateRouter(ctx, userID, &routerTypes.NetworkRouter{
|
||||
NetworkID: s.networkID,
|
||||
AccountID: s.accountID,
|
||||
Peer: secondRouterPeer.ID,
|
||||
Masquerade: true,
|
||||
Metric: 9998,
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID, "first routing peer must be affected")
|
||||
assert.Contains(t, affected, secondRouterPeer.ID, "second routing peer on the same network must also be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_DisabledRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
routers, err := s.manager.Store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, s.accountID, s.networkID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, routers, 1)
|
||||
routers[0].Enabled = false
|
||||
require.NoError(t, s.manager.Store.UpdateNetworkRouter(ctx, routers[0]))
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"disabled router's peer must still be affected: Enabled must not gate affected-peers")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_DisabledResource(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
res, err := s.manager.Store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, s.accountID, s.resourceID)
|
||||
require.NoError(t, err)
|
||||
res.Enabled = false
|
||||
require.NoError(t, s.manager.Store.SaveNetworkResource(ctx, res))
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"disabled resource must still resolve the routing peer: Enabled must not gate affected-peers")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_DisabledRule(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
policy.Rules[0].Enabled = false
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"disabled rule must still resolve the routing peer: Enabled must not gate affected-peers")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_MultiRule(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
second := s.addSecondTopology(t, "c")
|
||||
ctx := context.Background()
|
||||
|
||||
policy := &types.Policy{
|
||||
Enabled: true,
|
||||
Name: "multi-rule-two-resources",
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{s.sourceGroupID},
|
||||
Destinations: []string{s.resourceGroupID},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{s.sourceGroupID},
|
||||
Destinations: []string{second.resourceGroupID},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID, "routing peer for resource A must be affected")
|
||||
assert.Contains(t, affected, second.routerPeerID, "routing peer for resource B must be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_RouterOtherNetwork(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
second := s.addSecondTopology(t, "d")
|
||||
ctx := context.Background()
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID, "network A's routing peer must be affected")
|
||||
assert.NotContains(t, affected, second.routerPeerID,
|
||||
"a router in an unrelated network must not be affected by a policy that does not target its resource")
|
||||
}
|
||||
1802
management/server/affected_peers_test.go
Normal file
1802
management/server/affected_peers_test.go
Normal file
File diff suppressed because it is too large
Load Diff
825
management/server/affectedpeers/resolver.go
Normal file
825
management/server/affectedpeers/resolver.go
Normal file
@@ -0,0 +1,825 @@
|
||||
// Package affectedpeers computes which peers' network maps a change touches, so
|
||||
// only those peers are refreshed instead of the whole account.
|
||||
//
|
||||
// Two phases keep the dependency walk off the write transaction:
|
||||
// - Load: reads the needed collections. Call INSIDE the mutating tx (consistent,
|
||||
// and before a delete/removal severs the old state).
|
||||
// - Snapshot.Expand: in-memory walk, no store access. Run AFTER the tx commits.
|
||||
//
|
||||
// Enabled is never consulted: toggling it is itself an observable change.
|
||||
package affectedpeers
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// Snapshot is an in-memory view of the collections needed to expand a Change.
|
||||
// Loaded in-tx, walked by Expand after commit. Only the collections the Change
|
||||
// can touch are loaded; the rest stay nil (see Load).
|
||||
type Snapshot struct {
|
||||
policies []*types.Policy
|
||||
routes []*route.Route
|
||||
nsGroups []*nbdns.NameServerGroup
|
||||
dnsSettings *types.DNSSettings
|
||||
routers []*routerTypes.NetworkRouter
|
||||
resources []*resourceTypes.NetworkResource
|
||||
services []*rpservice.Service
|
||||
proxyByCluster map[string][]string
|
||||
groups map[string]*types.Group
|
||||
groupPeers map[string]map[string]struct{} // groupID -> member peer IDs
|
||||
}
|
||||
|
||||
// Load reads the collections a Change requires, inside the caller's tx. It mirrors
|
||||
// Expand's walker preconditions, loading only what the change can touch.
|
||||
func Load(ctx context.Context, s store.Store, accountID string, c Change) (*Snapshot, error) {
|
||||
snap := &Snapshot{}
|
||||
if c.isEmpty() {
|
||||
return snap, nil
|
||||
}
|
||||
|
||||
if err := snap.loadCollections(ctx, s, accountID, c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := snap.loadGroupIndex(ctx, s, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return snap, nil
|
||||
}
|
||||
|
||||
// loadCollections reads the policy/route/nameserver/dns/router/resource/proxy
|
||||
// collections a Change can touch, gated to what the walk needs.
|
||||
func (snap *Snapshot) loadCollections(ctx context.Context, s store.Store, accountID string, c Change) error {
|
||||
hasGroupOrPeerChange := len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.Resources) > 0
|
||||
hasNetworkObject := len(c.Routers) > 0 || len(c.Resources) > 0 || len(c.Networks) > 0
|
||||
// the resource<->router bridge can fire for any of these
|
||||
needsRoutersResources := hasGroupOrPeerChange || len(c.PostureCheckIDs) > 0 || len(c.Policies) > 0 || hasNetworkObject
|
||||
|
||||
if needsRoutersResources {
|
||||
if err := snap.loadPolicyRoutersResources(ctx, s, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if hasGroupOrPeerChange {
|
||||
if err := snap.loadRoutesAndProxy(ctx, s, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 {
|
||||
if err := snap.loadDNS(ctx, s, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadPolicyRoutersResources loads the policies plus the routers and resources
|
||||
// the resource<->router bridge walks.
|
||||
func (snap *Snapshot) loadPolicyRoutersResources(ctx context.Context, s store.Store, accountID string) error {
|
||||
var err error
|
||||
if snap.policies, err = s.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
if snap.routers, err = s.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
snap.resources, err = s.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||
return err
|
||||
}
|
||||
|
||||
// loadRoutesAndProxy loads the routes and the embedded-proxy services index.
|
||||
func (snap *Snapshot) loadRoutesAndProxy(ctx context.Context, s store.Store, accountID string) error {
|
||||
var err error
|
||||
if snap.routes, err = s.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
return snap.loadProxyServices(ctx, s, accountID)
|
||||
}
|
||||
|
||||
// loadDNS loads the nameserver groups and account DNS settings.
|
||||
func (snap *Snapshot) loadDNS(ctx context.Context, s store.Store, accountID string) error {
|
||||
var err error
|
||||
if snap.nsGroups, err = s.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
snap.dnsSettings, err = s.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
return err
|
||||
}
|
||||
|
||||
// loadProxyServices loads the embedded-proxy cluster index, and the services only
|
||||
// when the account actually has embedded proxy peers.
|
||||
func (snap *Snapshot) loadProxyServices(ctx context.Context, s store.Store, accountID string) error {
|
||||
var err error
|
||||
if snap.proxyByCluster, err = s.GetEmbeddedProxyPeerIDsByCluster(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(snap.proxyByCluster) == 0 {
|
||||
return nil
|
||||
}
|
||||
snap.services, err = s.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
return err
|
||||
}
|
||||
|
||||
// loadGroupIndex loads all groups (for group.Resources) and builds the
|
||||
// group->member-peers index. Always needed: the bridge resolves group.Resources
|
||||
// and Expand maps groups to member peers.
|
||||
func (snap *Snapshot) loadGroupIndex(ctx context.Context, s store.Store, accountID string) error {
|
||||
groups, err := s.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
snap.groups = make(map[string]*types.Group, len(groups))
|
||||
snap.groupPeers = make(map[string]map[string]struct{}, len(groups))
|
||||
for _, g := range groups {
|
||||
snap.groups[g.ID] = g
|
||||
members := make(map[string]struct{}, len(g.Peers))
|
||||
for _, pID := range g.Peers {
|
||||
members[pID] = struct{}{}
|
||||
}
|
||||
snap.groupPeers[g.ID] = members
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Change describes what changed in an account.
|
||||
type Change struct {
|
||||
ChangedGroupIDs []string
|
||||
ChangedPeerIDs []string
|
||||
Policies []*types.Policy
|
||||
Routes []*route.Route
|
||||
Routers []*routerTypes.NetworkRouter
|
||||
Resources []*resourceTypes.NetworkResource
|
||||
Networks []*networkTypes.Network
|
||||
PostureCheckIDs []string
|
||||
|
||||
// DistributionGroupIDs are groups whose members are directly affected, with no
|
||||
// dependency walk — the change distributes config to the groups' member peers
|
||||
// only (nameserver groups, DNS DisabledManagementGroups), not through the
|
||||
// policy/route reachability graph. Pass old∪new so both states refresh.
|
||||
DistributionGroupIDs []string
|
||||
|
||||
// RemovedPeersByGroup: peers that left a group, keyed by that group. They are no
|
||||
// longer in the group's member index but still lose its reachability, so they are
|
||||
// folded in — but only when the group is linked (an unlinked group has no map
|
||||
// impact), matching how current members are handled.
|
||||
RemovedPeersByGroup map[string][]string
|
||||
}
|
||||
|
||||
func (c Change) isEmpty() bool {
|
||||
return len(c.ChangedGroupIDs) == 0 &&
|
||||
len(c.ChangedPeerIDs) == 0 &&
|
||||
len(c.Policies) == 0 &&
|
||||
len(c.Routes) == 0 &&
|
||||
len(c.Routers) == 0 &&
|
||||
len(c.Resources) == 0 &&
|
||||
len(c.Networks) == 0 &&
|
||||
len(c.PostureCheckIDs) == 0 &&
|
||||
len(c.DistributionGroupIDs) == 0 &&
|
||||
len(c.RemovedPeersByGroup) == 0
|
||||
}
|
||||
|
||||
// Expand returns the deduplicated affected peer IDs from the preloaded Snapshot,
|
||||
// no store access. Run after the producing tx commits. Logs the full walk at
|
||||
// trace level for diagnosing a miscalculation.
|
||||
func (snap *Snapshot) Expand(ctx context.Context, accountID string, c Change) []string {
|
||||
if c.isEmpty() {
|
||||
return nil
|
||||
}
|
||||
r := newResolver(ctx, snap, accountID, c)
|
||||
log.WithContext(ctx).Tracef("affectedpeers expand start: account=%s changedGroups=%v changedPeers=%v policies=%d routes=%d routers=%d resources=%d networks=%d postureChecks=%v distributionGroups=%v",
|
||||
accountID, c.ChangedGroupIDs, c.ChangedPeerIDs, len(c.Policies), len(c.Routes), len(c.Routers), len(c.Resources), len(c.Networks), c.PostureCheckIDs, c.DistributionGroupIDs)
|
||||
r.walk()
|
||||
return r.expand()
|
||||
}
|
||||
|
||||
// Collect returns the affected group and direct-peer IDs without expanding groups
|
||||
// to members. Test-only introspection; use Resolve otherwise.
|
||||
func Collect(ctx context.Context, s store.Store, accountID string, c Change) (groupIDs []string, directPeerIDs []string) {
|
||||
if c.isEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
snap, err := Load(ctx, s, accountID, c)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to load snapshot for affected peers collect: %v", err)
|
||||
return nil, nil
|
||||
}
|
||||
r := newResolver(ctx, snap, accountID, c)
|
||||
r.walk()
|
||||
return setToSlice(r.groupSet), setToSlice(r.peerSet)
|
||||
}
|
||||
|
||||
func newResolver(ctx context.Context, snap *Snapshot, accountID string, c Change) *resolver {
|
||||
r := &resolver{
|
||||
ctx: ctx,
|
||||
snap: snap,
|
||||
accountID: accountID,
|
||||
change: c,
|
||||
changedGroupSet: toSet(c.ChangedGroupIDs),
|
||||
changedPeerSet: toSet(c.ChangedPeerIDs),
|
||||
groupSet: make(map[string]struct{}),
|
||||
peerSet: make(map[string]struct{}),
|
||||
networkIDs: make(map[string]struct{}),
|
||||
}
|
||||
// Resolve each changed peer to its groups here so callers pass only ChangedPeerIDs.
|
||||
r.seedChangedGroupsFromPeers()
|
||||
r.matchedPolicies = append(r.matchedPolicies, c.Policies...)
|
||||
return r
|
||||
}
|
||||
|
||||
// seedChangedGroupsFromPeers adds each changed peer's groups to changedGroupSet so
|
||||
// the group-driven walkers fire for memberships, not just direct peer references.
|
||||
func (r *resolver) seedChangedGroupsFromPeers() {
|
||||
if len(r.changedPeerSet) == 0 {
|
||||
return
|
||||
}
|
||||
for groupID, members := range r.snap.groupPeers {
|
||||
for pID := range r.changedPeerSet {
|
||||
if _, ok := members[pID]; ok {
|
||||
r.changedGroupSet[groupID] = struct{}{}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) walk() {
|
||||
r.collectFromExplicitPolicies()
|
||||
r.collectFromExplicitRoutes(r.change.Routes)
|
||||
r.collectFromExplicitRouters(r.change.Routers)
|
||||
r.collectFromExplicitResources(r.change.Resources)
|
||||
r.collectFromExplicitNetworks(r.change.Networks)
|
||||
r.collectFromPostureChecks(r.change.PostureCheckIDs)
|
||||
|
||||
// Distribution groups (nameserver/DNS) affect only their member peers: fold them
|
||||
// straight into groupSet so expand() maps them to members, without the policy/
|
||||
// route walk that changedGroupSet would trigger.
|
||||
addAll(r.groupSet, r.change.DistributionGroupIDs)
|
||||
|
||||
if len(r.changedGroupSet) > 0 || len(r.changedPeerSet) > 0 {
|
||||
r.collectFromPolicies()
|
||||
r.collectFromRoutes()
|
||||
r.collectFromNameServers()
|
||||
r.collectFromDNSSettings()
|
||||
r.collectFromNetworkRouters()
|
||||
r.collectFromProxyServices()
|
||||
}
|
||||
|
||||
r.collectResourceRouterBridge()
|
||||
}
|
||||
|
||||
type resolver struct {
|
||||
ctx context.Context
|
||||
snap *Snapshot
|
||||
accountID string
|
||||
change Change
|
||||
|
||||
changedGroupSet map[string]struct{}
|
||||
changedPeerSet map[string]struct{}
|
||||
|
||||
groupSet map[string]struct{}
|
||||
peerSet map[string]struct{}
|
||||
|
||||
matchedPolicies []*types.Policy
|
||||
networkIDs map[string]struct{}
|
||||
}
|
||||
|
||||
func (r *resolver) policies() []*types.Policy { return r.snap.policies }
|
||||
|
||||
func (r *resolver) networkResources() []*resourceTypes.NetworkResource { return r.snap.resources }
|
||||
|
||||
func (r *resolver) networkRouters() []*routerTypes.NetworkRouter { return r.snap.routers }
|
||||
|
||||
// peerIDsForGroups maps a group set to its member peer IDs via the preloaded index.
|
||||
func (r *resolver) peerIDsForGroups(groupSet map[string]struct{}) []string {
|
||||
seen := make(map[string]struct{})
|
||||
var ids []string
|
||||
for gID := range groupSet {
|
||||
for pID := range r.snap.groupPeers[gID] {
|
||||
if _, ok := seen[pID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[pID] = struct{}{}
|
||||
ids = append(ids, pID)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func (r *resolver) expand() []string {
|
||||
peerIDs := r.peerIDsForGroups(r.groupSet)
|
||||
|
||||
log.WithContext(r.ctx).Tracef("affectedpeers expand: account=%s affectedGroups=%v -> %d group-member peers; direct peers=%v",
|
||||
r.accountID, setToSlice(r.groupSet), len(peerIDs), setToSlice(r.peerSet))
|
||||
|
||||
seen := make(map[string]struct{}, len(peerIDs))
|
||||
for _, id := range peerIDs {
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
for id := range r.peerSet {
|
||||
if _, ok := seen[id]; !ok {
|
||||
peerIDs = append(peerIDs, id)
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Fold in removed peers only when their group is linked (in groupSet).
|
||||
for groupID, removed := range r.change.RemovedPeersByGroup {
|
||||
if _, linked := r.groupSet[groupID]; !linked {
|
||||
continue
|
||||
}
|
||||
for _, id := range removed {
|
||||
if _, ok := seen[id]; !ok {
|
||||
peerIDs = append(peerIDs, id)
|
||||
seen[id] = struct{}{}
|
||||
log.WithContext(r.ctx).Tracef("affectedpeers expand: removed peer %s from linked group %s -> affected", id, groupID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.WithContext(r.ctx).Tracef("affectedpeers expand done: account=%s -> %d affected peers: %v", r.accountID, len(peerIDs), peerIDs)
|
||||
return peerIDs
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromExplicitPolicies() {
|
||||
for _, policy := range r.matchedPolicies {
|
||||
if policy == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitPolicies: changed policy %s (%s) -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromExplicitRoutes(routes []*route.Route) {
|
||||
for _, rt := range routes {
|
||||
if rt == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitRoutes: changed route %s -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
|
||||
rt.ID, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
|
||||
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
|
||||
if rt.Peer != "" {
|
||||
r.peerSet[rt.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromExplicitRouters folds changed routers' peers and marks their networks
|
||||
// for the bridge. Passing the old router keeps a repointed router's previous peers
|
||||
// affected without a post-commit read.
|
||||
func (r *resolver) collectFromExplicitRouters(routers []*routerTypes.NetworkRouter) {
|
||||
for _, router := range routers {
|
||||
if router == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitRouters: changed router %s on network %s -> folding peerGroups=%v peer=%q and marking network for source bridge",
|
||||
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
|
||||
addAll(r.groupSet, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.peerSet[router.Peer] = struct{}{}
|
||||
}
|
||||
if router.NetworkID != "" {
|
||||
r.networkIDs[router.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromExplicitResources marks changed resources' networks for the bridge and
|
||||
// treats their group IDs as changed, so policies targeting the resource via a
|
||||
// now-detached (old) group still refresh.
|
||||
func (r *resolver) collectFromExplicitResources(resources []*resourceTypes.NetworkResource) {
|
||||
for _, resource := range resources {
|
||||
if resource == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitResources: changed resource %s on network %s -> marking network for bridge and treating groups %v as changed",
|
||||
resource.ID, resource.NetworkID, resource.GroupIDs)
|
||||
addAll(r.changedGroupSet, resource.GroupIDs)
|
||||
if resource.NetworkID != "" {
|
||||
r.networkIDs[resource.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromExplicitNetworks marks changed networks for the bridge. A network has
|
||||
// no groups/peers of its own.
|
||||
func (r *resolver) collectFromExplicitNetworks(networks []*networkTypes.Network) {
|
||||
for _, network := range networks {
|
||||
if network == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitNetworks: changed network %s -> marking for bridge", network.ID)
|
||||
if network.ID != "" {
|
||||
r.networkIDs[network.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromPostureChecks(postureCheckIDs []string) {
|
||||
if len(postureCheckIDs) == 0 {
|
||||
return
|
||||
}
|
||||
ids := toSet(postureCheckIDs)
|
||||
for _, policy := range r.policies() {
|
||||
if !policyReferencesPostureChecks(policy, ids) {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromPostureChecks: policy %s (%s) references changed posture checks %v -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, postureCheckIDs, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
r.matchedPolicies = append(r.matchedPolicies, policy)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromPolicies() {
|
||||
for _, policy := range r.policies() {
|
||||
matchedByGroup := policyReferencesGroups(policy, r.changedGroupSet)
|
||||
matchedByPeer := len(r.changedPeerSet) > 0 && policyReferencesDirectPeers(policy, r.changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromPolicies: policy %s (%s) matched (byGroup=%t byPeer=%t) -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, matchedByGroup, matchedByPeer, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
r.matchedPolicies = append(r.matchedPolicies, policy)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromRoutes() {
|
||||
for _, rt := range r.snap.routes {
|
||||
matchedByGroup := anyInSet(rt.Groups, r.changedGroupSet) || anyInSet(rt.PeerGroups, r.changedGroupSet) || anyInSet(rt.AccessControlGroups, r.changedGroupSet)
|
||||
matchedByPeer := rt.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(rt.Peer, r.changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromRoutes: route %s matched (byGroup=%t byPeer=%t) -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
|
||||
rt.ID, matchedByGroup, matchedByPeer, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
|
||||
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
|
||||
if rt.Peer != "" {
|
||||
r.peerSet[rt.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromNameServers() {
|
||||
if len(r.changedGroupSet) == 0 {
|
||||
return
|
||||
}
|
||||
for _, ns := range r.snap.nsGroups {
|
||||
if anyInSet(ns.Groups, r.changedGroupSet) {
|
||||
log.WithContext(r.ctx).Tracef("collectFromNameServers: nameserver group %s references a changed group -> folding its groups %v", ns.ID, ns.Groups)
|
||||
addAll(r.groupSet, ns.Groups)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromDNSSettings() {
|
||||
if len(r.changedGroupSet) == 0 || r.snap.dnsSettings == nil {
|
||||
return
|
||||
}
|
||||
for _, gID := range r.snap.dnsSettings.DisabledManagementGroups {
|
||||
if _, ok := r.changedGroupSet[gID]; ok {
|
||||
log.WithContext(r.ctx).Tracef("collectFromDNSSettings: changed group %s is in DisabledManagementGroups -> folding it", gID)
|
||||
r.groupSet[gID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromNetworkRouters() {
|
||||
for _, router := range r.networkRouters() {
|
||||
matchedByGroup := anyInSet(router.PeerGroups, r.changedGroupSet)
|
||||
matchedByPeer := router.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(router.Peer, r.changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromNetworkRouters: router %s on network %s matched (byGroup=%t byPeer=%t) -> folding peerGroups=%v peer=%q and marking network for source bridge",
|
||||
router.ID, router.NetworkID, matchedByGroup, matchedByPeer, router.PeerGroups, router.Peer)
|
||||
addAll(r.groupSet, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.peerSet[router.Peer] = struct{}{}
|
||||
}
|
||||
r.networkIDs[router.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromProxyServices() {
|
||||
if len(r.snap.proxyByCluster) == 0 || len(r.snap.services) == 0 {
|
||||
return
|
||||
}
|
||||
services, proxyByCluster := r.snap.services, r.snap.proxyByCluster
|
||||
|
||||
expanded := r.expandChangedPeersWithGroups()
|
||||
|
||||
for _, svc := range services {
|
||||
if svc == nil {
|
||||
continue
|
||||
}
|
||||
proxyPeers := proxyByCluster[svc.ProxyCluster]
|
||||
if len(proxyPeers) == 0 {
|
||||
continue
|
||||
}
|
||||
matchedByPeer := serviceMatchesChangedPeers(svc, proxyPeers, expanded)
|
||||
matchedByAccessGroup := anyInSet(svc.AccessGroups, r.changedGroupSet)
|
||||
if !matchedByPeer && !matchedByAccessGroup {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromProxyServices: service %s (cluster=%s) matched (byProxyOrTargetPeer=%t byAccessGroup=%t) -> folding %d proxy peers, peer targets and access groups %v",
|
||||
svc.ID, svc.ProxyCluster, matchedByPeer, matchedByAccessGroup, len(proxyPeers), svc.AccessGroups)
|
||||
for _, pid := range proxyPeers {
|
||||
r.peerSet[pid] = struct{}{}
|
||||
}
|
||||
for _, target := range svc.Targets {
|
||||
if target.TargetType == rpservice.TargetTypePeer && target.TargetId != "" {
|
||||
r.peerSet[target.TargetId] = struct{}{}
|
||||
}
|
||||
}
|
||||
addAll(r.groupSet, svc.AccessGroups)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) expandChangedPeersWithGroups() map[string]struct{} {
|
||||
if len(r.changedGroupSet) == 0 {
|
||||
return r.changedPeerSet
|
||||
}
|
||||
ids := r.peerIDsForGroups(r.changedGroupSet)
|
||||
if len(ids) == 0 {
|
||||
return r.changedPeerSet
|
||||
}
|
||||
merged := make(map[string]struct{}, len(r.changedPeerSet)+len(ids))
|
||||
for id := range r.changedPeerSet {
|
||||
merged[id] = struct{}{}
|
||||
}
|
||||
for _, id := range ids {
|
||||
merged[id] = struct{}{}
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
// collectResourceRouterBridge crosses between source peers and routing peers, which
|
||||
// are reachable only via resource -> network -> router, not through the policy's own
|
||||
// groups: source -> router (targeted resources' networks), then router -> source.
|
||||
func (r *resolver) collectResourceRouterBridge() {
|
||||
r.bridgeSourceToRouters()
|
||||
r.bridgeRoutersToSources()
|
||||
}
|
||||
|
||||
func (r *resolver) bridgeSourceToRouters() {
|
||||
resourceIDs := r.policyDestinationResourceIDs(r.matchedPolicies...)
|
||||
if len(resourceIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
networkIDs := r.resourceNetworkIDs(resourceIDs)
|
||||
log.WithContext(r.ctx).Tracef("bridgeSourceToRouters: targeted resources %v -> networks %v (their routers become affected via the router->source pass)",
|
||||
setToSlice(resourceIDs), setToSlice(networkIDs))
|
||||
for id := range networkIDs {
|
||||
r.networkIDs[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) bridgeRoutersToSources() {
|
||||
if len(r.networkIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: affected networks %v -> folding their routing peers and the source peers of policies targeting their resources",
|
||||
setToSlice(r.networkIDs))
|
||||
|
||||
r.foldRoutersOnNetworks(r.networkIDs)
|
||||
|
||||
resourceIDs := make(map[string]struct{})
|
||||
for _, resource := range r.networkResources() {
|
||||
if _, ok := r.networkIDs[resource.NetworkID]; ok {
|
||||
resourceIDs[resource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(resourceIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, policy := range r.policies() {
|
||||
if r.policyTargetsResources(policy, resourceIDs) {
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: policy %s (%s) targets an affected-network resource -> folding its source groups/peers", policy.ID, policy.Name)
|
||||
collectPolicySources(policy, r.groupSet, r.peerSet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) foldRoutersOnNetworks(networkIDs map[string]struct{}) {
|
||||
for _, router := range r.networkRouters() {
|
||||
if _, ok := networkIDs[router.NetworkID]; !ok {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: router %s serves affected network %s -> folding peerGroups=%v peer=%q",
|
||||
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
|
||||
addAll(r.groupSet, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.peerSet[router.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) resourceNetworkIDs(resourceIDs map[string]struct{}) map[string]struct{} {
|
||||
networkIDs := make(map[string]struct{})
|
||||
for _, resource := range r.networkResources() {
|
||||
if _, ok := resourceIDs[resource.ID]; ok {
|
||||
networkIDs[resource.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
return networkIDs
|
||||
}
|
||||
|
||||
func (r *resolver) policyTargetsResources(policy *types.Policy, resourceIDs map[string]struct{}) bool {
|
||||
if policy == nil {
|
||||
return false
|
||||
}
|
||||
destGroupSet := make(map[string]struct{})
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.DestinationResource.Type != types.ResourceTypePeer && isInSet(rule.DestinationResource.ID, resourceIDs) {
|
||||
return true
|
||||
}
|
||||
for _, gID := range rule.Destinations {
|
||||
destGroupSet[gID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(destGroupSet) == 0 {
|
||||
return false
|
||||
}
|
||||
for gID := range destGroupSet {
|
||||
group := r.snap.groups[gID]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
for _, res := range group.Resources {
|
||||
if isInSet(res.ID, resourceIDs) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *resolver) policyDestinationResourceIDs(policies ...*types.Policy) map[string]struct{} {
|
||||
resourceIDs := make(map[string]struct{})
|
||||
destGroupSet := collectPolicyDestinations(resourceIDs, policies...)
|
||||
r.addGroupResourceIDs(destGroupSet, resourceIDs)
|
||||
return resourceIDs
|
||||
}
|
||||
|
||||
// collectPolicyDestinations adds direct destination resource IDs to resourceIDs and
|
||||
// returns the referenced destination group IDs.
|
||||
func collectPolicyDestinations(resourceIDs map[string]struct{}, policies ...*types.Policy) map[string]struct{} {
|
||||
destGroupSet := make(map[string]struct{})
|
||||
for _, policy := range policies {
|
||||
if policy == nil {
|
||||
continue
|
||||
}
|
||||
for _, rule := range policy.Rules {
|
||||
addAll(destGroupSet, rule.Destinations)
|
||||
if rule.DestinationResource.Type != types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
resourceIDs[rule.DestinationResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
return destGroupSet
|
||||
}
|
||||
|
||||
// addGroupResourceIDs folds the resource IDs of the given groups into resourceIDs.
|
||||
func (r *resolver) addGroupResourceIDs(groupIDs map[string]struct{}, resourceIDs map[string]struct{}) {
|
||||
for gID := range groupIDs {
|
||||
group := r.snap.groups[gID]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
for _, res := range group.Resources {
|
||||
if res.ID != "" {
|
||||
resourceIDs[res.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func collectPolicyDirectPeers(policy *types.Policy, peerSet map[string]struct{}) {
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
peerSet[rule.DestinationResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func collectPolicySources(policy *types.Policy, groupSet, peerSet map[string]struct{}) {
|
||||
for _, rule := range policy.Rules {
|
||||
addAll(groupSet, rule.Sources)
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
if isDirectPeerInSet(rule.SourceResource, changedSet) || isDirectPeerInSet(rule.DestinationResource, changedSet) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func policyReferencesPostureChecks(policy *types.Policy, ids map[string]struct{}) bool {
|
||||
for _, id := range policy.SourcePostureChecks {
|
||||
if _, ok := ids[id]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isDirectPeerInSet(res types.Resource, set map[string]struct{}) bool {
|
||||
if res.Type != types.ResourceTypePeer || res.ID == "" {
|
||||
return false
|
||||
}
|
||||
_, ok := set[res.ID]
|
||||
return ok
|
||||
}
|
||||
|
||||
func serviceMatchesChangedPeers(svc *rpservice.Service, proxyPeers []string, changedPeers map[string]struct{}) bool {
|
||||
for _, pid := range proxyPeers {
|
||||
if _, ok := changedPeers[pid]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, target := range svc.Targets {
|
||||
if target.TargetType != rpservice.TargetTypePeer || target.TargetId == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := changedPeers[target.TargetId]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func anyInSet(ids []string, set map[string]struct{}) bool {
|
||||
for _, id := range ids {
|
||||
if _, ok := set[id]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isInSet(id string, set map[string]struct{}) bool {
|
||||
_, ok := set[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
func addAll(set map[string]struct{}, slices ...[]string) {
|
||||
for _, s := range slices {
|
||||
for _, id := range s {
|
||||
set[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func toSet(ids []string) map[string]struct{} {
|
||||
set := make(map[string]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
set[id] = struct{}{}
|
||||
}
|
||||
return set
|
||||
}
|
||||
|
||||
func setToSlice(set map[string]struct{}) []string {
|
||||
s := make([]string, 0, len(set))
|
||||
for id := range set {
|
||||
s = append(s, id)
|
||||
}
|
||||
return s
|
||||
}
|
||||
140
management/server/affectedpeers/resolver_test.go
Normal file
140
management/server/affectedpeers/resolver_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package affectedpeers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// policyGroupsAndPeers mirrors the explicit-policy extraction (RuleGroups +
|
||||
// direct peers) the resolver folds in, for asserting the pure logic.
|
||||
func policyGroupsAndPeers(policies ...*types.Policy) (groups []string, peers []string) {
|
||||
peerSet := map[string]struct{}{}
|
||||
for _, p := range policies {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
groups = append(groups, p.RuleGroups()...)
|
||||
collectPolicyDirectPeers(p, peerSet)
|
||||
}
|
||||
for id := range peerSet {
|
||||
peers = append(peers, id)
|
||||
}
|
||||
return groups, peers
|
||||
}
|
||||
|
||||
func TestPolicyGroupsAndPeers_Basic(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1", "g2"}, Destinations: []string{"g3"}}}}
|
||||
groups, peers := policyGroupsAndPeers(policy)
|
||||
assert.ElementsMatch(t, []string{"g1", "g2", "g3"}, groups)
|
||||
assert.Empty(t, peers)
|
||||
}
|
||||
|
||||
func TestPolicyGroupsAndPeers_WithPeerResources(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
Sources: []string{"g1"},
|
||||
SourceResource: types.Resource{ID: "p1", Type: types.ResourceTypePeer},
|
||||
Destinations: []string{"g2"},
|
||||
DestinationResource: types.Resource{ID: "p2", Type: types.ResourceTypePeer},
|
||||
}}}
|
||||
groups, peers := policyGroupsAndPeers(policy)
|
||||
assert.ElementsMatch(t, []string{"g1", "g2"}, groups)
|
||||
assert.ElementsMatch(t, []string{"p1", "p2"}, peers)
|
||||
}
|
||||
|
||||
func TestPolicyGroupsAndPeers_NilPolicy(t *testing.T) {
|
||||
groups, peers := policyGroupsAndPeers(nil)
|
||||
assert.Nil(t, groups)
|
||||
assert.Nil(t, peers)
|
||||
}
|
||||
|
||||
func TestPolicyGroupsAndPeers_MultiplePolicies(t *testing.T) {
|
||||
old := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1"}, Destinations: []string{"g2"}}}}
|
||||
updated := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g3"}, Destinations: []string{"g4"}}}}
|
||||
groups, _ := policyGroupsAndPeers(updated, old)
|
||||
assert.ElementsMatch(t, []string{"g1", "g2", "g3", "g4"}, groups)
|
||||
}
|
||||
|
||||
func TestPolicyGroupsAndPeers_NonPeerResource(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
Sources: []string{"g1"},
|
||||
SourceResource: types.Resource{ID: "domain-1", Type: types.ResourceTypeDomain},
|
||||
Destinations: []string{"g2"},
|
||||
}}}
|
||||
groups, peers := policyGroupsAndPeers(policy)
|
||||
assert.ElementsMatch(t, []string{"g1", "g2"}, groups)
|
||||
assert.Empty(t, peers, "domain resource type should not produce direct peer IDs")
|
||||
}
|
||||
|
||||
func TestChangeIsEmpty(t *testing.T) {
|
||||
assert.True(t, Change{}.isEmpty())
|
||||
assert.False(t, Change{ChangedGroupIDs: []string{"g"}}.isEmpty())
|
||||
assert.False(t, Change{ChangedPeerIDs: []string{"p"}}.isEmpty())
|
||||
assert.False(t, Change{Policies: []*types.Policy{{}}}.isEmpty())
|
||||
assert.False(t, Change{Resources: []*resourceTypes.NetworkResource{{ID: "r"}}}.isEmpty())
|
||||
assert.False(t, Change{Networks: []*networkTypes.Network{{ID: "n"}}}.isEmpty())
|
||||
assert.False(t, Change{PostureCheckIDs: []string{"pc"}}.isEmpty())
|
||||
}
|
||||
|
||||
func TestPolicyReferencesGroups(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1", "g2"}, Destinations: []string{"g3"}}}}
|
||||
|
||||
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g1": {}}))
|
||||
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g3": {}}))
|
||||
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{"g4": {}}))
|
||||
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{}))
|
||||
}
|
||||
|
||||
func TestPolicyReferencesDirectPeers(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
|
||||
}}}
|
||||
|
||||
assert.True(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p1": {}}))
|
||||
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"r1": {}}))
|
||||
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p2": {}}))
|
||||
}
|
||||
|
||||
func TestPolicyReferencesPostureChecks(t *testing.T) {
|
||||
policy := &types.Policy{SourcePostureChecks: []string{"pc1", "pc2"}}
|
||||
|
||||
assert.True(t, policyReferencesPostureChecks(policy, map[string]struct{}{"pc1": {}}))
|
||||
assert.False(t, policyReferencesPostureChecks(policy, map[string]struct{}{"pc3": {}}))
|
||||
}
|
||||
|
||||
func TestCollectPolicyDirectPeers(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypePeer, ID: "p2"},
|
||||
}, {
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
|
||||
}}}
|
||||
|
||||
peerSet := map[string]struct{}{}
|
||||
collectPolicyDirectPeers(policy, peerSet)
|
||||
|
||||
assert.Contains(t, peerSet, "p1")
|
||||
assert.Contains(t, peerSet, "p2")
|
||||
assert.NotContains(t, peerSet, "r1")
|
||||
}
|
||||
|
||||
func TestCollectPolicySources(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
Sources: []string{"g1"},
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
Destinations: []string{"g2"},
|
||||
}}}
|
||||
|
||||
groupSet := map[string]struct{}{}
|
||||
peerSet := map[string]struct{}{}
|
||||
collectPolicySources(policy, groupSet, peerSet)
|
||||
|
||||
assert.Contains(t, groupSet, "g1")
|
||||
assert.NotContains(t, groupSet, "g2", "destination groups must not be collected as sources")
|
||||
assert.Contains(t, peerSet, "p1")
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -47,8 +48,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var updateAccountPeers bool
|
||||
var eventsToStore []func()
|
||||
var snap *affectedpeers.Snapshot
|
||||
var change affectedpeers.Change
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
|
||||
@@ -63,11 +65,6 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
addedGroups := util.Difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
||||
removedGroups := util.Difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||
|
||||
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
@@ -75,6 +72,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
return err
|
||||
}
|
||||
|
||||
change = affectedpeers.Change{DistributionGroupIDs: slices.Concat(addedGroups, removedGroups)}
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -85,9 +87,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceDNSSettings, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -133,20 +133,6 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
|
||||
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, addedGroups)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, accountID, removedGroups)
|
||||
}
|
||||
|
||||
// validateDNSSettings validates the DNS settings.
|
||||
func validateDNSSettings(ctx context.Context, transaction store.Store, accountID string, settings *types.DNSSettings) error {
|
||||
if len(settings.DisabledManagementGroups) == 0 {
|
||||
|
||||
@@ -298,11 +298,11 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account
|
||||
return nil, err
|
||||
}
|
||||
|
||||
savedPeer1, _, _, err := am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer1, false)
|
||||
savedPeer1, _, _, _, err := am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer1, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer2, false)
|
||||
_, _, _, _, err = am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer2, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
@@ -79,7 +80,8 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
var snap *affectedpeers.Snapshot
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{newGroup.ID}}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
@@ -91,11 +93,6 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to create group: %v", err)
|
||||
}
|
||||
@@ -106,6 +103,11 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
}
|
||||
|
||||
snap, err = affectedpeers.Load(ctx, transaction, accountID, change)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -116,9 +118,7 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate})
|
||||
}
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -134,7 +134,8 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
var snap *affectedpeers.Snapshot
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{newGroup.ID}}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
@@ -153,20 +154,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
|
||||
peersToAdd := util.Difference(newGroup.Peers, oldGroup.Peers)
|
||||
peersToRemove := util.Difference(oldGroup.Peers, newGroup.Peers)
|
||||
|
||||
for _, peerID := range peersToAdd {
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, newGroup.ID); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, newGroup.ID, err)
|
||||
}
|
||||
}
|
||||
for _, peerID := range peersToRemove {
|
||||
if err := transaction.RemovePeerFromGroup(ctx, peerID, newGroup.ID); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to remove peer %s from group %s: %v", peerID, newGroup.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
|
||||
if err != nil {
|
||||
if err = syncGroupMembership(ctx, transaction, accountID, newGroup.ID, peersToAdd, peersToRemove); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -178,6 +166,17 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
return err
|
||||
}
|
||||
|
||||
// A membership change does not alter which entities reference the group, so
|
||||
// the dependency walk runs once against the post-change snapshot. The new
|
||||
// members are already in the snapshot's index; the removed members are
|
||||
// carried separately and folded in only when the group is linked.
|
||||
if len(peersToRemove) > 0 {
|
||||
change.RemovedPeersByGroup = map[string][]string{newGroup.ID: peersToRemove}
|
||||
}
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -188,13 +187,26 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// syncGroupMembership applies the peer membership delta for a group within a transaction.
|
||||
func syncGroupMembership(ctx context.Context, transaction store.Store, accountID, groupID string, peersToAdd, peersToRemove []string) error {
|
||||
for _, peerID := range peersToAdd {
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, groupID, err)
|
||||
}
|
||||
}
|
||||
for _, peerID := range peersToRemove {
|
||||
if err := transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to remove peer %s from group %s: %v", peerID, groupID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateGroups adds new groups to the account.
|
||||
// Note: This function does not acquire the global lock.
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
@@ -209,11 +221,14 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
var snaps []*affectedpeers.Snapshot
|
||||
var changes []affectedpeers.Change
|
||||
|
||||
var globalErr error
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
createdCount := 0
|
||||
for _, newGroup := range groups {
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{newGroup.ID}}
|
||||
var snap *affectedpeers.Snapshot
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
@@ -230,35 +245,31 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
return err
|
||||
}
|
||||
|
||||
groupIDs = append(groupIDs, newGroup.ID)
|
||||
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
return nil
|
||||
snap, err = affectedpeers.Load(ctx, transaction, accountID, change)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err)
|
||||
if len(groupIDs) == 1 {
|
||||
if createdCount == 0 {
|
||||
return err
|
||||
}
|
||||
globalErr = errors.Join(globalErr, err)
|
||||
// continue updating other groups
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
createdCount++
|
||||
snaps = append(snaps, snap)
|
||||
changes = append(changes, change)
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate})
|
||||
}
|
||||
go am.dispatchAffected(ctx, accountID, snaps, changes)
|
||||
|
||||
return globalErr
|
||||
}
|
||||
@@ -277,12 +288,13 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
var snaps []*affectedpeers.Snapshot
|
||||
var changes []affectedpeers.Change
|
||||
|
||||
var globalErr error
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
for _, newGroup := range groups {
|
||||
events, err := am.updateSingleGroup(ctx, accountID, userID, newGroup)
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{newGroup.ID}}
|
||||
events, snap, err := am.updateSingleGroup(ctx, accountID, userID, newGroup, change)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err)
|
||||
if len(groups) == 1 {
|
||||
@@ -292,27 +304,22 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
|
||||
continue
|
||||
}
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
groupIDs = append(groupIDs, newGroup.ID)
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
snaps = append(snaps, snap)
|
||||
changes = append(changes, change)
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
go am.dispatchAffected(ctx, accountID, snaps, changes)
|
||||
|
||||
return globalErr
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) ([]func(), error) {
|
||||
func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountID, userID string, newGroup *types.Group, change affectedpeers.Change) ([]func(), *affectedpeers.Snapshot, error) {
|
||||
var events []func()
|
||||
var snap *affectedpeers.Snapshot
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
@@ -333,9 +340,12 @@ func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
events = am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
return nil
|
||||
|
||||
var err error
|
||||
snap, err = affectedpeers.Load(ctx, transaction, accountID, change)
|
||||
return err
|
||||
})
|
||||
return events, err
|
||||
return events, snap, err
|
||||
}
|
||||
|
||||
// prepareGroupEvents prepares a list of event functions to be stored.
|
||||
@@ -438,6 +448,8 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
var allErrors error
|
||||
var groupIDsToDelete []string
|
||||
var deletedGroups []*types.Group
|
||||
var snap *affectedpeers.Snapshot
|
||||
var change affectedpeers.Change
|
||||
|
||||
extraSettings, err := am.settingsManager.GetExtraSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
@@ -445,26 +457,23 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
for _, groupID := range groupIDs {
|
||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||
if err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err = validateDeleteGroup(ctx, transaction, group, userID, extraSettings.FlowGroups); err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
|
||||
groupIDsToDelete = append(groupIDsToDelete, groupID)
|
||||
deletedGroups = append(deletedGroups, group)
|
||||
deletedGroups, allErrors = collectDeletableGroups(ctx, transaction, accountID, userID, groupIDs, extraSettings.FlowGroups)
|
||||
for _, group := range deletedGroups {
|
||||
groupIDsToDelete = append(groupIDsToDelete, group.ID)
|
||||
}
|
||||
|
||||
if len(groupIDsToDelete) == 0 {
|
||||
return allErrors
|
||||
}
|
||||
|
||||
// Delete: compute affected peers from the PRE-delete state. The groups,
|
||||
// their members and the entities referencing them still exist, so a plain
|
||||
// Load+Expand captures everyone — no removed-peer folding needed.
|
||||
change = affectedpeers.Change{ChangedGroupIDs: groupIDsToDelete}
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.DeleteGroups(ctx, accountID, groupIDsToDelete); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -483,25 +492,47 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
|
||||
}
|
||||
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return allErrors
|
||||
}
|
||||
|
||||
// collectDeletableGroups loads and validates each group for deletion, returning
|
||||
// the groups that may be deleted and the joined validation errors for the rest.
|
||||
func collectDeletableGroups(ctx context.Context, transaction store.Store, accountID, userID string, groupIDs, flowGroups []string) ([]*types.Group, error) {
|
||||
var deletable []*types.Group
|
||||
var allErrors error
|
||||
for _, groupID := range groupIDs {
|
||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||
if err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
if err = validateDeleteGroup(ctx, transaction, group, userID, flowGroups); err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
deletable = append(deletable, group)
|
||||
}
|
||||
return deletable, allErrors
|
||||
}
|
||||
|
||||
// GroupAddPeer appends peer to the group
|
||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
var snap *affectedpeers.Snapshot
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{groupID}}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
|
||||
if err := am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil {
|
||||
var err error
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -511,9 +542,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -521,8 +550,9 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
// GroupAddResource appends resource to the group
|
||||
func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
|
||||
var group *types.Group
|
||||
var updateAccountPeers bool
|
||||
var snap *affectedpeers.Snapshot
|
||||
var err error
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{groupID}}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
|
||||
@@ -534,12 +564,11 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
if err = transaction.UpdateGroup(ctx, group); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.UpdateGroup(ctx, group); err != nil {
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -549,29 +578,32 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GroupDeletePeer removes peer from the group
|
||||
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
var snap *affectedpeers.Snapshot
|
||||
change := affectedpeers.Change{
|
||||
ChangedGroupIDs: []string{groupID},
|
||||
RemovedPeersByGroup: map[string][]string{groupID: {peerID}},
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
|
||||
if err := am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil {
|
||||
// The removed peer is carried in change.RemovedPeersByGroup and folded in
|
||||
// only when the group is linked, so loading post-removal is correct.
|
||||
var err error
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -581,9 +613,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -591,8 +621,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
// GroupDeleteResource removes resource from the group
|
||||
func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
|
||||
var group *types.Group
|
||||
var updateAccountPeers bool
|
||||
var snap *affectedpeers.Snapshot
|
||||
var err error
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{groupID}}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
|
||||
@@ -604,8 +635,9 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
// Load before persisting the removal, so the snapshot still maps the group
|
||||
// to the resource and the bridge can reach its routing peers.
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -619,9 +651,7 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -832,49 +862,103 @@ func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store,
|
||||
}
|
||||
|
||||
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
|
||||
// It fetches each collection once and checks all groupIDs against them in memory.
|
||||
func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
groupSet := make(map[string]struct{}, len(groupIDs))
|
||||
for _, id := range groupIDs {
|
||||
groupSet[id] = struct{}{}
|
||||
}
|
||||
|
||||
if affected, err := dnsSettingsReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
|
||||
return affected, err
|
||||
}
|
||||
if affected, err := nameServersReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
|
||||
return affected, err
|
||||
}
|
||||
if affected, err := policiesReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
|
||||
return affected, err
|
||||
}
|
||||
if affected, err := routesReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
|
||||
return affected, err
|
||||
}
|
||||
if affected, err := networkRoutersReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
|
||||
return affected, err
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func dnsSettingsReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, groupID := range groupIDs {
|
||||
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToNetworkRouter(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
return anyInSet(dnsSettings.DisabledManagementGroups, groupSet), nil
|
||||
}
|
||||
|
||||
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
|
||||
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs)
|
||||
func nameServersReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
|
||||
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
if group.HasPeers() || group.HasResources() {
|
||||
for _, ns := range nameServerGroups {
|
||||
if anyInSet(ns.Groups, groupSet) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func policiesReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, policy := range policies {
|
||||
for _, rule := range policy.Rules {
|
||||
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func routesReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
|
||||
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, r := range routes {
|
||||
if anyInSet(r.Groups, groupSet) || anyInSet(r.PeerGroups, groupSet) || anyInSet(r.AccessControlGroups, groupSet) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func networkRoutersReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
|
||||
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, router := range routers {
|
||||
if anyInSet(router.PeerGroups, groupSet) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func anyInSet(ids []string, set map[string]struct{}) bool {
|
||||
for _, id := range ids {
|
||||
if _, ok := set[id]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ func TestGroupIPv6Assignment(t *testing.T) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
peer, _, _, err := am.AddPeer(ctx, "", setupKey.Key, "", &nbpeer.Peer{
|
||||
peer, _, _, _, err := am.AddPeer(ctx, "", setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "ipv6-test-host"},
|
||||
}, false)
|
||||
|
||||
@@ -479,7 +479,7 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
peer, _, _, err := h.accountManager.AddPeer(r.Context(), userAuth.AccountId, "", userAuth.UserId, newPeer, true)
|
||||
peer, _, _, _, err := h.accountManager.AddPeer(r.Context(), userAuth.AccountId, "", userAuth.UserId, newPeer, true)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -728,7 +728,7 @@ func Test_LoginPerformance(t *testing.T) {
|
||||
}
|
||||
|
||||
login := func() error {
|
||||
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
|
||||
_, _, _, _, err = am.LoginPeer(context.Background(), peerLogin)
|
||||
if err != nil {
|
||||
t.Logf("failed to login peer: %v", err)
|
||||
return err
|
||||
@@ -746,7 +746,7 @@ func Test_LoginPerformance(t *testing.T) {
|
||||
|
||||
go func(peerLogin types.PeerLogin, counterStart *int32) {
|
||||
defer wgPeer.Done()
|
||||
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
|
||||
_, _, _, _, err = am.LoginPeer(context.Background(), peerLogin)
|
||||
if err != nil {
|
||||
t.Logf("failed to login peer: %v", err)
|
||||
return
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
@@ -38,13 +39,13 @@ type MockAccountManager struct {
|
||||
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
|
||||
MarkPeerDisconnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
|
||||
AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error)
|
||||
GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error)
|
||||
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error)
|
||||
GetGroupByNameFunc func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error)
|
||||
@@ -97,7 +98,7 @@ type MockAccountManager struct {
|
||||
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error
|
||||
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
||||
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error)
|
||||
ExtendPeerSessionFunc func(ctx context.Context, peerPubKey, userID string) (time.Time, error)
|
||||
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
|
||||
@@ -132,6 +133,7 @@ type MockAccountManager struct {
|
||||
|
||||
AllowSyncFunc func(string, uint64) bool
|
||||
UpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason)
|
||||
ExpandAndUpdateAffectedFunc func(ctx context.Context, accountID string, snap *affectedpeers.Snapshot, change affectedpeers.Change)
|
||||
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason)
|
||||
RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error
|
||||
|
||||
@@ -209,6 +211,12 @@ func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID
|
||||
}
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) ExpandAndUpdateAffected(ctx context.Context, accountID string, snap *affectedpeers.Snapshot, change affectedpeers.Change) {
|
||||
if am.ExpandAndUpdateAffectedFunc != nil {
|
||||
am.ExpandAndUpdateAffectedFunc(ctx, accountID, snap, change)
|
||||
}
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
|
||||
if am.BufferUpdateAccountPeersFunc != nil {
|
||||
am.BufferUpdateAccountPeersFunc(ctx, accountID, reason)
|
||||
@@ -337,9 +345,9 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth
|
||||
}
|
||||
|
||||
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
if am.MarkPeerConnectedFunc != nil {
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt)
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
@@ -416,11 +424,11 @@ func (am *MockAccountManager) AddPeer(
|
||||
userId string,
|
||||
peer *nbpeer.Peer,
|
||||
temporary bool,
|
||||
) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
|
||||
if am.AddPeerFunc != nil {
|
||||
return am.AddPeerFunc(ctx, accountID, setupKey, userId, peer, temporary)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented")
|
||||
return nil, nil, nil, false, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented")
|
||||
}
|
||||
|
||||
// GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface
|
||||
@@ -854,11 +862,11 @@ func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, account
|
||||
}
|
||||
|
||||
// LoginPeer mocks LoginPeer of the AccountManager interface
|
||||
func (am *MockAccountManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
func (am *MockAccountManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
|
||||
if am.LoginPeerFunc != nil {
|
||||
return am.LoginPeerFunc(ctx, login)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented")
|
||||
return nil, nil, nil, false, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented")
|
||||
}
|
||||
|
||||
// ExtendPeerSession mocks ExtendPeerSession of the AccountManager interface
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -57,19 +59,19 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
SearchDomainsEnabled: searchDomainEnabled,
|
||||
}
|
||||
|
||||
var updateAccountPeers bool
|
||||
var snap *affectedpeers.Snapshot
|
||||
change := affectedpeers.Change{DistributionGroupIDs: newNSGroup.Groups}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, newNSGroup.Groups)
|
||||
if err != nil {
|
||||
if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil {
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -81,9 +83,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
|
||||
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationCreate})
|
||||
}
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return newNSGroup.Copy(), nil
|
||||
}
|
||||
@@ -102,7 +102,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var updateAccountPeers bool
|
||||
var snap *affectedpeers.Snapshot
|
||||
var change affectedpeers.Change
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupToSave.ID)
|
||||
@@ -115,12 +116,12 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areNameServerGroupChangesAffectPeers(ctx, transaction, nsGroupToSave, oldNSGroup)
|
||||
if err != nil {
|
||||
if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil {
|
||||
change = affectedpeers.Change{DistributionGroupIDs: slices.Concat(nsGroupToSave.Groups, oldNSGroup.Groups)}
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -132,9 +133,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
|
||||
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -150,7 +149,8 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
}
|
||||
|
||||
var nsGroup *nbdns.NameServerGroup
|
||||
var updateAccountPeers bool
|
||||
var snap *affectedpeers.Snapshot
|
||||
var change affectedpeers.Change
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
nsGroup, err = transaction.GetNameServerGroupByID(ctx, store.LockingStrengthUpdate, accountID, nsGroupID)
|
||||
@@ -158,8 +158,9 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, nsGroup.Groups)
|
||||
if err != nil {
|
||||
// Load before delete: the post-delete state no longer references the groups.
|
||||
change = affectedpeers.Change{DistributionGroupIDs: nsGroup.Groups}
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -175,9 +176,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
|
||||
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationDelete})
|
||||
}
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -224,24 +223,6 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
|
||||
return validateGroups(nameserverGroup.Groups, groups)
|
||||
}
|
||||
|
||||
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
|
||||
func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction store.Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) {
|
||||
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups)
|
||||
}
|
||||
|
||||
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
|
||||
if !primary && len(domains) == 0 {
|
||||
return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
|
||||
|
||||
@@ -896,11 +896,11 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer1, false)
|
||||
_, _, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer1, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer2, false)
|
||||
_, _, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer2, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
"github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -15,7 +16,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
serverTypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -127,30 +127,39 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var snap *affectedpeers.Snapshot
|
||||
change := affectedpeers.Change{Networks: []*types.Network{network}}
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
resources, err := transaction.GetNetworkResourcesByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get resources in network: %w", err)
|
||||
}
|
||||
|
||||
for _, resource := range resources {
|
||||
event, err := m.resourcesManager.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resource.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete resource: %w", err)
|
||||
}
|
||||
eventsToStore = append(eventsToStore, event...)
|
||||
}
|
||||
|
||||
routers, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
|
||||
netRouters, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get routers in network: %w", err)
|
||||
}
|
||||
|
||||
for _, router := range routers {
|
||||
event, err := m.routersManager.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, router.ID)
|
||||
var lerr error
|
||||
if snap, lerr = affectedpeers.Load(ctx, transaction, accountID, change); lerr != nil {
|
||||
return lerr
|
||||
}
|
||||
|
||||
for _, resource := range resources {
|
||||
deleted, event, err := m.resourcesManager.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resource.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete resource: %w", err)
|
||||
}
|
||||
change.Resources = append(change.Resources, deleted)
|
||||
eventsToStore = append(eventsToStore, event...)
|
||||
}
|
||||
|
||||
for _, router := range netRouters {
|
||||
deleted, event, err := m.routersManager.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, router.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete router: %w", err)
|
||||
}
|
||||
change.Routers = append(change.Routers, deleted)
|
||||
eventsToStore = append(eventsToStore, event)
|
||||
}
|
||||
|
||||
@@ -178,7 +187,7 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
|
||||
event()
|
||||
}
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetwork, Operation: serverTypes.UpdateOperationDelete})
|
||||
m.accountManager.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
@@ -29,7 +30,7 @@ type Manager interface {
|
||||
GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error)
|
||||
UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error)
|
||||
DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error
|
||||
DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) ([]func(), error)
|
||||
DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) (*types.NetworkResource, []func(), error)
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
@@ -114,45 +115,12 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var snap *affectedpeers.Snapshot
|
||||
change := affectedpeers.Change{Resources: []*types.NetworkResource{resource}}
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
|
||||
if err == nil {
|
||||
return status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name)
|
||||
}
|
||||
|
||||
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.SaveNetworkResource(ctx, resource)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save network resource: %w", err)
|
||||
}
|
||||
|
||||
event := func() {
|
||||
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceCreated, resource.EventMeta(network))
|
||||
}
|
||||
eventsToStore = append(eventsToStore, event)
|
||||
|
||||
res := nbtypes.Resource{
|
||||
ID: resource.ID,
|
||||
Type: nbtypes.ResourceType(resource.Type.String()),
|
||||
}
|
||||
for _, groupID := range resource.GroupIDs {
|
||||
event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, userID, groupID, &res)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add resource to group: %w", err)
|
||||
}
|
||||
eventsToStore = append(eventsToStore, event)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, resource.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
var txErr error
|
||||
eventsToStore, snap, txErr = m.createResourceInTransaction(ctx, transaction, userID, resource, change)
|
||||
return txErr
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create network resource: %w", err)
|
||||
@@ -162,11 +130,55 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
|
||||
event()
|
||||
}
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationCreate})
|
||||
m.accountManager.ExpandAndUpdateAffected(ctx, resource.AccountID, snap, change)
|
||||
|
||||
return resource, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) createResourceInTransaction(ctx context.Context, transaction store.Store, userID string, resource *types.NetworkResource, change affectedpeers.Change) ([]func(), *affectedpeers.Snapshot, error) {
|
||||
_, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
|
||||
if err == nil {
|
||||
return nil, nil, status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name)
|
||||
}
|
||||
|
||||
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.SaveNetworkResource(ctx, resource); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to save network resource: %w", err)
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceCreated, resource.EventMeta(network))
|
||||
})
|
||||
|
||||
res := nbtypes.Resource{
|
||||
ID: resource.ID,
|
||||
Type: nbtypes.ResourceType(resource.Type.String()),
|
||||
}
|
||||
for _, groupID := range resource.GroupIDs {
|
||||
event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, userID, groupID, &res)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to add resource to group: %w", err)
|
||||
}
|
||||
eventsToStore = append(eventsToStore, event)
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, resource.AccountID); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
snap, err := affectedpeers.Load(ctx, transaction, resource.AccountID, change)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return eventsToStore, snap, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) {
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
|
||||
if err != nil {
|
||||
@@ -207,6 +219,8 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
resource.Prefix = prefix
|
||||
|
||||
var eventsToStore []func()
|
||||
var snap *affectedpeers.Snapshot
|
||||
var change affectedpeers.Change
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID)
|
||||
if err != nil {
|
||||
@@ -232,6 +246,14 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
return fmt.Errorf("failed to get network resource: %w", err)
|
||||
}
|
||||
|
||||
oldGroups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthNone, resource.AccountID, resource.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get old resource groups: %w", err)
|
||||
}
|
||||
for _, g := range oldGroups {
|
||||
oldResource.GroupIDs = append(oldResource.GroupIDs, g.ID)
|
||||
}
|
||||
|
||||
err = transaction.SaveNetworkResource(ctx, resource)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save network resource: %w", err)
|
||||
@@ -247,6 +269,11 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceUpdated, resource.EventMeta(network))
|
||||
})
|
||||
|
||||
change = affectedpeers.Change{Resources: []*types.NetworkResource{oldResource, resource}}
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, resource.AccountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, resource.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
@@ -270,7 +297,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
}
|
||||
}()
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationUpdate})
|
||||
m.accountManager.ExpandAndUpdateAffected(ctx, resource.AccountID, snap, change)
|
||||
|
||||
return resource, nil
|
||||
}
|
||||
@@ -331,8 +358,26 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
|
||||
}
|
||||
|
||||
var events []func()
|
||||
var snap *affectedpeers.Snapshot
|
||||
var change affectedpeers.Change
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resourceID)
|
||||
existing, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthUpdate, accountID, resourceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network resource: %w", err)
|
||||
}
|
||||
oldGroups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthNone, accountID, resourceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get resource groups: %w", err)
|
||||
}
|
||||
for _, g := range oldGroups {
|
||||
existing.GroupIDs = append(existing.GroupIDs, g.ID)
|
||||
}
|
||||
change = affectedpeers.Change{Resources: []*types.NetworkResource{existing}}
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resourceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete resource: %w", err)
|
||||
}
|
||||
@@ -352,51 +397,53 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
|
||||
event()
|
||||
}
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationDelete})
|
||||
m.accountManager.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) ([]func(), error) {
|
||||
func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) (*types.NetworkResource, []func(), error) {
|
||||
resource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthUpdate, accountID, resourceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get network resource: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to get network resource: %w", err)
|
||||
}
|
||||
|
||||
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, accountID, networkID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get network: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
|
||||
if resource.NetworkID != networkID {
|
||||
return nil, errors.New("resource not part of network")
|
||||
return nil, nil, errors.New("resource not part of network")
|
||||
}
|
||||
|
||||
groups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthUpdate, accountID, resourceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get resource groups: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to get resource groups: %w", err)
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
|
||||
for _, group := range groups {
|
||||
resource.GroupIDs = append(resource.GroupIDs, group.ID)
|
||||
|
||||
event, err := m.groupsManager.RemoveResourceFromGroupInTransaction(ctx, transaction, accountID, userID, group.ID, resourceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to remove resource from group: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to remove resource from group: %w", err)
|
||||
}
|
||||
eventsToStore = append(eventsToStore, event)
|
||||
}
|
||||
|
||||
err = transaction.DeleteNetworkResource(ctx, accountID, resourceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to delete network resource: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to delete network resource: %w", err)
|
||||
}
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
m.accountManager.StoreEvent(ctx, userID, resourceID, accountID, activity.NetworkResourceDeleted, resource.EventMeta(network))
|
||||
})
|
||||
|
||||
return eventsToStore, nil
|
||||
return resource, eventsToStore, nil
|
||||
}
|
||||
|
||||
func NewManagerMock() Manager {
|
||||
@@ -431,6 +478,6 @@ func (m *mockManager) DeleteResource(ctx context.Context, accountID, userID, net
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockManager) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) ([]func(), error) {
|
||||
return []func(){}, nil
|
||||
func (m *mockManager) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) (*types.NetworkResource, []func(), error) {
|
||||
return nil, []func(){}, nil
|
||||
}
|
||||
|
||||
@@ -9,13 +9,13 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
serverTypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -26,7 +26,7 @@ type Manager interface {
|
||||
GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error)
|
||||
UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error)
|
||||
DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error
|
||||
DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error)
|
||||
DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (*types.NetworkRouter, func(), error)
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
@@ -90,6 +90,8 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
}
|
||||
|
||||
var network *networkTypes.Network
|
||||
var snap *affectedpeers.Snapshot
|
||||
change := affectedpeers.Change{Routers: []*types.NetworkRouter{router}}
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
|
||||
if err != nil {
|
||||
@@ -112,6 +114,10 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, router.AccountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
@@ -120,7 +126,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network))
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationCreate})
|
||||
m.accountManager.ExpandAndUpdateAffected(ctx, router.AccountID, snap, change)
|
||||
|
||||
return router, nil
|
||||
}
|
||||
@@ -156,36 +162,12 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
|
||||
}
|
||||
|
||||
var network *networkTypes.Network
|
||||
var snap *affectedpeers.Snapshot
|
||||
var change affectedpeers.Change
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
|
||||
existing, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, router.AccountID, router.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network router: %w", err)
|
||||
}
|
||||
|
||||
if existing.AccountID != router.AccountID {
|
||||
return status.NewNetworkRouterNotFoundError(router.ID)
|
||||
}
|
||||
|
||||
if existing.NetworkID != router.NetworkID {
|
||||
return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
|
||||
}
|
||||
|
||||
err = transaction.UpdateNetworkRouter(ctx, router)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update network router: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, router.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
var txErr error
|
||||
network, snap, change, txErr = m.updateRouterInTransaction(ctx, transaction, router)
|
||||
return txErr
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -193,11 +175,47 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network))
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationUpdate})
|
||||
m.accountManager.ExpandAndUpdateAffected(ctx, router.AccountID, snap, change)
|
||||
|
||||
return router, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) updateRouterInTransaction(ctx context.Context, transaction store.Store, router *types.NetworkRouter) (*networkTypes.Network, *affectedpeers.Snapshot, affectedpeers.Change, error) {
|
||||
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
|
||||
if err != nil {
|
||||
return nil, nil, affectedpeers.Change{}, fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
|
||||
existing, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, router.AccountID, router.ID)
|
||||
if err != nil {
|
||||
return nil, nil, affectedpeers.Change{}, fmt.Errorf("failed to get network router: %w", err)
|
||||
}
|
||||
|
||||
if existing.AccountID != router.AccountID {
|
||||
return nil, nil, affectedpeers.Change{}, status.NewNetworkRouterNotFoundError(router.ID)
|
||||
}
|
||||
|
||||
if existing.NetworkID != router.NetworkID {
|
||||
return nil, nil, affectedpeers.Change{}, status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
|
||||
}
|
||||
|
||||
if err = transaction.UpdateNetworkRouter(ctx, router); err != nil {
|
||||
return nil, nil, affectedpeers.Change{}, fmt.Errorf("failed to update network router: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, router.AccountID); err != nil {
|
||||
return nil, nil, affectedpeers.Change{}, fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
change := affectedpeers.Change{Routers: []*types.NetworkRouter{existing, router}}
|
||||
snap, err := affectedpeers.Load(ctx, transaction, router.AccountID, change)
|
||||
if err != nil {
|
||||
return nil, nil, affectedpeers.Change{}, err
|
||||
}
|
||||
|
||||
return network, snap, change, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error {
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete)
|
||||
if err != nil {
|
||||
@@ -208,8 +226,19 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
|
||||
}
|
||||
|
||||
var event func()
|
||||
var snap *affectedpeers.Snapshot
|
||||
var change affectedpeers.Change
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
event, err = m.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, routerID)
|
||||
existing, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, accountID, routerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network router: %w", err)
|
||||
}
|
||||
change = affectedpeers.Change{Routers: []*types.NetworkRouter{existing}}
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, event, err = m.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, routerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete network router: %w", err)
|
||||
}
|
||||
@@ -227,36 +256,36 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
|
||||
|
||||
event()
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationDelete})
|
||||
m.accountManager.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) {
|
||||
func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (*types.NetworkRouter, func(), error) {
|
||||
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthNone, accountID, networkID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get network: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
|
||||
router, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, accountID, routerID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get network router: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to get network router: %w", err)
|
||||
}
|
||||
|
||||
if router.NetworkID != networkID {
|
||||
return nil, status.NewRouterNotPartOfNetworkError(routerID, networkID)
|
||||
return nil, nil, status.NewRouterNotPartOfNetworkError(routerID, networkID)
|
||||
}
|
||||
|
||||
err = transaction.DeleteNetworkRouter(ctx, accountID, routerID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to delete network router: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to delete network router: %w", err)
|
||||
}
|
||||
|
||||
event := func() {
|
||||
m.accountManager.StoreEvent(ctx, userID, routerID, accountID, activity.NetworkRouterDeleted, router.EventMeta(network))
|
||||
}
|
||||
|
||||
return event, nil
|
||||
return router, event, nil
|
||||
}
|
||||
|
||||
func NewManagerMock() Manager {
|
||||
@@ -287,6 +316,9 @@ func (m *mockManager) DeleteRouter(ctx context.Context, accountID, userID, netwo
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockManager) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) {
|
||||
return func() {}, nil
|
||||
func (m *mockManager) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (*types.NetworkRouter, func(), error) {
|
||||
return nil, func() {
|
||||
// no-op mock: returns zero values so tests that don't exercise router deletion
|
||||
// can satisfy the Manager interface without a real store.
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -73,7 +74,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
//
|
||||
// Disconnects use MarkPeerDisconnected and require the session to match
|
||||
// exactly; see PeerStatus.SessionStartedAt for the protocol.
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusConnect, time.Since(start))
|
||||
@@ -105,35 +106,22 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
am.updatePeerLocationIfChanged(ctx, accountID, peer, realIP)
|
||||
}
|
||||
|
||||
expired := peer.Status != nil && peer.Status.LoginExpired
|
||||
|
||||
if peer.AddedWithSSOLogin() {
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
|
||||
am.schedulePeerLoginExpiration(ctx, accountID)
|
||||
}
|
||||
if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
|
||||
}
|
||||
if err = am.schedulePeerExpirations(ctx, accountID, peer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if expired {
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
|
||||
// A login-expired peer reconnecting, or an embedded proxy peer flipping to
|
||||
// connected (which triggers SynthesizePrivateServiceZones), must refresh the
|
||||
// peers reachable from it. The embedded-proxy fan-out tolerates a dispatch error.
|
||||
if peer.Status != nil && peer.Status.LoginExpired {
|
||||
affectedPeerIDs := am.markConnectedAffectedPeers(ctx, accountID, peer.ID, nmap)
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}, affectedPeerIDs); err != nil {
|
||||
return fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// An embedded proxy peer flipping to connected is the trigger for
|
||||
// SynthesizePrivateServiceZones to emit DNS A records pointing at its
|
||||
// tunnel IP. Without an account-wide netmap recompute, user peers keep
|
||||
// the stale synth (or no synth at all on first connect) until some
|
||||
// other change pokes the controller. Fire OnPeersUpdated so the
|
||||
// buffered recompute fans the new state out to every peer.
|
||||
if peer.ProxyMeta.Embedded {
|
||||
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
|
||||
affectedPeerIDs := am.markConnectedAffectedPeers(ctx, accountID, peer.ID, nmap)
|
||||
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}, affectedPeerIDs); err != nil {
|
||||
log.WithContext(ctx).Warnf("notify network map controller of embedded proxy %s connect: %v", peer.ID, err)
|
||||
}
|
||||
}
|
||||
@@ -141,6 +129,25 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
return nil
|
||||
}
|
||||
|
||||
// schedulePeerExpirations reschedules the account's login/inactivity expiration
|
||||
// timers for an SSO peer that just connected.
|
||||
func (am *DefaultAccountManager) schedulePeerExpirations(ctx context.Context, accountID string, peer *nbpeer.Peer) error {
|
||||
if !peer.AddedWithSSOLogin() {
|
||||
return nil
|
||||
}
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
|
||||
am.schedulePeerLoginExpiration(ctx, accountID)
|
||||
}
|
||||
if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkPeerDisconnected marks a peer as disconnected, but only when the
|
||||
// stored session token matches the one passed in. A mismatch means a
|
||||
// newer stream has already taken ownership of the peer — disconnects from
|
||||
@@ -175,11 +182,12 @@ func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerP
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusApplied)
|
||||
|
||||
// Symmetric with MarkPeerConnected: when an embedded proxy peer goes
|
||||
// offline, drive an account-wide netmap recompute so the synthesized
|
||||
// DNS records that pointed at it are pulled. Without this the records
|
||||
// linger client-side at TTL until something else triggers a refresh.
|
||||
// offline, refresh the peers that had synthesized records pointing at
|
||||
// it so they pull the stale entries instead of waiting out TTL.
|
||||
if peer.ProxyMeta.Embedded {
|
||||
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
log.WithContext(ctx).Warnf("notify network map controller of embedded proxy %s disconnect: %v", peer.ID, err)
|
||||
}
|
||||
}
|
||||
@@ -346,7 +354,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
}
|
||||
}
|
||||
|
||||
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
affectedPeerIDs = append(affectedPeerIDs, peer.ID)
|
||||
err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
}
|
||||
@@ -501,10 +512,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
return status.NewPeerNotPartOfAccountError()
|
||||
}
|
||||
|
||||
var peer *nbpeer.Peer
|
||||
var settings *types.Settings
|
||||
var eventsToStore []func()
|
||||
|
||||
serviceID, err := am.serviceManager.GetServiceIDByTargetID(ctx, accountID, peerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if resource is used by service: %w", err)
|
||||
@@ -513,8 +520,38 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
return status.NewPeerInUseError(peerID, serviceID)
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
change := affectedpeers.Change{ChangedPeerIDs: []string{peerID}}
|
||||
settings, eventsToStore, snap, err := am.deletePeerInTransaction(ctx, accountID, userID, peerID, change)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if err = am.integratedPeerValidator.PeerDeleted(ctx, accountID, peerID, settings.Extra); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete peer %s from integrated validator: %v", peerID, err)
|
||||
}
|
||||
|
||||
affectedPeerIDs := snap.Expand(ctx, accountID, change)
|
||||
if err = am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}, affectedPeerIDs); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peerID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// deletePeerInTransaction loads the peer + settings, captures the affected-peers
|
||||
// snapshot (before the delete, while the peer's group memberships still exist),
|
||||
// then deletes the peer and bumps the network serial — all in one transaction.
|
||||
func (am *DefaultAccountManager) deletePeerInTransaction(ctx context.Context, accountID, userID, peerID string, change affectedpeers.Change) (*types.Settings, []func(), *affectedpeers.Snapshot, error) {
|
||||
var settings *types.Settings
|
||||
var eventsToStore []func()
|
||||
var snap *affectedpeers.Snapshot
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
peer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -528,8 +565,11 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
return err
|
||||
}
|
||||
|
||||
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}, settings)
|
||||
if err != nil {
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}, settings); err != nil {
|
||||
return fmt.Errorf("failed to delete peer: %w", err)
|
||||
}
|
||||
|
||||
@@ -539,23 +579,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if err = am.integratedPeerValidator.PeerDeleted(ctx, accountID, peerID, settings.Extra); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete peer %s from integrated validator: %v", peerID, err)
|
||||
}
|
||||
|
||||
if err = am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peerID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return settings, eventsToStore, snap, err
|
||||
}
|
||||
|
||||
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
|
||||
@@ -694,10 +718,10 @@ func (am *DefaultAccountManager) handleSetupKeyAddedPeer(ctx context.Context, en
|
||||
// to it. We also add the User ID to the peer metadata to identify registrant. If no userID provided, then fail with status.PermissionDenied
|
||||
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
|
||||
// The peer property is just a placeholder for the Peer properties to pass further
|
||||
func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
|
||||
if setupKey == "" && userID == "" && !peer.ProxyMeta.Embedded {
|
||||
// no auth method provided => reject access
|
||||
return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
|
||||
return nil, nil, nil, false, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
|
||||
}
|
||||
|
||||
upperKey := strings.ToUpper(setupKey)
|
||||
@@ -713,7 +737,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
// The connecting peer should be able to recover with a retry.
|
||||
_, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peer.Key)
|
||||
if err == nil {
|
||||
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
|
||||
return nil, nil, nil, false, status.Errorf(status.PreconditionFailed, "peer has been already registered")
|
||||
}
|
||||
|
||||
opEvent := &activity.Event{
|
||||
@@ -724,7 +748,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
|
||||
peerAddConfig, err := am.processPeerAddAuth(ctx, accountID, userID, encodedHashedKey, peer, temporary, addedByUser, addedBySetupKey, opEvent)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
accountID = peerAddConfig.AccountID
|
||||
ephemeral := peerAddConfig.Ephemeral
|
||||
@@ -739,7 +763,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
}
|
||||
|
||||
if err := domain.ValidateDomainsList(peer.ExtraDNSLabels); err != nil {
|
||||
return nil, nil, nil, status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err)
|
||||
return nil, nil, nil, false, status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err)
|
||||
}
|
||||
|
||||
registrationTime := time.Now().UTC()
|
||||
@@ -765,7 +789,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
}
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed to get account settings: %w", err)
|
||||
return nil, nil, nil, false, fmt.Errorf("failed to get account settings: %w", err)
|
||||
}
|
||||
|
||||
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
|
||||
@@ -783,30 +807,30 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
|
||||
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed getting network: %w", err)
|
||||
return nil, nil, nil, false, fmt.Errorf("failed getting network: %w", err)
|
||||
}
|
||||
|
||||
maxAttempts := 10
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
netPrefix, err := netip.ParsePrefix(network.Net.String())
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("parse network prefix: %w", err)
|
||||
return nil, nil, nil, false, fmt.Errorf("parse network prefix: %w", err)
|
||||
}
|
||||
freeIP, err := types.AllocateRandomPeerIP(netPrefix)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed to get free IP: %w", err)
|
||||
return nil, nil, nil, false, fmt.Errorf("failed to get free IP: %w", err)
|
||||
}
|
||||
|
||||
var freeLabel string
|
||||
if ephemeral || attempt > 1 {
|
||||
freeLabel, err = getPeerIPDNSLabel(freeIP, peer.Meta.Hostname)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err)
|
||||
return nil, nil, nil, false, fmt.Errorf("failed to get free DNS label: %w", err)
|
||||
}
|
||||
} else {
|
||||
freeLabel, err = nbdns.GetParsedDomainLabel(peer.Meta.Hostname)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err)
|
||||
return nil, nil, nil, false, fmt.Errorf("failed to get free DNS label: %w", err)
|
||||
}
|
||||
}
|
||||
newPeer.DNSLabel = freeLabel
|
||||
@@ -828,11 +852,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
if allocate {
|
||||
v6Prefix, err := netip.ParsePrefix(network.NetV6.String())
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("parse IPv6 prefix: %w", err)
|
||||
return nil, nil, nil, false, fmt.Errorf("parse IPv6 prefix: %w", err)
|
||||
}
|
||||
freeIPv6, err := types.AllocateRandomPeerIPv6(v6Prefix)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("allocate peer IPv6: %w", err)
|
||||
return nil, nil, nil, false, fmt.Errorf("allocate peer IPv6: %w", err)
|
||||
}
|
||||
newPeer.IPv6 = freeIPv6
|
||||
}
|
||||
@@ -905,10 +929,10 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err)
|
||||
return nil, nil, nil, false, fmt.Errorf("failed to add peer to database: %w", err)
|
||||
}
|
||||
if newPeer == nil {
|
||||
return nil, nil, nil, fmt.Errorf("new peer is nil")
|
||||
return nil, nil, nil, false, fmt.Errorf("new peer is nil")
|
||||
}
|
||||
|
||||
opEvent.TargetID = newPeer.ID
|
||||
@@ -916,7 +940,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
if !addedByUser {
|
||||
opEvent.Meta["setup_key_name"] = peerAddConfig.SetupKeyName
|
||||
}
|
||||
if newPeer.Status != nil && newPeer.Status.RequiresApproval {
|
||||
requiresApproval := newPeer.Status != nil && newPeer.Status.RequiresApproval
|
||||
if requiresApproval {
|
||||
opEvent.Meta["pending_approval"] = true
|
||||
}
|
||||
|
||||
@@ -924,12 +949,18 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||
}
|
||||
|
||||
if err := am.networkMapController.OnPeersAdded(ctx, accountID, []string{newPeer.ID}); err != nil {
|
||||
network, postureChecks, enableSSH, err := getPeerLoginInfo(ctx, am.Store, accountID, newPeer, !requiresApproval)
|
||||
if err != nil {
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
changedPeerIDs := []string{newPeer.ID}
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
if err := am.networkMapController.OnPeersAdded(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
|
||||
}
|
||||
|
||||
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
||||
return p, nmap, pc, err
|
||||
return newPeer, network, postureChecks, enableSSH, nil
|
||||
}
|
||||
|
||||
func getPeerIPDNSLabel(ip netip.Addr, peerHostName string) (string, error) {
|
||||
@@ -1011,17 +1042,51 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(postureChecks) > 0 || versionChanged)) {
|
||||
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
|
||||
if err != nil {
|
||||
nmap, resPostureChecks, dnsFwdPort, err := am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || updated || (len(postureChecks) > 0 || versionChanged) {
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(postureChecks) > 0)
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer)
|
||||
return peer, nmap, resPostureChecks, dnsFwdPort, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
// syncPeerAffectedPeers resolves the peers affected by a SyncPeer change. The
|
||||
// peer's own validated network map is bidirectional for policy and routing
|
||||
// reachability, so when the peer stays valid and no source-posture gate is in
|
||||
// play it already lists every affected peer — reuse it and skip the full
|
||||
// dependency walk. Posture checks gate the source side of a policy only, so a
|
||||
// metadata change that flips a posture result removes this peer from others'
|
||||
// maps asymmetrically; that case (and an invalid peer, whose map is empty) falls
|
||||
// back to the resolver.
|
||||
func (am *DefaultAccountManager) syncPeerAffectedPeers(ctx context.Context, accountID, peerID string, nmap *types.NetworkMap, peerNotValid, metaUpdated, hasPostureChecks bool) []string {
|
||||
if peerNotValid || (metaUpdated && hasPostureChecks) {
|
||||
return am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, []string{peerID})
|
||||
}
|
||||
return affectedPeerIDsFromNetworkMap(nmap, peerID)
|
||||
}
|
||||
|
||||
// markConnectedAffectedPeers resolves the peers affected when a peer connects
|
||||
// (login-expiry reconnect or embedded-proxy connect). The connecting peer's
|
||||
// network map already lists them bidirectionally — the synthesized
|
||||
// private-service policy puts proxy access-group members in the proxy peer's own
|
||||
// map, and these edges carry no source-posture gate. An invalid peer has an
|
||||
// empty map, so fall back to the resolver in that case.
|
||||
func (am *DefaultAccountManager) markConnectedAffectedPeers(ctx context.Context, accountID, peerID string, nmap *types.NetworkMap) []string {
|
||||
if nmap == nil || len(nmap.Peers)+len(nmap.OfflinePeers) == 0 {
|
||||
return am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, []string{peerID})
|
||||
}
|
||||
return affectedPeerIDsFromNetworkMap(nmap, peerID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
|
||||
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
|
||||
// we couldn't find this peer by its public key which can mean that peer hasn't been registered yet.
|
||||
// Try registering it.
|
||||
@@ -1037,12 +1102,12 @@ func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, lo
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err)
|
||||
return nil, nil, nil, status.Errorf(status.Internal, "failed while logging in peer")
|
||||
return nil, nil, nil, false, status.Errorf(status.Internal, "failed while logging in peer")
|
||||
}
|
||||
|
||||
// LoginPeer logs in or registers a peer.
|
||||
// If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so.
|
||||
func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return am.handlePeerLoginNotFound(ctx, login, err)
|
||||
@@ -1054,20 +1119,17 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
if login.UserID == "" {
|
||||
err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
}
|
||||
|
||||
var peer *nbpeer.Peer
|
||||
var updateRemotePeers bool
|
||||
var isPeerUpdated bool
|
||||
var ipv6CapabilityChanged bool
|
||||
var postureChecks []*posture.Checks
|
||||
var shouldStorePeer bool
|
||||
var peerGroupIDs []string
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
@@ -1076,9 +1138,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
return err
|
||||
}
|
||||
|
||||
// this flag prevents unnecessary calls to the persistent store.
|
||||
shouldStorePeer := false
|
||||
|
||||
if login.UserID != "" {
|
||||
if peer.UserID != login.UserID {
|
||||
log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID)
|
||||
@@ -1092,7 +1151,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
|
||||
if changed {
|
||||
shouldStorePeer = true
|
||||
updateRemotePeers = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1101,23 +1159,9 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
return err
|
||||
}
|
||||
|
||||
oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
|
||||
isPeerUpdated, _ = peer.UpdateMetaIfNew(login.Meta)
|
||||
ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
|
||||
if isPeerUpdated {
|
||||
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
|
||||
shouldStorePeer = true
|
||||
|
||||
postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if peer.SSHKey != login.SSHKey {
|
||||
peer.SSHKey = login.SSHKey
|
||||
shouldStorePeer = true
|
||||
updateRemotePeers = true
|
||||
}
|
||||
|
||||
if !peer.AllowExtraDNSLabels && len(login.ExtraDNSLabels) > 0 {
|
||||
@@ -1133,23 +1177,28 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
if updateRemotePeers || isStatusChanged || ipv6CapabilityChanged || (isPeerUpdated && len(postureChecks) > 0) {
|
||||
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
network, postureChecks, enableSSH, err := getPeerLoginInfo(ctx, am.Store, accountID, peer, !isRequiresApproval)
|
||||
if err != nil {
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
if isStatusChanged || shouldStorePeer {
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
return nil, nil, nil, false, fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
|
||||
return p, nmap, pc, err
|
||||
return peer, network, postureChecks, enableSSH, nil
|
||||
}
|
||||
|
||||
// ExtendPeerSession refreshes the peer's SSO session deadline by updating
|
||||
@@ -1225,6 +1274,50 @@ func (am *DefaultAccountManager) ExtendPeerSession(ctx context.Context, peerPubK
|
||||
return refreshed.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration), nil
|
||||
}
|
||||
|
||||
// getPeerLoginInfo computes the login/register response data (network, posture
|
||||
// checks, SSH) from the store without building the peer's full network map.
|
||||
func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID string, peer *nbpeer.Peer, isValid bool) (*types.Network, []*posture.Checks, bool, error) {
|
||||
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, false, fmt.Errorf("get account network: %w", err)
|
||||
}
|
||||
|
||||
if !isValid {
|
||||
return network, nil, false, nil
|
||||
}
|
||||
|
||||
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
enableSSH, err := isPeerSSHEnabled(ctx, transaction, accountID, peer)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
return network, postureChecks, enableSSH, nil
|
||||
}
|
||||
|
||||
func isPeerSSHEnabled(ctx context.Context, transaction store.Store, accountID string, peer *nbpeer.Peer) (bool, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
peerGroups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
peerGroupIDs := make(map[string]struct{}, len(peerGroups))
|
||||
for _, g := range peerGroups {
|
||||
peerGroupIDs[g.ID] = struct{}{}
|
||||
}
|
||||
|
||||
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, peerGroupIDs, peer.SSHEnabled), nil
|
||||
}
|
||||
|
||||
// getPeerPostureChecks returns the posture checks for the peer.
|
||||
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
@@ -1407,6 +1500,100 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
_ = am.networkMapController.UpdateAccountPeers(ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// ExpandAndUpdateAffected expands a Snapshot (loaded INSIDE the now-committed
|
||||
// transaction) into the affected peers and dispatches the network-map refresh.
|
||||
// Pure in-memory work plus dispatch, so it runs AFTER commit — the fan-out walk
|
||||
// never holds the write lock, over the consistent in-tx snapshot. Exported so the
|
||||
// networks sub-package managers (which hold only account.Manager) share it.
|
||||
func (am *DefaultAccountManager) ExpandAndUpdateAffected(ctx context.Context, accountID string, snap *affectedpeers.Snapshot, change affectedpeers.Change) {
|
||||
go am.dispatchAffected(ctx, accountID, []*affectedpeers.Snapshot{snap}, []affectedpeers.Change{change})
|
||||
}
|
||||
|
||||
// dispatchAffected expands one or more (snapshot, change) pairs — collected across
|
||||
// one or several transactions — unions their affected peers, and dispatches a
|
||||
// single network-map refresh. Each snapshot must already be loaded inside its
|
||||
// transaction; this runs AFTER commit (pure in-memory + dispatch). It is spawned
|
||||
// in a goroutine that outlives the request, so it detaches from the request
|
||||
// context's cancellation up front.
|
||||
func (am *DefaultAccountManager) dispatchAffected(ctx context.Context, accountID string, snaps []*affectedpeers.Snapshot, changes []affectedpeers.Change) {
|
||||
ctx = context.WithoutCancel(ctx)
|
||||
|
||||
var lists [][]string
|
||||
for i, snap := range snaps {
|
||||
if snap == nil {
|
||||
continue
|
||||
}
|
||||
lists = append(lists, snap.Expand(ctx, accountID, changes[i]))
|
||||
}
|
||||
|
||||
affectedPeerIDs := unionStrings(lists...)
|
||||
if len(affectedPeerIDs) == 0 {
|
||||
log.WithContext(ctx).Tracef("no affected peers for account %s", accountID)
|
||||
return
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("updating %d affected peers for account %s: %v", len(affectedPeerIDs), accountID, affectedPeerIDs)
|
||||
_ = am.networkMapController.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
// unionStrings concatenates the given string lists into one deduplicated slice,
|
||||
// preserving first-occurrence order.
|
||||
func unionStrings(lists ...[]string) []string {
|
||||
seen := make(map[string]struct{})
|
||||
var out []string
|
||||
for _, list := range lists {
|
||||
for _, id := range list {
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// affectedPeerIDsFromNetworkMap returns the peer IDs referenced by a peer's
|
||||
// network map (its connected and offline peers, which include routing and proxy
|
||||
// peers), excluding the peer itself. For a freshly added peer these are, by ACL
|
||||
// symmetry, exactly the peers its addition affects.
|
||||
func affectedPeerIDsFromNetworkMap(nmap *types.NetworkMap, selfPeerID string) []string {
|
||||
if nmap == nil {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{}, len(nmap.Peers)+len(nmap.OfflinePeers))
|
||||
ids := make([]string, 0, len(nmap.Peers)+len(nmap.OfflinePeers))
|
||||
add := func(peers []*nbpeer.Peer) {
|
||||
for _, p := range peers {
|
||||
if p == nil || p.ID == "" || p.ID == selfPeerID {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[p.ID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[p.ID] = struct{}{}
|
||||
ids = append(ids, p.ID)
|
||||
}
|
||||
}
|
||||
add(nmap.Peers)
|
||||
add(nmap.OfflinePeers)
|
||||
return ids
|
||||
}
|
||||
|
||||
// resolveAffectedPeersForPeerChanges loads a snapshot and expands it for a peer
|
||||
// change. The graph is unchanged by these paths, so it runs out of the mutating
|
||||
// transaction (after commit); the resolver derives the peers' group memberships
|
||||
// during the walk, so the caller passes only the changed peer IDs.
|
||||
func (am *DefaultAccountManager) resolveAffectedPeersForPeerChanges(ctx context.Context, s store.Store, accountID string, changedPeerIDs []string) []string {
|
||||
change := affectedpeers.Change{ChangedPeerIDs: changedPeerIDs}
|
||||
snap, err := affectedpeers.Load(ctx, s, accountID, change)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to load snapshot for affected peers: %v", err)
|
||||
return nil
|
||||
}
|
||||
return snap.Expand(ctx, accountID, change)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
|
||||
_ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID, reason)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user