Compare commits

..

5 Commits

Author SHA1 Message Date
riccardom
b2c5732847 You now need to explicitly call these around 2026-06-16 10:42:14 +02:00
riccardom
0340893854 Now we need to apply MDM in the GetConfig 2026-06-15 18:36:38 +02:00
riccardom
874195440c Removes static vars 2026-06-15 17:32:46 +02:00
riccardom
bec26d5a14 Removes dead code 2026-06-15 13:01:04 +02:00
riccardom
db2c9b6f49 MDM Android mobile wiring 2026-06-15 12:04:26 +02:00
173 changed files with 2990 additions and 10456 deletions

View File

@@ -9,13 +9,10 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.6"
GORELEASER_VER: "v2.16.0"
SIGN_PIPE_VER: "v0.1.5"
GORELEASER_VER: "v2.14.3"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
flags: ""
SKIP_PUBLISH: "true"
SKIP_DOCKER_PUSH: "false"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
@@ -133,6 +130,8 @@ jobs:
windows_packages_artifact_url: ${{ steps.upload_windows_packages.outputs.artifact-url }}
macos_packages_artifact_url: ${{ steps.upload_macos_packages.outputs.artifact-url }}
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
env:
flags: ""
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -144,27 +143,8 @@ jobs:
id: semver_parser
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
- name: Set snapshot flag
if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: |
echo "flags=--snapshot" >> $GITHUB_ENV
- name: Set build vars
if: ${{ startsWith(github.ref, 'refs/tags/v') }}
run: |
if [[ "x-${{ steps.semver_parser.outputs.prerelease }}" == "x-" && "x-${{ github.repository }}" == "x-netbirdio/netbird" ]]; then
echo "x-${{ github.repository }}"
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
echo "SKIP_PUBLISH=false" >> $GITHUB_ENV
else
echo "x-${{ github.repository }}"
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
fi
if [[ "x-${{ github.repository }}" != "x-netbirdio/netbird" ]]; then
echo "SKIP_DOCKER_PUSH=true" >> $GITHUB_ENV
fi
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Set up Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
@@ -181,8 +161,6 @@ jobs:
${{ runner.os }}-go-releaser-
- name: Install modules
run: go mod tidy
- name: run openapi generator
run: bash shared/management/http/api/generate.sh
- name: check git status
run: git --no-pager diff --exit-code
- name: Set up QEMU
@@ -232,8 +210,6 @@ jobs:
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
NFPM_NETBIRD_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
SKIP_PUBLISH: ${{ env.SKIP_PUBLISH }}
SKIP_DOCKER_PUSH: ${{ env.SKIP_DOCKER_PUSH }}
- name: Verify RPM signatures
run: |
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
@@ -356,22 +332,8 @@ jobs:
id: semver_parser
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
- name: Set snapshot flag
if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: |
echo "flags=--snapshot" >> $GITHUB_ENV
- name: Set build vars
if: ${{ startsWith(github.ref, 'refs/tags/v') }}
run: |
if [[ "x-${{ steps.semver_parser.outputs.prerelease }}" == "x-" && "x-${{ github.repository }}" == "x-netbirdio/netbird" ]]; then
echo "x-${{ github.repository }}"
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
echo "SKIP_PUBLISH=false" >> $GITHUB_ENV
else
echo "x-${{ github.repository }}"
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
fi
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Set up Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
@@ -431,7 +393,6 @@ jobs:
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
NFPM_NETBIRD_UI_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
SKIP_PUBLISH: ${{ env.SKIP_PUBLISH }}
- name: Verify RPM signatures
run: |
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '

View File

@@ -1,7 +1,5 @@
version: 2
env:
- SKIP_PUBLISH={{ if index .Env "SKIP_PUBLISH" }}{{ .Env.SKIP_PUBLISH }}{{ else }}true{{ end }}
- SKIP_DOCKER_PUSH={{ if index .Env "SKIP_DOCKER_PUSH" }}{{ .Env.SKIP_DOCKER_PUSH }}{{ else }}false{{ end }}
project_name: netbird
builds:
- id: netbird-wasm
@@ -76,8 +74,6 @@ builds:
- amd64
- arm64
- arm
goarm:
- 7
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
@@ -92,8 +88,6 @@ builds:
- amd64
- arm64
- arm
goarm:
- 7
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
@@ -108,8 +102,6 @@ builds:
- amd64
- arm64
- arm
goarm:
- 7
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
@@ -130,8 +122,6 @@ builds:
- amd64
- arm64
- arm
goarm:
- 7
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
@@ -146,8 +136,6 @@ builds:
- amd64
- arm64
- arm
goarm:
- 7
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
@@ -162,8 +150,6 @@ builds:
- amd64
- arm64
- arm
goarm:
- 7
ldflags:
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
mod_timestamp: "{{ .CommitTimestamp }}"
@@ -184,8 +170,6 @@ builds:
- amd64
- arm64
- arm
goarm:
- 7
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
@@ -238,192 +222,670 @@ nfpms:
rpm:
signature:
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
dockers_v2:
- id: netbird
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird
images:
- netbirdio/netbird
- ghcr.io/netbirdio/netbird
tags:
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: client/Dockerfile
extra_files:
- client/netbird-entrypoint.sh
platforms:
- linux/amd64
- linux/arm64
- linux/arm/6
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
- id: netbird-rootless
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird
images:
- netbirdio/netbird
- ghcr.io/netbirdio/netbird
tags:
- "v{{ .Version }}-rootless"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: client/Dockerfile-rootless
extra_files:
- client/netbird-entrypoint.sh
platforms:
- linux/amd64
- linux/arm64
- linux/arm/6
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
- id: relay
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird-relay
images:
- netbirdio/relay
- ghcr.io/netbirdio/relay
tags:
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: relay/Dockerfile
platforms:
- linux/amd64
- linux/arm64
- linux/arm
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
- id: signal
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird-signal
images:
- netbirdio/signal
- ghcr.io/netbirdio/signal
tags:
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: signal/Dockerfile
platforms:
- linux/amd64
- linux/arm64
- linux/arm
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
- id: management
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird-mgmt
images:
- netbirdio/management
- ghcr.io/netbirdio/management
tags:
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: management/Dockerfile
platforms:
- linux/amd64
- linux/arm64
- linux/arm
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
- id: upload
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird-upload
images:
- netbirdio/upload
- ghcr.io/netbirdio/upload
tags:
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: upload-server/Dockerfile
platforms:
- linux/amd64
- linux/arm64
- linux/arm
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
- id: netbird-server
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird-server
images:
- netbirdio/netbird-server
- ghcr.io/netbirdio/netbird-server
tags:
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: combined/Dockerfile
platforms:
- linux/amd64
- linux/arm64
- linux/arm
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
- id: netbird-proxy
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird-proxy
images:
- netbirdio/reverse-proxy
- ghcr.io/netbirdio/reverse-proxy
tags:
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: proxy/Dockerfile
platforms:
- linux/amd64
- linux/arm64
- linux/arm
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
dockers:
- image_templates:
- netbirdio/netbird:{{ .Version }}-amd64
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
ids:
- netbird
goarch: amd64
use: buildx
dockerfile: client/Dockerfile
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
ids:
- netbird
goarch: arm64
use: buildx
dockerfile: client/Dockerfile
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
ids:
- netbird
goarch: arm
goarm: 6
use: buildx
dockerfile: client/Dockerfile
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-amd64
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
ids:
- netbird
goarch: amd64
use: buildx
dockerfile: client/Dockerfile-rootless
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
ids:
- netbird
goarch: arm64
use: buildx
dockerfile: client/Dockerfile-rootless
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
ids:
- netbird
goarch: arm
goarm: 6
use: buildx
dockerfile: client/Dockerfile-rootless
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-amd64
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
ids:
- netbird-relay
goarch: amd64
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
ids:
- netbird-relay
goarch: arm64
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-arm
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
ids:
- netbird-relay
goarch: arm
goarm: 6
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/signal:{{ .Version }}-amd64
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
ids:
- netbird-signal
goarch: amd64
use: buildx
dockerfile: signal/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
ids:
- netbird-signal
goarch: arm64
use: buildx
dockerfile: signal/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/signal:{{ .Version }}-arm
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
ids:
- netbird-signal
goarch: arm
goarm: 6
use: buildx
dockerfile: signal/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-amd64
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
ids:
- netbird-mgmt
goarch: amd64
use: buildx
dockerfile: management/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
ids:
- netbird-mgmt
goarch: arm64
use: buildx
dockerfile: management/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-arm
- ghcr.io/netbirdio/management:{{ .Version }}-arm
ids:
- netbird-mgmt
goarch: arm
goarm: 6
use: buildx
dockerfile: management/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-debug-amd64
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
ids:
- netbird-mgmt
goarch: amd64
use: buildx
dockerfile: management/Dockerfile.debug
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-debug-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
ids:
- netbird-mgmt
goarch: arm64
use: buildx
dockerfile: management/Dockerfile.debug
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-debug-arm
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
ids:
- netbird-mgmt
goarch: arm
goarm: 6
use: buildx
dockerfile: management/Dockerfile.debug
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-amd64
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
ids:
- netbird-upload
goarch: amd64
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
ids:
- netbird-upload
goarch: arm64
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
ids:
- netbird-upload
goarch: arm
goarm: 6
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-amd64
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
ids:
- netbird-server
goarch: amd64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
ids:
- netbird-server
goarch: arm64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
ids:
- netbird-server
goarch: arm
goarm: 6
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
ids:
- netbird-proxy
goarch: amd64
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
ids:
- netbird-proxy
goarch: arm64
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
ids:
- netbird-proxy
goarch: arm
goarm: 6
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
docker_manifests:
- name_template: netbirdio/netbird:{{ .Version }}
image_templates:
- netbirdio/netbird:{{ .Version }}-arm64v8
- netbirdio/netbird:{{ .Version }}-arm
- netbirdio/netbird:{{ .Version }}-amd64
- name_template: netbirdio/netbird:latest
image_templates:
- netbirdio/netbird:{{ .Version }}-arm64v8
- netbirdio/netbird:{{ .Version }}-arm
- netbirdio/netbird:{{ .Version }}-amd64
- name_template: netbirdio/netbird:{{ .Version }}-rootless
image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- netbirdio/netbird:{{ .Version }}-rootless-arm
- netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: netbirdio/netbird:rootless-latest
image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- netbirdio/netbird:{{ .Version }}-rootless-arm
- netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: netbirdio/relay:{{ .Version }}
image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- netbirdio/relay:{{ .Version }}-arm
- netbirdio/relay:{{ .Version }}-amd64
- name_template: netbirdio/relay:latest
image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- netbirdio/relay:{{ .Version }}-arm
- netbirdio/relay:{{ .Version }}-amd64
- name_template: netbirdio/signal:{{ .Version }}
image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8
- netbirdio/signal:{{ .Version }}-arm
- netbirdio/signal:{{ .Version }}-amd64
- name_template: netbirdio/signal:latest
image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8
- netbirdio/signal:{{ .Version }}-arm
- netbirdio/signal:{{ .Version }}-amd64
- name_template: netbirdio/management:{{ .Version }}
image_templates:
- netbirdio/management:{{ .Version }}-arm64v8
- netbirdio/management:{{ .Version }}-arm
- netbirdio/management:{{ .Version }}-amd64
- name_template: netbirdio/management:latest
image_templates:
- netbirdio/management:{{ .Version }}-arm64v8
- netbirdio/management:{{ .Version }}-arm
- netbirdio/management:{{ .Version }}-amd64
- name_template: netbirdio/management:debug-latest
image_templates:
- netbirdio/management:{{ .Version }}-debug-arm64v8
- netbirdio/management:{{ .Version }}-debug-arm
- netbirdio/management:{{ .Version }}-debug-amd64
- name_template: netbirdio/upload:{{ .Version }}
image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
- name_template: netbirdio/upload:latest
image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:{{ .Version }}
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:latest
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:latest
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}-rootless
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: ghcr.io/netbirdio/netbird:rootless-latest
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: ghcr.io/netbirdio/relay:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/relay:latest
image_templates:
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/signal:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/signal:latest
image_templates:
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/management:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-arm
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/management:latest
image_templates:
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-arm
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/management:debug-latest
image_templates:
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
- name_template: ghcr.io/netbirdio/upload:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/upload:latest
image_templates:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:latest
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: netbirdio/reverse-proxy:{{ .Version }}
image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- netbirdio/reverse-proxy:{{ .Version }}-arm
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: netbirdio/reverse-proxy:latest
image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- netbirdio/reverse-proxy:{{ .Version }}-arm
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/reverse-proxy:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/reverse-proxy:latest
image_templates:
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
brews:
- ids:
- default
skip_upload: "{{ .Env.SKIP_PUBLISH }}"
repository:
owner: netbirdio
name: homebrew-tap
@@ -440,7 +902,6 @@ brews:
uploads:
- name: debian
skip: "{{ .Env.SKIP_PUBLISH }}"
ids:
- netbird_deb
mode: archive
@@ -449,7 +910,6 @@ uploads:
method: PUT
- name: yum
skip: "{{ .Env.SKIP_PUBLISH }}"
ids:
- netbird_rpm
mode: archive

View File

@@ -1,6 +1,5 @@
version: 2
env:
- SKIP_PUBLISH={{ if index .Env "SKIP_PUBLISH" }}{{ .Env.SKIP_PUBLISH }}{{ else }}true{{ end }}
project_name: netbird-ui
builds:
- id: netbird-ui
@@ -102,7 +101,6 @@ nfpms:
uploads:
- name: debian
skip: "{{ .Env.SKIP_PUBLISH }}"
ids:
- netbird_ui_deb
mode: archive
@@ -111,7 +109,6 @@ uploads:
method: PUT
- name: yum
skip: "{{ .Env.SKIP_PUBLISH }}"
ids:
- netbird_ui_rpm
mode: archive

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.24
FROM alpine:3.23.3
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \
bash \
@@ -21,7 +21,7 @@ ENV \
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
ARG TARGETPLATFORM
ARG NETBIRD_BINARY=$TARGETPLATFORM/netbird
ARG NETBIRD_BINARY=netbird
COPY client/netbird-entrypoint.sh /usr/local/bin/netbird-entrypoint.sh
COPY "${NETBIRD_BINARY}" /usr/local/bin/netbird

View File

@@ -4,7 +4,7 @@
# podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.24
FROM alpine:3.22.0
RUN apk add --no-cache \
bash \
@@ -27,7 +27,7 @@ ENV \
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
ARG TARGETPLATFORM
ARG NETBIRD_BINARY=$TARGETPLATFORM/netbird
ARG NETBIRD_BINARY=netbird
COPY client/netbird-entrypoint.sh /usr/local/bin/netbird-entrypoint.sh
COPY "${NETBIRD_BINARY}" /usr/local/bin/netbird

View File

@@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/mdm"
"github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
@@ -75,6 +76,13 @@ type Client struct {
connectClient *internal.ConnectClient
config *profilemanager.Config
cacheDir string
// mdmLoader holds the per-Client MDM policy source. Set by
// SetMDMPolicyFetcher (called from the Kotlin side). Each Run
// passes this loader to the resolved Config so applyMDMPolicy
// picks up the active overlay. Nil means "MDM enforcement off
// for this Client".
mdmLoader *mdm.Loader
}
func (c *Client) setState(cfg *profilemanager.Config, cacheDir string, cc *internal.ConnectClient) {
@@ -129,6 +137,7 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
if err != nil {
return err
}
c.applyMDMOverlay(cfg)
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive)
@@ -173,6 +182,7 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
if err != nil {
return err
}
c.applyMDMOverlay(cfg)
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive)
@@ -230,6 +240,7 @@ func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (strin
if err != nil {
return "", fmt.Errorf("load config: %w", err)
}
c.applyMDMOverlay(cfg)
cacheDir = platformFiles.CacheDir()
}

80
client/android/mdm.go Normal file
View File

@@ -0,0 +1,80 @@
//go:build android
package android
import (
"encoding/json"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/mdm"
)
// PolicyFetcher is the mobile-side bridge for the MDM managed-config
// snapshot. The native layer (Kotlin) implements this and registers
// the instance per Client via Client.SetMDMPolicyFetcher. Every
// invocation of fetchJSON must read the current RestrictionsManager
// state and return the result as a JSON-encoded map[string]any string.
//
// JSON is used because gomobile does not support map[string]any
// crossing the JNI boundary — the adapter on the Go side parses the
// string back into the map[string]any expected by mdm.Loader.
//
// Return value contract:
// - "" (empty) : interpreted as "no MDM source / no managed keys"
// - "{}" : managed config explicitly empty
// - "{...}" : JSON object with key/value pairs
// - malformed JSON : logged and treated as empty
type PolicyFetcher interface {
FetchJSON() string
}
// jsonFetcherAdapter wraps a gomobile-exposed PolicyFetcher into the
// internal mdm.PolicyFetcher interface, taking care of JSON decoding
// on every Fetch.
type jsonFetcherAdapter struct {
inner PolicyFetcher
}
func (a *jsonFetcherAdapter) Fetch() map[string]any {
raw := a.inner.FetchJSON()
if raw == "" {
return nil
}
var out map[string]any
if err := json.Unmarshal([]byte(raw), &out); err != nil {
log.Warnf("MDM mobile fetcher: invalid JSON payload from native: %v", err)
return nil
}
return out
}
// SetMDMPolicyFetcher registers the native-provided MDM policy fetcher
// on this Client. Call once from the gomobile-init code (Kotlin
// Application.onCreate or Service onCreate) before invoking Run /
// RunWithoutLogin. Passing nil disables MDM enforcement on this
// Client.
//
// The fetcher is held as a *mdm.Loader instance on the Client (no
// package-level state) — multiple Clients in the same process get
// independent Loaders, and tests can inject fakes per Client.
func (c *Client) SetMDMPolicyFetcher(p PolicyFetcher) {
if p == nil {
c.mdmLoader = mdm.NewLoader(nil)
return
}
c.mdmLoader = mdm.NewLoader(&jsonFetcherAdapter{inner: p})
}
// applyMDMOverlay applies the Client-held MDM Loader's current policy
// on top of the just-read Config. Called immediately after every
// UpdateOrCreateConfig — profilemanager's apply() initialises the
// policy to empty and leaves overlay responsibility to the lifecycle
// owner. No-op when no fetcher was registered.
func (c *Client) applyMDMOverlay(cfg *profilemanager.Config) {
if cfg == nil || c.mdmLoader == nil {
return
}
cfg.ApplyMDMPolicy(c.mdmLoader.Load())
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
@@ -23,7 +24,6 @@ const (
// Profile represents a profile for gomobile
type Profile struct {
ID string
Name string
IsActive bool
}
@@ -53,10 +53,10 @@ func (p *ProfileArray) Get(i int) *Profile {
├── state.json ← Default profile state
├── active_profile.json ← Active profile tracker (JSON with Name + Username)
└── profiles/ ← Subdirectory for non-default profiles
├── work.json ← Legacy work profile config
├── work.state.json ← Legacy work profile state
├── 4c5f5c8198c3989cffb5b5394f5a7ae0.json ← ID profile config
── 4c5f5c8198c3989cffb5b5394f5a7ae0.state.json ← ID profile state
├── work.json ← Work profile config
├── work.state.json ← Work profile state
├── personal.json ← Personal profile config
── personal.state.json ← Personal profile state
*/
// ProfileManager manages profiles for Android
@@ -99,7 +99,6 @@ func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
var profiles []*Profile
for _, p := range internalProfiles {
profiles = append(profiles, &Profile{
ID: p.ID.String(),
Name: p.Name,
IsActive: p.IsActive,
})
@@ -109,65 +108,55 @@ func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
}
// GetActiveProfile returns the currently active profile name
func (pm *ProfileManager) GetActiveProfile() (*Profile, error) {
func (pm *ProfileManager) GetActiveProfile() (string, error) {
// Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json
activeState, err := pm.serviceMgr.GetActiveProfileState()
if err != nil {
return nil, fmt.Errorf("failed to get active profile: %w", err)
return "", fmt.Errorf("failed to get active profile: %w", err)
}
// ActiveProfileState only stores the ID (and username), not the display
// name. Resolve the ID to the full profile so callers get the real Name.
prof, err := pm.serviceMgr.ResolveProfile(activeState.ID.String(), androidUsername)
if err != nil {
return nil, fmt.Errorf("failed to resolve active profile %q: %w", activeState.ID, err)
}
return &Profile{ID: prof.ID.String(), Name: prof.Name, IsActive: true}, nil
return activeState.Name, nil
}
// SwitchProfile switches to a different profile
func (pm *ProfileManager) SwitchProfile(id string) error {
func (pm *ProfileManager) SwitchProfile(profileName string) error {
// Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json
err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: profilemanager.ID(id),
Name: profileName,
Username: androidUsername,
})
if err != nil {
return fmt.Errorf("failed to switch profile: %w", err)
}
log.Infof("switched to profile: %s", id)
log.Infof("switched to profile: %s", profileName)
return nil
}
// AddProfile creates a new profile
func (pm *ProfileManager) AddProfile(profileName string) error {
// Use ServiceManager (creates profile in profiles/ directory)
profile, err := pm.serviceMgr.AddProfile(profileName, androidUsername)
if err != nil {
if err := pm.serviceMgr.AddProfile(profileName, androidUsername); err != nil {
return fmt.Errorf("failed to add profile: %w", err)
}
log.Infof("created new profile: %s", profile.ID)
log.Infof("created new profile: %s", profileName)
return nil
}
// LogoutProfile logs out from a profile (clears authentication)
func (pm *ProfileManager) LogoutProfile(id string) error {
configPath, err := pm.getProfileConfigPath(id)
func (pm *ProfileManager) LogoutProfile(profileName string) error {
profileName = sanitizeProfileName(profileName)
configPath, err := pm.getProfileConfigPath(profileName)
if err != nil {
return err
}
if !profilemanager.IsValidProfileFilenameStem(profilemanager.ID(id)) {
return fmt.Errorf("id '%s' is not valid", id)
}
// Check if profile exists
if _, err := os.Stat(configPath); os.IsNotExist(err) {
return fmt.Errorf("profile '%s' does not exist", id)
return fmt.Errorf("profile '%s' does not exist", profileName)
}
// Read current config using internal profilemanager
@@ -185,57 +174,53 @@ func (pm *ProfileManager) LogoutProfile(id string) error {
return fmt.Errorf("failed to save config: %w", err)
}
log.Infof("logged out from profile: %s", id)
log.Infof("logged out from profile: %s", profileName)
return nil
}
// RemoveProfile deletes a profile
func (pm *ProfileManager) RemoveProfile(id string) error {
func (pm *ProfileManager) RemoveProfile(profileName string) error {
// Use ServiceManager (removes profile from profiles/ directory)
if err := pm.serviceMgr.RemoveProfile(profilemanager.ID(id), androidUsername); err != nil {
if err := pm.serviceMgr.RemoveProfile(profileName, androidUsername); err != nil {
return fmt.Errorf("failed to remove profile: %w", err)
}
log.Infof("removed profile: %s", id)
log.Infof("removed profile: %s", profileName)
return nil
}
// getProfileConfigPath returns the config file path for a profile
// This is needed for Android-specific path handling (netbird.cfg for default profile)
func (pm *ProfileManager) getProfileConfigPath(id string) (string, error) {
if !profilemanager.IsValidProfileFilenameStem(profilemanager.ID(id)) {
return "", fmt.Errorf("id %q is not valid", id)
}
if id == profilemanager.DefaultProfileName {
func (pm *ProfileManager) getProfileConfigPath(profileName string) (string, error) {
if profileName == "" || profileName == profilemanager.DefaultProfileName {
// Android uses netbird.cfg for default profile instead of default.json
// Default profile is stored in root configDir, not in profiles/
return filepath.Join(pm.configDir, defaultConfigFilename), nil
}
// Non-default profiles are stored in profiles subdirectory
// This matches the Java Preferences.java expectation
profileName = sanitizeProfileName(profileName)
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, id+".json"), nil
return filepath.Join(profilesDir, profileName+".json"), nil
}
// GetConfigPath returns the config file path for a given profile id
// GetConfigPath returns the config file path for a given profile
// Java should call this instead of constructing paths with Preferences.configFile()
func (pm *ProfileManager) GetConfigPath(id string) (string, error) {
return pm.getProfileConfigPath(id)
func (pm *ProfileManager) GetConfigPath(profileName string) (string, error) {
return pm.getProfileConfigPath(profileName)
}
// GetStateFilePath returns the state file path for a given profile
// Java should call this instead of constructing paths with Preferences.stateFile()
func (pm *ProfileManager) GetStateFilePath(id string) (string, error) {
if id == "" || id == profilemanager.DefaultProfileName {
func (pm *ProfileManager) GetStateFilePath(profileName string) (string, error) {
if profileName == "" || profileName == profilemanager.DefaultProfileName {
return filepath.Join(pm.configDir, "state.json"), nil
}
if !profilemanager.IsValidProfileFilenameStem(profilemanager.ID(id)) {
return "", fmt.Errorf("id %q is not valid", id)
}
profileName = sanitizeProfileName(profileName)
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, id+".state.json"), nil
return filepath.Join(profilesDir, profileName+".state.json"), nil
}
// GetActiveConfigPath returns the config file path for the currently active profile
@@ -245,7 +230,7 @@ func (pm *ProfileManager) GetActiveConfigPath() (string, error) {
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return pm.GetConfigPath(activeProfile.ID)
return pm.GetConfigPath(activeProfile)
}
// GetActiveStateFilePath returns the state file path for the currently active profile
@@ -255,5 +240,18 @@ func (pm *ProfileManager) GetActiveStateFilePath() (string, error) {
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return pm.GetStateFilePath(activeProfile.ID)
return pm.GetStateFilePath(activeProfile)
}
// sanitizeProfileName removes invalid characters from profile name
func sanitizeProfileName(name string) string {
// Keep only alphanumeric, underscore, and hyphen
var result strings.Builder
for _, r := range name {
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
(r >= '0' && r <= '9') || r == '_' || r == '-' {
result.WriteRune(r)
}
}
return result.String()
}

View File

@@ -17,6 +17,7 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/mdm"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/util"
@@ -96,19 +97,17 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
}
handle := activeProf.ID.String()
loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,
ManagementUrl: managementURL,
IsUnixDesktopClient: isUnixRunningDesktop(),
Hostname: hostName,
DnsLabels: dnsLabelsReq,
ProfileName: &handle,
ProfileName: &activeProf.Name,
Username: &username,
}
profileState, err := pm.GetProfileState(activeProf.ID)
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
@@ -172,13 +171,14 @@ func getActiveProfile(ctx context.Context, pm *profilemanager.ProfileManager, pr
return activeProf, nil
}
func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManager, handle string, username string) error {
resolvedID, err := switchProfile(ctx, handle, username)
func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManager, profileName string, username string) error {
err := switchProfile(context.Background(), profileName, username)
if err != nil {
return fmt.Errorf("switch profile on daemon: %v", err)
}
if err := pm.SwitchProfile(resolvedID); err != nil {
err = pm.SwitchProfile(profileName)
if err != nil {
return fmt.Errorf("switch profile: %v", err)
}
@@ -206,15 +206,11 @@ func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManage
return nil
}
// switchProfile asks the daemon to switch to the profile identified by
// handle (a name, ID, or unique ID prefix). Returns the resolved profile
// ID so the caller can update the local active-profile state without
// re-resolving the handle.
func switchProfile(ctx context.Context, handle string, username string) (profilemanager.ID, error) {
func switchProfile(ctx context.Context, profileName string, username string) error {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return "", fmt.Errorf("failed to connect to daemon error: %v\n"+
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
@@ -222,15 +218,15 @@ func switchProfile(ctx context.Context, handle string, username string) (profile
client := proto.NewDaemonServiceClient(conn)
resp, err := client.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &handle,
_, err = client.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &profileName,
Username: &username,
})
if err != nil {
return "", fmt.Errorf("switch profile failed: %v", err)
return fmt.Errorf("switch profile failed: %v", err)
}
return profilemanager.ID(resp.Id), nil
return nil
}
func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, activeProf *profilemanager.Profile) error {
@@ -253,8 +249,13 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
if err != nil {
return fmt.Errorf("read config file %s: %v", configFilePath, err)
}
// CLI standalone login: profilemanager no longer auto-applies MDM,
// so layer in the OS-native policy here. Desktop builds construct
// a Loader with no fetcher — the build-tagged loadPlatform reads
// the registry/plist directly.
config.ApplyMDMPolicy(mdm.NewLoader(nil).Load())
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.ID)
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
@@ -282,7 +283,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
return nil
}
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string, profileID profilemanager.ID) error {
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
@@ -296,7 +297,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
jwtToken := ""
if setupKey == "" && needsLogin {
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileID)
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
@@ -311,10 +312,10 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
return nil
}
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileID profilemanager.ID) (*auth.TokenInfo, error) {
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) {
hint := ""
pm := profilemanager.NewProfileManager()
profileState, err := pm.GetProfileState(profileID)
profileState, err := pm.GetProfileState(profileName)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {

View File

@@ -27,7 +27,7 @@ func TestLogin(t *testing.T) {
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
sm := profilemanager.ServiceManager{}
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: "default",
Name: "default",
Username: currUser.Username,
})
if err != nil {

View File

@@ -2,16 +2,11 @@ package cmd
import (
"context"
"errors"
"fmt"
"os/user"
"strings"
"text/tabwriter"
"time"
"github.com/spf13/cobra"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
@@ -19,8 +14,6 @@ import (
"github.com/netbirdio/netbird/util"
)
var profileListShowID bool
var profileCmd = &cobra.Command{
Use: "profile",
Short: "Manage NetBird client profiles",
@@ -38,40 +31,27 @@ var profileListCmd = &cobra.Command{
var profileAddCmd = &cobra.Command{
Use: "add <profile_name>",
Short: "Add a new profile",
Long: `Add a new profile. Profile name is free-form, a unique ID is generated for the on-disk config file.`,
Long: `Add a new profile to the NetBird client. The profile name must be unique.`,
Args: cobra.ExactArgs(1),
RunE: addProfileFunc,
}
var profileRenameCmd = &cobra.Command{
Use: "rename <profile> <new_profile_name>",
Short: "Renames an existing profile",
Long: `Renames an existing profile (by a name, ID, or unique ID prefix). Profile name is free-form.`,
Args: cobra.ExactArgs(2),
RunE: renameProfileFunc,
}
var profileRemoveCmd = &cobra.Command{
Use: "remove <profile>",
Short: "Remove a profile",
Long: `Remove a profile by name, ID, or unique ID prefix.`,
Aliases: []string{"rm"},
Args: cobra.ExactArgs(1),
RunE: removeProfileFunc,
Use: "remove <profile_name>",
Short: "Remove a profile",
Long: `Remove a profile from the NetBird client. The profile must not be inactive.`,
Args: cobra.ExactArgs(1),
RunE: removeProfileFunc,
}
var profileSelectCmd = &cobra.Command{
Use: "select <profile>",
Use: "select <profile_name>",
Short: "Select a profile",
Long: `Make the specified profile active. Accepts a name, ID, or unique ID prefix.`,
Long: `Make the specified profile active. This will switch the client to use the selected profile's configuration.`,
Args: cobra.ExactArgs(1),
RunE: selectProfileFunc,
}
func init() {
profileListCmd.Flags().BoolVar(&profileListShowID, "show-id", false, "show the profile ID column")
}
func setupCmd(cmd *cobra.Command) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
@@ -85,7 +65,6 @@ func setupCmd(cmd *cobra.Command) error {
return nil
}
func listProfilesFunc(cmd *cobra.Command, _ []string) error {
if err := setupCmd(cmd); err != nil {
return err
@@ -104,33 +83,25 @@ func listProfilesFunc(cmd *cobra.Command, _ []string) error {
daemonClient := proto.NewDaemonServiceClient(conn)
resp, err := daemonClient.ListProfiles(cmd.Context(), &proto.ListProfilesRequest{
profiles, err := daemonClient.ListProfiles(cmd.Context(), &proto.ListProfilesRequest{
Username: currUser.Username,
})
if err != nil {
return err
}
tw := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 0, 2, ' ', 0)
if profileListShowID {
fmt.Fprintln(tw, "ID\tNAME\tACTIVE")
} else {
fmt.Fprintln(tw, "NAME\tACTIVE")
}
for _, profile := range resp.Profiles {
marker := ""
// list profiles, add a tick if the profile is active
cmd.Println("Found", len(profiles.Profiles), "profiles:")
for _, profile := range profiles.Profiles {
// use a cross to indicate the passive profiles
activeMarker := "✗"
if profile.IsActive {
marker = "✓"
}
name := profilemanager.StripCtrlChars(profile.Name)
id := profilemanager.ID(profile.Id)
if profileListShowID {
fmt.Fprintf(tw, "%s\t%s\t%s\n", id.ShortID(), name, marker)
} else {
fmt.Fprintf(tw, "%s\t%s\n", name, marker)
activeMarker = "✓"
}
cmd.Println(activeMarker, profile.Name)
}
return tw.Flush()
return nil
}
func addProfileFunc(cmd *cobra.Command, args []string) error {
@@ -150,82 +121,21 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
}
daemonClient := proto.NewDaemonServiceClient(conn)
profileName := args[0]
resp, err := daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
_, err = daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
ProfileName: profileName,
Username: currUser.Username,
})
if err != nil {
return fmt.Errorf("add profile request: %w", err)
}
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, profileName)
if dupCount > 1 {
cmd.Printf("Warning: %d other profile(s) already use the name %q.\n", dupCount-1, profileName)
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
}
id := profilemanager.ID(resp.Id)
cmd.Printf("Profile added: %s %s\n", id.ShortID(), profilemanager.StripCtrlChars(profileName))
return nil
}
func renameProfileFunc(cmd *cobra.Command, args []string) error {
if err := setupCmd(cmd); err != nil {
return err
}
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil {
return fmt.Errorf("connect to service CLI interface: %w", err)
}
defer conn.Close()
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
daemonClient := proto.NewDaemonServiceClient(conn)
handle := args[0]
newProfilename := args[1]
resp, err := daemonClient.RenameProfile(cmd.Context(), &proto.RenameProfileRequest{
Handle: handle,
Username: currUser.Username,
NewProfileName: newProfilename,
})
if err != nil {
return wrapAmbiguityError(err, handle)
}
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, newProfilename)
if dupCount > 1 {
cmd.Printf("Warning: %d other profile(s) already use the name %q.\n", dupCount-1, newProfilename)
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
}
cmd.Printf("Profile renamed from %s to %s\n", profilemanager.StripCtrlChars(resp.OldProfileName), profilemanager.StripCtrlChars(newProfilename))
cmd.Println("Profile added successfully:", profileName)
return nil
}
func countProfilesWithName(ctx context.Context, c proto.DaemonServiceClient, username, name string) (int, error) {
resp, err := c.ListProfiles(ctx, &proto.ListProfilesRequest{Username: username})
if err != nil {
return 0, err
}
n := 0
for _, p := range resp.Profiles {
if p.Name == name {
n++
}
}
return n, nil
}
func removeProfileFunc(cmd *cobra.Command, args []string) error {
if err := setupCmd(cmd); err != nil {
return err
@@ -243,17 +153,18 @@ func removeProfileFunc(cmd *cobra.Command, args []string) error {
}
daemonClient := proto.NewDaemonServiceClient(conn)
handle := args[0]
resp, err := daemonClient.RemoveProfile(cmd.Context(), &proto.RemoveProfileRequest{
ProfileName: handle,
profileName := args[0]
_, err = daemonClient.RemoveProfile(cmd.Context(), &proto.RemoveProfileRequest{
ProfileName: profileName,
Username: currUser.Username,
})
if err != nil {
return wrapAmbiguityError(err, handle)
return err
}
cmd.Printf("Profile removed: %s\n", resp.Id)
cmd.Println("Profile removed successfully:", profileName)
return nil
}
@@ -263,7 +174,7 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
}
profileManager := profilemanager.NewProfileManager()
handle := args[0]
profileName := args[0]
currUser, err := user.Current()
if err != nil {
@@ -280,15 +191,32 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
daemonClient := proto.NewDaemonServiceClient(conn)
switchResp, err := daemonClient.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &handle,
Username: &currUser.Username,
profiles, err := daemonClient.ListProfiles(ctx, &proto.ListProfilesRequest{
Username: currUser.Username,
})
if err != nil {
return wrapAmbiguityError(err, handle)
return fmt.Errorf("list profiles: %w", err)
}
if err := profileManager.SwitchProfile(profilemanager.ID(switchResp.Id)); err != nil {
var profileExists bool
for _, profile := range profiles.Profiles {
if profile.Name == profileName {
profileExists = true
break
}
}
if !profileExists {
return fmt.Errorf("profile %s does not exist", profileName)
}
if err := switchProfile(cmd.Context(), profileName, currUser.Username); err != nil {
return err
}
err = profileManager.SwitchProfile(profileName)
if err != nil {
return err
}
@@ -303,30 +231,6 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
}
}
id := profilemanager.ID(switchResp.Id)
cmd.Printf("Profile switched to: %s\n", id.ShortID())
cmd.Println("Profile switched successfully to:", profileName)
return nil
}
// wrapAmbiguityError turns the daemon's gRPC InvalidArgument errors
// (which carry the resolver's message verbatim) into CLI-friendly text
// that points the user at --show-id.
func wrapAmbiguityError(err error, handle string) error {
if err == nil {
return nil
}
st, ok := gstatus.FromError(err)
if !ok {
return err
}
switch st.Code() {
case codes.InvalidArgument:
msg := st.Message()
if strings.Contains(msg, "ambiguous") {
return errors.New(msg + "\nRun `netbird profile list --show-id` to see IDs, then select by ID prefix:\n netbird profile select|remove <id-prefix>")
}
case codes.NotFound:
return fmt.Errorf("profile %q not found", handle)
}
return err
}

View File

@@ -190,7 +190,6 @@ func init() {
// profile commands
profileCmd.AddCommand(profileListCmd)
profileCmd.AddCommand(profileAddCmd)
profileCmd.AddCommand(profileRenameCmd)
profileCmd.AddCommand(profileRemoveCmd)
profileCmd.AddCommand(profileSelectCmd)

View File

@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/mdm"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain"
@@ -128,12 +129,13 @@ func upFunc(cmd *cobra.Command, args []string) error {
var profileSwitched bool
// switch profile if provided
if profileName != "" {
resolvedID, err := switchProfile(cmd.Context(), profileName, username.Username)
err = switchProfile(cmd.Context(), profileName, username.Username)
if err != nil {
return fmt.Errorf("switch profile: %v", err)
}
if err := pm.SwitchProfile(resolvedID); err != nil {
err = pm.SwitchProfile(profileName)
if err != nil {
return fmt.Errorf("switch profile: %v", err)
}
@@ -186,10 +188,14 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
if err != nil {
return fmt.Errorf("get config file: %v", err)
}
// CLI foreground path runs without the daemon Server: layer in the
// active MDM policy explicitly so a forced ManagementURL / PSK /
// other managed key actually takes effect on this run.
config.ApplyMDMPolicy(mdm.NewLoader(nil).Load())
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.ID)
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
@@ -260,10 +266,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
}
// set the new config
req := setupSetConfigReq(customDNSAddressConverted, cmd, activeProf.ID.String(), username.Username)
req := setupSetConfigReq(customDNSAddressConverted, cmd, activeProf.Name, username.Username)
if _, err := client.SetConfig(ctx, req); err != nil {
if st, ok := gstatus.FromError(err); ok && st.Code() == codes.Unavailable {
log.Warnf("setConfig method is not available in the daemon: %s", st.Message())
log.Warnf("setConfig method is not available in the daemon")
} else {
return fmt.Errorf("call service setConfig method: %v", err)
}
@@ -288,11 +294,10 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
return fmt.Errorf("setup login request: %v", err)
}
profileID := activeProf.ID.String()
loginRequest.ProfileName = &profileID
loginRequest.ProfileName = &activeProf.Name
loginRequest.Username = &username
profileState, err := pm.GetProfileState(activeProf.ID)
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
@@ -329,7 +334,7 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
}
if _, err := client.Up(ctx, &proto.UpRequest{
ProfileName: &profileID,
ProfileName: &activeProf.Name,
Username: &username,
}); err != nil {
return fmt.Errorf("call service up method: %v", err)

View File

@@ -29,14 +29,14 @@ func TestUpDaemon(t *testing.T) {
}
sm := profilemanager.ServiceManager{}
created, err := sm.AddProfile("test1", currUser.Username)
err = sm.AddProfile("test1", currUser.Username)
if err != nil {
t.Fatalf("failed to add profile: %v", err)
return
}
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: created.ID,
Name: "test1",
Username: currUser.Username,
})
if err != nil {

View File

@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/mdm"
sshcommon "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain"
@@ -215,6 +216,10 @@ func New(opts Options) (*Client, error) {
if err != nil {
return nil, fmt.Errorf("create config: %w", err)
}
// Embedded path runs without the daemon Server: apply the active
// MDM policy explicitly so a forced ManagementURL / PSK / other
// managed key takes effect on this embedded engine instance.
config.ApplyMDMPolicy(mdm.NewLoader(nil).Load())
if opts.PrivateKey != "" {
config.PrivateKey = opts.PrivateKey

View File

@@ -41,6 +41,7 @@ type ICEBind struct {
*wgConn.StdNetBind
transportNet transport.Net
filterFn udpmux.FilterFn
address wgaddr.Address
mtu uint16
@@ -60,11 +61,12 @@ type ICEBind struct {
ipv6Conn *net.UDPConn
}
func NewICEBind(transportNet transport.Net, address wgaddr.Address, mtu uint16) *ICEBind {
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
transportNet: transportNet,
filterFn: filterFn,
address: address,
mtu: mtu,
endpoints: make(map[netip.Addr]net.Conn),
@@ -263,6 +265,7 @@ func (s *ICEBind) createOrUpdateMux() {
udpmux.UniversalUDPMuxParams{
UDPConn: muxConn,
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
MTU: s.mtu,
},

View File

@@ -289,7 +289,7 @@ func setupICEBind(t *testing.T) *ICEBind {
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/10"),
}
return NewICEBind(transportNet, address, 1280)
return NewICEBind(transportNet, nil, address, 1280)
}
func createDualStackConns(t *testing.T) (*net.UDPConn, *net.UDPConn) {

View File

@@ -1,13 +1,10 @@
package device
import (
"fmt"
"net/netip"
"runtime/debug"
"sync"
"sync/atomic"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun"
)
@@ -44,13 +41,10 @@ type PacketCapture interface {
type FilteredDevice struct {
tun.Device
filter PacketFilter
capture atomic.Pointer[PacketCapture]
// panicHandler is invoked after a panic in the underlying device is
// recovered in Read or Write.
panicHandler atomic.Pointer[func()]
mutex sync.RWMutex
closeOnce sync.Once
filter PacketFilter
capture atomic.Pointer[PacketCapture]
mutex sync.RWMutex
closeOnce sync.Once
}
// newDeviceFilter constructor function
@@ -76,7 +70,7 @@ func (d *FilteredDevice) Close() error {
// Read wraps read method with filtering feature
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
if n, err = d.deviceRead(bufs, sizes, offset); err != nil {
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
return 0, err
}
@@ -118,7 +112,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
d.mutex.RUnlock()
if filter == nil {
return d.deviceWrite(bufs, offset)
return d.Device.Write(bufs, offset)
}
filteredBufs := make([][]byte, 0, len(bufs))
@@ -131,44 +125,9 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
}
}
n, err := d.deviceWrite(filteredBufs, offset)
if err != nil {
return n, err
}
return n + dropped, nil
}
// deviceRead calls the underlying device Read, recovering from panics in the
// wintun read path and converting them into errors.
func (d *FilteredDevice) deviceRead(bufs [][]byte, sizes []int, offset int) (n int, err error) {
defer d.recoverFromPanic("read", &n, &err)
return d.Device.Read(bufs, sizes, offset)
}
// deviceWrite calls the underlying device Write, recovering from panics in the
// wintun write path and converting them into errors.
func (d *FilteredDevice) deviceWrite(bufs [][]byte, offset int) (n int, err error) {
defer d.recoverFromPanic("write", &n, &err)
return d.Device.Write(bufs, offset)
}
// recoverFromPanic converts a panic in the underlying device into a regular
// error and invokes the registered panic handler. The wintun read path is
// known to panic on zero-length packets that third-party filter drivers can
// place in the ring.
func (d *FilteredDevice) recoverFromPanic(op string, n *int, err *error) {
r := recover()
if r == nil {
return
}
log.Errorf("recovered panic in tun device %s: %v\n%s", op, r, debug.Stack())
*n = 0
*err = fmt.Errorf("tun device %s panic: %v", op, r)
if handler := d.panicHandler.Load(); handler != nil {
(*handler)()
}
n, err := d.Device.Write(filteredBufs, offset)
n += dropped
return n, err
}
// SetFilter sets packet filter to device
@@ -178,17 +137,6 @@ func (d *FilteredDevice) SetFilter(filter PacketFilter) {
d.mutex.Unlock()
}
// SetPanicHandler registers a handler invoked after a recovered panic in Read
// or Write. The device is unusable after such a panic; the handler should
// trigger recreation of the interface. Pass nil to remove.
func (d *FilteredDevice) SetPanicHandler(handler func()) {
if handler == nil {
d.panicHandler.Store(nil)
return
}
d.panicHandler.Store(&handler)
}
// SetCapture sets or clears the packet capture sink. Pass nil to disable.
// Uses atomic store so the hot path (Read/Write) is a single pointer load
// with no locking overhead when capture is off.

View File

@@ -221,60 +221,3 @@ func TestDeviceWrapperRead(t *testing.T) {
}
})
}
func TestDeviceWrapperReadPanic(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
tun := mocks.NewMockDevice(ctrl)
tun.EXPECT().Read(gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(bufs [][]byte, sizes []int, offset int) (int, error) {
// Reproduce the wintun zero-length packet panic (index out of range).
packet := make([]byte, 0)
return int(packet[0]), nil
})
wrapped := newDeviceFilter(tun)
handlerCalled := false
wrapped.SetPanicHandler(func() { handlerCalled = true })
n, err := wrapped.Read([][]byte{{}}, []int{0}, 0)
if err == nil {
t.Errorf("expected error from recovered panic, got nil")
}
if n != 0 {
t.Errorf("expected n=0, got %d", n)
}
if !handlerCalled {
t.Errorf("expected panic handler to be called")
}
}
func TestDeviceWrapperWritePanic(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
tun := mocks.NewMockDevice(ctrl)
tun.EXPECT().Write(gomock.Any(), gomock.Any()).
DoAndReturn(func(bufs [][]byte, offset int) (int, error) {
packet := make([]byte, 0)
return int(packet[0]), nil
})
wrapped := newDeviceFilter(tun)
handlerCalled := false
wrapped.SetPanicHandler(func() { handlerCalled = true })
n, err := wrapped.Write([][]byte{{0x45, 0x00}}, 0)
if err == nil {
t.Errorf("expected error from recovered panic, got nil")
}
if n != 0 {
t.Errorf("expected n=0, got %d", n)
}
if !handlerCalled {
t.Errorf("expected panic handler to be called")
}
}

View File

@@ -32,6 +32,8 @@ type TunKernelDevice struct {
link *wgLink
udpMuxConn net.PacketConn
udpMux *udpmux.UniversalUDPMuxDefault
filterFn udpmux.FilterFn
}
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice {
@@ -102,6 +104,7 @@ func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
bindParams := udpmux.UniversalUDPMuxParams{
UDPConn: nbnet.WrapPacketConn(rawSock),
Net: t.transportNet,
FilterFn: t.filterFn,
WGAddress: t.address,
MTU: t.mtu,
}

View File

@@ -63,6 +63,7 @@ type WGIFaceOpts struct {
MTU uint16
MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net
FilterFn udpmux.FilterFn
DisableDNS bool
}

View File

@@ -11,7 +11,7 @@ import (
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
var tun WGTunDevice
if netstack.IsEnabled() {

View File

@@ -9,7 +9,7 @@ import (
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
if netstack.IsEnabled() {
wgIFace := &WGIface{

View File

@@ -10,7 +10,7 @@ import (
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),

View File

@@ -14,7 +14,7 @@ import (
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
return &WGIface{
tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
userspaceBind: true,
@@ -30,7 +30,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
return &WGIface{
tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind),
userspaceBind: true,

View File

@@ -8,6 +8,8 @@ import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
@@ -20,6 +22,10 @@ import (
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// FilterFn is a function that filters out candidates based on the address.
// If it returns true, the address is to be filtered. It also returns the prefix of matching route.
type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
// It then passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct {
@@ -37,6 +43,7 @@ type UniversalUDPMuxParams struct {
UDPConn net.PacketConn
XORMappedAddrCacheTTL time.Duration
Net transport.Net
FilterFn FilterFn
WGAddress wgaddr.Address
MTU uint16
}
@@ -61,6 +68,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
PacketConn: params.UDPConn,
mux: m,
logger: params.Logger,
filterFn: params.FilterFn,
address: params.WGAddress,
}
@@ -107,12 +115,15 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
}
}
// UDPConn is a wrapper around UDPMux conn that overrides WriteTo to drop packets destined for the overlay subnet.
// UDPConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type UDPConn struct {
net.PacketConn
mux *UniversalUDPMuxDefault
logger logging.LeveledLogger
address wgaddr.Address
mux *UniversalUDPMuxDefault
logger logging.LeveledLogger
filterFn FilterFn
// TODO: reset cache on route changes
addrCache sync.Map
address wgaddr.Address
}
// GetPacketConn returns the underlying PacketConn
@@ -121,18 +132,67 @@ func (u *UDPConn) GetPacketConn() net.PacketConn {
}
func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
if u.filterFn == nil {
return u.PacketConn.WriteTo(b, addr)
}
dst := udpAddr.AddrPort().Addr().Unmap()
if (u.address.Network.IsValid() && u.address.Network.Contains(dst)) || (u.address.IPv6Net.IsValid() && u.address.IPv6Net.Contains(dst)) {
log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return 0, fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
if isRouted, found := u.addrCache.Load(addr.String()); found {
return u.handleCachedAddress(isRouted.(bool), b, addr)
}
return u.handleUncachedAddress(b, addr)
}
func (u *UDPConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
if isRouted {
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
}
return u.PacketConn.WriteTo(b, addr)
}
func (u *UDPConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
if err := u.performFilterCheck(addr); err != nil {
return 0, err
}
return u.PacketConn.WriteTo(b, addr)
}
func (u *UDPConn) performFilterCheck(addr net.Addr) error {
host, err := getHostFromAddr(addr)
if err != nil {
log.Errorf("Failed to get host from address %s: %v", addr, err)
return nil
}
a, err := netip.ParseAddr(host)
if err != nil {
log.Errorf("Failed to parse address %s: %v", addr, err)
return nil
}
if u.address.Network.Contains(a) {
log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
}
if isRouted, prefix, err := u.filterFn(a); err != nil {
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
} else {
u.addrCache.Store(addr.String(), isRouted)
if isRouted {
// Extra log, as the error only shows up with ICE logging enabled
log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix)
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
}
}
return nil
}
func getHostFromAddr(addr net.Addr) (string, error) {
host, _, err := net.SplitHostPort(addr.String())
return host, err
}
// GetSharedConn returns the shared udp conn
func (m *UniversalUDPMuxDefault) GetSharedConn() net.PacketConn {
return m.params.UDPConn
@@ -165,13 +225,6 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A
return nil
}
src := udpAddr.AddrPort().Addr().Unmap()
wg := m.params.WGAddress
if (wg.Network.IsValid() && wg.Network.Contains(src)) || (wg.IPv6Net.IsValid() && wg.IPv6Net.Contains(src)) {
log.Debugf("dropping STUN message from overlay source %s", udpAddr)
return nil
}
if m.isXORMappedResponse(msg, udpAddr.String()) {
err := m.handleXORMappedResponse(udpAddr, msg)
if err != nil {

View File

@@ -66,7 +66,7 @@ func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(nil, wgAddress, 1280)
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
endpointAddress := &net.UDPAddr{
IP: net.IPv4(10, 0, 0, 1),
Port: 1234,

View File

@@ -22,7 +22,7 @@ func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(nil, wgAddress, 1280)
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
endpointAddress := &net.UDPAddr{
IP: net.IPv4(10, 0, 0, 1),
Port: 1234,

View File

@@ -118,8 +118,6 @@ func (c *ConnectClient) RunOniOS(
networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager,
stateFilePath string,
cacheDir string,
logFilePath string,
) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
debug.SetGCPercent(5)
@@ -129,9 +127,8 @@ func (c *ConnectClient) RunOniOS(
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
StateFilePath: stateFilePath,
TempDir: cacheDir,
}
return c.run(mobileDependency, nil, logFilePath)
return c.run(mobileDependency, nil, "")
}
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {

View File

@@ -250,7 +250,6 @@ type BundleGenerator struct {
syncResponse *mgmProto.SyncResponse
logPath string
tempDir string
statePath string
cpuProfile []byte
capturePath string
refreshStatus func() // Optional callback to refresh status before bundle generation
@@ -277,7 +276,6 @@ type GeneratorDependencies struct {
SyncResponse *mgmProto.SyncResponse
LogPath string
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
StatePath string // Path to the state file. If empty, the ServiceManager default path is used.
CPUProfile []byte
CapturePath string
RefreshStatus func()
@@ -301,7 +299,6 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
syncResponse: deps.SyncResponse,
logPath: deps.LogPath,
tempDir: deps.TempDir,
statePath: deps.StatePath,
cpuProfile: deps.CPUProfile,
capturePath: deps.CapturePath,
refreshStatus: deps.RefreshStatus,
@@ -853,11 +850,8 @@ func (g *BundleGenerator) maskSecrets() {
}
func (g *BundleGenerator) addStateFile() error {
path := g.statePath
if path == "" {
sm := profilemanager.NewServiceManager("")
path = sm.GetStatePath()
}
sm := profilemanager.NewServiceManager("")
path := sm.GetStatePath()
if path == "" {
return nil
}

View File

@@ -1,36 +0,0 @@
//go:build ios
package debug
import (
"path/filepath"
log "github.com/sirupsen/logrus"
)
// swiftLogFile is the Swift app log written by the iOS app into the same log
// directory as the Go client log, so it can be collected into the bundle.
const swiftLogFile = "swift-log.log"
// addPlatformLog collects logs for the iOS debug bundle. iOS has no logcat or
// systemd journal, so we rely on file-based logs. addLogfile handles the Go
// client log (logPath) with rotation, the stderr/stdout companions and
// anonymization. The iOS app writes its own Swift log into the same directory,
// so we add it alongside the Go log.
func (g *BundleGenerator) addPlatformLog() error {
if err := g.addLogfile(); err != nil {
return err
}
if g.logPath == "" {
return nil
}
swiftLogPath := filepath.Join(filepath.Dir(g.logPath), swiftLogFile)
if err := g.addSingleLogfile(swiftLogPath, swiftLogFile); err != nil {
// The Swift log is best-effort: the app may not have written it yet.
log.Warnf("failed to add %s to debug bundle: %v", swiftLogFile, err)
}
return nil
}

View File

@@ -1,4 +1,4 @@
//go:build !android && !ios
//go:build !android
package debug

View File

@@ -843,7 +843,6 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
"PreSharedKey": "sensitive: WireGuard pre-shared key",
"SSHKey": "sensitive: SSH private key",
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
"Name": "non-config: profile name is not needed for debug purposes",
"policy": "non-config: in-memory MDM policy snapshot, surfaced via Config.Policy() / GetConfigResponse.MDMManagedFields",
}

View File

@@ -53,6 +53,7 @@ import (
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/syncstore"
"github.com/netbirdio/netbird/client/internal/updater"
@@ -239,7 +240,7 @@ type Engine struct {
syncStore syncstore.Store
syncStoreDir string
flowManager nftypes.FlowManager
flowManager nftypes.FlowManager
// auto-update
updateManager *updater.Manager
@@ -530,10 +531,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return fmt.Errorf("create wg interface: %w", err)
}
if filteredDevice := e.wgInterface.GetDevice(); filteredDevice != nil {
filteredDevice.SetPanicHandler(e.triggerClientRestart)
}
if err := e.createFirewall(); err != nil {
e.close()
return err
@@ -1714,13 +1711,6 @@ func (e *Engine) receiveSignalEvents() {
return e.ctx.Err()
}
// Self-addressed heartbeat: the signal client's receive watchdog
// round-trips this through the server to confirm the receive stream
// is delivering. Liveness is already recorded before this handler.
if msg.GetBody().GetType() == sProto.Body_HEARTBEAT {
return nil
}
conn, ok := e.peerStore.PeerConn(msg.Key)
if !ok {
return fmt.Errorf("wrongly addressed message %s", msg.Key)
@@ -1919,6 +1909,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
WGPrivKey: e.config.WgPrivateKey.String(),
MTU: e.config.MTU,
TransportNet: transportNet,
FilterFn: e.addrViaRoutes,
DisableDNS: e.config.DisableDNS,
}
@@ -2166,6 +2157,21 @@ func (e *Engine) startNetworkMonitor() {
}()
}
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
var vpnRoutes []netip.Prefix
for _, routes := range e.routeManager.GetClientRoutes() {
if len(routes) > 0 && routes[0] != nil {
vpnRoutes = append(vpnRoutes, routes[0].Network)
}
}
if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn {
return true, prefix, nil
}
return false, netip.Prefix{}, nil
}
func (e *Engine) stopDNSServer() {
if e.dnsServer == nil {
return

View File

@@ -1024,17 +1024,14 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
return d.relayStates
}
// extend the list of stun, turn servers with the relay server connections
// extend the list of stun, turn servers with relay address
relayStates := slices.Clone(d.relayStates)
states := d.relayMgr.RelayStates()
if len(states) == 0 {
// no relay connection tracked yet; surface configured servers as
// unavailable with the real reconnect error when known
err := relayClient.ErrRelayClientNotConnected
if connErr := d.relayMgr.RelayConnectError(); connErr != nil {
err = connErr
}
// if the server connection is not established then we will use the general address
// in case of connection we will use the instance specific address
instanceAddr, _, err := d.relayMgr.RelayInstanceAddress()
if err != nil {
// TODO add their status
for _, r := range d.relayMgr.ServerURLs() {
relayStates = append(relayStates, relay.ProbeResult{
URI: r,
@@ -1044,14 +1041,10 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
return relayStates
}
for _, rs := range states {
relayStates = append(relayStates, relay.ProbeResult{
URI: rs.URL,
Err: rs.Err,
Transport: rs.Transport,
})
relayState := relay.ProbeResult{
URI: instanceAddr,
}
return relayStates
return append(relayStates, relayState)
}
func (d *Status) ForwardingRules() []firewall.ForwardRule {
@@ -1412,7 +1405,6 @@ func (fs FullStatus) ToProto() *proto.FullStatus {
pbRelayState := &proto.RelayState{
URI: relayState.URI,
Available: relayState.Err == nil,
Transport: relayState.Transport,
}
if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error()

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
"net/netip"
"strconv"
"sync"
"time"
@@ -164,6 +165,10 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA
return
}
if candidateViaRoutes(candidate, haRoutes) {
return
}
if err := w.agent.AddRemoteCandidate(candidate); err != nil {
w.log.Errorf("error while handling remote candidate")
return
@@ -584,6 +589,34 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
return ec, nil
}
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
addr, err := netip.ParseAddr(candidate.Address())
if err != nil {
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
return false
}
var routePrefixes []netip.Prefix
for _, routes := range clientRoutes {
if len(routes) > 0 && routes[0] != nil {
routePrefixes = append(routePrefixes, routes[0].Network)
}
}
for _, prefix := range routePrefixes {
// default route is handled by route exclusion / ip rules
if prefix.Bits() == 0 {
continue
}
if prefix.Contains(addr) {
log.Debugf("Ignoring candidate [%s], its address is part of routed network %s", candidate.String(), prefix)
return true
}
}
return false
}
func isRelayCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay
}

View File

@@ -58,10 +58,6 @@ var DefaultInterfaceBlacklist = []string{
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
}
// loadMDMPolicy is the package-level indirection used by apply() to read the
// active MDM policy. Tests override this to inject a fake policy.
var loadMDMPolicy = mdm.LoadPolicy
// ConfigInput carries configuration changes to the client
type ConfigInput struct {
ManagementURL string
@@ -108,10 +104,6 @@ type ConfigInput struct {
// Config Configuration type
type Config struct {
// Name is the human-readable profile name shown in CLI/UI listings.
// It is independent of the profile's on-disk filename (which is the ID).
Name string
// Wireguard private key of local peer
PrivateKey string
PreSharedKey string
@@ -184,14 +176,27 @@ type Config struct {
MTU uint16
// policy is the MDM policy that produced the currently-set values for
// any MDM-enforced fields. Set by applyMDMPolicy at the tail of apply()
// and reset on every apply() invocation. Never persisted to disk.
// Callers query enforcement state via Policy() and the mdm.Policy API
// (HasKey, ManagedKeys, IsEmpty).
// policy is the MDM policy that produced the currently-set values
// for any MDM-enforced fields. Set by ApplyMDMPolicy on every
// invocation. Never persisted to disk. Callers query enforcement
// state via Policy() and the mdm.Policy API (HasKey, ManagedKeys,
// IsEmpty).
policy *mdm.Policy `json:"-"`
}
// ApplyMDMPolicy overlays the supplied MDM Policy on top of the
// currently resolved Config values. Idempotent — pass an empty Policy
// to clear any prior overlay. The lifecycle owner (Server.getConfig
// on desktop, the Client.Run path on mobile) calls this with
// loader.Load() once the per-process Loader is known; the Config
// itself holds no reference to the Loader.
func (config *Config) ApplyMDMPolicy(policy *mdm.Policy) {
if config == nil {
return
}
config.applyMDMPolicy(policy)
}
// Policy returns the MDM policy applied to this Config. Returns a non-nil
// empty Policy when MDM enforcement is inactive; callers can always invoke
// HasKey / ManagedKeys / IsEmpty without a nil check.
@@ -274,16 +279,6 @@ func createNewConfig(input ConfigInput) (*Config, error) {
}
func (config *Config) apply(input ConfigInput) (updated bool, err error) {
if config.Name != "" {
sanitized, err := sanitizeDisplayName(config.Name)
if err != nil {
return false, fmt.Errorf("invalid profile name: %w", err)
}
if sanitized != config.Name {
config.Name = sanitized
updated = true
}
}
if config.ManagementURL == nil {
log.Infof("using default Management URL %s", DefaultManagementURL)
config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL)
@@ -648,9 +643,11 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
// MDM is the last override layer: any key present in the policy
// supersedes defaults, on-disk config, env vars and CLI input.
config.applyMDMPolicy(loadMDMPolicy())
// Initialise the MDM overlay to "no enforcement" so Config.Policy()
// never returns a stale or nil policy on a freshly applied Config.
// Lifecycle owners that want to enforce a real MDM policy invoke
// Config.ApplyMDMPolicy(loader.Load()) after this returns.
config.applyMDMPolicy(mdm.NewPolicy(nil))
return updated, nil
}

View File

@@ -10,24 +10,58 @@ import (
"github.com/netbirdio/netbird/client/mdm"
)
// withMDMPolicy temporarily overrides the package-level loadMDMPolicy hook so
// apply() observes the supplied Policy. The original loader is restored at
// test cleanup.
func withMDMPolicy(t *testing.T, policy *mdm.Policy) {
// fakeFetcher implements mdm.PolicyFetcher returning a pre-set policy
// map. Test helper used to construct a Loader without touching the OS
// or any package-level state.
type fakeFetcher struct{ values map[string]any }
func (f *fakeFetcher) Fetch() map[string]any { return f.values }
// loaderFor builds an mdm.Loader whose loadPlatform returns the
// supplied Policy's underlying values.
func loaderFor(policy *mdm.Policy) *mdm.Loader {
if policy == nil || policy.IsEmpty() {
return mdm.NewLoader(&fakeFetcher{values: nil})
}
values := make(map[string]any)
for _, k := range policy.ManagedKeys() {
if v, ok := policy.GetString(k); ok {
values[k] = v
continue
}
if v, ok := policy.GetBool(k); ok {
values[k] = v
continue
}
if v, ok := policy.GetInt(k); ok {
values[k] = v
continue
}
if v, ok := policy.GetStringSlice(k); ok {
values[k] = v
}
}
return mdm.NewLoader(&fakeFetcher{values: values})
}
// configWithMDM is the test convenience that builds a Config via
// UpdateOrCreateConfig and overlays the supplied MDM policy on top —
// mirrors the production pattern (Server.getConfig / Client.applyMDMOverlay)
// where the Loader lives outside Config and the apply step is driven
// by the lifecycle owner.
func configWithMDM(t *testing.T, input ConfigInput, policy *mdm.Policy) *Config {
t.Helper()
prev := loadMDMPolicy
loadMDMPolicy = func() *mdm.Policy { return policy }
t.Cleanup(func() { loadMDMPolicy = prev })
cfg, err := UpdateOrCreateConfig(input)
require.NoError(t, err)
require.NotNil(t, cfg)
cfg.ApplyMDMPolicy(loaderFor(policy).Load())
return cfg
}
func TestApply_MDMEmpty_NoEnforcement(t *testing.T) {
withMDMPolicy(t, mdm.NewPolicy(nil))
cfg, err := UpdateOrCreateConfig(ConfigInput{
cfg := configWithMDM(t, ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
})
require.NoError(t, err)
require.NotNil(t, cfg)
}, mdm.NewPolicy(nil))
assert.True(t, cfg.Policy().IsEmpty(), "no MDM source ⇒ empty Policy")
assert.False(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
@@ -39,18 +73,15 @@ func TestApply_MDMEmpty_NoEnforcement(t *testing.T) {
func TestApply_MDMOnly_OverridesDefaults(t *testing.T) {
const mdmURL = "https://corp.mdm.example.com:443"
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
cfg := configWithMDM(t, ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
}, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: mdmURL,
mdm.KeyDisableClientRoutes: true,
mdm.KeyBlockInbound: true,
}))
cfg, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
})
require.NoError(t, err)
require.NotNil(t, cfg)
assert.Equal(t, mdmURL, cfg.ManagementURL.String())
assert.True(t, cfg.DisableClientRoutes)
assert.True(t, cfg.BlockInbound)
@@ -65,16 +96,12 @@ func TestApply_MDMBeatsCLIInput(t *testing.T) {
const mdmURL = "https://mdm.example.com:443"
const cliURL = "https://cli.example.com:443"
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: mdmURL,
}))
cfg, err := UpdateOrCreateConfig(ConfigInput{
cfg := configWithMDM(t, ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
ManagementURL: cliURL,
})
require.NoError(t, err)
require.NotNil(t, cfg)
}, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: mdmURL,
}))
// MDM wins over CLI-supplied management URL.
assert.Equal(t, mdmURL, cfg.ManagementURL.String())
@@ -82,16 +109,12 @@ func TestApply_MDMBeatsCLIInput(t *testing.T) {
}
func TestApply_MDMInvalidURL_KeepsPreviousValue(t *testing.T) {
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
cfg := configWithMDM(t, ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
}, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: "not-a-url",
}))
cfg, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
})
require.NoError(t, err)
require.NotNil(t, cfg)
// Invalid MDM URL is logged and skipped: default URL stays in place
// to keep the client functional.
assert.Equal(t, DefaultManagementURL, cfg.ManagementURL.String())
@@ -106,24 +129,20 @@ func TestApply_MDMBoolKeysOverrideOnDiskValue(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "config.json")
// Seed without MDM.
withMDMPolicy(t, mdm.NewPolicy(nil))
_, err := UpdateOrCreateConfig(ConfigInput{
configWithMDM(t, ConfigInput{
ConfigPath: tmp,
DisableClientRoutes: boolPtr(false),
RosenpassEnabled: boolPtr(false),
})
require.NoError(t, err)
}, mdm.NewPolicy(nil))
// Now enable MDM enforcement for these keys.
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
cfg := configWithMDM(t, ConfigInput{
ConfigPath: tmp,
}, mdm.NewPolicy(map[string]any{
mdm.KeyDisableClientRoutes: true,
mdm.KeyRosenpassEnabled: true,
}))
cfg, err := UpdateOrCreateConfig(ConfigInput{ConfigPath: tmp})
require.NoError(t, err)
require.NotNil(t, cfg)
assert.True(t, cfg.DisableClientRoutes, "MDM override should flip on-disk false to true")
assert.True(t, cfg.RosenpassEnabled)
assert.True(t, cfg.Policy().HasKey(mdm.KeyDisableClientRoutes))
@@ -133,16 +152,12 @@ func TestApply_MDMBoolKeysOverrideOnDiskValue(t *testing.T) {
func TestApply_MDMPreSharedKeyRedactionSentinelRejected(t *testing.T) {
const maskSentinel = "**********"
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
cfg := configWithMDM(t, ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
}, mdm.NewPolicy(map[string]any{
mdm.KeyPreSharedKey: maskSentinel,
}))
cfg, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
})
require.NoError(t, err)
require.NotNil(t, cfg)
// Mask sentinel must not be persisted as the actual PSK.
assert.NotEqual(t, maskSentinel, cfg.PreSharedKey)
// Key still marked managed so user writes are still rejected.

View File

@@ -1,118 +0,0 @@
package profilemanager
import (
"crypto/rand"
"encoding/hex"
"fmt"
"path/filepath"
"strings"
"unicode"
"unicode/utf8"
)
const (
// profileIDByteLen is the number of random bytes generated for a new
// profile ID. The resulting hex string is twice this length.
profileIDByteLen = 16
// shortIDLen is the number of leading characters of an ID we render in
// list output. Profiles per device are few, so 8 chars is collision-safe
// in practice and easy to type as a prefix.
shortIDLen = 8
// maxProfileNameLen caps the human-readable profile name to keep table
// output legible and prevent denial-of-service via huge JSON fields.
maxProfileNameLen = 128
// maxProfileIDLen bounds the on-disk filename we'll accept. New
// IDs are 32 hex chars, legacy stems are sanitized profile names. The
// cap is generous enough to cover both without permitting absurdly
// long filenames.
maxProfileIDLen = 64
)
type ID string
// generateProfileID returns a new random hex ID for a profile file.
func generateProfileID() (ID, error) {
buf := make([]byte, profileIDByteLen)
if _, err := rand.Read(buf); err != nil {
return "", fmt.Errorf("read random bytes: %w", err)
}
return ID(hex.EncodeToString(buf)), nil
}
// IsValidProfileFilenameStem reports whether id is safe to use as the stem
// of a profile JSON filename.
func IsValidProfileFilenameStem(id ID) bool {
s := id.String()
if s == "" || len(s) > maxProfileIDLen {
return false
}
if s == defaultProfileName {
return true
}
if strings.ContainsAny(s, `/\`) || strings.Contains(s, "..") {
return false
}
// filepath.Base catches any leftover separators on platforms with
// exotic path conventions.
if filepath.Base(s) != s {
return false
}
for _, r := range s {
if !(unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '-') {
return false
}
}
return true
}
// sanitizeDisplayName normalizes a user-supplied profile display name for
// storage. It strips ASCII control characters, rejects invalid UTF-8, and
// caps the length. Emojis, spaces, punctuation, and non-ASCII letters are
// preserved. Returns an error if nothing usable remains.
func sanitizeDisplayName(name string) (string, error) {
if !utf8.ValidString(name) {
return "", fmt.Errorf("name is not valid UTF-8")
}
name = StripCtrlChars(name)
name = strings.TrimSpace(name)
if name == "" {
return "", fmt.Errorf("name is empty after sanitization")
}
if utf8.RuneCountInString(name) > maxProfileNameLen {
return "", fmt.Errorf("name exceeds %d characters", maxProfileNameLen)
}
return name, nil
}
// StripCtrlChars control characters from a name before printing it.
func StripCtrlChars(name string) string {
var b strings.Builder
b.Grow(len(name))
for _, r := range name {
// Skip C0 controls and DEL, plus C1 controls (0x800x9F).
if r < 0x20 || r == 0x7F || (r >= 0x80 && r <= 0x9F) {
continue
}
b.WriteRune(r)
}
return b.String()
}
// ShortID truncates an ID for display.
func (id ID) ShortID() string {
if id == DefaultProfileName {
return DefaultProfileName
}
runes := []rune(id)
if len(runes) <= shortIDLen {
return id.String()
}
return string(runes[:shortIDLen])
}
func (id ID) String() string {
return string(id)
}

View File

@@ -19,41 +19,19 @@ const (
)
type Profile struct {
// ID is the on-disk filename stem (without .json). For new profiles
// it is a 32-char hex string; legacy profiles created before the
// ID-keyed layout keep their original name as their ID. The reserved
// value "default" identifies the special default profile.
ID ID
// Name is the human-readable display name. Falls back to ID when the
// underlying JSON has no "name" field set.
Name string
// Path is the absolute path to the profile JSON. Populated by the
// loader so callers do not have to reconstruct it from ID + dir.
Path string
Name string
IsActive bool
}
func (p *Profile) FilePath() (string, error) {
if p.Path != "" {
return p.Path, nil
if p.Name == "" {
return "", fmt.Errorf("active profile name is empty")
}
id := p.ID
if id == "" {
id = ID(p.Name)
}
if id == "" {
return "", fmt.Errorf("profile ID is empty")
}
if id == defaultProfileName {
if p.Name == defaultProfileName {
return DefaultConfigPath, nil
}
if !IsValidProfileFilenameStem(id) {
return "", fmt.Errorf("invalid profile ID: %q", id)
}
username, err := user.Current()
if err != nil {
return "", fmt.Errorf("failed to get current user: %w", err)
@@ -64,13 +42,10 @@ func (p *Profile) FilePath() (string, error) {
return "", fmt.Errorf("failed to get config directory for user %s: %w", username.Username, err)
}
return filepath.Join(configDir, id.String()+".json"), nil
return filepath.Join(configDir, p.Name+".json"), nil
}
func (p *Profile) IsDefault() bool {
if p.ID != "" {
return p.ID == defaultProfileName
}
return p.Name == defaultProfileName
}
@@ -82,24 +57,18 @@ func NewProfileManager() *ProfileManager {
return &ProfileManager{}
}
// GetActiveProfile returns the active profile as recorded in the local
// user state file. Only ID is populated.
func (pm *ProfileManager) GetActiveProfile() (*Profile, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
id := pm.getActiveProfileState()
return &Profile{ID: id}, nil
prof := pm.getActiveProfileState()
return &Profile{Name: prof}, nil
}
// SwitchProfile records the given profile ID as active in the local user
// state file.
func (pm *ProfileManager) SwitchProfile(id ID) error {
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
return fmt.Errorf("invalid profile ID: %q", id)
}
func (pm *ProfileManager) SwitchProfile(profileName string) error {
profileName = sanitizeProfileName(profileName)
if err := pm.setActiveProfileState(id); err != nil {
if err := pm.setActiveProfileState(profileName); err != nil {
return fmt.Errorf("failed to switch profile: %w", err)
}
return nil
@@ -116,7 +85,7 @@ func sanitizeProfileName(name string) string {
}, name)
}
func (pm *ProfileManager) getActiveProfileState() ID {
func (pm *ProfileManager) getActiveProfileState() string {
configDir, err := getConfigDir()
if err != nil {
@@ -144,10 +113,10 @@ func (pm *ProfileManager) getActiveProfileState() ID {
return defaultProfileName
}
return ID(profileName)
return profileName
}
func (pm *ProfileManager) setActiveProfileState(id ID) error {
func (pm *ProfileManager) setActiveProfileState(profileName string) error {
configDir, err := getConfigDir()
if err != nil {
@@ -156,7 +125,7 @@ func (pm *ProfileManager) setActiveProfileState(id ID) error {
statePath := filepath.Join(configDir, activeProfileStateFilename)
err = os.WriteFile(statePath, []byte(id), 0600)
err = os.WriteFile(statePath, []byte(profileName), 0600)
if err != nil {
return fmt.Errorf("failed to write active profile state: %w", err)
}
@@ -173,7 +142,7 @@ func GetLoginHint() string {
return ""
}
profileState, err := pm.GetProfileState(activeProf.ID)
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
return ""

View File

@@ -50,14 +50,14 @@ func TestServiceManager_CreateAndGetDefaultProfile(t *testing.T) {
state, err := sm.GetActiveProfileState()
assert.NoError(t, err)
assert.Equal(t, defaultProfileName, state.ID.String()) // No active profile state yet
assert.Equal(t, state.Name, defaultProfileName) // No active profile state yet
err = sm.SetActiveProfileStateToDefault()
assert.NoError(t, err)
active, err := sm.GetActiveProfileState()
assert.NoError(t, err)
assert.Equal(t, "default", active.ID.String())
assert.Equal(t, "default", active.Name)
})
})
}
@@ -92,14 +92,14 @@ func TestServiceManager_SetActiveProfileState(t *testing.T) {
currUser, err := user.Current()
assert.NoError(t, err)
sm := &ServiceManager{}
state := &ActiveProfileState{ID: "foo", Username: currUser.Username}
state := &ActiveProfileState{Name: "foo", Username: currUser.Username}
err = sm.SetActiveProfileState(state)
assert.NoError(t, err)
// Should error on nil or incomplete state
err = sm.SetActiveProfileState(nil)
assert.Error(t, err)
err = sm.SetActiveProfileState(&ActiveProfileState{ID: "", Username: ""})
err = sm.SetActiveProfileState(&ActiveProfileState{Name: "", Username: ""})
assert.Error(t, err)
})
})

View File

@@ -2,7 +2,6 @@ package profilemanager
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
@@ -24,43 +23,12 @@ var (
DefaultConfigPathDir = ""
DefaultConfigPath = ""
ActiveProfileStatePath = ""
)
var (
ErrorOldDefaultConfigNotFound = errors.New("old default config not found")
)
// ErrAmbiguousHandle is returned when a profile handle (ID prefix or name)
// matches more than one profile. Callers can render Candidates to help the
// user disambiguate.
type ErrAmbiguousHandle struct {
Handle string
Candidates []Profile
Kind AmbiguityKind
}
// AmbiguityKind describes which matcher produced the ambiguity, so callers
// can tailor the error message.
type AmbiguityKind int
const (
AmbiguityKindIDPrefix AmbiguityKind = iota
AmbiguityKindName
)
// profileMeta is the minimal slice of a profile JSON we need, so we avoid
// reading all fields
type profileMeta struct {
Name string
}
func (e *ErrAmbiguousHandle) Error() string {
switch e.Kind {
case AmbiguityKindIDPrefix:
return fmt.Sprintf("ID prefix %q is ambiguous (matches %d profiles)", e.Handle, len(e.Candidates))
default:
return fmt.Sprintf("name %q is ambiguous (%d profiles share this name)", e.Handle, len(e.Candidates))
}
}
func init() {
DefaultConfigPathDir = "/var/lib/netbird/"
@@ -86,34 +54,25 @@ func init() {
}
type ActiveProfileState struct {
// ID is the on-disk filename stem of the active profile. The JSON tag stays
// as "name" for backwards compatibility with active state files written
// before the ID-based config files. Legacy values were profile names, which
// were also the legacy filename stems, so they still resolve to the correct
// file on disk.
ID ID `json:"name"`
Name string `json:"name"`
Username string `json:"username"`
}
func (a *ActiveProfileState) FilePath() (string, error) {
if a.ID == "" {
return "", fmt.Errorf("active profile ID is empty")
if a.Name == "" {
return "", fmt.Errorf("active profile name is empty")
}
if a.ID == defaultProfileName {
if a.Name == defaultProfileName {
return DefaultConfigPath, nil
}
if !IsValidProfileFilenameStem(a.ID) {
return "", fmt.Errorf("invalid profile ID: %q", a.ID)
}
configDir, err := getConfigDirForUser(a.Username)
if err != nil {
return "", fmt.Errorf("failed to get config directory for user %s: %w", a.Username, err)
}
return filepath.Join(configDir, a.ID.String()+".json"), nil
return filepath.Join(configDir, a.Name+".json"), nil
}
type ServiceManager struct {
@@ -219,7 +178,7 @@ func (s *ServiceManager) GetActiveProfileState() (*ActiveProfileState, error) {
return nil, fmt.Errorf("failed to set active profile to default: %w", err)
}
return &ActiveProfileState{
ID: defaultProfileName,
Name: "default",
Username: "",
}, nil
} else {
@@ -227,12 +186,12 @@ func (s *ServiceManager) GetActiveProfileState() (*ActiveProfileState, error) {
}
}
if activeProfile.ID == "" {
if activeProfile.Name == "" {
if err := s.SetActiveProfileStateToDefault(); err != nil {
return nil, fmt.Errorf("failed to set active profile to default: %w", err)
}
return &ActiveProfileState{
ID: defaultProfileName,
Name: "default",
Username: "",
}, nil
}
@@ -257,29 +216,25 @@ func (s *ServiceManager) setDefaultActiveState() error {
}
func (s *ServiceManager) SetActiveProfileState(a *ActiveProfileState) error {
if a == nil || a.ID == "" {
if a == nil || a.Name == "" {
return errors.New("invalid active profile state")
}
if a.ID != defaultProfileName && a.Username == "" {
return fmt.Errorf("username must be set for non-default profiles, got: %s", a.ID)
}
if a.ID != defaultProfileName && !IsValidProfileFilenameStem(a.ID) {
return fmt.Errorf("invalid profile ID: %q", a.ID)
if a.Name != defaultProfileName && a.Username == "" {
return fmt.Errorf("username must be set for non-default profiles, got: %s", a.Name)
}
if err := util.WriteJsonWithRestrictedPermission(context.Background(), ActiveProfileStatePath, a); err != nil {
return fmt.Errorf("failed to write active profile state: %w", err)
}
log.Infof("active profile set to %s for %s", a.ID, a.Username)
log.Infof("active profile set to %s for %s", a.Name, a.Username)
return nil
}
func (s *ServiceManager) SetActiveProfileStateToDefault() error {
return s.SetActiveProfileState(&ActiveProfileState{
ID: defaultProfileName,
Name: "default",
Username: "",
})
}
@@ -288,117 +243,57 @@ func (s *ServiceManager) DefaultProfilePath() string {
return DefaultConfigPath
}
// AddProfile creates a new profile with a generated ID. The user-supplied
// displayName is stored inside the JSON's name field, the on-disk filename
// uses the generated ID.
//
// The returned Profile carries the freshly-generated ID so callers can
// show it to the user (and so the gRPC AddProfileResponse can include
// it).
func (s *ServiceManager) AddProfile(displayName, username string) (*Profile, error) {
func (s *ServiceManager) AddProfile(profileName, username string) error {
configDir, err := s.getConfigDir(username)
if err != nil {
return nil, fmt.Errorf("failed to get config directory: %w", err)
return fmt.Errorf("failed to get config directory: %w", err)
}
displayName, err = sanitizeDisplayName(displayName)
profileName = sanitizeProfileName(profileName)
if profileName == defaultProfileName {
return fmt.Errorf("cannot create profile with reserved name: %s", defaultProfileName)
}
profPath := filepath.Join(configDir, profileName+".json")
profileExists, err := fileExists(profPath)
if err != nil {
return nil, fmt.Errorf("invalid profile name: %w", err)
return fmt.Errorf("failed to check if profile exists: %w", err)
}
if profileExists {
return ErrProfileAlreadyExists
}
id, err := generateProfileID()
if err != nil {
return nil, fmt.Errorf("generate profile id: %w", err)
}
profPath := filepath.Join(configDir, id.String()+".json")
cfg, err := createNewConfig(ConfigInput{ConfigPath: profPath})
if err != nil {
return nil, fmt.Errorf("failed to create new config: %w", err)
}
cfg.Name = displayName
if err := util.WriteJson(context.Background(), profPath, cfg); err != nil {
return nil, fmt.Errorf("failed to write profile config: %w", err)
return fmt.Errorf("failed to create new config: %w", err)
}
return &Profile{
ID: id,
Name: displayName,
Path: profPath,
}, nil
}
func (s *ServiceManager) RenameProfile(id ID, username string, newName string) error {
displayName, err := sanitizeDisplayName(newName)
err = util.WriteJson(context.Background(), profPath, cfg)
if err != nil {
return fmt.Errorf("invalid profile name: %w", err)
return fmt.Errorf("failed to write profile config: %w", err)
}
if !IsValidProfileFilenameStem(id) {
return fmt.Errorf("invalid profile ID: %q", id)
}
profiles, err := s.loadAllProfiles(username)
if err != nil {
return fmt.Errorf("load profiles: %w", err)
}
var target *Profile
for i := range profiles {
if profiles[i].ID == id {
target = &profiles[i]
break
}
}
if target == nil {
return ErrProfileNotFound
}
data, err := os.ReadFile(target.Path)
if err != nil {
return err
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return err
}
cfg.Name = displayName
if err := util.WriteJson(context.Background(), target.Path, cfg); err != nil {
return fmt.Errorf("failed to write profile name: %w", err)
}
return nil
}
// RemoveProfile deletes the profile identified by id. Callers must have
// already resolved any user-supplied handle to a concrete ID via
// ResolveProfile.
func (s *ServiceManager) RemoveProfile(id ID, username string) error {
if id == defaultProfileName {
defaultName := readProfileName(DefaultConfigPath)
if defaultName == "" {
defaultName = defaultProfileName
}
return fmt.Errorf("cannot remove default profile with name: %s", defaultName)
}
if !IsValidProfileFilenameStem(id) {
return fmt.Errorf("invalid profile ID: %q", id)
}
profiles, err := s.loadAllProfiles(username)
func (s *ServiceManager) RemoveProfile(profileName, username string) error {
configDir, err := s.getConfigDir(username)
if err != nil {
return fmt.Errorf("load profiles: %w", err)
return fmt.Errorf("failed to get config directory: %w", err)
}
var target *Profile
for i := range profiles {
if profiles[i].ID == id {
target = &profiles[i]
break
}
profileName = sanitizeProfileName(profileName)
if profileName == defaultProfileName {
return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName)
}
if target == nil {
profPath := filepath.Join(configDir, profileName+".json")
profileExists, err := fileExists(profPath)
if err != nil {
return fmt.Errorf("failed to check if profile exists: %w", err)
}
if !profileExists {
return ErrProfileNotFound
}
@@ -406,26 +301,57 @@ func (s *ServiceManager) RemoveProfile(id ID, username string) error {
if err != nil && !errors.Is(err, ErrNoActiveProfile) {
return fmt.Errorf("failed to get active profile: %w", err)
}
if activeProf != nil && activeProf.ID == id {
return fmt.Errorf("cannot remove active profile: %s", id)
if activeProf != nil && activeProf.Name == profileName {
return fmt.Errorf("cannot remove active profile: %s", profileName)
}
if err := util.RemoveJson(target.Path); err != nil {
err = util.RemoveJson(profPath)
if err != nil {
return fmt.Errorf("failed to remove profile config: %w", err)
}
stateFile := filepath.Join(filepath.Dir(target.Path), id.String()+".state.json")
if err := os.Remove(stateFile); err != nil && !os.IsNotExist(err) {
log.Warnf("failed to remove profile state file %s: %v", stateFile, err)
}
return nil
}
// ListProfiles returns every profile for the given user, including the
// default profile, with IsActive flags set.
func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) {
return s.loadAllProfiles(username)
configDir, err := s.getConfigDir(username)
if err != nil {
return nil, fmt.Errorf("failed to get config directory: %w", err)
}
files, err := util.ListFiles(configDir, "*.json")
if err != nil {
return nil, fmt.Errorf("failed to list profile files: %w", err)
}
var filtered []string
for _, file := range files {
if strings.HasSuffix(file, "state.json") {
continue // skip state files
}
filtered = append(filtered, file)
}
sort.Strings(filtered)
var activeProfName string
activeProf, err := s.GetActiveProfileState()
if err == nil {
activeProfName = activeProf.Name
}
var profiles []Profile
// add default profile always
profiles = append(profiles, Profile{Name: defaultProfileName, IsActive: activeProfName == "" || activeProfName == defaultProfileName})
for _, file := range filtered {
profileName := strings.TrimSuffix(filepath.Base(file), ".json")
var isActive bool
if activeProfName != "" && activeProfName == profileName {
isActive = true
}
profiles = append(profiles, Profile{Name: profileName, IsActive: isActive})
}
return profiles, nil
}
// GetStatePath returns the path to the state file based on the operating system
@@ -443,12 +369,7 @@ func (s *ServiceManager) GetStatePath() string {
return defaultStatePath
}
if activeProf.ID == defaultProfileName {
return defaultStatePath
}
if !IsValidProfileFilenameStem(activeProf.ID) {
log.Warnf("invalid active profile ID %q, using default state path", activeProf.ID)
if activeProf.Name == defaultProfileName {
return defaultStatePath
}
@@ -458,7 +379,7 @@ func (s *ServiceManager) GetStatePath() string {
return defaultStatePath
}
return filepath.Join(configDir, activeProf.ID.String()+".state.json")
return filepath.Join(configDir, activeProf.Name+".state.json")
}
// getConfigDir returns the profiles directory, using profilesDir if set, otherwise getConfigDirForUser
@@ -469,169 +390,3 @@ func (s *ServiceManager) getConfigDir(username string) (string, error) {
return getConfigDirForUser(username)
}
// loadAllProfiles returns every profile visible to the daemon for the
// given user, including the default profile. The returned slice is sorted
// by ID for a stable display order.
//
// Each Profile is fully populated: ID is the filename stem, Name comes
// from the JSON's "name" field (falling back to the filename stem when absent)
// and Path is built from a basename read off disk.
func (s *ServiceManager) loadAllProfiles(username string) ([]Profile, error) {
activeID, activeIsDefault := s.activeProfileID()
defaultName := readProfileName(DefaultConfigPath)
if defaultName == "" {
defaultName = defaultProfileName
}
profiles := []Profile{{
ID: defaultProfileName,
Name: defaultName,
Path: DefaultConfigPath,
IsActive: activeIsDefault,
}}
configDir, err := s.getConfigDir(username)
if err != nil {
return nil, fmt.Errorf("get config directory: %w", err)
}
entries, err := os.ReadDir(configDir)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return profiles, nil
}
return nil, fmt.Errorf("read profile directory: %w", err)
}
var fileProfiles []Profile
for _, entry := range entries {
if entry.IsDir() {
continue
}
base := entry.Name()
if !strings.HasSuffix(base, ".json") {
continue
}
if strings.HasSuffix(base, ".state.json") {
continue
}
stem := ID(strings.TrimSuffix(base, ".json"))
if stem == defaultProfileName {
// default lives at the top-level config dir, not under /<user>
continue
}
if !IsValidProfileFilenameStem(ID(stem)) {
continue
}
path := filepath.Join(configDir, base)
name := readProfileName(path)
if name == "" {
name = stem.String()
}
fileProfiles = append(fileProfiles, Profile{
ID: stem,
Name: name,
Path: path,
IsActive: stem == ID(activeID),
})
}
sort.Slice(fileProfiles, func(i, j int) bool {
if fileProfiles[i].Name != fileProfiles[j].Name {
return fileProfiles[i].Name < fileProfiles[j].Name
}
// Sort tie-break on ID so duplicate names always render in the same order.
return fileProfiles[i].ID < fileProfiles[j].ID
})
profiles = append(profiles, fileProfiles...)
return profiles, nil
}
// readProfileName parses just the "name" field from the profile Json.
func readProfileName(path string) string {
data, err := os.ReadFile(path)
if err != nil {
return ""
}
var meta profileMeta
if err := json.Unmarshal(data, &meta); err != nil {
return ""
}
return meta.Name
}
// activeProfileID returns the currently-active profile's ID. The second
// return value is true when the active profile is the default one.
func (s *ServiceManager) activeProfileID() (ID, bool) {
state, err := s.GetActiveProfileState()
if err != nil || state == nil {
return defaultProfileName, true
}
if state.ID == "" || state.ID == defaultProfileName {
return defaultProfileName, true
}
return state.ID, false
}
// ResolveProfile turns a user-supplied handle into a Profile. Resolution
// precedence is: exact ID match, then unique exact name, then unique ID
// prefix. Ambiguous matches return *ErrAmbiguousHandle so callers can
// surface the candidates.
func (s *ServiceManager) ResolveProfile(handle, username string) (*Profile, error) {
if handle == "" {
return nil, fmt.Errorf("profile handle is empty")
}
profiles, err := s.loadAllProfiles(username)
if err != nil {
return nil, err
}
for i := range profiles {
if profiles[i].ID == ID(handle) {
return &profiles[i], nil
}
}
var nameMatches []Profile
for i := range profiles {
if profiles[i].Name == handle {
nameMatches = append(nameMatches, profiles[i])
}
}
if len(nameMatches) == 1 {
return &nameMatches[0], nil
}
if len(nameMatches) > 1 {
return nil, &ErrAmbiguousHandle{
Handle: handle,
Candidates: nameMatches,
Kind: AmbiguityKindName,
}
}
// ID prefix match. Skip the default profile so `select d` does not
// accidentally pick it via prefix.
var prefixMatches []Profile
for i := range profiles {
if profiles[i].ID == defaultProfileName {
continue
}
if strings.HasPrefix(profiles[i].ID.String(), handle) {
prefixMatches = append(prefixMatches, profiles[i])
}
}
if len(prefixMatches) == 1 {
return &prefixMatches[0], nil
}
if len(prefixMatches) > 1 {
return nil, &ErrAmbiguousHandle{
Handle: handle,
Candidates: prefixMatches,
Kind: AmbiguityKindIDPrefix,
}
}
return nil, ErrProfileNotFound
}

View File

@@ -1,230 +0,0 @@
package profilemanager
import (
"context"
"errors"
"os"
"os/user"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/util"
)
// withTestSM wires up patched globals + a clean config dir and returns a
// fully initialized ServiceManager plus the username we are scoped to.
func withTestSM(t *testing.T, fn func(sm *ServiceManager, username string)) {
t.Helper()
withTempConfigDir(t, func(configDir string) {
withPatchedGlobals(t, configDir, func() {
u, err := user.Current()
require.NoError(t, err)
sm := &ServiceManager{}
require.NoError(t, sm.CreateDefaultProfile())
fn(sm, u.Username)
})
})
}
func TestServiceProfile_ExactID(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
created, err := sm.AddProfile("work", username)
require.NoError(t, err)
got, err := sm.ResolveProfile(created.ID.String(), username)
require.NoError(t, err)
assert.Equal(t, created.ID, got.ID)
assert.Equal(t, "work", got.Name)
})
}
func TestServiceProfile_IDPrefix(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
created, err := sm.AddProfile("work", username)
require.NoError(t, err)
prefix := created.ID[:4]
got, err := sm.ResolveProfile(prefix.String(), username)
require.NoError(t, err)
assert.Equal(t, created.ID, got.ID)
})
}
func TestServiceProfile_AmbiguousPrefix(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
// Plant two profiles whose IDs share a known prefix by writing
// the files directly, since generated IDs are random.
configDir, err := sm.getConfigDir(username)
require.NoError(t, err)
for _, id := range []string{"abcd1111aaaa", "abcd2222bbbb"} {
path := filepath.Join(configDir, id+".json")
require.NoError(t, util.WriteJson(context.Background(), path, &Config{Name: id}))
}
_, err = sm.ResolveProfile("abcd", username)
var amb *ErrAmbiguousHandle
require.ErrorAs(t, err, &amb)
assert.Equal(t, AmbiguityKindIDPrefix, amb.Kind)
assert.Len(t, amb.Candidates, 2)
})
}
func TestServiceProfile_ExactNameUnique(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
_, err := sm.AddProfile("work", username)
require.NoError(t, err)
got, err := sm.ResolveProfile("work", username)
require.NoError(t, err)
assert.Equal(t, "work", got.Name)
})
}
func TestServiceProfile_AmbiguousName(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
_, err := sm.AddProfile("work", username)
require.NoError(t, err)
_, err = sm.AddProfile("work", username)
require.NoError(t, err)
_, err = sm.ResolveProfile("work", username)
var amb *ErrAmbiguousHandle
require.ErrorAs(t, err, &amb)
assert.Equal(t, AmbiguityKindName, amb.Kind)
assert.Len(t, amb.Candidates, 2)
})
}
func TestServiceProfile_NotFound(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
_, err := sm.ResolveProfile("nope", username)
assert.ErrorIs(t, err, ErrProfileNotFound)
})
}
func TestServiceProfile_DefaultByExactID(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
got, err := sm.ResolveProfile(defaultProfileName, username)
require.NoError(t, err)
assert.Equal(t, defaultProfileName, got.ID.String())
})
}
func TestServiceProfile_LegacyFilenameCoexists(t *testing.T) {
// Legacy profiles stored as <name>.json with no "name" JSON field
// should still be discoverable by name and removable by name.
withTestSM(t, func(sm *ServiceManager, username string) {
configDir, err := sm.getConfigDir(username)
require.NoError(t, err)
path := filepath.Join(configDir, "legacy.json")
require.NoError(t, util.WriteJson(context.Background(), path, &Config{}))
got, err := sm.ResolveProfile("legacy", username)
require.NoError(t, err)
assert.Equal(t, "legacy", got.ID.String())
// Name falls back to the filename stem when JSON omits it.
assert.Equal(t, "legacy", got.Name)
})
}
func TestAddProfile_AllowsDuplicateWithFlag(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
first, err := sm.AddProfile("work", username)
require.NoError(t, err)
second, err := sm.AddProfile("work", username)
require.NoError(t, err)
assert.NotEqual(t, first.ID, second.ID)
assert.Equal(t, "work", second.Name)
})
}
func TestAddProfile_RejectsInvalidNames(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
cases := []string{
"", // empty
"\x00\x01", // only control chars (becomes empty)
strings.Repeat("a", maxProfileNameLen+1), // too long
}
for _, name := range cases {
_, err := sm.AddProfile(name, username)
assert.Error(t, err, "expected error for %q", name)
}
})
}
func TestRemoveProfile_RejectsInvalidID(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
err := sm.RemoveProfile("../escape", username)
assert.Error(t, err)
})
}
func TestSanitizeDisplayName(t *testing.T) {
cases := []struct {
in string
want string
wantErr bool
}{
{"work", "work", false},
{"My Work Account", "My Work Account", false},
{"emoji 🚀 ok", "emoji 🚀 ok", false},
{"漢字テスト", "漢字テスト", false},
{"with\x00null", "withnull", false},
{"\x01\x02\x03", "", true},
{"", "", true},
}
for _, tc := range cases {
got, err := sanitizeDisplayName(tc.in)
if tc.wantErr {
assert.Error(t, err, "case %q", tc.in)
continue
}
assert.NoError(t, err, "case %q", tc.in)
assert.Equal(t, tc.want, got, "case %q", tc.in)
}
}
func TestIsValidProfileFilenameStem(t *testing.T) {
cases := []struct {
in string
want bool
}{
{"default", true},
{"abc123def456", true},
{"legacy-name", true},
{"legacy_name", true},
{"", false},
{"..", false},
{"../etc", false},
{"foo/bar", false},
{`foo\bar`, false},
{"with space", false},
{"with.dot", false},
{strings.Repeat("a", maxProfileIDLen+1), false},
}
for _, tc := range cases {
got := IsValidProfileFilenameStem(ID(tc.in))
assert.Equal(t, tc.want, got, "case %q", tc.in)
}
}
func TestRemoveProfile_DeletesStateFile(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
created, err := sm.AddProfile("work", username)
require.NoError(t, err)
configDir, err := sm.getConfigDir(username)
require.NoError(t, err)
statePath := filepath.Join(configDir, created.ID.String()+".state.json")
require.NoError(t, os.WriteFile(statePath, []byte(`{"email":"a@b"}`), 0600))
require.NoError(t, sm.RemoveProfile(created.ID, username))
_, err = os.Stat(statePath)
assert.True(t, errors.Is(err, os.ErrNotExist), "state file should be removed")
})
}

View File

@@ -13,20 +13,13 @@ type ProfileState struct {
Email string `json:"email"`
}
// GetProfileState reads the per-profile state file keyed by profile ID.
// The state file lives in the user's config directory. Legacy state files
// keyed by the old profile name remain readable.
func (pm *ProfileManager) GetProfileState(id ID) (*ProfileState, error) {
func (pm *ProfileManager) GetProfileState(profileName string) (*ProfileState, error) {
configDir, err := getConfigDir()
if err != nil {
return nil, fmt.Errorf("get config directory: %w", err)
}
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
return nil, fmt.Errorf("invalid profile ID: %q", id)
}
stateFile := filepath.Join(configDir, id.String()+".state.json")
stateFile := filepath.Join(configDir, profileName+".state.json")
stateFileExists, err := fileExists(stateFile)
if err != nil {
return nil, fmt.Errorf("failed to check if profile state file exists: %w", err)
@@ -58,12 +51,7 @@ func (pm *ProfileManager) SetActiveProfileState(state *ProfileState) error {
return fmt.Errorf("get active profile: %w", err)
}
id := activeProf.ID
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
return fmt.Errorf("invalid active profile ID: %q", id)
}
stateFile := filepath.Join(configDir, id.String()+".state.json")
stateFile := filepath.Join(configDir, activeProf.Name+".state.json")
err = util.WriteJsonWithRestrictedPermission(context.Background(), stateFile, state)
if err != nil {
return fmt.Errorf("write profile state: %w", err)

View File

@@ -32,9 +32,6 @@ type ProbeResult struct {
URI string
Err error
Addr string
// Transport is the negotiated relay transport, empty
// for stun/turn probes or when not connected.
Transport string
}
type StunTurnProbe struct {

View File

@@ -22,14 +22,14 @@ type removePeerCall struct {
}
type mockServer struct {
mu sync.Mutex
addCalls []addPeerCall
removed []removePeerCall
nextID rp.PeerID
addErr error
removeErr error
closed bool
ran bool
mu sync.Mutex
addCalls []addPeerCall
removed []removePeerCall
nextID rp.PeerID
addErr error
removeErr error
closed bool
ran bool
}
func (m *mockServer) AddPeer(cfg rp.PeerConfig) (rp.PeerID, error) {
@@ -51,7 +51,7 @@ func (m *mockServer) RemovePeer(id rp.PeerID) error {
return m.removeErr
}
func (m *mockServer) Run() error { m.ran = true; return nil }
func (m *mockServer) Run() error { m.ran = true; return nil }
func (m *mockServer) Close() error { m.closed = true; return nil }
type setPSKCall struct {

View File

@@ -41,3 +41,4 @@ func TestDeterministicSeedKey_TooShortKey_ReturnsError(t *testing.T) {
_, err = DeterministicSeedKey(long, short)
require.Error(t, err)
}

View File

@@ -9,7 +9,6 @@ import (
"net/url"
"runtime"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
@@ -333,8 +332,6 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
}
}
m.notifier.Close()
m.mux.Lock()
defer m.mux.Unlock()
m.clientRoutes = nil
@@ -703,8 +700,6 @@ func resolveURLsToIPs(urls []string) []net.IP {
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
m.mirrorV6ExitPairSelections(clientRoutes)
// An explicit user "deselect all" must not be overridden by management auto-apply.
// Auto-applying an exit node here would call SelectRoutes, which clears the
// deselect-all flag and re-enables every route the user turned off.
@@ -721,24 +716,6 @@ func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HA
m.logExitNodeUpdate(exitNodeInfo)
}
// mirrorV6ExitPairSelections keeps every synthesized "-v6" exit route's selection
// consistent with its v4 base. The v4/v6 exit pair is a single toggle, so the v6
// entry always follows the base: deselecting the v4 exit node also drops its ::/0
// pair, and any stale (orphaned) explicit selection on the v6 entry is reset. This
// runs before selection is read so both collectExitNodeInfo and FilterSelectedExitNodes
// see consistent state, including pairs loaded from persisted selector state.
func (m *DefaultManager) mirrorV6ExitPairSelections(clientRoutes route.HAMap) {
routesByNetID := make(map[route.NetID][]*route.Route, len(clientRoutes))
for haID, routes := range clientRoutes {
routesByNetID[haID.NetID()] = routes
}
for v6ID := range route.V6ExitMergeSet(routesByNetID) {
baseID := route.NetID(strings.TrimSuffix(string(v6ID), route.V6ExitSuffix))
m.routeSelector.SyncPairedSelection(baseID, v6ID)
}
}
type exitNodeInfo struct {
allIDs []route.NetID
selectedByManagement []route.NetID

View File

@@ -1,47 +0,0 @@
package routemanager
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/route"
)
// TestUpdateRouteSelectorFromManagement_MirrorsV6ExitPair reproduces the bug seen
// in netbird-engine.log: persisted selector state has the v4 exit node deselected
// but its synthesized "-v6" pair explicitly selected (orphaned), so the ::/0 route
// leaked onto the tunnel. The management update must mirror the v4 deselect onto the
// v6 pair so FilterSelectedExitNodes drops it.
func TestUpdateRouteSelectorFromManagement_MirrorsV6ExitPair(t *testing.T) {
const (
v4ID = route.NetID("Exit Node (raspberrypi)")
v6ID = route.NetID("Exit Node (raspberrypi)-v6")
)
all := []route.NetID{v4ID, v6ID}
rs := routeselector.NewRouteSelector()
// Orphan the v6 selection: select the pair, then deselect only the v4 base.
require.NoError(t, rs.SelectRoutes([]route.NetID{v4ID, v6ID}, true, all))
require.NoError(t, rs.DeselectRoutes([]route.NetID{v4ID}, all))
require.True(t, rs.IsSelected(v6ID), "precondition: orphaned v6 selection survives v4 deselect")
m := &DefaultManager{routeSelector: rs}
v4Route := &route.Route{NetID: v4ID, Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: v6ID, Network: netip.MustParsePrefix("::/0")}
clientRoutes := route.HAMap{
"Exit Node (raspberrypi)|0.0.0.0/0": {v4Route},
"Exit Node (raspberrypi)-v6|::/0": {v6Route},
}
m.updateRouteSelectorFromManagement(clientRoutes)
assert.False(t, rs.IsSelected(v6ID), "v6 pair must follow the v4 base deselect after the management update")
filtered := rs.FilterSelectedExitNodes(clientRoutes)
assert.Empty(t, filtered, "deselected v4 exit node must not leak its ::/0 pair onto the tunnel")
}

View File

@@ -16,7 +16,7 @@ import (
type Notifier struct {
initialRoutes []*route.Route
currentRoutes []*route.Route
fakeIPRoutes []*route.Route
fakeIPRoutes []*route.Route
listener listener.NetworkChangeListener
listenerMux sync.Mutex
@@ -119,7 +119,3 @@ func (n *Notifier) GetInitialRouteRanges() []string {
sort.Strings(initialStrings)
return initialStrings
}
func (n *Notifier) Close() {
// unused
}

View File

@@ -3,7 +3,6 @@
package notifier
import (
"container/list"
"net/netip"
"slices"
"sort"
@@ -15,26 +14,19 @@ import (
)
type Notifier struct {
mu sync.Mutex
cond *sync.Cond
currentPrefixes []string
listener listener.NetworkChangeListener
queue *list.List
closed bool
listener listener.NetworkChangeListener
listenerMux sync.Mutex
}
func NewNotifier() *Notifier {
n := &Notifier{
queue: list.New(),
}
n.cond = sync.NewCond(&n.mu)
go n.deliverLoop()
return n
return &Notifier{}
}
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.mu.Lock()
defer n.mu.Unlock()
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
n.listener = listener
}
@@ -51,52 +43,32 @@ func (n *Notifier) OnNewRoutes(route.HAMap) {
}
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
newNets := make([]string, 0, len(prefixes))
newNets := make([]string, 0)
for _, prefix := range prefixes {
newNets = append(newNets, prefix.String())
}
sort.Strings(newNets)
n.mu.Lock()
if slices.Equal(n.currentPrefixes, newNets) {
n.mu.Unlock()
return
}
n.currentPrefixes = newNets
routes := strings.Join(n.currentPrefixes, ",")
n.queue.PushBack(routes)
n.cond.Signal()
n.mu.Unlock()
}
func (n *Notifier) Close() {
n.mu.Lock()
n.closed = true
n.cond.Signal()
n.mu.Unlock()
n.currentPrefixes = newNets
n.notify()
}
func (n *Notifier) notify() {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
if n.listener == nil {
return
}
go func(l listener.NetworkChangeListener) {
l.OnNetworkChanged(strings.Join(n.currentPrefixes, ","))
}(n.listener)
}
func (n *Notifier) GetInitialRouteRanges() []string {
return nil
}
func (n *Notifier) deliverLoop() {
for {
n.mu.Lock()
for n.queue.Len() == 0 && !n.closed {
n.cond.Wait()
}
if n.closed && n.queue.Len() == 0 {
n.mu.Unlock()
return
}
routes := n.queue.Remove(n.queue.Front()).(string)
l := n.listener
n.mu.Unlock()
if l != nil {
l.OnNetworkChanged(routes)
}
}
}

View File

@@ -38,7 +38,3 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
func (n *Notifier) GetInitialRouteRanges() []string {
return []string{}
}
func (n *Notifier) Close() {
// unused
}

View File

@@ -121,12 +121,9 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
return Nexthop{}, vars.ErrRouteNotAllowed
}
// BSDs blackhole a /32 added inside a directly-connected subnet; Linux/Windows need it to beat the wt0 route.
switch runtime.GOOS {
case "darwin", "freebsd", "netbsd", "openbsd", "dragonfly":
if isLocal, subnet := r.isPrefixInLocalSubnets(prefix); isLocal {
return Nexthop{}, fmt.Errorf("prefix %s is part of local subnet %s: %w", prefix, subnet, vars.ErrRouteNotAllowed)
}
// Check if the prefix is part of any local subnets
if isLocal, subnet := r.isPrefixInLocalSubnets(prefix); isLocal {
return Nexthop{}, fmt.Errorf("prefix %s is part of local subnet %s: %w", prefix, subnet, vars.ErrRouteNotAllowed)
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"slices"
"strings"
"sync"
"github.com/hashicorp/go-multierror"
@@ -131,33 +132,6 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
return rs.isSelectedLocked(routeID)
}
// SyncPairedSelection forces pairedID's explicit selection state to match baseID's,
// so a synthesized "-v6" exit route always follows its v4 base: selecting or
// deselecting the v4 exit node governs the ::/0 pair, and any stale (orphaned)
// explicit state on the v6 entry is reset. The v4/v6 exit pair is treated as a single
// toggle, so the v6 entry carries no independent selection of its own.
func (rs *RouteSelector) SyncPairedSelection(baseID, pairedID route.NetID) {
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.deselectAll {
return
}
_, baseSelected := rs.selectedRoutes[baseID]
_, baseDeselected := rs.deselectedRoutes[baseID]
delete(rs.selectedRoutes, pairedID)
delete(rs.deselectedRoutes, pairedID)
switch {
case baseSelected:
rs.selectedRoutes[pairedID] = struct{}{}
case baseDeselected:
rs.deselectedRoutes[pairedID] = struct{}{}
}
}
// FilterSelected removes unselected routes from the provided map.
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
rs.mu.RLock()
@@ -177,13 +151,14 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
}
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this route.
// The lookup is literal; v4/v6 exit pairs are kept consistent at write time via SyncPairedSelection,
// so a synthesized "-v6" entry carries the same explicit state as its v4 base.
// Intended for exit-node code paths: a v6 exit-node pair (e.g. "MyExit-v6") with no explicit state of
// its own inherits its v4 base's state, so legacy persisted selections that predate v6 pairing
// transparently apply to the synthesized v6 entry.
func (rs *RouteSelector) HasUserSelectionForRoute(routeID route.NetID) bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
return rs.hasUserSelectionForRouteLocked(routeID)
return rs.hasUserSelectionForRouteLocked(rs.effectiveNetID(routeID))
}
func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap {
@@ -212,6 +187,83 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap
return filtered
}
// effectiveNetID returns the v4 base for a "-v6" exit pair entry that has no explicit
// state of its own, so selections made on the v4 entry govern the v6 entry automatically.
// Only call this from exit-node-specific code paths: applying it to a non-exit "-v6" route
// would make it inherit unrelated v4 state. Must be called with rs.mu held.
func (rs *RouteSelector) effectiveNetID(id route.NetID) route.NetID {
name := string(id)
if !strings.HasSuffix(name, route.V6ExitSuffix) {
return id
}
if _, ok := rs.selectedRoutes[id]; ok {
return id
}
if _, ok := rs.deselectedRoutes[id]; ok {
return id
}
return route.NetID(strings.TrimSuffix(name, route.V6ExitSuffix))
}
func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool {
if rs.deselectAll {
return false
}
_, deselected := rs.deselectedRoutes[routeID]
return !deselected
}
func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool {
if rs.deselectAll {
return true
}
_, deselected := rs.deselectedRoutes[netID]
return deselected
}
func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool {
_, selected := rs.selectedRoutes[routeID]
_, deselected := rs.deselectedRoutes[routeID]
return selected || deselected
}
func isExitNode(rt []*route.Route) bool {
return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network))
}
func (rs *RouteSelector) applyExitNodeFilter(
id route.HAUniqueID,
netID route.NetID,
rt []*route.Route,
out route.HAMap,
) {
// Exit-node path: apply the v4/v6 pair mirror so a deselect on the v4 base also
// drops the synthesized v6 entry that lacks its own explicit state.
effective := rs.effectiveNetID(netID)
if rs.hasUserSelectionForRouteLocked(effective) {
if rs.isSelectedLocked(effective) {
out[id] = rt
}
return
}
// no explicit selection for this route: defer to management's SkipAutoApply flag
sel := collectSelected(rt)
if len(sel) > 0 {
out[id] = sel
}
}
func collectSelected(rt []*route.Route) []*route.Route {
var sel []*route.Route
for _, r := range rt {
if !r.SkipAutoApply {
sel = append(sel, r)
}
}
return sel
}
// MarshalJSON implements the json.Marshaler interface
func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
rs.mu.RLock()
@@ -265,59 +317,3 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
return nil
}
func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool {
if rs.deselectAll {
return false
}
_, deselected := rs.deselectedRoutes[routeID]
return !deselected
}
func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool {
if rs.deselectAll {
return true
}
_, deselected := rs.deselectedRoutes[netID]
return deselected
}
func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool {
_, selected := rs.selectedRoutes[routeID]
_, deselected := rs.deselectedRoutes[routeID]
return selected || deselected
}
func (rs *RouteSelector) applyExitNodeFilter(
id route.HAUniqueID,
netID route.NetID,
rt []*route.Route,
out route.HAMap,
) {
if rs.hasUserSelectionForRouteLocked(netID) {
if rs.isSelectedLocked(netID) {
out[id] = rt
}
return
}
// no explicit selection for this route: defer to management's SkipAutoApply flag
sel := collectSelected(rt)
if len(sel) > 0 {
out[id] = sel
}
}
func isExitNode(rt []*route.Route) bool {
return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network))
}
func collectSelected(rt []*route.Route) []*route.Route {
var sel []*route.Route
for _, r := range rt {
if !r.SkipAutoApply {
sel = append(sel, r)
}
}
return sel
}

View File

@@ -330,73 +330,39 @@ func TestRouteSelector_FilterSelectedExitNodes(t *testing.T) {
assert.Len(t, filtered, 0) // No routes should be selected
}
// TestRouteSelector_V6ExitPairSync covers SyncPairedSelection, which keeps a v4
// exit node and its synthesized "-v6" counterpart consistent. The selector itself
// is literal and never infers a v6 entry's state from its v4 base; callers that know
// the pairing (exit-node code paths) call SyncPairedSelection to force the v6 entry
// to follow the base, treating the pair as a single toggle.
func TestRouteSelector_V6ExitPairSync(t *testing.T) {
// TestRouteSelector_V6ExitPairInherits covers the v4/v6 exit-node pair selection
// mirror. The mirror is scoped to exit-node code paths: HasUserSelectionForRoute
// and FilterSelectedExitNodes resolve a "-v6" entry without explicit state to its
// v4 base, so legacy persisted selections that predate v6 pairing transparently
// apply to the synthesized v6 entry. General lookups (IsSelected, FilterSelected)
// stay literal so unrelated routes named "*-v6" don't inherit unrelated state.
func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
all := []route.NetID{"exit1", "exit1-v6", "exit2", "exit2-v6", "corp", "corp-v6"}
t.Run("selector lookups stay literal without sync", func(t *testing.T) {
t.Run("HasUserSelectionForRoute mirrors deselected v4 base", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
// The selector does not pair-resolve: the v6 entry is independent until synced.
assert.False(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 entry has no state of its own")
assert.True(t, rs.IsSelected("exit1-v6"), "unsynced v6 entry stays selected by default")
assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 pair sees v4 base's user selection")
// A route literally named "exit1-something" must never pair-resolve either.
assert.False(t, rs.HasUserSelectionForRoute("exit1-something"))
// unrelated v6 with no v4 base touched is unaffected
assert.False(t, rs.HasUserSelectionForRoute("exit2-v6"))
})
t.Run("sync mirrors deselected v4 base onto v6", func(t *testing.T) {
t.Run("IsSelected stays literal for non-exit lookups", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
// A non-exit route literally named "corp-v6" must not inherit "corp"'s state
// via the mirror; the mirror only applies in exit-node code paths.
assert.False(t, rs.IsSelected("corp"))
assert.True(t, rs.IsSelected("corp-v6"), "non-exit *-v6 routes must not inherit unrelated v4 state")
})
t.Run("explicit v6 state overrides v4 base in filter", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.False(t, rs.IsSelected("exit1"))
assert.False(t, rs.IsSelected("exit1-v6"), "v6 pair follows v4 base deselect")
assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 carries explicit deselect after sync")
})
t.Run("sync mirrors selected v4 base onto v6", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1"}, false, all))
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.True(t, rs.IsSelected("exit1"))
assert.True(t, rs.IsSelected("exit1-v6"), "v6 pair follows v4 base select")
})
t.Run("sync clears v6 state when base has no explicit selection", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1-v6"}, true, all))
require.True(t, rs.HasUserSelectionForRoute("exit1-v6"))
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.False(t, rs.HasUserSelectionForRoute("exit1-v6"),
"v6 explicit state is cleared so it follows management like its base")
})
// Regression for the observed bug (see netbird-engine.log): persisted state has
// the v4 base deselected but the v6 sibling explicitly selected (orphaned). The
// sync must reset the orphan so the ::/0 route does not leak onto the tunnel.
t.Run("sync clears orphaned explicit v6 selection on deselected base", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
// Prior state: both explicitly selected, then only the v4 base deselected,
// leaving the v6 entry as a stale explicit selection.
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1", "exit1-v6"}, true, all))
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
require.True(t, rs.IsSelected("exit1-v6"), "precondition: orphaned v6 selection")
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.False(t, rs.IsSelected("exit1-v6"), "orphaned v6 selection reset to follow v4 deselect")
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
@@ -404,14 +370,23 @@ func TestRouteSelector_V6ExitPairSync(t *testing.T) {
"exit1|0.0.0.0/0": {v4Route},
"exit1-v6|::/0": {v6Route},
}
filtered := rs.FilterSelectedExitNodes(routes)
assert.Empty(t, filtered, "deselecting v4 base must drop the v6 pair even if it was explicitly selected before")
assert.NotContains(t, filtered, route.HAUniqueID("exit1|0.0.0.0/0"))
assert.Contains(t, filtered, route.HAUniqueID("exit1-v6|::/0"), "explicit v6 select wins over v4 base")
})
t.Run("filter drops synced v6 pair of deselected v4 base", func(t *testing.T) {
t.Run("non-v6-suffix routes unaffected", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
// A route literally named "exit1-something" must not pair-resolve.
assert.False(t, rs.HasUserSelectionForRoute("exit1-something"))
})
t.Run("filter v6 paired with deselected v4 base", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
rs.SyncPairedSelection("exit1", "exit1-v6")
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
@@ -424,15 +399,6 @@ func TestRouteSelector_V6ExitPairSync(t *testing.T) {
assert.Empty(t, filtered, "deselecting v4 base must also drop the v6 pair")
})
t.Run("deselectAll makes sync a no-op", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
rs.DeselectAllRoutes()
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.False(t, rs.HasUserSelectionForRoute("exit1-v6"), "sync must not write explicit state under deselectAll")
})
t.Run("non-exit *-v6 routes pass through FilterSelectedExitNodes", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))

View File

@@ -17,7 +17,6 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
@@ -26,7 +25,6 @@ import (
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
types "github.com/netbirdio/netbird/upload-server/types"
)
// ConnectionListener export internal Listener for mobile
@@ -56,7 +54,6 @@ type selectRoute struct {
Network netip.Prefix
Domains domain.List
Selected bool
Status string
extraNetworks []netip.Prefix
}
@@ -68,8 +65,6 @@ func init() {
type Client struct {
cfgFile string
stateFile string
cacheDir string
logFilePath string
recorder *peer.Status
ctxCancel context.CancelFunc
ctxCancelLock *sync.Mutex
@@ -80,21 +75,16 @@ type Client struct {
onHostDnsFn func([]string)
dnsManager dns.IosDnsManager
loginComplete bool
connectClient *internal.ConnectClient
// preloadedConfig holds config loaded from JSON (used on tvOS where file writes are blocked)
preloadedConfig *profilemanager.Config
stateMu sync.RWMutex
connectClient *internal.ConnectClient
config *profilemanager.Config
}
// NewClient instantiate a new Client
func NewClient(cfgFile, stateFile, cacheDir, logFilePath, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client {
func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client {
return &Client{
cfgFile: cfgFile,
stateFile: stateFile,
cacheDir: cacheDir,
logFilePath: logFilePath,
deviceName: deviceName,
osName: osName,
osVersion: osVersion,
@@ -171,13 +161,8 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
c.onHostDnsFn = func([]string) {}
cfg.WgIface = interfaceName
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
c.setState(cfg, connectClient)
// Persist the latest sync response so DebugBundle can include the network
// map. On iOS this is backed by disk to keep it out of the constrained
// process memory (see the syncstore package).
connectClient.SetSyncResponsePersistence(true)
return connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile, c.cacheDir, c.logFilePath)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
}
// Stop the internal client and free the resources
@@ -189,84 +174,6 @@ func (c *Client) Stop() {
}
c.ctxCancel()
c.setState(nil, nil)
}
// DebugBundle generates a debug bundle, uploads it and returns the upload key.
// It works with or without a running engine: when the engine is up it reuses
// the live config, sync response and client metrics; otherwise it loads the
// config from disk (or the preloaded tvOS config).
func (c *Client) DebugBundle(anonymize bool) (string, error) {
cfg, cc := c.stateSnapshot()
// If the engine hasn't been started, load config so we can reach management.
if cfg == nil {
if c.preloadedConfig != nil {
cfg = c.preloadedConfig
} else {
var err error
// Use DirectUpdateOrCreateConfig to avoid atomic file operations
// (temp file + rename) blocked by the tvOS sandbox.
cfg, err = profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
StateFilePath: c.stateFile,
})
if err != nil {
return "", fmt.Errorf("load config: %w", err)
}
}
}
deps := debug.GeneratorDependencies{
InternalConfig: cfg,
StatusRecorder: c.recorder,
TempDir: c.cacheDir,
StatePath: c.stateFile,
LogPath: c.logFilePath,
}
if cc != nil {
resp, err := cc.GetLatestSyncResponse()
if err != nil {
log.Warnf("get latest sync response: %v", err)
}
deps.SyncResponse = resp
if e := cc.Engine(); e != nil {
if cm := e.GetClientMetrics(); cm != nil {
deps.ClientMetrics = cm
}
}
}
bundleGenerator := debug.NewBundleGenerator(
deps,
debug.BundleConfig{
Anonymize: anonymize,
IncludeSystemInfo: true,
},
)
path, err := bundleGenerator.Generate()
if err != nil {
return "", fmt.Errorf("generate debug bundle: %w", err)
}
defer func() {
if err := os.Remove(path); err != nil {
log.Errorf("failed to remove debug bundle file: %v", err)
}
}()
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
if err != nil {
return "", fmt.Errorf("upload debug bundle: %w", err)
}
log.Infof("debug bundle uploaded with key %s", key)
return key, nil
}
// SetTraceLogLevel configure the logger to trace level
@@ -320,16 +227,6 @@ func (c *Client) RemoveConnectionListener() {
c.recorder.RemoveConnectionListener()
}
// IsLoginRequiredCached reports whether the LAST observed management error was an
// auth failure (PermissionDenied/InvalidArgument), using the in-memory status
// recorder. Unlike IsLoginRequired() it performs NO network call, so it is safe to
// call from the connection listener during teardown (e.g. onDisconnected) without
// blocking on a slow or unavailable network. Returns false while connected to
// management or when the last error was not auth-related.
func (c *Client) IsLoginRequiredCached() bool {
return c.recorder.IsLoginRequired()
}
func (c *Client) IsLoginRequired() bool {
var ctx context.Context
//nolint
@@ -457,12 +354,11 @@ func (c *Client) ClearLoginComplete() {
}
func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
_, connectClient := c.stateSnapshot()
if connectClient == nil {
if c.connectClient == nil {
return nil, fmt.Errorf("not connected")
}
engine := connectClient.Engine()
engine := c.connectClient.Engine()
if engine == nil {
return nil, fmt.Errorf("not connected")
}
@@ -481,57 +377,9 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
routes := buildSelectRoutes(routesMap, routeSelector.IsSelected, v6ExitMerged)
resolvedDomains := c.recorder.GetResolvedDomainsStates()
// Compute each route's connection status in the core (mirroring the Android
// bridge), so the UI doesn't have to infer it by string-matching the joined
// Network value against peer routes. For a merged exit node the status reflects
// whichever of the v4/v6 prefixes is served by a connected peer; for dynamic
// (DNS) routes the peer route key is the domain pattern (see dynamic.Route.String).
connectedRoutes := c.connectedRouteSet()
for _, r := range routes {
r.Status = routeStatus(r, connectedRoutes)
}
return prepareRouteSelectionDetails(routes, resolvedDomains), nil
}
// connectedRouteSet returns the set of route keys (as strings) currently served by a
// connected peer, gathered across all connected peers' route tables. The keys match
// what the route manager records: a prefix string for static routes (e.g. "0.0.0.0/0")
// and the domain pattern for dynamic routes (e.g. "*.example.com").
func (c *Client) connectedRouteSet() map[string]struct{} {
connected := map[string]struct{}{}
for _, p := range c.recorder.GetFullStatus().Peers {
if p.ConnStatus != peer.StatusConnected {
continue
}
for r := range p.GetRoutes() {
connected[r] = struct{}{}
}
}
return connected
}
// routeStatus reports "Connected" if any of the route's keys is served by a connected
// peer: the primary Network prefix, an extra v6 network of a merged exit node, or the
// domain pattern for a dynamic DNS route. Otherwise "Idle".
func routeStatus(r *selectRoute, connectedRoutes map[string]struct{}) string {
keys := make([]string, 0, 1+len(r.extraNetworks))
if len(r.Domains) > 0 {
keys = append(keys, r.Domains.SafeString())
} else {
keys = append(keys, r.Network.String())
}
for _, extra := range r.extraNetworks {
keys = append(keys, extra.String())
}
for _, k := range keys {
if _, ok := connectedRoutes[k]; ok {
return peer.StatusConnected.String()
}
}
return peer.StatusIdle.String()
}
func buildSelectRoutes(routesMap map[route.NetID][]*route.Route, isSelected func(route.NetID) bool, v6Merged map[route.NetID]struct{}) []*selectRoute {
var routes []*selectRoute
for id, rt := range routesMap {
@@ -614,7 +462,6 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom
Network: netStr,
Domains: &domainDetails,
Selected: r.Selected,
Status: r.Status,
})
}
@@ -623,12 +470,11 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom
}
func (c *Client) SelectRoute(id string) error {
_, connectClient := c.stateSnapshot()
if connectClient == nil {
if c.connectClient == nil {
return fmt.Errorf("not connected")
}
engine := connectClient.Engine()
engine := c.connectClient.Engine()
if engine == nil {
return fmt.Errorf("not connected")
}
@@ -654,11 +500,10 @@ func (c *Client) SelectRoute(id string) error {
}
func (c *Client) DeselectRoute(id string) error {
_, connectClient := c.stateSnapshot()
if connectClient == nil {
if c.connectClient == nil {
return fmt.Errorf("not connected")
}
engine := connectClient.Engine()
engine := c.connectClient.Engine()
if engine == nil {
return fmt.Errorf("not connected")
}
@@ -682,22 +527,6 @@ func (c *Client) DeselectRoute(id string) error {
return nil
}
// setState stores the running engine state so DebugBundle can reuse the live
// config and ConnectClient. It is cleared on Stop.
func (c *Client) setState(cfg *profilemanager.Config, cc *internal.ConnectClient) {
c.stateMu.Lock()
defer c.stateMu.Unlock()
c.config = cfg
c.connectClient = cc
}
// stateSnapshot returns the current config and ConnectClient under the lock.
func (c *Client) stateSnapshot() (*profilemanager.Config, *internal.ConnectClient) {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.config, c.connectClient
}
func formatDuration(d time.Duration) string {
ds := d.String()
dotIndex := strings.Index(ds, ".")

View File

@@ -20,7 +20,6 @@ type RoutesSelectionInfo struct {
Network string
Domains *DomainDetails
Selected bool
Status string
}
type DomainCollection interface {

View File

@@ -84,16 +84,48 @@ func NewPolicy(values map[string]any) *Policy {
return &Policy{values: values}
}
// LoadPolicy reads the platform-native MDM configuration. Returns an
// empty (but non-nil) Policy when no source is present, the source is
// empty, or the platform is unsupported.
// PolicyFetcher is implemented by mobile platforms (Android / iOS) that
// push the OS-managed configuration into the Go runtime instead of
// having Go read an on-disk source directly. Desktop platforms ignore
// this interface — Loader.loadPlatform on windows/darwin reads the
// registry / plist on its own. A Loader constructed with a non-nil
// fetcher delegates to it on mobile; passing nil disables MDM
// enforcement (loadPlatform returns nil values).
type PolicyFetcher interface {
Fetch() map[string]any
}
// Loader is the DI-friendly entry point for reading the active MDM
// policy. Construct one at the daemon's lifecycle owner (Server on
// desktop, gomobile-exposed bridge on mobile) and pass it to anything
// that needs to read MDM state (the reload ticker, profilemanager's
// Config). Each callsite has the Loader handed in instead of looking
// up package-level state.
type Loader struct {
fetcher PolicyFetcher
}
// NewLoader constructs a Loader. The fetcher is consulted only on
// mobile builds (ios || android); on desktop it is unused but accepted
// to keep a single constructor signature across platforms — pass nil
// on desktop.
func NewLoader(f PolicyFetcher) *Loader {
return &Loader{fetcher: f}
}
// Load reads the platform-native MDM configuration and returns a
// Policy. Returns an empty (but non-nil) Policy when no source is
// present, the source is empty, or the platform is unsupported.
//
// Diagnostic logging differentiates the three states:
// - source absent / unsupported platform: trace log only
// - source present, zero keys: info "MDM enrolled (no managed keys)"
// - source present, N keys: info "MDM enrolled with N managed keys: [...]"
func LoadPolicy() *Policy {
values, err := loadPlatformPolicy()
func (l *Loader) Load() *Policy {
if l == nil {
return &Policy{values: map[string]any{}}
}
values, err := l.loadPlatform()
if err != nil {
log.Tracef("MDM policy load: %v", err)
return &Policy{values: map[string]any{}}

View File

@@ -25,8 +25,10 @@ import (
// writable plist, as a defense against tampered installs.
const policyPlistPath = "/Library/Managed Preferences/io.netbird.client.plist"
// loadPlatformPolicy reads the MDM-managed configuration from the macOS
// managed-preferences plist at policyPlistPath. Returns:
// loadPlatform reads the MDM-managed configuration from the macOS
// managed-preferences plist at policyPlistPath. The Loader's fetcher
// field is unused on this platform — the plist is the authoritative
// source. Returns:
// - (nil, nil) when the plist is absent (device not MDM-enrolled for
// NetBird, or admin has not yet pushed a payload)
// - (map, nil) with N entries when N managed values are present
@@ -39,7 +41,13 @@ const policyPlistPath = "/Library/Managed Preferences/io.netbird.client.plist"
// skipped so a stray entry in the payload does not block startup.
// Native plist value types map naturally onto the Policy accessor
// expectations (GetString / GetBool / GetInt / GetStringSlice).
func loadPlatformPolicy() (map[string]any, error) {
func (l *Loader) loadPlatform() (map[string]any, error) {
// Honour the injected fetcher when present so tests (and any
// future non-macOS MDM channel) can short-circuit the plist read
// with a scripted policy.
if l != nil && l.fetcher != nil {
return l.fetcher.Fetch(), nil
}
f, err := os.Open(policyPlistPath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {

View File

@@ -2,13 +2,14 @@
package mdm
// loadPlatformPolicy is unused on mobile: the native layer (Swift on iOS,
// Kotlin/Java on Android) reads the OS managed-config store and pushes the
// resulting dictionary in-process via a gomobile entry point that lands in
// Phase 5 / Phase 6. The stub keeps the package compilable for mobile
// builds and returns (nil, nil) — the platform-absent sentinel that
// LoadPolicy in policy.go treats as "no MDM source present".
func loadPlatformPolicy() (map[string]any, error) {
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
return nil, nil
// loadPlatform reads the OS-managed configuration via the native
// PolicyFetcher injected at Loader construction. Returns
// (nil, nil) — the platform-absent sentinel that Loader.Load treats as
// "no MDM source present" — when no fetcher was provided.
func (l *Loader) loadPlatform() (map[string]any, error) {
if l == nil || l.fetcher == nil {
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see Loader.Load.
return nil, nil
}
return l.fetcher.Fetch(), nil
}

View File

@@ -2,13 +2,17 @@
package mdm
// loadPlatformPolicy returns no policy on platforms without an MDM channel
// (Linux, FreeBSD). MDM enforcement is off and the client behaves as if
// the feature did not exist. Returns (nil, nil) — the platform-absent
// sentinel the caller (LoadPolicy in policy.go) treats as "no MDM
// source present"; an error here would just translate to the same
// outcome with an extra log line.
func loadPlatformPolicy() (map[string]any, error) {
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
// loadPlatform reads the MDM policy on platforms without a native MDM
// channel (Linux, FreeBSD). When no fetcher was injected the policy is
// (nil, nil) — the platform-absent sentinel that Loader.Load treats as
// "MDM enforcement disabled". A non-nil fetcher takes precedence: it
// is the test-seam used by unit tests to inject a scripted policy
// without touching the OS, and the same hook supports any future
// non-mobile OS that grows an out-of-band MDM channel.
func (l *Loader) loadPlatform() (map[string]any, error) {
if l != nil && l.fetcher != nil {
return l.fetcher.Fetch(), nil
}
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see Loader.Load.
return nil, nil
}

View File

@@ -150,10 +150,12 @@ func TestPolicy_GetStringSlice(t *testing.T) {
})
}
func TestLoadPolicy_PlatformStubReturnsEmpty(t *testing.T) {
// loadPlatformPolicy is a stub on every OS for Phase 1. LoadPolicy must
// degrade gracefully and never return nil.
p := LoadPolicy()
func TestLoader_NilFetcherReturnsEmpty(t *testing.T) {
// Loader.Load with no fetcher (desktop construction) must degrade
// gracefully and never return nil; on linux loadPlatform is a stub
// returning (nil, nil), and Load is expected to translate that
// into a non-nil empty Policy.
p := NewLoader(nil).Load()
require.NotNil(t, p)
assert.True(t, p.IsEmpty())
assert.Empty(t, p.ManagedKeys())

View File

@@ -61,8 +61,10 @@ func readRegistryValue(k registry.Key, name, canonical string, out map[string]an
}
}
// loadPlatformPolicy reads the MDM-managed configuration from the
// Windows registry under HKLM\Software\Policies\NetBird. Returns:
// loadPlatform reads the MDM-managed configuration from the Windows
// registry under HKLM\Software\Policies\NetBird. The Loader's fetcher
// field is unused on this platform — the registry is the
// authoritative source. Returns:
// - (nil, nil) when the key is absent (device not MDM-enrolled for NetBird)
// - (map, nil) with N entries when N managed values are set (N may be 0)
// - (nil, err) on open / enumerate registry errors
@@ -70,7 +72,13 @@ func readRegistryValue(k registry.Key, name, canonical string, out map[string]an
// Per-value type coercion + skip-on-error is delegated to
// readRegistryValue. Unknown value names are logged and skipped so a
// malformed deployment does not block startup.
func loadPlatformPolicy() (map[string]any, error) {
func (l *Loader) loadPlatform() (map[string]any, error) {
// Honour the injected fetcher when present so tests (and any
// future non-Windows MDM channel) can short-circuit the registry
// read with a scripted policy.
if l != nil && l.fetcher != nil {
return l.fetcher.Fetch(), nil
}
k, err := registry.OpenKey(registry.LOCAL_MACHINE, policyRegistryPath, registry.QUERY_VALUE)
if err != nil {
if errors.Is(err, registry.ErrNotExist) {

View File

@@ -15,33 +15,33 @@ import (
// instead, hence anticipating the ticker mechanism entirely.
const DefaultReloadInterval = 1 * time.Minute
// policyLoader is the indirection through which the ticker reads the
// OS-native policy, both for the initial observation and on every tick.
// Production points it at LoadPolicy; tests in this package override it to
// feed a scripted sequence of policies without touching the real OS store.
var policyLoader = LoadPolicy
// Ticker periodically re-reads the OS-native MDM policy via LoadPolicy and
// invokes the onChange callback (supplied to Run) whenever the observed
// Policy diverges from the last observation (added / removed / changed
// keys). Launch with Run from a goroutine; cancel the supplied context
// to stop.
// Ticker periodically re-reads the OS-native MDM policy via the
// injected Loader and invokes the onChange callback (supplied to Run)
// whenever the observed Policy diverges from the last observation
// (added / removed / changed keys). Launch with Run from a goroutine;
// cancel the supplied context to stop.
type Ticker struct {
interval time.Duration
loader *Loader
prev *Policy
}
// NewTicker constructs a Ticker that will re-read the OS-native policy
// every reloadInterval once Run is called.
// The initial snapshot is populated by calling policyLoader at
// every reloadInterval once Run is called. The Loader is injected so
// the ticker doesn't depend on any package-level state — production
// passes the daemon-owned Loader, tests pass a fake Loader (built with
// a fake PolicyFetcher).
//
// The initial snapshot is populated by calling loader.Load() at
// construction time so the first tick only fires
// onChange when the policy actually changed since boot — without
// this baseline the first tick would report every currently-managed
// key as "added" and trigger a spurious engine restart.
func NewTicker(reloadInterval time.Duration) *Ticker {
func NewTicker(reloadInterval time.Duration, loader *Loader) *Ticker {
return &Ticker{
interval: reloadInterval,
prev: policyLoader(),
loader: loader,
prev: loader.Load(),
}
}
@@ -58,7 +58,7 @@ func (t *Ticker) Run(ctx context.Context, onChange func(prev, curr *Policy) erro
log.Info("MDM policy reload ticker stopped")
return
case <-tk.C:
curr := policyLoader()
curr := t.loader.Load()
if policiesEqual(t.prev, curr) {
continue
}

View File

@@ -13,28 +13,40 @@ import (
// testReloadInterval for speeding up the ticker cadence under `go test`
const testReloadInterval = 1 * time.Second
// withPolicyLoader overrides the package-level policyLoader for the duration
// of the test so the ticker observes a scripted policy instead of the real
// OS-native store. The original loader is restored on cleanup.
func withPolicyLoader(t *testing.T, fn func() *Policy) {
t.Helper()
prev := policyLoader
policyLoader = fn
t.Cleanup(func() { policyLoader = prev })
// fakePolicyFetcher implements PolicyFetcher returning a scripted
// policy map. Goroutine-safe so the test can mutate the script while
// the ticker is observing it.
type fakePolicyFetcher struct {
mu sync.Mutex
values map[string]any
}
func (f *fakePolicyFetcher) Fetch() map[string]any {
f.mu.Lock()
defer f.mu.Unlock()
if f.values == nil {
return nil
}
out := make(map[string]any, len(f.values))
for k, v := range f.values {
out[k] = v
}
return out
}
func (f *fakePolicyFetcher) set(values map[string]any) {
f.mu.Lock()
defer f.mu.Unlock()
f.values = values
}
func TestTicker_FiresOnChangeWithDelta(t *testing.T) {
var mu sync.Mutex
current := NewPolicy(nil) // initial observation: empty (no enforcement)
withPolicyLoader(t, func() *Policy {
mu.Lock()
defer mu.Unlock()
return current
})
fetcher := &fakePolicyFetcher{} // initial observation: empty (no enforcement)
loader := NewLoader(fetcher)
type change struct{ prev, curr *Policy }
changes := make(chan change, 1)
tk := NewTicker(testReloadInterval)
tk := NewTicker(testReloadInterval, loader)
require.Equal(t, testReloadInterval, tk.interval)
ctx, cancel := context.WithCancel(context.Background())
@@ -49,15 +61,13 @@ func TestTicker_FiresOnChangeWithDelta(t *testing.T) {
})
close(done)
}()
// Stop Run and wait for it to exit before returning, so the policyLoader
// restore in t.Cleanup can't race the ticker goroutine still reading it.
// Stop Run and wait for it to exit before returning, so the test
// goroutine doesn't race the still-running ticker.
defer func() { cancel(); <-done }()
// Flip the OS-observed policy from empty to one managed key. The next
// tick must detect the diff and invoke onChange.
mu.Lock()
current = NewPolicy(map[string]any{KeyManagementURL: "https://mdm.example.com:443"})
mu.Unlock()
// Flip the OS-observed policy from empty to one managed key. The
// next tick must detect the diff and invoke onChange.
fetcher.set(map[string]any{KeyManagementURL: "https://mdm.example.com:443"})
select {
case c := <-changes:
@@ -69,12 +79,11 @@ func TestTicker_FiresOnChangeWithDelta(t *testing.T) {
}
func TestTicker_NoCallbackWhenPolicyUnchanged(t *testing.T) {
withPolicyLoader(t, func() *Policy {
return NewPolicy(map[string]any{KeyBlockInbound: true})
})
fetcher := &fakePolicyFetcher{values: map[string]any{KeyBlockInbound: true}}
loader := NewLoader(fetcher)
fired := make(chan struct{}, 1)
tk := NewTicker(testReloadInterval)
tk := NewTicker(testReloadInterval, loader)
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
@@ -90,8 +99,8 @@ func TestTicker_NoCallbackWhenPolicyUnchanged(t *testing.T) {
}()
defer func() { cancel(); <-done }()
// Over ~2 ticks at the 1s test cadence the policy never changes, so the
// diff guard must suppress the callback entirely.
// Over ~2 ticks at the 1s test cadence the policy never changes,
// so the diff guard must suppress the callback entirely.
select {
case <-fired:
t.Fatal("onChange fired despite an unchanged policy")

File diff suppressed because it is too large Load Diff

View File

@@ -85,8 +85,6 @@ service DaemonService {
rpc AddProfile(AddProfileRequest) returns (AddProfileResponse) {}
rpc RenameProfile(RenameProfileRequest) returns (RenameProfileResponse) {}
rpc RemoveProfile(RemoveProfileRequest) returns (RemoveProfileResponse) {}
rpc ListProfiles(ListProfilesRequest) returns (ListProfilesResponse) {}
@@ -380,9 +378,6 @@ message RelayState {
string URI = 1;
bool available = 2;
string error = 3;
// transport is the negotiated relay transport (e.g. "ws", "quic"),
// empty for stun/turn probes or when not connected.
string transport = 4;
}
message NSGroupState {
@@ -627,18 +622,11 @@ message GetEventsResponse {
}
message SwitchProfileRequest {
// profileName is treated as a handle: exact ID, unique ID prefix, or
// unique display name. The daemon resolves it server-side.
optional string profileName = 1;
optional string username = 2;
}
message SwitchProfileResponse {
// id is the resolved on-disk ID of the profile that became active.
// Lets CLI clients update their local active-profile state without
// duplicating the resolution logic.
string id = 1;
}
message SwitchProfileResponse {}
message SetConfigRequest {
string username = 1;
@@ -705,42 +693,17 @@ message SetConfigResponse{}
message AddProfileRequest {
string username = 1;
// profileName carries the human-readable display name for the new
// profile. The on-disk filename is a separately-generated ID.
string profileName = 2;
}
message AddProfileResponse {
// id is the generated on-disk ID of the new profile. CLI clients
// display a truncated form, UI clients can ignore it.
string id = 1;
}
message RenameProfileRequest {
string username = 1;
// handle: an exact ID, a unique ID prefix, or a unique display name.
string handle = 2;
// newProfileName is the new human-readable display name for the profile.
string newProfileName = 3;
}
message RenameProfileResponse {
// confirm the old profile name after resolving handle.
string oldProfileName = 1;
}
message AddProfileResponse {}
message RemoveProfileRequest {
string username = 1;
// profileName is treated as a handle: an exact ID, a unique ID
// prefix, or a unique display name. Resolution happens server-side.
string profileName = 2;
}
message RemoveProfileResponse {
// id is the full resolved ID of the removed profile, so callers can
// confirm exactly which profile a name/prefix handle resolved to.
string id = 1;
}
message RemoveProfileResponse {}
message ListProfilesRequest {
string username = 1;
@@ -753,7 +716,6 @@ message ListProfilesResponse {
message Profile {
string name = 1;
bool is_active = 2;
string id = 3;
}
message GetActiveProfileRequest {}
@@ -761,7 +723,6 @@ message GetActiveProfileRequest {}
message GetActiveProfileResponse {
string profileName = 1;
string username = 2;
string id = 3;
}
message LogoutRequest {

View File

@@ -45,7 +45,6 @@ const (
DaemonService_SwitchProfile_FullMethodName = "/daemon.DaemonService/SwitchProfile"
DaemonService_SetConfig_FullMethodName = "/daemon.DaemonService/SetConfig"
DaemonService_AddProfile_FullMethodName = "/daemon.DaemonService/AddProfile"
DaemonService_RenameProfile_FullMethodName = "/daemon.DaemonService/RenameProfile"
DaemonService_RemoveProfile_FullMethodName = "/daemon.DaemonService/RemoveProfile"
DaemonService_ListProfiles_FullMethodName = "/daemon.DaemonService/ListProfiles"
DaemonService_GetActiveProfile_FullMethodName = "/daemon.DaemonService/GetActiveProfile"
@@ -113,7 +112,6 @@ type DaemonServiceClient interface {
SwitchProfile(ctx context.Context, in *SwitchProfileRequest, opts ...grpc.CallOption) (*SwitchProfileResponse, error)
SetConfig(ctx context.Context, in *SetConfigRequest, opts ...grpc.CallOption) (*SetConfigResponse, error)
AddProfile(ctx context.Context, in *AddProfileRequest, opts ...grpc.CallOption) (*AddProfileResponse, error)
RenameProfile(ctx context.Context, in *RenameProfileRequest, opts ...grpc.CallOption) (*RenameProfileResponse, error)
RemoveProfile(ctx context.Context, in *RemoveProfileRequest, opts ...grpc.CallOption) (*RemoveProfileResponse, error)
ListProfiles(ctx context.Context, in *ListProfilesRequest, opts ...grpc.CallOption) (*ListProfilesResponse, error)
GetActiveProfile(ctx context.Context, in *GetActiveProfileRequest, opts ...grpc.CallOption) (*GetActiveProfileResponse, error)
@@ -424,16 +422,6 @@ func (c *daemonServiceClient) AddProfile(ctx context.Context, in *AddProfileRequ
return out, nil
}
func (c *daemonServiceClient) RenameProfile(ctx context.Context, in *RenameProfileRequest, opts ...grpc.CallOption) (*RenameProfileResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(RenameProfileResponse)
err := c.cc.Invoke(ctx, DaemonService_RenameProfile_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) RemoveProfile(ctx context.Context, in *RemoveProfileRequest, opts ...grpc.CallOption) (*RemoveProfileResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(RemoveProfileResponse)
@@ -625,7 +613,6 @@ type DaemonServiceServer interface {
SwitchProfile(context.Context, *SwitchProfileRequest) (*SwitchProfileResponse, error)
SetConfig(context.Context, *SetConfigRequest) (*SetConfigResponse, error)
AddProfile(context.Context, *AddProfileRequest) (*AddProfileResponse, error)
RenameProfile(context.Context, *RenameProfileRequest) (*RenameProfileResponse, error)
RemoveProfile(context.Context, *RemoveProfileRequest) (*RemoveProfileResponse, error)
ListProfiles(context.Context, *ListProfilesRequest) (*ListProfilesResponse, error)
GetActiveProfile(context.Context, *GetActiveProfileRequest) (*GetActiveProfileResponse, error)
@@ -736,9 +723,6 @@ func (UnimplementedDaemonServiceServer) SetConfig(context.Context, *SetConfigReq
func (UnimplementedDaemonServiceServer) AddProfile(context.Context, *AddProfileRequest) (*AddProfileResponse, error) {
return nil, status.Error(codes.Unimplemented, "method AddProfile not implemented")
}
func (UnimplementedDaemonServiceServer) RenameProfile(context.Context, *RenameProfileRequest) (*RenameProfileResponse, error) {
return nil, status.Error(codes.Unimplemented, "method RenameProfile not implemented")
}
func (UnimplementedDaemonServiceServer) RemoveProfile(context.Context, *RemoveProfileRequest) (*RemoveProfileResponse, error) {
return nil, status.Error(codes.Unimplemented, "method RemoveProfile not implemented")
}
@@ -1253,24 +1237,6 @@ func _DaemonService_AddProfile_Handler(srv interface{}, ctx context.Context, dec
return interceptor(ctx, in, info, handler)
}
func _DaemonService_RenameProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RenameProfileRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).RenameProfile(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: DaemonService_RenameProfile_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).RenameProfile(ctx, req.(*RenameProfileRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_RemoveProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RemoveProfileRequest)
if err := dec(in); err != nil {
@@ -1601,10 +1567,6 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "AddProfile",
Handler: _DaemonService_AddProfile_Handler,
},
{
MethodName: "RenameProfile",
Handler: _DaemonService_RenameProfile_Handler,
},
{
MethodName: "RemoveProfile",
Handler: _DaemonService_RemoveProfile_Handler,

View File

@@ -79,7 +79,7 @@ func TestPersistLoginOverrides(t *testing.T) {
_, err := profilemanager.UpdateOrCreateConfig(seed)
require.NoError(t, err, "seed config")
activeProf := &profilemanager.ActiveProfileState{ID: "default"}
activeProf := &profilemanager.ActiveProfileState{Name: "default"}
err = persistLoginOverrides(activeProf, tt.newMgmtURL, tt.newPSK)
require.NoError(t, err, "persistLoginOverrides")

View File

@@ -20,10 +20,6 @@ import (
// a no-op echo, never as a conflict with the policy.
const preSharedKeyRedactedSentinel = "**********"
// loadMDMPolicy is the indirection used by server handlers to read the
// active MDM policy. Tests override this to inject a fake policy.
var loadMDMPolicy = mdm.LoadPolicy
// conflictCheck is a value-aware comparison between a single field in
// the incoming request and the corresponding MDM-enforced value. It
// runs only when the field was actually set in the request (presence

View File

@@ -78,7 +78,7 @@ type Server struct {
// changed by connectWithRetryRuns goroutine exit — for that
// (goroutine-still-alive) check, see connectionGoroutineRunning() which
// derives from clientGiveUpChan close state. Protected by s.mutex.
clientRunning bool
clientRunning bool
clientRunningChan chan struct{}
clientGiveUpChan chan struct{} // closed when connectWithRetryRuns goroutine exits
@@ -110,6 +110,15 @@ type Server struct {
// stopped by the rootCtx cancellation.
mdmTicker *mdm.Ticker
// mdmLoader is the daemon-owned source of the active MDM policy.
// Constructed once during Server.Start (with a nil PolicyFetcher on
// desktop — the build-tagged Loader.loadPlatform reads the OS
// registry / plist directly) and injected into every consumer:
// mdmTicker for its periodic reload, the SetConfig / Login MDM
// gates for conflict detection, and every Config produced via
// getConfig() so its apply() picks up the same overlay.
mdmLoader *mdm.Loader
updateManager *updater.Manager
jwtCache *jwtCache
@@ -173,8 +182,14 @@ func (s *Server) Start() error {
// Runs re-resolves Config (re-running profilemanager.Config.apply which
// applies the freshly-read MDM policy as the last layer) and brings
// the engine back with the new values.
if s.mdmLoader == nil {
// Desktop builds pass a nil PolicyFetcher: the Loader's
// build-tagged loadPlatform reads the OS source directly
// (registry on Windows, plist on macOS, no-op elsewhere).
s.mdmLoader = mdm.NewLoader(nil)
}
if s.mdmTicker == nil {
s.mdmTicker = mdm.NewTicker(mdm.DefaultReloadInterval)
s.mdmTicker = mdm.NewTicker(mdm.DefaultReloadInterval, s.mdmLoader)
go s.mdmTicker.Run(s.rootCtx, s.onMDMPolicyChange)
}
@@ -370,12 +385,12 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
// by the active MDM policy. The error carries an MDMManagedFields-
// Violation detail listing the offending key names. Non-conflicting
// fields in the same request are not applied either.
policy := loadMDMPolicy()
policy := s.mdmLoader.Load()
if err := rejectMDMManagedFieldConflicts(mdmManagedFieldConflicts(msg, policy)); err != nil {
return nil, err
}
config, err := s.setConfigInputFromRequest(msg)
config, err := setConfigInputFromRequest(msg)
if err != nil {
return nil, err
}
@@ -398,17 +413,17 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
// field is its own optional case. Returns the resolved ConfigInput
// and a non-nil error only when the active profile file path cannot
// be determined.
func (s *Server) setConfigInputFromRequest(msg *proto.SetConfigRequest) (profilemanager.ConfigInput, error) {
func setConfigInputFromRequest(msg *proto.SetConfigRequest) (profilemanager.ConfigInput, error) {
var config profilemanager.ConfigInput
resolved, err := s.resolveProfileHandle(msg.ProfileName, msg.Username)
if err != nil {
log.Errorf("failed to resolve profile %q: %v", msg.ProfileName, err)
return config, err
profState := profilemanager.ActiveProfileState{
Name: msg.ProfileName,
Username: msg.Username,
}
profPath := resolved.Path
if profPath == "" {
profPath = profilemanager.DefaultConfigPath
profPath, err := profState.FilePath()
if err != nil {
log.Errorf("failed to get active profile file path: %v", err)
return config, fmt.Errorf("failed to get active profile file path: %w", err)
}
config.ConfigPath = profPath
@@ -496,7 +511,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
if s.checkUpdateSettingsDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
}
policy := loadMDMPolicy()
policy := s.mdmLoader.Load()
if err := rejectMDMManagedFieldConflicts(loginRequestMDMConflicts(msg, policy)); err != nil {
return nil, err
}
@@ -535,9 +550,30 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
}
if msg.ProfileName != nil {
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
log.Errorf("failed to switch profile: %v", err)
return nil, err
if *msg.ProfileName != "default" && (msg.Username == nil || *msg.Username == "") {
log.Errorf("profile name is set to %s, but username is not provided", *msg.ProfileName)
return nil, fmt.Errorf("profile name is set to %s, but username is not provided", *msg.ProfileName)
}
var username string
if *msg.ProfileName != "default" {
username = *msg.Username
}
if *msg.ProfileName != activeProf.Name && username != activeProf.Username {
if s.checkProfilesDisabled() {
log.Errorf("profiles are disabled, you cannot use this feature without profiles enabled")
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
log.Infof("switching to profile %s for user '%s'", *msg.ProfileName, username)
if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: *msg.ProfileName,
Username: username,
}); err != nil {
log.Errorf("failed to set active profile state: %v", err)
return nil, fmt.Errorf("failed to set active profile state: %w", err)
}
}
}
@@ -547,7 +583,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
log.Infof("active profile: %s for %s", activeProf.ID, activeProf.Username)
log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username)
s.mutex.Lock()
@@ -785,10 +821,10 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
}
if msg != nil && msg.ProfileName != nil {
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
s.mutex.Unlock()
log.Errorf("failed to switch profile: %v", err)
return nil, err
return nil, fmt.Errorf("failed to switch profile: %w", err)
}
}
@@ -799,7 +835,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
log.Infof("active profile: %s for %s", activeProf.ID, activeProf.Username)
log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username)
config, _, err := s.getConfig(activeProf)
if err != nil {
@@ -843,60 +879,34 @@ func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error)
}
}
// resolveProfileHandle resolves a wire-level profile handle (display
// name, ID, or unique ID prefix) to a concrete profile. Returns gRPC
// status errors so handlers can return them directly.
func (s *Server) resolveProfileHandle(handle, username string) (*profilemanager.Profile, error) {
p, err := s.profileManager.ResolveProfile(handle, username)
if err == nil {
return p, nil
}
var amb *profilemanager.ErrAmbiguousHandle
if errors.As(err, &amb) {
return nil, gstatus.Errorf(codes.InvalidArgument, "%v", amb)
}
if errors.Is(err, profilemanager.ErrProfileNotFound) {
return nil, gstatus.Errorf(codes.NotFound, "profile %q not found", handle)
}
return nil, fmt.Errorf("resolve profile: %w", err)
}
// switchProfileIfNeeded resolves the user-supplied handle, updates the
// active profile state if it differs from the current one, and returns
// the resolved profile so callers can include its ID in RPC responses.
func (s *Server) switchProfileIfNeeded(handle string, userName *string, activeProf *profilemanager.ActiveProfileState) (*profilemanager.Profile, error) {
if handle != profilemanager.DefaultProfileName && (userName == nil || *userName == "") {
log.Errorf("profile name is set to %s, but username is not provided", handle)
return nil, fmt.Errorf("profile name is set to %s, but username is not provided", handle)
func (s *Server) switchProfileIfNeeded(profileName string, userName *string, activeProf *profilemanager.ActiveProfileState) error {
if profileName != "default" && (userName == nil || *userName == "") {
log.Errorf("profile name is set to %s, but username is not provided", profileName)
return fmt.Errorf("profile name is set to %s, but username is not provided", profileName)
}
var username string
if handle != profilemanager.DefaultProfileName {
if profileName != "default" {
username = *userName
}
resolved, err := s.resolveProfileHandle(handle, username)
if err != nil {
return nil, err
}
if resolved.ID != activeProf.ID || username != activeProf.Username {
if profileName != activeProf.Name || username != activeProf.Username {
if s.checkProfilesDisabled() {
log.Errorf("profiles are disabled, you cannot use this feature without profiles enabled")
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
return gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
log.Infof("switching to profile %s (%s) for user %s", resolved.Name, resolved.ID, username)
log.Infof("switching to profile %s for user %s", profileName, username)
if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: resolved.ID,
Name: profileName,
Username: username,
}); err != nil {
log.Errorf("failed to set active profile state: %v", err)
return nil, fmt.Errorf("failed to set active profile state: %w", err)
return fmt.Errorf("failed to set active profile state: %w", err)
}
}
return resolved, nil
return nil
}
// SwitchProfile switches the active profile in the daemon.
@@ -911,9 +921,9 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
}
if msg != nil && msg.ProfileName != nil {
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
log.Errorf("failed to switch profile: %v", err)
return nil, err
return nil, fmt.Errorf("failed to switch profile: %w", err)
}
}
activeProf, err = s.profileManager.GetActiveProfileState()
@@ -929,7 +939,7 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
s.config = config
return &proto.SwitchProfileResponse{Id: activeProf.ID.String()}, nil
return &proto.SwitchProfileResponse{}, nil
}
// Down engine work in the daemon.
@@ -1019,27 +1029,22 @@ func (s *Server) Logout(ctx context.Context, msg *proto.LogoutRequest) (*proto.L
}
func (s *Server) handleProfileLogout(ctx context.Context, msg *proto.LogoutRequest) (*proto.LogoutResponse, error) {
if err := s.validateProfileOperation(*msg.ProfileName, true); err != nil {
return nil, err
}
if msg.Username == nil || *msg.Username == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided when profile name is specified")
}
username := *msg.Username
resolved, err := s.resolveProfileHandle(*msg.ProfileName, username)
if err != nil {
return nil, err
}
if err := s.validateProfileOperation(resolved.ID, true); err != nil {
return nil, err
}
if err := s.logoutFromProfile(ctx, resolved); err != nil {
log.Errorf("failed to logout from profile %s: %v", resolved.ID, err)
if err := s.logoutFromProfile(ctx, *msg.ProfileName, username); err != nil {
log.Errorf("failed to logout from profile %s: %v", *msg.ProfileName, err)
return nil, gstatus.Errorf(codes.Internal, "logout: %v", err)
}
activeProf, _ := s.profileManager.GetActiveProfileState()
if activeProf != nil && activeProf.ID == resolved.ID {
if activeProf != nil && activeProf.Name == *msg.ProfileName {
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
log.Errorf("failed to cleanup connection: %v", err)
}
@@ -1098,33 +1103,39 @@ func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*prof
return nil, false, fmt.Errorf("failed to get config: %w", err)
}
// Apply the daemon-owned MDM policy on top of the just-resolved
// Config. profilemanager's apply() initialises the policy to
// empty — the Loader lives outside Config, so this overlay step
// is driven externally here.
config.ApplyMDMPolicy(s.mdmLoader.Load())
return config, configExisted, nil
}
func (s *Server) canRemoveProfile(id profilemanager.ID) error {
if id == profilemanager.DefaultProfileName {
func (s *Server) canRemoveProfile(profileName string) error {
if profileName == profilemanager.DefaultProfileName {
return fmt.Errorf("remove profile with reserved name: %s", profilemanager.DefaultProfileName)
}
activeProf, err := s.profileManager.GetActiveProfileState()
if err == nil && activeProf.ID == id {
return fmt.Errorf("remove active profile: %s", id)
if err == nil && activeProf.Name == profileName {
return fmt.Errorf("remove active profile: %s", profileName)
}
return nil
}
func (s *Server) validateProfileOperation(id profilemanager.ID, allowActiveProfile bool) error {
func (s *Server) validateProfileOperation(profileName string, allowActiveProfile bool) error {
if s.checkProfilesDisabled() {
return gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
if id == "" {
if profileName == "" {
return gstatus.Errorf(codes.InvalidArgument, "profile name must be provided")
}
if !allowActiveProfile {
if err := s.canRemoveProfile(id); err != nil {
if err := s.canRemoveProfile(profileName); err != nil {
return gstatus.Errorf(codes.InvalidArgument, "%v", err)
}
}
@@ -1132,21 +1143,29 @@ func (s *Server) validateProfileOperation(id profilemanager.ID, allowActiveProfi
return nil
}
func (s *Server) logoutFromProfile(ctx context.Context, profile *profilemanager.Profile) error {
// logoutFromProfile logs out from a specific profile by loading its config and sending logout request
func (s *Server) logoutFromProfile(ctx context.Context, profileName, username string) error {
activeProf, err := s.profileManager.GetActiveProfileState()
if err == nil && activeProf.ID == profile.ID && s.connectClient != nil {
if err == nil && activeProf.Name == profileName && s.connectClient != nil {
return s.sendLogoutRequest(ctx)
}
cfgPath := profile.Path
if cfgPath == "" {
cfgPath = profilemanager.DefaultConfigPath
profileState := &profilemanager.ActiveProfileState{
Name: profileName,
Username: username,
}
profilePath, err := profileState.FilePath()
if err != nil {
return fmt.Errorf("get profile path: %w", err)
}
config, err := profilemanager.GetConfig(cfgPath)
config, err := profilemanager.GetConfig(profilePath)
if err != nil {
return fmt.Errorf("profile '%s' not found", profile.ID)
return fmt.Errorf("profile '%s' not found", profileName)
}
// Honour any MDM-enforced ManagementURL when issuing the logout
// RPC: the user-stored value may have been overridden by policy.
config.ApplyMDMPolicy(s.mdmLoader.Load())
return s.sendLogoutRequestWithConfig(ctx, config)
}
@@ -1563,14 +1582,15 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
return nil, ctx.Err()
}
resolved, err := s.resolveProfileHandle(req.ProfileName, req.Username)
if err != nil {
log.Errorf("failed to resolve profile %q: %v", req.ProfileName, err)
return nil, err
prof := profilemanager.ActiveProfileState{
Name: req.ProfileName,
Username: req.Username,
}
cfgPath := resolved.Path
if cfgPath == "" {
cfgPath = profilemanager.DefaultConfigPath
cfgPath, err := prof.FilePath()
if err != nil {
log.Errorf("failed to get active profile file path: %v", err)
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
}
cfg, err := profilemanager.GetConfig(cfgPath)
@@ -1578,6 +1598,11 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
log.Errorf("failed to get active profile config: %v", err)
return nil, fmt.Errorf("failed to get active profile config: %w", err)
}
// Overlay the active MDM policy so the response's MDMManagedFields
// list reflects what the GUI / CLI must render as read-only.
// profilemanager.GetConfig itself returns a Config without the
// overlay (Loader lives outside profilemanager).
cfg.ApplyMDMPolicy(s.mdmLoader.Load())
managementURL := cfg.ManagementURL
adminURL := cfg.AdminURL
@@ -1675,39 +1700,12 @@ func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided")
}
created, err := s.profileManager.AddProfile(msg.ProfileName, msg.Username)
if err != nil {
if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil {
log.Errorf("failed to create profile: %v", err)
return nil, fmt.Errorf("failed to create profile: %w", err)
}
return &proto.AddProfileResponse{Id: created.ID.String()}, nil
}
func (s *Server) RenameProfile(ctx context.Context, msg *proto.RenameProfileRequest) (*proto.RenameProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.checkProfilesDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
if msg.Handle == "" || msg.Username == "" || msg.NewProfileName == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name, username and new profile name must be provided")
}
resolved, err := s.resolveProfileHandle(msg.Handle, msg.Username)
if err != nil {
return nil, err
}
err = s.profileManager.RenameProfile(resolved.ID, msg.Username, msg.NewProfileName)
if err != nil {
log.Errorf("failed to rename profile: %v", err)
return nil, fmt.Errorf("failed to rename profile: %w", err)
}
return &proto.RenameProfileResponse{OldProfileName: resolved.Name}, nil
return &proto.AddProfileResponse{}, nil
}
// RemoveProfile removes a profile from the daemon.
@@ -1715,29 +1713,20 @@ func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequ
s.mutex.Lock()
defer s.mutex.Unlock()
if s.checkProfilesDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
if msg.ProfileName == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name must be provided")
}
resolved, err := s.resolveProfileHandle(msg.ProfileName, msg.Username)
if err != nil {
if err := s.validateProfileOperation(msg.ProfileName, false); err != nil {
return nil, err
}
if err := s.logoutFromProfile(ctx, resolved); err != nil {
log.Warnf("failed to logout from profile %s before removal: %v", resolved.ID, err)
if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil {
log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err)
}
if err := s.profileManager.RemoveProfile(resolved.ID, msg.Username); err != nil {
if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil {
log.Errorf("failed to remove profile: %v", err)
return nil, fmt.Errorf("failed to remove profile: %w", err)
}
return &proto.RemoveProfileResponse{Id: resolved.ID.String()}, nil
return &proto.RemoveProfileResponse{}, nil
}
// ListProfiles lists all profiles in the daemon.
@@ -1760,7 +1749,6 @@ func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesReques
}
for i, profile := range profiles {
response.Profiles[i] = &proto.Profile{
Id: profile.ID.String(),
Name: profile.Name,
IsActive: profile.IsActive,
}
@@ -1769,9 +1757,7 @@ func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesReques
return response, nil
}
// GetActiveProfile returns the active profile in the daemon. The ProfileName
// field carries the display name for backwards compatibility with UI clients,
// new callers should prefer Id.
// GetActiveProfile returns the active profile in the daemon.
func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
@@ -1782,23 +1768,9 @@ func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfi
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
// Fallback to legacy name == ID
displayName := activeProfile.ID.String()
if activeProfile.ID != profilemanager.DefaultProfileName {
if profiles, lerr := s.profileManager.ListProfiles(activeProfile.Username); lerr == nil {
for _, p := range profiles {
if p.ID == activeProfile.ID {
displayName = p.Name
break
}
}
}
}
return &proto.GetActiveProfileResponse{
ProfileName: displayName,
ProfileName: activeProfile.Name,
Username: activeProfile.Username,
Id: activeProfile.ID.String(),
}, nil
}

View File

@@ -97,7 +97,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: "test-profile",
Name: "test-profile",
Username: currUser.Username,
})
if err != nil {
@@ -158,7 +158,7 @@ func TestServer_Up(t *testing.T) {
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: profilemanager.ID(profName),
Name: profName,
Username: currUser.Username,
})
if err != nil {
@@ -228,7 +228,7 @@ func TestServer_SubcribeEvents(t *testing.T) {
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: "default",
Name: "default",
Username: currUser.Username,
})
if err != nil {

View File

@@ -16,14 +16,40 @@ import (
"github.com/netbirdio/netbird/client/proto"
)
// withMDMPolicy temporarily overrides the server-package loadMDMPolicy hook
// so SetConfig observes the supplied Policy. Restores the original loader
// at test cleanup.
func withMDMPolicy(t *testing.T, policy *mdm.Policy) {
// fakeMDMFetcher implements mdm.PolicyFetcher returning a pre-set
// policy map. Tests build one per Server instance to inject a
// scripted MDM overlay via a Loader rather than via package-level state.
type fakeMDMFetcher struct{ values map[string]any }
func (f *fakeMDMFetcher) Fetch() map[string]any { return f.values }
// withMDMPolicy installs an mdm.Loader on the given Server whose
// loadPlatform returns the supplied Policy's underlying values. Use
// after setupServerWithProfile to inject the scripted policy the
// SetConfig / Login MDM gates will observe.
func withMDMPolicy(t *testing.T, s *Server, policy *mdm.Policy) {
t.Helper()
prev := loadMDMPolicy
loadMDMPolicy = func() *mdm.Policy { return policy }
t.Cleanup(func() { loadMDMPolicy = prev })
values := map[string]any{}
if policy != nil {
for _, k := range policy.ManagedKeys() {
if v, ok := policy.GetString(k); ok {
values[k] = v
continue
}
if v, ok := policy.GetBool(k); ok {
values[k] = v
continue
}
if v, ok := policy.GetInt(k); ok {
values[k] = v
continue
}
if v, ok := policy.GetStringSlice(k); ok {
values[k] = v
}
}
}
s.mdmLoader = mdm.NewLoader(&fakeMDMFetcher{values: values})
}
// setupServerWithProfile mirrors the boilerplate of TestSetConfig_AllFieldsSaved:
@@ -62,7 +88,7 @@ func setupServerWithProfile(t *testing.T) (s *Server, ctx context.Context, profN
pm := profilemanager.ServiceManager{}
require.NoError(t, pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: profilemanager.ID(profName),
Name: profName,
Username: currUser.Username,
}))
@@ -89,12 +115,11 @@ func extractViolation(t *testing.T, err error) *proto.MDMManagedFieldsViolation
}
func TestSetConfig_MDMReject_SingleField(t *testing.T) {
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
s, ctx, profName, username, _ := setupServerWithProfile(t)
withMDMPolicy(t, s, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: "https://mdm.example.com:443",
}))
s, ctx, profName, username, _ := setupServerWithProfile(t)
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
ProfileName: profName,
Username: username,
@@ -106,14 +131,13 @@ func TestSetConfig_MDMReject_SingleField(t *testing.T) {
}
func TestSetConfig_MDMReject_MultipleFields(t *testing.T) {
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
s, ctx, profName, username, _ := setupServerWithProfile(t)
withMDMPolicy(t, s, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: "https://mdm.example.com:443",
mdm.KeyBlockInbound: true,
mdm.KeyRosenpassEnabled: true,
}))
s, ctx, profName, username, _ := setupServerWithProfile(t)
blockInbound := false
rosenpassEnabled := false
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
@@ -137,12 +161,11 @@ func TestSetConfig_MDMReject_AllOrNothing(t *testing.T) {
// enforced field AND a non-enforced field (RosenpassEnabled).
// The whole request must be rejected — non-conflicting fields are not
// applied either.
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
s, ctx, profName, username, cfgPath := setupServerWithProfile(t)
withMDMPolicy(t, s, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: "https://mdm.example.com:443",
}))
s, ctx, profName, username, cfgPath := setupServerWithProfile(t)
rosenpassEnabled := true
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
ProfileName: profName,
@@ -164,12 +187,11 @@ func TestSetConfig_MDMReject_AllOrNothing(t *testing.T) {
func TestSetConfig_MDMAllow_NonManagedFields(t *testing.T) {
// MDM enforces ManagementURL but the user only writes RosenpassEnabled.
// Request must succeed.
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
s, ctx, profName, username, _ := setupServerWithProfile(t)
withMDMPolicy(t, s, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: "https://mdm.example.com:443",
}))
s, ctx, profName, username, _ := setupServerWithProfile(t)
rosenpassEnabled := true
resp, err := s.SetConfig(ctx, &proto.SetConfigRequest{
ProfileName: profName,
@@ -183,9 +205,8 @@ func TestSetConfig_MDMAllow_NonManagedFields(t *testing.T) {
func TestSetConfig_MDMEmpty_NoEnforcement(t *testing.T) {
// No MDM policy active: any field can be written.
withMDMPolicy(t, mdm.NewPolicy(nil))
s, ctx, profName, username, _ := setupServerWithProfile(t)
withMDMPolicy(t, s, mdm.NewPolicy(nil))
resp, err := s.SetConfig(ctx, &proto.SetConfigRequest{
ProfileName: profName,

View File

@@ -47,7 +47,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: profilemanager.ID(profName),
Name: profName,
Username: currUser.Username,
})
require.NoError(t, err)
@@ -96,7 +96,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
DisableNotifications: &disableNotifications,
LazyConnectionEnabled: &lazyConnectionEnabled,
BlockInbound: &blockInbound,
DisableIpv6: &disableIPv6,
DisableIpv6: &disableIPv6,
NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"},
CleanNATExternalIPs: false,
CustomDNSAddress: []byte("1.1.1.1:53"),
@@ -112,7 +112,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.NoError(t, err)
profState := profilemanager.ActiveProfileState{
ID: profilemanager.ID(profName),
Name: profName,
Username: currUser.Username,
}
cfgPath, err := profState.FilePath()

View File

@@ -98,7 +98,6 @@ type RelayStateOutputDetail struct {
URI string `json:"uri" yaml:"uri"`
Available bool `json:"available" yaml:"available"`
Error string `json:"error" yaml:"error"`
Transport string `json:"transport,omitempty" yaml:"transport,omitempty"`
}
type RelayStateOutput struct {
@@ -220,8 +219,7 @@ func mapRelays(relays []*proto.RelayState) RelayStateOutput {
RelayStateOutputDetail{
URI: relay.URI,
Available: available,
Error: relayErrorString(relay.GetError()),
Transport: relay.GetTransport(),
Error: relay.GetError(),
},
)
@@ -237,12 +235,6 @@ func mapRelays(relays []*proto.RelayState) RelayStateOutput {
}
}
// relayErrorString flattens a newline-joined aggregated relay error onto a
// single line for status output.
func relayErrorString(s string) string {
return strings.ReplaceAll(s, "\n", "; ")
}
func mapNSGroups(servers []*proto.NSGroupState) []NsServerGroupStateOutput {
mappedNSGroups := make([]NsServerGroupStateOutput, 0, len(servers))
for _, pbNsGroupServer := range servers {
@@ -449,8 +441,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
available = "Unavailable"
reason = fmt.Sprintf(", reason: %s", relay.Error)
}
} else if relay.Transport != "" {
available = fmt.Sprintf("%s via %s", available, relay.Transport)
}
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)

View File

@@ -647,13 +647,3 @@ func TestTimeAgo(t *testing.T) {
})
}
}
func TestMapRelaysTransport(t *testing.T) {
out := mapRelays([]*proto.RelayState{
{URI: "rels://relay.example:443", Available: true, Transport: "quic"},
{URI: "rels://relay2.example:443", Available: true, Transport: "ws"},
})
require.Len(t, out.Details, 2)
assert.Equal(t, "quic", out.Details[0].Transport)
assert.Equal(t, "ws", out.Details[1].Transport)
}

View File

@@ -645,7 +645,7 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
}
req := &proto.SetConfigRequest{
ProfileName: activeProf.ID.String(),
ProfileName: activeProf.Name,
Username: currUser.Username,
}
@@ -818,15 +818,13 @@ func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginRe
return nil, fmt.Errorf("get current user: %w", err)
}
handle := activeProf.ID.String()
loginReq := &proto.LoginRequest{
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
ProfileName: &handle,
ProfileName: &activeProf.Name,
Username: &currUser.Username,
}
profileState, err := s.profileManager.GetProfileState(activeProf.ID)
profileState, err := s.profileManager.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
@@ -1369,7 +1367,7 @@ func (s *serviceClient) getSrvConfig() {
}
srvCfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
ProfileName: activeProf.ID.String(),
ProfileName: activeProf.Name,
Username: currUser.Username,
})
if err != nil {
@@ -1615,7 +1613,7 @@ func (s *serviceClient) loadSettings() {
}
cfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
ProfileName: activeProf.ID.String(),
ProfileName: activeProf.Name,
Username: currUser.Username,
})
if err != nil {
@@ -1815,7 +1813,7 @@ func (s *serviceClient) updateConfig() error {
}
req := proto.SetConfigRequest{
ProfileName: activeProf.ID.String(),
ProfileName: activeProf.Name,
Username: currUser.Username,
DisableAutoConnect: &disableAutoStart,
ServerSSHAllowed: &sshAllowed,

View File

@@ -66,7 +66,7 @@ func (s *serviceClient) showProfilesUI() {
} else {
indicator.SetText("")
}
nameLabel.SetText(formatProfileLabel(profile, profiles))
nameLabel.SetText(profile.Name)
// Configure Select/Active button
selectBtn.SetText(func() string {
@@ -88,7 +88,7 @@ func (s *serviceClient) showProfilesUI() {
return
}
// switch
err = s.switchProfile(profile.ID)
err = s.switchProfile(profile.Name)
if err != nil {
log.Errorf("failed to switch profile: %v", err)
dialog.ShowError(errors.New("failed to select profile"), s.wProfiles)
@@ -130,7 +130,7 @@ func (s *serviceClient) showProfilesUI() {
logoutBtn.Show()
logoutBtn.SetText("Deregister")
logoutBtn.OnTapped = func() {
s.handleProfileLogout(profile, refresh)
s.handleProfileLogout(profile.Name, refresh)
}
// Remove profile
@@ -144,7 +144,7 @@ func (s *serviceClient) showProfilesUI() {
return
}
err = s.removeProfile(profile.ID)
err = s.removeProfile(profile.Name)
if err != nil {
log.Errorf("failed to remove profile: %v", err)
dialog.ShowError(fmt.Errorf("failed to remove profile"), s.wProfiles)
@@ -250,7 +250,7 @@ func (s *serviceClient) addProfile(profileName string) error {
return nil
}
func (s *serviceClient) switchProfile(handle string) error {
func (s *serviceClient) switchProfile(profileName string) error {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
return fmt.Errorf(getClientFMT, err)
@@ -261,15 +261,15 @@ func (s *serviceClient) switchProfile(handle string) error {
return fmt.Errorf("get current user: %w", err)
}
resp, err := conn.SwitchProfile(s.ctx, &proto.SwitchProfileRequest{
ProfileName: &handle,
if _, err := conn.SwitchProfile(s.ctx, &proto.SwitchProfileRequest{
ProfileName: &profileName,
Username: &currUser.Username,
})
if err != nil {
}); err != nil {
return fmt.Errorf("switch profile failed: %w", err)
}
if err := s.profileManager.SwitchProfile(profilemanager.ID(resp.Id)); err != nil {
err = s.profileManager.SwitchProfile(profileName)
if err != nil {
return fmt.Errorf("switch profile: %w", err)
}
@@ -299,27 +299,10 @@ func (s *serviceClient) removeProfile(profileName string) error {
}
type Profile struct {
ID string
Name string
IsActive bool
}
// formatProfileLabel returns the display label for a profile. Profiles can
// share the same Name, so when more than one profile in profiles carries this
// Name, a short form of the ID is appended to disambiguate the entries.
func formatProfileLabel(profile Profile, profiles []Profile) string {
count := 0
for _, p := range profiles {
if p.Name == profile.Name {
count++
}
}
if count <= 1 {
return profile.Name
}
return fmt.Sprintf("%s (%s)", profile.Name, profilemanager.ID(profile.ID).ShortID())
}
func (s *serviceClient) getProfiles() ([]Profile, error) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
@@ -341,7 +324,6 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
for _, profile := range profilesResp.Profiles {
profiles = append(profiles, Profile{
ID: profile.Id,
Name: profile.Name,
IsActive: profile.IsActive,
})
@@ -350,10 +332,10 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
return profiles, nil
}
func (s *serviceClient) handleProfileLogout(profile Profile, refreshCallback func()) {
func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback func()) {
dialog.ShowConfirm(
"Deregister",
fmt.Sprintf("Are you sure you want to deregister from '%s'?", profile.Name),
fmt.Sprintf("Are you sure you want to deregister from '%s'?", profileName),
func(confirm bool) {
if !confirm {
return
@@ -374,10 +356,8 @@ func (s *serviceClient) handleProfileLogout(profile Profile, refreshCallback fun
}
username := currUser.Username
// ProfileName is treated as a handle; send the ID so the
// daemon resolves to exactly this profile.
_, err = conn.Logout(s.ctx, &proto.LogoutRequest{
ProfileName: &profile.ID,
ProfileName: &profileName,
Username: &username,
})
if err != nil {
@@ -388,7 +368,7 @@ func (s *serviceClient) handleProfileLogout(profile Profile, refreshCallback fun
dialog.ShowInformation(
"Deregistered",
fmt.Sprintf("Successfully deregistered from '%s'", profile.Name),
fmt.Sprintf("Successfully deregistered from '%s'", profileName),
s.wProfiles,
)
@@ -481,7 +461,6 @@ func (p *profileMenu) getProfiles() ([]Profile, error) {
for _, profile := range profilesResp.Profiles {
profiles = append(profiles, Profile{
ID: profile.Id,
Name: profile.Name,
IsActive: profile.IsActive,
})
@@ -522,7 +501,7 @@ func (p *profileMenu) refresh() {
}
if activeProf.ProfileName == "default" || activeProf.Username == currUser.Username {
activeProfState, err := p.profileManager.GetProfileState(profilemanager.ID(activeProf.Id))
activeProfState, err := p.profileManager.GetProfileState(activeProf.ProfileName)
if err != nil {
log.Warnf("failed to get active profile state: %v", err)
p.emailMenuItem.Hide()
@@ -533,7 +512,7 @@ func (p *profileMenu) refresh() {
}
for _, profile := range profiles {
item := p.profileMenuItem.AddSubMenuItem(formatProfileLabel(profile, profiles), "")
item := p.profileMenuItem.AddSubMenuItem(profile.Name, "")
if profile.IsActive {
item.Check()
}
@@ -562,8 +541,8 @@ func (p *profileMenu) refresh() {
return
}
switchResp, err := conn.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &profile.ID,
_, err = conn.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &profile.Name,
Username: &currUser.Username,
})
if err != nil {
@@ -573,7 +552,7 @@ func (p *profileMenu) refresh() {
return
}
err = p.profileManager.SwitchProfile(profilemanager.ID(switchResp.Id))
err = p.profileManager.SwitchProfile(profile.Name)
if err != nil {
log.Errorf("failed to switch profile '%s': %v", profile.Name, err)
return
@@ -748,10 +727,7 @@ func (p *profileMenu) updateMenu() {
}
sort.Slice(profiles, func(i, j int) bool {
if profiles[i].Name != profiles[j].Name {
return profiles[i].Name < profiles[j].Name
}
return profiles[i].ID < profiles[j].ID
return profiles[i].Name < profiles[j].Name
})
p.mu.Lock()

View File

@@ -21,7 +21,6 @@ import (
"github.com/netbirdio/netbird/client/wasm/internal/http"
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
nbwebsocket "github.com/netbirdio/netbird/client/wasm/internal/websocket"
"github.com/netbirdio/netbird/util"
)
@@ -31,7 +30,6 @@ const (
pingTimeout = 10 * time.Second
defaultLogLevel = "warn"
defaultSSHDetectionTimeout = 20 * time.Second
dialWebSocketTimeout = 30 * time.Second
icmpEchoRequest = 8
icmpCodeEcho = 0
@@ -679,7 +677,6 @@ func createClientObject(client *netbird.Client) js.Value {
obj["createSSHConnection"] = createSSHMethod(client)
obj["proxyRequest"] = createProxyRequestMethod(client)
obj["createRDPProxy"] = createRDPProxyMethod(client)
obj["dialWebSocket"] = createDialWebSocketMethod(client)
obj["status"] = createStatusMethod(client)
obj["statusSummary"] = createStatusSummaryMethod(client)
obj["statusDetail"] = createStatusDetailMethod(client)
@@ -694,74 +691,6 @@ func createClientObject(client *netbird.Client) js.Value {
return js.ValueOf(obj)
}
func createDialWebSocketMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
url, protocols, timeout, errVal := parseDialWebSocketArgs(args)
if !errVal.IsUndefined() {
return errVal
}
return createPromise(func(resolve, reject js.Value) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
conn, err := nbwebsocket.Dial(ctx, client, url, protocols)
if err != nil {
reject.Invoke(js.ValueOf(fmt.Sprintf("dial websocket: %v", err)))
return
}
resolve.Invoke(nbwebsocket.NewJSInterface(conn))
})
})
}
func parseDialWebSocketArgs(args []js.Value) (url string, protocols []string, timeout time.Duration, errVal js.Value) {
if len(args) < 1 || args[0].Type() != js.TypeString {
return "", nil, 0, js.ValueOf("error: dialWebSocket requires a URL string argument")
}
url = args[0].String()
if len(args) >= 2 && !args[1].IsNull() && !args[1].IsUndefined() {
arr, err := jsStringArray(args[1])
if err != nil {
return "", nil, 0, js.ValueOf(fmt.Sprintf("error: protocols: %v", err))
}
protocols = arr
}
timeout = dialWebSocketTimeout
if len(args) >= 3 && !args[2].IsNull() && !args[2].IsUndefined() {
if args[2].Type() != js.TypeNumber {
return "", nil, 0, js.ValueOf("error: timeoutMs must be a number")
}
timeoutMs := args[2].Int()
if timeoutMs <= 0 {
return "", nil, 0, js.ValueOf("error: timeout must be positive")
}
timeout = time.Duration(timeoutMs) * time.Millisecond
}
return url, protocols, timeout, js.Undefined()
}
// jsStringArray converts a JS array of strings to a Go []string.
func jsStringArray(v js.Value) ([]string, error) {
if !v.InstanceOf(js.Global().Get("Array")) {
return nil, fmt.Errorf("expected array")
}
n := v.Length()
out := make([]string, n)
for i := 0; i < n; i++ {
el := v.Index(i)
if el.Type() != js.TypeString {
return nil, fmt.Errorf("element %d is not a string", i)
}
out[i] = el.String()
}
return out, nil
}
// netBirdClientConstructor acts as a JavaScript constructor function
func netBirdClientConstructor(_ js.Value, args []js.Value) any {
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {

View File

@@ -1,304 +0,0 @@
//go:build js
package websocket
import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"sync"
"syscall/js"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
netbird "github.com/netbirdio/netbird/client/embed"
log "github.com/sirupsen/logrus"
)
type closeError struct {
code uint16
reason string
}
func (e *closeError) Error() string {
return fmt.Sprintf("websocket closed: %d %s", e.code, e.reason)
}
// bufferedConn fronts a net.Conn with a reader that serves any bytes buffered
// during the WebSocket handshake before falling through to the raw conn.
type bufferedConn struct {
net.Conn
r io.Reader
}
func (c *bufferedConn) Read(p []byte) (int, error) { return c.r.Read(p) }
// Conn wraps a WebSocket connection over a NetBird TCP connection.
type Conn struct {
conn net.Conn
mu sync.Mutex
closed chan struct{}
closeOnce sync.Once
closeErr error
}
// Dial establishes a WebSocket connection to the given URL through the NetBird network.
// Optional protocols are sent via the Sec-WebSocket-Protocol header.
func Dial(ctx context.Context, client *netbird.Client, rawURL string, protocols []string) (*Conn, error) {
d := ws.Dialer{
NetDial: client.Dial,
Protocols: protocols,
}
conn, br, _, err := d.Dial(ctx, rawURL)
if err != nil {
return nil, fmt.Errorf("websocket dial: %w", err)
}
// br is non-nil when the server pushed frames alongside the handshake
// response; those bytes live in the bufio.Reader and must be drained
// before reading from conn, otherwise we'd skip the first frames.
if br != nil {
if br.Buffered() > 0 {
conn = &bufferedConn{Conn: conn, r: io.MultiReader(br, conn)}
} else {
ws.PutReader(br)
}
}
return &Conn{
conn: conn,
closed: make(chan struct{}),
}, nil
}
// ReadMessage reads the next WebSocket message, handling control frames automatically.
func (c *Conn) ReadMessage() (ws.OpCode, []byte, error) {
for {
msgs, err := wsutil.ReadServerMessage(c.conn, nil)
if err != nil {
return 0, nil, err
}
for _, msg := range msgs {
if msg.OpCode.IsControl() {
if err := c.handleControl(msg); err != nil {
return 0, nil, err
}
continue
}
return msg.OpCode, msg.Payload, nil
}
}
}
func (c *Conn) handleControl(msg wsutil.Message) error {
switch msg.OpCode {
case ws.OpPing:
c.mu.Lock()
defer c.mu.Unlock()
return wsutil.WriteClientMessage(c.conn, ws.OpPong, msg.Payload)
case ws.OpClose:
code, reason := parseClosePayload(msg.Payload)
return &closeError{code: code, reason: reason}
default:
return nil
}
}
// WriteText sends a text WebSocket message.
func (c *Conn) WriteText(data []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
return wsutil.WriteClientMessage(c.conn, ws.OpText, data)
}
// WriteBinary sends a binary WebSocket message.
func (c *Conn) WriteBinary(data []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
return wsutil.WriteClientMessage(c.conn, ws.OpBinary, data)
}
// Close sends a close frame with StatusNormalClosure and closes the underlying connection.
func (c *Conn) Close() error {
return c.closeWith(ws.StatusNormalClosure, "")
}
// closeWith sends a close frame with the given code/reason and closes the underlying connection.
// Used to echo the server's code when responding to a server-initiated close per RFC 6455 §5.5.1.
func (c *Conn) closeWith(code ws.StatusCode, reason string) error {
var first bool
c.closeOnce.Do(func() {
first = true
close(c.closed)
c.mu.Lock()
_ = wsutil.WriteClientMessage(c.conn, ws.OpClose, ws.NewCloseFrameBody(code, reason))
c.mu.Unlock()
c.closeErr = c.conn.Close()
})
if !first {
return net.ErrClosed
}
return c.closeErr
}
// NewJSInterface creates a JavaScript object wrapping the WebSocket connection.
// It exposes: send(string|Uint8Array), close(), and callback properties
// onmessage, onclose, onerror.
//
// Callback properties may be set from the JS thread while the read loop
// goroutine reads them. In WASM this is safe because Go and JS share a
// single thread, but the design would need synchronization on
// multi-threaded runtimes.
func NewJSInterface(conn *Conn) js.Value {
obj := js.Global().Get("Object").Call("create", js.Null())
sendFunc := js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 1 {
log.Errorf("websocket send requires a data argument")
return js.ValueOf(false)
}
data := args[0]
switch data.Type() {
case js.TypeString:
if err := conn.WriteText([]byte(data.String())); err != nil {
log.Errorf("failed to send websocket text: %v", err)
return js.ValueOf(false)
}
default:
buf, err := jsToBytes(data)
if err != nil {
log.Errorf("failed to convert js value to bytes: %v", err)
return js.ValueOf(false)
}
if err := conn.WriteBinary(buf); err != nil {
log.Errorf("failed to send websocket binary: %v", err)
return js.ValueOf(false)
}
}
return js.ValueOf(true)
})
obj.Set("send", sendFunc)
closeFunc := js.FuncOf(func(_ js.Value, _ []js.Value) any {
if err := conn.Close(); err != nil {
log.Debugf("failed to close websocket: %v", err)
}
return js.Undefined()
})
obj.Set("close", closeFunc)
go func() {
defer func() {
if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
log.Debugf("close websocket on readLoop exit: %v", err)
}
}()
readLoop(conn, obj)
// Undefining before Release turns post-close JS calls into TypeError
// instead of a silent "call to released function".
obj.Set("send", js.Undefined())
obj.Set("close", js.Undefined())
sendFunc.Release()
closeFunc.Release()
}()
return obj
}
func jsToBytes(data js.Value) ([]byte, error) {
var uint8Array js.Value
switch {
case data.InstanceOf(js.Global().Get("Uint8Array")):
uint8Array = data
case data.InstanceOf(js.Global().Get("ArrayBuffer")):
uint8Array = js.Global().Get("Uint8Array").New(data)
default:
return nil, fmt.Errorf("send: unsupported data type, use string, Uint8Array, or ArrayBuffer")
}
buf := make([]byte, uint8Array.Get("length").Int())
js.CopyBytesToGo(buf, uint8Array)
return buf, nil
}
func readLoop(conn *Conn, obj js.Value) {
var ce *closeError
defer func() { invokeOnClose(obj, ce) }()
for {
select {
case <-conn.closed:
return
default:
}
op, payload, err := conn.ReadMessage()
if err != nil {
ce = handleReadError(conn, obj, err)
return
}
dispatchMessage(obj, op, payload)
}
}
func handleReadError(conn *Conn, obj js.Value, err error) *closeError {
var ce *closeError
if errors.As(err, &ce) {
if cerr := conn.closeWith(ws.StatusCode(ce.code), ce.reason); cerr != nil {
log.Debugf("failed to close websocket after server close frame: %v", cerr)
}
return ce
}
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
return nil
}
if onerror := obj.Get("onerror"); onerror.Truthy() {
onerror.Invoke(js.ValueOf(err.Error()))
}
return nil
}
func invokeOnClose(obj js.Value, ce *closeError) {
onclose := obj.Get("onclose")
if !onclose.Truthy() {
return
}
if ce != nil {
onclose.Invoke(js.ValueOf(int(ce.code)), js.ValueOf(ce.reason))
return
}
onclose.Invoke()
}
func dispatchMessage(obj js.Value, op ws.OpCode, payload []byte) {
onmessage := obj.Get("onmessage")
if !onmessage.Truthy() {
return
}
switch op {
case ws.OpText:
onmessage.Invoke(js.ValueOf(string(payload)))
case ws.OpBinary:
uint8Array := js.Global().Get("Uint8Array").New(len(payload))
js.CopyBytesToJS(uint8Array, payload)
onmessage.Invoke(uint8Array)
}
}
func parseClosePayload(payload []byte) (uint16, string) {
if len(payload) < 2 {
return 1005, "" // RFC 6455: No Status Rcvd
}
code := binary.BigEndian.Uint16(payload[:2])
return code, string(payload[2:])
}

View File

@@ -2,5 +2,4 @@ FROM ubuntu:24.04
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
ENTRYPOINT [ "/go/bin/netbird-server" ]
CMD ["--config", "/etc/netbird/config.yaml"]
ARG TARGETPLATFORM
COPY ${TARGETPLATFORM}/netbird-server /go/bin/netbird-server
COPY netbird-server /go/bin/netbird-server

3
go.mod
View File

@@ -56,7 +56,6 @@ require (
github.com/fsnotify/fsnotify v1.9.0
github.com/gliderlabs/ssh v0.3.8
github.com/go-jose/go-jose/v4 v4.1.4
github.com/gobwas/ws v1.4.0
github.com/goccy/go-yaml v1.18.0
github.com/godbus/dbus/v5 v5.1.0
github.com/golang-jwt/jwt/v5 v5.3.1
@@ -216,8 +215,6 @@ require (
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
github.com/go-webauthn/webauthn v0.16.4 // indirect
github.com/go-webauthn/x v0.2.3 // indirect
github.com/gobwas/httphead v0.1.0 // indirect
github.com/gobwas/pool v0.2.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
github.com/google/btree v1.1.3 // indirect

7
go.sum
View File

@@ -249,12 +249,6 @@ github.com/go-webauthn/webauthn v0.16.4 h1:R9jqR/cYZa7hRquFF7Za/8qoH/K/TIs1/Q/4C
github.com/go-webauthn/webauthn v0.16.4/go.mod h1:SU2ljAgToTV/YLPI0C05QS4qn+e04WpB5g1RMfcZfS4=
github.com/go-webauthn/x v0.2.3 h1:8oArS+Rc1SWFLXhE17KZNx258Z4kUSyaDgsSncCO5RA=
github.com/go-webauthn/x v0.2.3/go.mod h1:tM04GF3V6VYq79AZMl7vbj4q6pz9r7L2criWRzbWhPk=
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs=
github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc=
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
@@ -851,7 +845,6 @@ golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@@ -2,5 +2,4 @@ FROM ubuntu:24.04
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
ENTRYPOINT [ "/go/bin/netbird-mgmt","management"]
CMD ["--log-file", "console"]
ARG TARGETPLATFORM
COPY ${TARGETPLATFORM}/netbird-mgmt /go/bin/netbird-mgmt
COPY netbird-mgmt /go/bin/netbird-mgmt

View File

@@ -0,0 +1,5 @@
FROM ubuntu:24.04
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
ENTRYPOINT [ "/go/bin/netbird-mgmt","management","--log-level","debug"]
CMD ["--log-file", "console"]
COPY netbird-mgmt /go/bin/netbird-mgmt

View File

@@ -45,7 +45,7 @@ type Controller struct {
EphemeralPeersManager ephemeral.Manager
accountUpdateLocks sync.Map
affectedPeerUpdateLocks sync.Map
sendAccountUpdateLocks sync.Map
updateAccountPeersBufferInterval atomic.Int64
// dnsDomain is used for peer resolution. This is appended to the peer's name
dnsDomain string
@@ -64,13 +64,6 @@ type bufferUpdate struct {
update atomic.Bool
}
type bufferAffectedUpdate struct {
sendMu sync.Mutex
dataMu sync.Mutex
next *time.Timer
peerIDs map[string]struct{}
}
var _ network_map.Controller = (*Controller)(nil)
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, ephemeralPeersManager ephemeral.Manager, config *config.Config) *Controller {
@@ -208,7 +201,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
@@ -233,6 +226,44 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
return nil
}
func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName())
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
}
bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
b := bufUpd.(*bufferUpdate)
if !b.mu.TryLock() {
b.update.Store(true)
return nil
}
if b.next != nil {
b.next.Stop()
}
go func() {
defer b.mu.Unlock()
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
if !b.update.Load() {
return
}
b.update.Store(false)
if b.next == nil {
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
})
return
}
b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load()))
}()
return nil
}
// UpdatePeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
@@ -242,143 +273,6 @@ func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, r
return c.sendUpdateAccountPeers(ctx, accountID, reason)
}
// UpdateAffectedPeers updates only the specified peers that belong to an account.
func (c *Controller) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
if len(peerIDs) == 0 {
return nil
}
return c.sendUpdateForAffectedPeers(ctx, accountID, peerIDs)
}
func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: account %s, %d affected peers: %v (caller: %s)", accountID, len(peerIDs), peerIDs, util.GetCallerName())
if !c.hasConnectedPeers(peerIDs) {
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no connected peers among %v, skipping", peerIDs)
return nil
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get account: %v", err)
}
globalStart := time.Now()
peersToUpdate := c.filterConnectedAffectedPeers(account, peerIDs)
if len(peersToUpdate) == 0 {
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no peers to update (affected peers not found in account or no channels)")
return nil
}
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: sending network map to %d connected peers", len(peersToUpdate))
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return fmt.Errorf("failed to get validate peers: %v", err)
}
var wg sync.WaitGroup
semaphore := make(chan struct{}, 10)
account.InjectProxyPolicies(ctx)
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return fmt.Errorf("failed to get proxy network maps: %v", err)
}
extraSetting, err := c.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get flow enabled status: %v", err)
}
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return fmt.Errorf("failed to get account zones: %v", err)
}
for _, peer := range peersToUpdate {
wg.Add(1)
semaphore <- struct{}{}
go func(p *nbpeer.Peer) {
defer wg.Done()
defer func() { <-semaphore }()
start := time.Now()
postureChecks, err := c.getPeerPostureChecks(account, p.ID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", p.ID, err)
return
}
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
start = time.Now()
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
peerGroups := account.GetPeerGroups(p.ID)
start = time.Now()
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
c.metrics.CountToSyncResponseDuration(time.Since(start))
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeNetworkMap,
})
}(peer)
}
wg.Wait()
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersDuration(time.Since(globalStart))
}
return nil
}
func (c *Controller) hasConnectedPeers(peerIDs []string) bool {
for _, id := range peerIDs {
if c.peersUpdateManager.HasChannel(id) {
return true
}
}
return false
}
func (c *Controller) filterConnectedAffectedPeers(account *types.Account, peerIDs []string) []*nbpeer.Peer {
affected := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
affected[id] = struct{}{}
}
var result []*nbpeer.Peer
for _, peer := range account.Peers {
if _, ok := affected[peer.ID]; ok && c.peersUpdateManager.HasChannel(peer.ID) {
result = append(result, peer)
}
}
return result
}
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
if !c.peersUpdateManager.HasChannel(peerId) {
return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId)
@@ -487,164 +381,66 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
return nil
}
// BufferUpdateAffectedPeers accumulates peer IDs and flushes them after the buffer interval.
func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error {
if len(peerIDs) == 0 {
return nil
}
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
}
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName())
bufUpd, _ := c.affectedPeerUpdateLocks.LoadOrStore(accountID, &bufferAffectedUpdate{
peerIDs: make(map[string]struct{}),
})
b := bufUpd.(*bufferAffectedUpdate)
b.addPeerIDs(peerIDs)
if !b.sendMu.TryLock() {
// Another goroutine is already sending; it will pick up our IDs on its next drain.
return nil
}
b.stopTimer()
// The send and the debounced timer outlive the calling request, so detach from
// its context to avoid sending with a cancelled context once the handler returns.
bgCtx := context.WithoutCancel(ctx)
collected := b.drainPeerIDs()
go func() {
defer b.sendMu.Unlock()
_ = c.sendUpdateForAffectedPeers(bgCtx, accountID, collected)
// Check if more peer IDs accumulated while we were sending.
if !b.hasPending() {
return
}
// Schedule a debounced flush for the newly accumulated IDs.
b.setTimer(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
ids := b.drainPeerIDs()
if len(ids) > 0 {
_ = c.sendUpdateForAffectedPeers(bgCtx, accountID, ids)
}
})
}()
return nil
}
func (b *bufferAffectedUpdate) addPeerIDs(ids []string) {
b.dataMu.Lock()
for _, id := range ids {
b.peerIDs[id] = struct{}{}
}
b.dataMu.Unlock()
}
func (b *bufferAffectedUpdate) drainPeerIDs() []string {
b.dataMu.Lock()
defer b.dataMu.Unlock()
if len(b.peerIDs) == 0 {
return nil
}
ids := make([]string, 0, len(b.peerIDs))
for id := range b.peerIDs {
ids = append(ids, id)
}
b.peerIDs = make(map[string]struct{})
return ids
}
func (b *bufferAffectedUpdate) hasPending() bool {
b.dataMu.Lock()
defer b.dataMu.Unlock()
return len(b.peerIDs) > 0
}
func (b *bufferAffectedUpdate) stopTimer() {
b.dataMu.Lock()
defer b.dataMu.Unlock()
if b.next != nil {
b.next.Stop()
}
}
func (b *bufferAffectedUpdate) setTimer(d time.Duration, f func()) {
b.dataMu.Lock()
defer b.dataMu.Unlock()
if b.next == nil {
b.next = time.AfterFunc(d, f)
return
}
b.next.Reset(d)
}
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error) {
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return nil, nil, 0, err
return nil, nil, nil, 0, err
}
emptyMap := &types.NetworkMap{
Network: network.Copy(),
}
return emptyMap, nil, 0, nil
return peer, emptyMap, nil, 0, nil
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, 0, err
return nil, nil, nil, 0, err
}
account.InjectProxyPolicies(ctx)
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, nil, 0, err
return nil, nil, nil, 0, err
}
startPosture := time.Now()
postureChecks, err := c.getPeerPostureChecks(account, peerID)
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
if err != nil {
return nil, nil, 0, err
return nil, nil, nil, 0, err
}
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return nil, nil, 0, err
return nil, nil, nil, 0, err
}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return nil, nil, 0, err
return nil, nil, nil, 0, err
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peerID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
proxyNetworkMap, ok := proxyNetworkMaps[peerID]
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
networkMap.Merge(proxyNetworkMap)
}
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
return networkMap, postureChecks, dnsFwdPort, nil
return peer, networkMap, postureChecks, dnsFwdPort, nil
}
// GetDNSDomain returns the configured dnsDomain
@@ -782,24 +578,21 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
return false, nil
}
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
if len(affectedPeerIDs) == 0 {
log.WithContext(ctx).Tracef("no affected peers for peer update in account %s, skipping", accountID)
return nil
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
err := c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate})
if err != nil {
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
}
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate})
return nil
}
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
if len(affectedPeerIDs) == 0 {
log.WithContext(ctx).Tracef("no affected peers for peer add in account %s, skipping", accountID)
return nil
}
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationCreate})
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationCreate})
}
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return err
@@ -832,11 +625,7 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
c.peersUpdateManager.CloseChannel(ctx, peerID)
}
if len(affectedPeerIDs) == 0 {
log.WithContext(ctx).Tracef("no affected peers for peer delete in account %s, skipping", accountID)
return nil
}
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
}
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)

View File

@@ -19,19 +19,17 @@ const (
type Controller interface {
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error
BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error)
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
GetDNSDomain(settings *types.Settings) string
StartWarmup(context.Context)
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
CountStreams() int
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string, affectedPeerIDs []string) error
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error
DisconnectPeers(ctx context.Context, accountId string, peerIDs []string)
OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerID string)

View File

@@ -57,20 +57,6 @@ func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, r
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID, reason)
}
// BufferUpdateAffectedPeers mocks base method.
func (m *MockController) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BufferUpdateAffectedPeers", ctx, accountID, peerIDs, reason)
ret0, _ := ret[0].(error)
return ret0
}
// BufferUpdateAffectedPeers indicates an expected call of BufferUpdateAffectedPeers.
func (mr *MockControllerMockRecorder) BufferUpdateAffectedPeers(ctx, accountID, peerIDs, reason any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAffectedPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAffectedPeers), ctx, accountID, peerIDs, reason)
}
// CountStreams mocks base method.
func (m *MockController) CountStreams() int {
m.ctrl.T.Helper()
@@ -127,20 +113,21 @@ func (mr *MockControllerMockRecorder) GetNetworkMap(ctx, peerID any) *gomock.Cal
}
// GetValidatedPeerWithMap mocks base method.
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error) {
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, peerID)
ret0, _ := ret[0].(*types.NetworkMap)
ret1, _ := ret[1].([]*posture.Checks)
ret2, _ := ret[2].(int64)
ret3, _ := ret[3].(error)
return ret0, ret1, ret2, ret3
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(*types.NetworkMap)
ret2, _ := ret[2].([]*posture.Checks)
ret3, _ := ret[3].(int64)
ret4, _ := ret[4].(error)
return ret0, ret1, ret2, ret3, ret4
}
// GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap.
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peerID any) *gomock.Call {
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, peerID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
}
// OnPeerConnected mocks base method.
@@ -171,45 +158,45 @@ func (mr *MockControllerMockRecorder) OnPeerDisconnected(ctx, accountID, peerID
}
// OnPeersAdded mocks base method.
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs, affectedPeerIDs)
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeersAdded indicates an expected call of OnPeersAdded.
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs, affectedPeerIDs any) *gomock.Call {
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs, affectedPeerIDs)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs)
}
// OnPeersDeleted mocks base method.
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs, affectedPeerIDs)
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeersDeleted indicates an expected call of OnPeersDeleted.
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs, affectedPeerIDs any) *gomock.Call {
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs, affectedPeerIDs)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs)
}
// OnPeersUpdated mocks base method.
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string, affectedPeerIDs []string) error {
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs, affectedPeerIDs)
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeersUpdated indicates an expected call of OnPeersUpdated.
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs, affectedPeerIDs any) *gomock.Call {
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs, affectedPeerIDs)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs)
}
// StartWarmup mocks base method.
@@ -263,17 +250,3 @@ func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID, reason
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID, reason)
}
// UpdateAffectedPeers mocks base method.
func (m *MockController) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateAffectedPeers", ctx, accountID, peerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateAffectedPeers indicates an expected call of UpdateAffectedPeers.
func (mr *MockControllerMockRecorder) UpdateAffectedPeers(ctx, accountID, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAffectedPeers", reflect.TypeOf((*MockController)(nil).UpdateAffectedPeers), ctx, accountID, peerIDs)
}

View File

@@ -242,7 +242,7 @@ func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, pee
},
}
_, _, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, true)
_, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, true)
if err != nil {
return fmt.Errorf("failed to create proxy peer: %w", err)
}

View File

@@ -918,10 +918,6 @@ func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID strin
}
for _, svc := range services {
if err = transaction.DeleteServiceTargets(ctx, accountID, svc.ID); err != nil {
return fmt.Errorf("failed to delete service targets: %w", err)
}
if err = transaction.DeleteService(ctx, accountID, svc.ID); err != nil {
return fmt.Errorf("failed to delete service: %w", err)
}
@@ -1274,10 +1270,6 @@ func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serv
return status.Errorf(status.PermissionDenied, "cannot delete service exposed by another peer")
}
if err = transaction.DeleteServiceTargets(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service targets: %w", err)
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service: %w", err)
}
@@ -1327,10 +1319,6 @@ func (m *Manager) deleteExpiredPeerService(ctx context.Context, accountID, peerI
return nil
}
if err = transaction.DeleteServiceTargets(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service targets: %w", err)
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service: %w", err)
}

View File

@@ -434,7 +434,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
t.Helper()
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
return srv
}
@@ -458,9 +458,6 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
txMock.EXPECT().
GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID).
Return(newEphemeralService(), nil)
txMock.EXPECT().
DeleteServiceTargets(ctx, accountID, serviceID).
Return(nil)
txMock.EXPECT().
DeleteService(ctx, accountID, serviceID).
Return(nil)
@@ -563,9 +560,6 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
txMock.EXPECT().
GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID).
Return(newEphemeralService(), nil)
txMock.EXPECT().
DeleteServiceTargets(ctx, accountID, serviceID).
Return(nil)
txMock.EXPECT().
DeleteService(ctx, accountID, serviceID).
Return(nil)
@@ -610,9 +604,6 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
txMock.EXPECT().
GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID).
Return(newEphemeralService(), nil)
txMock.EXPECT().
DeleteServiceTargets(ctx, accountID, serviceID).
Return(nil)
txMock.EXPECT().
DeleteService(ctx, accountID, serviceID).
Return(nil)
@@ -723,7 +714,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
@@ -1147,7 +1138,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
@@ -1201,67 +1192,6 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
assert.Len(t, targets, 0, "All targets should be deleted when service is deleted")
}
func TestDeleteExpiredPeerService_DeletesTargets(t *testing.T) {
ctx := context.Background()
mgr, testStore := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
targets, err := testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
require.NoError(t, err)
require.Len(t, targets, 1, "ephemeral peer-exposed service should have exactly one persisted target before reaping")
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
err = mgr.deleteExpiredPeerService(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
require.Error(t, err, "expired peer-exposed service should be deleted")
s, ok := status.FromError(err)
require.True(t, ok)
assert.Equal(t, status.NotFound, s.Type())
targets, err = testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
require.NoError(t, err)
assert.Len(t, targets, 0, "orphaned target rows must be deleted when an expired peer-exposed service is reaped")
}
func TestDeleteServiceFromPeer_DeletesTargets(t *testing.T) {
ctx := context.Background()
mgr, testStore := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
targets, err := testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
require.NoError(t, err)
require.Len(t, targets, 1, "ephemeral peer-exposed service should have exactly one persisted target before stopping")
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
require.Error(t, err, "stopped peer-exposed service should be deleted")
s, ok := status.FromError(err)
require.True(t, ok)
assert.Equal(t, status.NotFound, s.Type())
targets, err = testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
require.NoError(t, err)
assert.Len(t, targets, 0, "orphaned target rows must be deleted when a peer stops its exposed service")
}
func TestValidateProtocolChange(t *testing.T) {
tests := []struct {
name string

View File

@@ -219,7 +219,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.IdpManager(), s.ProxyManager(), s.Store())
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store())
s.AfterInit(func(s *BaseServer) {
proxyService.SetServiceManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController())

View File

@@ -33,8 +33,6 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
proxyauth "github.com/netbirdio/netbird/proxy/auth"
@@ -84,9 +82,6 @@ type ProxyServiceServer struct {
// Manager for users
usersManager users.Manager
// Manager for IdP-enriched user data (may be nil when no IdP is configured)
idpManager idp.Manager
// Store for one-time authentication tokens
tokenStore *OneTimeTokenStore
@@ -162,7 +157,7 @@ func enforceAccountScope(ctx context.Context, requestAccountID string) error {
}
// NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, idpManager idp.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
ctx, cancel := context.WithCancel(context.Background())
s := &ProxyServiceServer{
accessLogManager: accessLogMgr,
@@ -171,7 +166,6 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
pkceVerifierStore: pkceStore,
peersManager: peersManager,
usersManager: usersManager,
idpManager: idpManager,
proxyManager: proxyMgr,
tokenChecker: tokenChecker,
snapshotBatchSize: snapshotBatchSizeFromEnv(),
@@ -1708,7 +1702,22 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
}
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
principalID, displayIdentity := s.getTunnelPeerInfo(ctx, domain, service, peer)
// Resolve the principal: when the peer is linked to a user, the human
// is the principal so multiple peers owned by the same user share a
// single identity. Unlinked peers (machine agents) are their own
// principal keyed on peer.ID. displayIdentity is what upstream gateways
// tag spend with — user.Email when linked, peer.Name when not.
principalID := peer.ID
displayIdentity := peer.Name
if peer.UserID != "" {
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
principalID = user.Id
if user.Email != "" {
displayIdentity = user.Email
}
}
}
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
@@ -1745,45 +1754,6 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
}, nil
}
// getTunnelPeerInfo returns the principal ID and display name for a peer, e.g. a
// user or peer ID, and peer name or user email.
func (s *ProxyServiceServer) getTunnelPeerInfo(ctx context.Context, domain string, service *rpservice.Service, peer *peer.Peer) (string, string) {
// Resolve the principal: when the peer is linked to a user, the human is the
// principal so multiple peers owned by the same user share a single
// identity. Unlinked peers (machine agents) are their own principal keyed on
// peer.ID. displayIdentity is what upstream gateways tag spend with —
// user.Email when linked, peer.Name when not.
// If the peer isn't associated with a user, return the peer info directly.
if peer.UserID == "" {
return peer.ID, peer.Name
}
// Otherwise, if the peer is linked to a user, the user is the principal and
// if an IdP is available, we gather details on the user from it.
principalID := peer.UserID
displayIdentity := peer.Name
// Stored column first (cheap, but often empty for OIDC-provisioned users).
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
principalID = user.Id
if user.Email != "" {
displayIdentity = user.Email
}
}
// IdP enrichment wins when available — the stored email column is a
// best-effort cache and is frequently empty for OIDC users. Enrichment
// failures must never fail the RPC; we simply keep the stored/peer identity.
if s.idpManager != nil {
if ud, uerr := s.idpManager.GetUserDataByID(ctx, peer.UserID, idp.AppMetadata{WTAccountID: service.AccountID}); uerr == nil && ud != nil && ud.Email != "" {
displayIdentity = ud.Email
} else if uerr != nil {
log.WithFields(log.Fields{"domain": domain, "user_id": peer.UserID, "error": uerr.Error()}).Debug("ValidateTunnelPeer: IdP user enrichment failed; using stored/peer identity")
}
}
return principalID, displayIdentity
}
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
// groups. Private services authorise against AccessGroups (empty list fails
// closed — Validate() rejects that at save time but the RPC is the security

View File

@@ -3,19 +3,14 @@ package grpc
import (
"context"
"errors"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
type mockReverseProxyManager struct {
@@ -142,52 +137,6 @@ func (m *mockUsersManager) GetUserWithGroups(ctx context.Context, userID string)
return user, nil, nil
}
// mockTunnelPeersManager implements only the two peers.Manager methods that
// ValidateTunnelPeer calls; the embedded interface satisfies the rest (and
// panics if any unexpected method is invoked).
type mockTunnelPeersManager struct {
peers.Manager
peer *peer.Peer
peerErr error
groups []*types.Group
groupsErr error
}
func (m *mockTunnelPeersManager) GetPeerByTunnelIP(_ context.Context, _ string, _ net.IP) (*peer.Peer, error) {
return m.peer, m.peerErr
}
func (m *mockTunnelPeersManager) GetPeerWithGroups(_ context.Context, _, _ string) (*peer.Peer, []*types.Group, error) {
return m.peer, m.groups, m.groupsErr
}
// mockTunnelIdpManager implements only GetUserDataByID; the embedded interface
// satisfies the rest of idp.Manager. hasData==false returns (nil, nil) to model
// an IdP that knows nothing about the user.
type mockTunnelIdpManager struct {
idp.Manager
email string
hasData bool
err error
gotCalls int
gotMeta []idp.AppMetadata
}
func (m *mockTunnelIdpManager) GetUserDataByID(_ context.Context, userID string, meta idp.AppMetadata) (*idp.UserData, error) {
m.gotCalls++
m.gotMeta = append(m.gotMeta, meta)
if m.err != nil {
return nil, m.err
}
if !m.hasData {
// This might not be a thing any of the actual IDP implementations do,
// i.e. return a nil value with no error, but it seems valuable to test
// that behavior here.
return nil, nil //nolint:nilnil
}
return &idp.UserData{ID: userID, Email: m.email}, nil
}
func TestValidateUserGroupAccess(t *testing.T) {
tests := []struct {
name string
@@ -405,163 +354,6 @@ func TestValidateUserGroupAccess(t *testing.T) {
}
}
// TestValidateTunnelPeerUserEmailEnrichment verifies the UserEmail/UserId
// resolution in ValidateTunnelPeer, including the IdP-enrichment fallback order
// (IdP email -> stored User.Email -> peer.Name).
func TestValidateTunnelPeerUserEmailEnrichment(t *testing.T) {
const (
domain = "app.example.com"
accountID = "account1"
peerID = "peer1"
peerName = "peer-display-name"
userID = "user1"
)
storedUser := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: "stored@example.com"}}
storedUserNoEmail := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: ""}}
tests := []struct {
name string
peerUserID string
storedUsers map[string]*types.User
storedErr error
noIdP bool
idpEmail string
idpHasData bool
idpErr error
expectEmail string
expectUserID string
expectIdPHit bool
}{
{
name: "idp email wins over stored email",
peerUserID: userID,
storedUsers: storedUser,
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: "idp@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when idp returns empty email",
peerUserID: userID,
storedUsers: storedUser,
idpEmail: "",
idpHasData: true,
expectEmail: "stored@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when idp has no data",
peerUserID: userID,
storedUsers: storedUser,
idpHasData: false,
expectEmail: "stored@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when idp errors",
peerUserID: userID,
storedUsers: storedUser,
idpErr: errors.New("idp unreachable"),
expectEmail: "stored@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when no idp manager",
peerUserID: userID,
storedUsers: storedUser,
noIdP: true,
expectEmail: "stored@example.com",
expectUserID: userID,
},
{
name: "idp email when stored email is empty",
peerUserID: userID,
storedUsers: storedUserNoEmail,
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: "idp@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "idp email when stored user missing keeps peer.UserID as principal",
peerUserID: userID,
storedUsers: map[string]*types.User{},
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: "idp@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "unlinked peer uses peer name and never consults idp",
peerUserID: "",
storedUsers: storedUser,
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: peerName,
expectUserID: peerID,
expectIdPHit: false,
},
{
name: "linked peer with empty stored email and no idp falls back to peer name",
peerUserID: userID,
storedUsers: storedUserNoEmail,
noIdP: true,
expectEmail: peerName,
expectUserID: userID,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc := &service.Service{Domain: domain, AccountID: accountID}
server := &ProxyServiceServer{
serviceManager: &mockReverseProxyManager{
proxiesByAccount: map[string][]*service.Service{accountID: {svc}},
},
peersManager: &mockTunnelPeersManager{
peer: &peer.Peer{ID: peerID, Name: peerName, UserID: tt.peerUserID},
},
usersManager: &mockUsersManager{users: tt.storedUsers, err: tt.storedErr},
}
var idpMock *mockTunnelIdpManager
if !tt.noIdP {
idpMock = &mockTunnelIdpManager{email: tt.idpEmail, hasData: tt.idpHasData, err: tt.idpErr}
server.idpManager = idpMock
}
resp, err := server.ValidateTunnelPeer(context.Background(), &proto.ValidateTunnelPeerRequest{
Domain: domain,
TunnelIp: "100.64.0.1",
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.True(t, resp.GetValid(), "expected access granted")
assert.Equal(t, tt.expectEmail, resp.GetUserEmail())
assert.Equal(t, tt.expectUserID, resp.GetUserId())
if idpMock != nil {
if tt.expectIdPHit {
assert.Equal(t, 1, idpMock.gotCalls, "expected IdP to be consulted")
require.Len(t, idpMock.gotMeta, 1)
assert.Equal(t, accountID, idpMock.gotMeta[0].WTAccountID)
} else {
assert.Equal(t, 0, idpMock.gotCalls, "expected IdP to not be consulted")
}
}
})
}
}
func TestGetAccountProxyByDomain(t *testing.T) {
tests := []struct {
name string

View File

@@ -778,7 +778,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
sshKey = loginReq.GetPeerKeys().GetSshPubKey()
}
peer, network, postureChecks, enableSSH, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{
peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{
WireGuardPubKey: peerKey.String(),
SSHKey: string(sshKey),
Meta: peerMeta,
@@ -792,7 +792,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
return nil, mapError(ctx, err)
}
loginResp, err := s.prepareLoginResponse(ctx, peer, network, postureChecks, enableSSH)
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
if err != nil {
log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err)
return nil, status.Errorf(codes.Internal, "failed logging in peer")
@@ -895,7 +895,7 @@ func (s *Server) ExtendAuthSession(ctx context.Context, req *proto.EncryptedMess
}, nil
}
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, network *types.Network, postureChecks []*posture.Checks, enableSSH bool) (*proto.LoginResponse, error) {
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
var relayToken *Token
var err error
if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 {
@@ -914,7 +914,7 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
// if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
PeerConfig: toPeerConfig(peer, network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, enableSSH),
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, netMap.EnableSSH),
Checks: toProtocolChecks(ctx, postureChecks),
}

View File

@@ -42,7 +42,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, nil, proxyManager, nil)
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil)
proxyService.SetServiceManager(serviceManager)
createTestProxies(t, ctx, testStore)

View File

@@ -1894,7 +1894,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
}
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano(), netMap); err != nil {
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano()); err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
@@ -2577,9 +2577,7 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
if err != nil {
return err
}
changedPeerIDs := []string{peerID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, changedPeerIDs, affectedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, []string{peerID})
if err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -2670,9 +2668,7 @@ func (am *DefaultAccountManager) UpdatePeerIPv6(ctx context.Context, accountID,
}
if updateNetworkMap {
changedPeerIDs := []string{peerID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peerID}); err != nil {
return fmt.Errorf("notify network map controller: %w", err)
}
}

View File

@@ -13,7 +13,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/affectedpeers"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/idp"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -62,7 +61,7 @@ type Manager interface {
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error
MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
@@ -70,7 +69,7 @@ type Manager interface {
UpdatePeerIPv6(ctx context.Context, accountID, userID, peerID string, newIPv6 netip.Addr) error
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error)
AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
@@ -109,8 +108,8 @@ type Manager interface {
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) // used by peer gRPC API
ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) // used by peer gRPC API for ExtendAuthSession
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) // used by peer gRPC API for ExtendAuthSession
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API
GetExternalCacheManager() ExternalCacheManager
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
@@ -129,7 +128,6 @@ type Manager interface {
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
ExpandAndUpdateAffected(ctx context.Context, accountID string, snap *affectedpeers.Snapshot, change affectedpeers.Change)
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error

View File

@@ -15,7 +15,6 @@ import (
dns "github.com/netbirdio/netbird/dns"
service "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
activity "github.com/netbirdio/netbird/management/server/activity"
affectedpeers "github.com/netbirdio/netbird/management/server/affectedpeers"
idp "github.com/netbirdio/netbird/management/server/idp"
peer "github.com/netbirdio/netbird/management/server/peer"
posture "github.com/netbirdio/netbird/management/server/posture"
@@ -80,15 +79,14 @@ func (mr *MockManagerMockRecorder) AccountExists(ctx, accountID interface{}) *go
}
// AddPeer mocks base method.
func (m *MockManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, p *peer.Peer, temporary bool) (*peer.Peer, *types.Network, []*posture.Checks, bool, error) {
func (m *MockManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, p *peer.Peer, temporary bool) (*peer.Peer, *types.NetworkMap, []*posture.Checks, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddPeer", ctx, accountID, setupKey, userID, p, temporary)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(*types.Network)
ret1, _ := ret[1].(*types.NetworkMap)
ret2, _ := ret[2].([]*posture.Checks)
ret3, _ := ret[3].(bool)
ret4, _ := ret[4].(error)
return ret0, ret1, ret2, ret3, ret4
ret3, _ := ret[3].(error)
return ret0, ret1, ret2, ret3
}
// AddPeer indicates an expected call of AddPeer.
@@ -1290,15 +1288,14 @@ func (mr *MockManagerMockRecorder) ListUsers(ctx, accountID interface{}) *gomock
}
// LoginPeer mocks base method.
func (m *MockManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*peer.Peer, *types.Network, []*posture.Checks, bool, error) {
func (m *MockManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*peer.Peer, *types.NetworkMap, []*posture.Checks, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoginPeer", ctx, login)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(*types.Network)
ret1, _ := ret[1].(*types.NetworkMap)
ret2, _ := ret[2].([]*posture.Checks)
ret3, _ := ret[3].(bool)
ret4, _ := ret[4].(error)
return ret0, ret1, ret2, ret3, ret4
ret3, _ := ret[3].(error)
return ret0, ret1, ret2, ret3
}
// LoginPeer indicates an expected call of LoginPeer.
@@ -1323,17 +1320,17 @@ func (mr *MockManagerMockRecorder) ExtendPeerSession(ctx, peerPubKey, userID int
}
// MarkPeerConnected mocks base method.
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt)
ret0, _ := ret[0].(error)
return ret0
}
// MarkPeerConnected indicates an expected call of MarkPeerConnected.
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt)
}
// MarkPeerDisconnected mocks base method.
@@ -1640,18 +1637,6 @@ func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID, reason int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID, reason)
}
// ExpandAndUpdateAffected mocks base method.
func (m *MockManager) ExpandAndUpdateAffected(ctx context.Context, accountID string, snap *affectedpeers.Snapshot, change affectedpeers.Change) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ExpandAndUpdateAffected", ctx, accountID, snap, change)
}
// ExpandAndUpdateAffected indicates an expected call of ExpandAndUpdateAffected.
func (mr *MockManagerMockRecorder) ExpandAndUpdateAffected(ctx, accountID, snap, change interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExpandAndUpdateAffected", reflect.TypeOf((*MockManager)(nil).ExpandAndUpdateAffected), ctx, accountID, snap, change)
}
// UpdateAccountSettings mocks base method.
func (m *MockManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
m.ctrl.T.Helper()

View File

@@ -84,7 +84,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager nbAccount.Manager, account
setupKey = key.Key
}
_, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey, userID, peer, false)
_, _, _, err := manager.AddPeer(context.Background(), "", setupKey, userID, peer, false)
if err != nil {
t.Error("expected to add new peer successfully after creating new account, but failed", err)
}
@@ -1092,7 +1092,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
}
expectedPeerKey := key.PublicKey().String()
peer, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
}, false)
@@ -1156,7 +1156,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
expectedPeerKey := key.PublicKey().String()
expectedUserID := userID
peer, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
}, false)
@@ -1504,7 +1504,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
peerKey := key.PublicKey().String()
peer, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: peerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: peerKey},
}, false)
@@ -1826,7 +1826,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peer, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
@@ -1836,7 +1836,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
require.NoError(t, err, "unable to mark peer connected")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
@@ -1882,7 +1882,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
@@ -1907,7 +1907,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second)
@@ -1927,7 +1927,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
}, false)
@@ -1935,7 +1935,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
t.Run("disconnect peer when session token matches", func(t *testing.T) {
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -1956,7 +1956,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) {
// Newer stream wins on connect (sets SessionStartedAt = now ns).
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -1980,7 +1980,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) {
node2SyncTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano())
require.NoError(t, err, "node 2 should connect peer")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -1990,7 +1990,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
"SessionStartedAt should equal node2SyncTime token")
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano())
require.NoError(t, err, "stale connect should not return error")
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -2017,7 +2017,7 @@ func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "race-peer"},
}, false)
@@ -2052,7 +2052,7 @@ func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
defer done.Done()
ready.Done()
start.Wait()
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token, nil)
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token)
}()
}
@@ -2080,7 +2080,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
@@ -2093,7 +2093,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{}
@@ -3215,7 +3215,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
return nil, nil, err
}
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, nil, proxyManager, nil)
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager, nil)
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
if err != nil {
return nil, nil, err
@@ -3276,7 +3276,7 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
}
expectedPeerKey := key.PublicKey().String()
peer, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
Status: &nbpeer.PeerStatus{
@@ -3305,19 +3305,6 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
// when the channel delivers.
const peerUpdateTimeout = 5 * time.Second
func drainPeerUpdates(ch <-chan *network_map.UpdateMessage) {
for {
select {
case _, ok := <-ch:
if !ok {
return
}
case <-time.After(200 * time.Millisecond):
return
}
}
}
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
t.Helper()
select {
@@ -3444,7 +3431,7 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
_, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
WireGuardPubKey: account.Peers["peer-1"].Key,
SSHKey: "someKey",
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
@@ -3513,7 +3500,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
_, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
WireGuardPubKey: "some-new-key" + strconv.Itoa(i),
SSHKey: "someKey",
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
@@ -3908,13 +3895,13 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
key2, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peer1, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
}, false)
require.NoError(t, err, "unable to add peer1")
peer2, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
}, false)

Some files were not shown because too many files have changed in this diff Show More