diff --git a/.github/workflows/check-license-dependencies.yml b/.github/workflows/check-license-dependencies.yml index d1d2a8e50..a721cb516 100644 --- a/.github/workflows/check-license-dependencies.yml +++ b/.github/workflows/check-license-dependencies.yml @@ -31,7 +31,7 @@ jobs: while IFS= read -r dir; do echo "=== Checking $dir ===" # Search for problematic imports, excluding test files - RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true) + RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" | grep -v "tools/idp-migrate/" || true) if [ -n "$RESULTS" ]; then echo "❌ Found problematic dependencies:" echo "$RESULTS" @@ -88,7 +88,7 @@ jobs: IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath") # Check if any importer is NOT in management/signal/relay - BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\)" | head -1) + BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1) if [ -n "$BSD_IMPORTER" ]; then echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER" diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 9e753ce73..62dfe9bce 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -19,7 +19,7 @@ jobs: - name: codespell uses: codespell-project/actions-codespell@v2 with: - ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te + ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA skip: go.mod,go.sum,**/proxy/web/** golangci: strategy: diff --git a/.github/workflows/proto-version-check.yml b/.github/workflows/proto-version-check.yml new file mode 100644 index 000000000..ea300419d --- /dev/null +++ b/.github/workflows/proto-version-check.yml @@ -0,0 +1,62 @@ +name: Proto Version Check + +on: + pull_request: + paths: + - "**/*.pb.go" + +jobs: + check-proto-versions: + runs-on: ubuntu-latest + steps: + - name: Check for proto tool version changes + uses: actions/github-script@v7 + with: + script: | + const files = await github.paginate(github.rest.pulls.listFiles, { + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: context.issue.number, + per_page: 100, + }); + + const pbFiles = files.filter(f => f.filename.endsWith('.pb.go')); + const missingPatch = pbFiles.filter(f => !f.patch).map(f => f.filename); + if (missingPatch.length > 0) { + core.setFailed( + `Cannot inspect patch data for:\n` + + missingPatch.map(f => `- ${f}`).join('\n') + + `\nThis can happen with very large PRs. Verify proto versions manually.` + ); + return; + } + const versionPattern = /^[+-]\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/; + const violations = []; + + for (const file of pbFiles) { + const changed = file.patch + .split('\n') + .filter(line => versionPattern.test(line)); + if (changed.length > 0) { + violations.push({ + file: file.filename, + lines: changed, + }); + } + } + + if (violations.length > 0) { + const details = violations.map(v => + `${v.file}:\n${v.lines.map(l => ' ' + l).join('\n')}` + ).join('\n\n'); + + core.setFailed( + `Proto version strings changed in generated files.\n` + + `This usually means the wrong protoc or protoc-gen-go version was used.\n` + + `Regenerate with the matching tool versions.\n\n` + + details + ); + return; + } + + console.log('No proto version string changes detected'); diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 83444b541..c1ae01a98 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.1.1" + SIGN_PIPE_VER: "v0.1.4" GORELEASER_VER: "v2.14.3" PRODUCT_NAME: "NetBird" COPYRIGHT: "NetBird GmbH" @@ -114,7 +114,13 @@ jobs: retention-days: 30 release: - runs-on: ubuntu-latest-m + runs-on: ubuntu-24.04-8-core + outputs: + release_artifact_url: ${{ steps.upload_release.outputs.artifact-url }} + linux_packages_artifact_url: ${{ steps.upload_linux_packages.outputs.artifact-url }} + windows_packages_artifact_url: ${{ steps.upload_windows_packages.outputs.artifact-url }} + macos_packages_artifact_url: ${{ steps.upload_macos_packages.outputs.artifact-url }} + ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }} env: flags: "" steps: @@ -213,10 +219,13 @@ jobs: if: always() run: rm -f /tmp/gpg-rpm-signing-key.asc - name: Tag and push images (amd64 only) + id: tag_and_push_images if: | (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) || (github.event_name == 'push' && github.ref == 'refs/heads/main') run: | + set -euo pipefail + resolve_tags() { if [[ "${{ github.event_name }}" == "pull_request" ]]; then echo "pr-${{ github.event.pull_request.number }}" @@ -225,6 +234,17 @@ jobs: fi } + ghcr_package_url() { + local image="$1" package encoded_package + package="${image#ghcr.io/}" + package="${package#*/}" + package="${package%%:*}" + encoded_package="${package//\//%2F}" + echo "https://github.com/orgs/netbirdio/packages/container/package/${encoded_package}" + } + + image_refs=() + tag_and_push() { local src="$1" img_name tag dst img_name="${src%%:*}" @@ -233,35 +253,56 @@ jobs: echo "Tagging ${src} -> ${dst}" docker tag "$src" "$dst" docker push "$dst" + image_refs+=("$dst") done } - export -f tag_and_push resolve_tags + cat > /tmp/goreleaser-artifacts.json <<'JSON' + ${{ steps.goreleaser.outputs.artifacts }} + JSON - echo '${{ steps.goreleaser.outputs.artifacts }}' | \ - jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \ - grep '^ghcr.io/' | while read -r SRC; do - tag_and_push "$SRC" - done + mapfile -t src_images < <( + jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name | select(startswith("ghcr.io/"))' /tmp/goreleaser-artifacts.json + ) + + for src in "${src_images[@]}"; do + tag_and_push "$src" + done + + { + echo "images_markdown<> "$GITHUB_OUTPUT" - name: upload non tags for debug purposes + id: upload_release uses: actions/upload-artifact@v4 with: name: release path: dist/ retention-days: 7 - name: upload linux packages + id: upload_linux_packages uses: actions/upload-artifact@v4 with: name: linux-packages path: dist/netbird_linux** retention-days: 7 - name: upload windows packages + id: upload_windows_packages uses: actions/upload-artifact@v4 with: name: windows-packages path: dist/netbird_windows** retention-days: 7 - name: upload macos packages + id: upload_macos_packages uses: actions/upload-artifact@v4 with: name: macos-packages @@ -270,6 +311,8 @@ jobs: release_ui: runs-on: ubuntu-latest + outputs: + release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }} steps: - name: Parse semver string id: semver_parser @@ -360,6 +403,7 @@ jobs: if: always() run: rm -f /tmp/gpg-rpm-signing-key.asc - name: upload non tags for debug purposes + id: upload_release_ui uses: actions/upload-artifact@v4 with: name: release-ui @@ -368,6 +412,8 @@ jobs: release_ui_darwin: runs-on: macos-latest + outputs: + release_ui_darwin_artifact_url: ${{ steps.upload_release_ui_darwin.outputs.artifact-url }} steps: - if: ${{ !startsWith(github.ref, 'refs/tags/v') }} run: echo "flags=--snapshot" >> $GITHUB_ENV @@ -402,15 +448,258 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: upload non tags for debug purposes + id: upload_release_ui_darwin uses: actions/upload-artifact@v4 with: name: release-ui-darwin path: dist/ retention-days: 3 - trigger_signer: + test_windows_installer: + name: "Windows Installer / Build Test" + runs-on: windows-2022 + needs: [release, release_ui] + strategy: + fail-fast: false + matrix: + include: + - arch: amd64 + wintun_arch: amd64 + - arch: arm64 + wintun_arch: arm64 + defaults: + run: + shell: powershell + env: + PackageWorkdir: netbird_windows_${{ matrix.arch }} + downloadPath: '${{ github.workspace }}\temp' + steps: + - name: Parse semver string + id: semver_parser + uses: booxmedialtd/ws-action-parse-semver@v1 + with: + input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }} + version_extractor_regex: '\/v(.*)$' + + - name: Checkout + uses: actions/checkout@v4 + + - name: Add 7-Zip to PATH + run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + + - name: Download release artifacts + uses: actions/download-artifact@v4 + with: + name: release + path: release + + - name: Download UI release artifacts + uses: actions/download-artifact@v4 + with: + name: release-ui + path: release-ui + + - name: Stage binaries into dist + run: | + $workdir = "dist\${{ env.PackageWorkdir }}" + New-Item -ItemType Directory -Force -Path $workdir | Out-Null + $client = Get-ChildItem -Recurse -Path release -Filter "netbird_*_windows_${{ matrix.arch }}.tar.gz" | Select-Object -First 1 + $ui = Get-ChildItem -Recurse -Path release-ui -Filter "netbird-ui-windows_*_windows_${{ matrix.arch }}.tar.gz" | Select-Object -First 1 + if (-not $client) { Write-Host "::error::client tarball not found for ${{ matrix.arch }}"; exit 1 } + if (-not $ui) { Write-Host "::error::ui tarball not found for ${{ matrix.arch }}"; exit 1 } + Write-Host "Client: $($client.FullName)" + Write-Host "UI: $($ui.FullName)" + tar -zvxf $client.FullName -C $workdir + tar -zvxf $ui.FullName -C $workdir + Get-ChildItem $workdir + + - name: Download wintun + uses: carlosperate/download-file-action@v2 + id: download-wintun + with: + file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip + file-name: wintun.zip + location: ${{ env.downloadPath }} + sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51' + + - name: Decompress wintun files + run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }} + + - name: Move wintun.dll into dist + run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\ + + - name: Download Mesa3D (amd64 only) + uses: carlosperate/download-file-action@v2 + id: download-mesa3d + if: matrix.arch == 'amd64' + with: + file-url: https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z + file-name: mesa3d.7z + location: ${{ env.downloadPath }} + sha256: '71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9' + + - name: Extract Mesa3D driver (amd64 only) + if: matrix.arch == 'amd64' + run: 7z x -o"${{ env.downloadPath }}" "${{ env.downloadPath }}/mesa3d.7z" + + - name: Move opengl32.dll into dist (amd64 only) + if: matrix.arch == 'amd64' + run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\ + + - name: Download EnVar plugin for NSIS + uses: carlosperate/download-file-action@v2 + with: + file-url: https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip + file-name: envar_plugin.zip + location: ${{ github.workspace }} + + - name: Extract EnVar plugin + run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/envar_plugin.zip" + + - name: Download ShellExecAsUser plugin for NSIS (amd64 only) + uses: carlosperate/download-file-action@v2 + if: matrix.arch == 'amd64' + with: + file-url: https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z + file-name: ShellExecAsUser_amd64-Unicode.7z + location: ${{ github.workspace }} + + - name: Extract ShellExecAsUser plugin (amd64 only) + if: matrix.arch == 'amd64' + run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z" + + - name: Build NSIS installer + uses: joncloud/makensis-action@v3.3 + with: + additional-plugin-paths: ${{ github.workspace }}/NSIS_Plugins/Plugins + script-file: client/installer.nsis + arguments: "/V4 /DARCH=${{ matrix.arch }}" + env: + APPVER: ${{ steps.semver_parser.outputs.major }}.${{ steps.semver_parser.outputs.minor }}.${{ steps.semver_parser.outputs.patch }}.${{ github.run_id }} + + - name: Rename NSIS installer + run: mv netbird-installer.exe netbird_installer_test_windows_${{ matrix.arch }}.exe + + - name: Install WiX + run: | + dotnet tool install --global wix --version 6.0.2 + wix extension add WixToolset.Util.wixext/6.0.2 + + - name: Build MSI installer + env: + NETBIRD_VERSION: "${{ steps.semver_parser.outputs.fullversion }}" + run: wix build -arch ${{ matrix.arch == 'amd64' && 'x64' || 'arm64' }} -ext WixToolset.Util.wixext -o netbird_installer_test_windows_${{ matrix.arch }}.msi .\client\netbird.wxs -d ProcessorArchitecture=${{ matrix.arch == 'amd64' && 'x64' || 'arm64' }} -d ArchSuffix=${{ matrix.arch }} + + - name: Upload installer artifacts + if: always() + uses: actions/upload-artifact@v4 + with: + name: windows-installer-test-${{ matrix.arch }} + path: | + netbird_installer_test_windows_${{ matrix.arch }}.exe + netbird_installer_test_windows_${{ matrix.arch }}.msi + retention-days: 3 + + comment_release_artifacts: + name: Comment release artifacts runs-on: ubuntu-latest needs: [release, release_ui, release_ui_darwin] + if: ${{ always() && github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository }} + permissions: + contents: read + issues: write + pull-requests: write + steps: + - name: Create or update PR comment + uses: actions/github-script@v7 + env: + RELEASE_RESULT: ${{ needs.release.result }} + RELEASE_UI_RESULT: ${{ needs.release_ui.result }} + RELEASE_UI_DARWIN_RESULT: ${{ needs.release_ui_darwin.result }} + RELEASE_ARTIFACT_URL: ${{ needs.release.outputs.release_artifact_url }} + LINUX_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.linux_packages_artifact_url }} + WINDOWS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.windows_packages_artifact_url }} + MACOS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.macos_packages_artifact_url }} + RELEASE_UI_ARTIFACT_URL: ${{ needs.release_ui.outputs.release_ui_artifact_url }} + RELEASE_UI_DARWIN_ARTIFACT_URL: ${{ needs.release_ui_darwin.outputs.release_ui_darwin_artifact_url }} + GHCR_IMAGES_MARKDOWN: ${{ needs.release.outputs.ghcr_images }} + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const marker = ''; + const { owner, repo } = context.repo; + const issue_number = context.payload.pull_request.number; + const runUrl = `${context.serverUrl}/${owner}/${repo}/actions/runs/${context.runId}`; + const shortSha = context.payload.pull_request.head.sha.slice(0, 7); + + const artifactCell = (url, result) => { + if (url) return `[Download](${url})`; + return result && result !== 'success' ? `_Not available (${result})_` : '_Not available_'; + }; + + const artifacts = [ + ['All release artifacts', process.env.RELEASE_ARTIFACT_URL, process.env.RELEASE_RESULT], + ['Linux packages', process.env.LINUX_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT], + ['Windows packages', process.env.WINDOWS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT], + ['macOS packages', process.env.MACOS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT], + ['UI artifacts', process.env.RELEASE_UI_ARTIFACT_URL, process.env.RELEASE_UI_RESULT], + ['UI macOS artifacts', process.env.RELEASE_UI_DARWIN_ARTIFACT_URL, process.env.RELEASE_UI_DARWIN_RESULT], + ]; + + const artifactRows = artifacts + .map(([name, url, result]) => `| ${name} | ${artifactCell(url, result)} |`) + .join('\n'); + + const ghcrImages = (process.env.GHCR_IMAGES_MARKDOWN || '').trim() || '_No GHCR images were pushed._'; + + const body = [ + marker, + '## Release artifacts', + '', + `Built for PR head \`${shortSha}\` in [workflow run #${process.env.GITHUB_RUN_NUMBER}](${runUrl}).`, + '', + '| Artifact | Link |', + '| --- | --- |', + artifactRows, + '', + '### GHCR images (amd64)', + ghcrImages, + '', + '_This comment is updated by the Release workflow. Artifact links expire according to the workflow retention policy._', + ].join('\n'); + + const comments = await github.paginate(github.rest.issues.listComments, { + owner, + repo, + issue_number, + per_page: 100, + }); + + const previous = comments.find(comment => + comment.user?.type === 'Bot' && comment.body?.includes(marker) + ); + + if (previous) { + await github.rest.issues.updateComment({ + owner, + repo, + comment_id: previous.id, + body, + }); + core.info(`Updated release artifacts comment ${previous.id}`); + } else { + const { data } = await github.rest.issues.createComment({ + owner, + repo, + issue_number, + body, + }); + core.info(`Created release artifacts comment ${data.id}`); + } + + trigger_signer: + runs-on: ubuntu-latest + needs: [release, release_ui, release_ui_darwin, test_windows_installer] if: startsWith(github.ref, 'refs/tags/') steps: - name: Trigger binaries sign pipelines diff --git a/.github/workflows/sync-tag.yml b/.github/workflows/sync-tag.yml index 1cc553b12..a75d9a9d5 100644 --- a/.github/workflows/sync-tag.yml +++ b/.github/workflows/sync-tag.yml @@ -9,6 +9,8 @@ concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} cancel-in-progress: true +# Receiving workflows (cloud sync-tag, mobile bump-netbird) expect the short +# tag form (e.g. v0.30.0), not refs/tags/v0.30.0 — github.ref_name, not github.ref. jobs: trigger_sync_tag: runs-on: ubuntu-latest @@ -20,4 +22,30 @@ jobs: ref: main repo: ${{ secrets.UPSTREAM_REPO }} token: ${{ secrets.NC_GITHUB_TOKEN }} + inputs: '{ "tag": "${{ github.ref_name }}" }' + + trigger_android_bump: + runs-on: ubuntu-latest + if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-') + steps: + - name: Trigger android-client submodule bump + uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1 + with: + workflow: bump-netbird.yml + ref: main + repo: netbirdio/android-client + token: ${{ secrets.NC_GITHUB_TOKEN }} + inputs: '{ "tag": "${{ github.ref_name }}" }' + + trigger_ios_bump: + runs-on: ubuntu-latest + if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-') + steps: + - name: Trigger ios-client submodule bump + uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1 + with: + workflow: bump-netbird.yml + ref: main + repo: netbirdio/ios-client + token: ${{ secrets.NC_GITHUB_TOKEN }} inputs: '{ "tag": "${{ github.ref_name }}" }' \ No newline at end of file diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 65e63dfa8..5ea479148 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -154,6 +154,26 @@ builds: - -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}} mod_timestamp: "{{ .CommitTimestamp }}" + - id: netbird-idp-migrate + dir: tools/idp-migrate + env: + - CGO_ENABLED=1 + - >- + {{- if eq .Runtime.Goos "linux" }} + {{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }} + {{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }} + {{- end }} + binary: netbird-idp-migrate + goos: + - linux + goarch: + - amd64 + - arm64 + - arm + ldflags: + - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser + mod_timestamp: "{{ .CommitTimestamp }}" + universal_binaries: - id: netbird @@ -166,6 +186,10 @@ archives: - netbird-wasm name_template: "{{ .ProjectName }}_{{ .Version }}" format: binary + - id: netbird-idp-migrate + builds: + - netbird-idp-migrate + name_template: "netbird-idp-migrate_{{ .Version }}_{{ .Os }}_{{ .Arch }}" nfpms: - maintainer: Netbird diff --git a/CONTRIBUTOR_LICENSE_AGREEMENT.md b/CONTRIBUTOR_LICENSE_AGREEMENT.md index 1fdd072c9..b0a6ee218 100644 --- a/CONTRIBUTOR_LICENSE_AGREEMENT.md +++ b/CONTRIBUTOR_LICENSE_AGREEMENT.md @@ -1,7 +1,7 @@ ## Contributor License Agreement This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual -submitting this Agreement and NetBird GmbH, c/o Max-Beer-Straße 2-4 Münzstraße 12 10178 Berlin, Germany, +submitting this Agreement and NetBird GmbH, Brunnenstraße 196, 10119 Berlin, Germany, referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions under which NetBird may utilize software contributions provided by the Contributor for inclusion in its software development projects. By submitting this Agreement, the Contributor confirms their acceptance diff --git a/Makefile b/Makefile index 43379e115..5d52b94fa 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint $(GOLANGCI_LINT): @echo "Installing golangci-lint..." @mkdir -p ./bin - @GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest + @GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest # Lint only changed files (fast, for pre-push) lint: $(GOLANGCI_LINT) diff --git a/client/Dockerfile b/client/Dockerfile index 64d5ba04f..53e4555ef 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -17,6 +17,7 @@ ENV \ NETBIRD_BIN="/usr/local/bin/netbird" \ NB_LOG_FILE="console,/var/log/netbird/client.log" \ NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \ + NB_ENABLE_CAPTURE="false" \ NB_ENTRYPOINT_SERVICE_TIMEOUT="30" ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] diff --git a/client/Dockerfile-rootless b/client/Dockerfile-rootless index 69d00aaf2..706bf40de 100644 --- a/client/Dockerfile-rootless +++ b/client/Dockerfile-rootless @@ -23,6 +23,7 @@ ENV \ NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \ NB_LOG_FILE="console,/var/lib/netbird/client.log" \ NB_DISABLE_DNS="true" \ + NB_ENABLE_CAPTURE="false" \ NB_ENTRYPOINT_SERVICE_TIMEOUT="30" ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] diff --git a/client/android/client.go b/client/android/client.go index 3fc571559..37e17a363 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -8,6 +8,7 @@ import ( "os" "slices" "sync" + "time" "golang.org/x/exp/maps" @@ -15,6 +16,7 @@ import ( "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/debug" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" @@ -26,6 +28,7 @@ import ( "github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" + types "github.com/netbirdio/netbird/upload-server/types" ) // ConnectionListener export internal Listener for mobile @@ -68,7 +71,30 @@ type Client struct { uiVersion string networkChangeListener listener.NetworkChangeListener + stateMu sync.RWMutex connectClient *internal.ConnectClient + config *profilemanager.Config + cacheDir string +} + +func (c *Client) setState(cfg *profilemanager.Config, cacheDir string, cc *internal.ConnectClient) { + c.stateMu.Lock() + defer c.stateMu.Unlock() + c.config = cfg + c.cacheDir = cacheDir + c.connectClient = cc +} + +func (c *Client) stateSnapshot() (*profilemanager.Config, string, *internal.ConnectClient) { + c.stateMu.RLock() + defer c.stateMu.RUnlock() + return c.config, c.cacheDir, c.connectClient +} + +func (c *Client) getConnectClient() *internal.ConnectClient { + c.stateMu.RLock() + defer c.stateMu.RUnlock() + return c.connectClient } // NewClient instantiate a new Client @@ -93,6 +119,7 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid cfgFile := platformFiles.ConfigurationFilePath() stateFile := platformFiles.StateFilePath() + cacheDir := platformFiles.CacheDir() log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile) @@ -124,8 +151,9 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile) + connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) + c.setState(cfg, cacheDir, connectClient) + return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir) } // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). @@ -135,6 +163,7 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR cfgFile := platformFiles.ConfigurationFilePath() stateFile := platformFiles.StateFilePath() + cacheDir := platformFiles.CacheDir() log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile) @@ -157,8 +186,9 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile) + connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) + c.setState(cfg, cacheDir, connectClient) + return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir) } // Stop the internal client and free the resources @@ -173,11 +203,12 @@ func (c *Client) Stop() { } func (c *Client) RenewTun(fd int) error { - if c.connectClient == nil { + cc := c.getConnectClient() + if cc == nil { return fmt.Errorf("engine not running") } - e := c.connectClient.Engine() + e := cc.Engine() if e == nil { return fmt.Errorf("engine not initialized") } @@ -185,6 +216,73 @@ func (c *Client) RenewTun(fd int) error { return e.RenewTun(fd) } +// DebugBundle generates a debug bundle, uploads it, and returns the upload key. +// It works both with and without a running engine. +func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (string, error) { + cfg, cacheDir, cc := c.stateSnapshot() + + // If the engine hasn't been started, load config from disk + if cfg == nil { + var err error + cfg, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ + ConfigPath: platformFiles.ConfigurationFilePath(), + }) + if err != nil { + return "", fmt.Errorf("load config: %w", err) + } + cacheDir = platformFiles.CacheDir() + } + + deps := debug.GeneratorDependencies{ + InternalConfig: cfg, + StatusRecorder: c.recorder, + TempDir: cacheDir, + } + + if cc != nil { + resp, err := cc.GetLatestSyncResponse() + if err != nil { + log.Warnf("get latest sync response: %v", err) + } + deps.SyncResponse = resp + + if e := cc.Engine(); e != nil { + if cm := e.GetClientMetrics(); cm != nil { + deps.ClientMetrics = cm + } + } + } + + bundleGenerator := debug.NewBundleGenerator( + deps, + debug.BundleConfig{ + Anonymize: anonymize, + IncludeSystemInfo: true, + }, + ) + + path, err := bundleGenerator.Generate() + if err != nil { + return "", fmt.Errorf("generate debug bundle: %w", err) + } + defer func() { + if err := os.Remove(path); err != nil { + log.Errorf("failed to remove debug bundle file: %v", err) + } + }() + + uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path) + if err != nil { + return "", fmt.Errorf("upload debug bundle: %w", err) + } + + log.Infof("debug bundle uploaded with key %s", key) + return key, nil +} + // SetTraceLogLevel configure the logger to trace level func (c *Client) SetTraceLogLevel() { log.SetLevel(log.TraceLevel) @@ -205,7 +303,7 @@ func (c *Client) PeersList() *PeerInfoArray { pi := PeerInfo{ p.IP, p.FQDN, - p.ConnStatus.String(), + int(p.ConnStatus), PeerRoutes{routes: maps.Keys(p.GetRoutes())}, } peerInfos[n] = pi @@ -214,12 +312,13 @@ func (c *Client) PeersList() *PeerInfoArray { } func (c *Client) Networks() *NetworkArray { - if c.connectClient == nil { + cc := c.getConnectClient() + if cc == nil { log.Error("not connected") return nil } - engine := c.connectClient.Engine() + engine := cc.Engine() if engine == nil { log.Error("could not get engine") return nil @@ -300,7 +399,7 @@ func (c *Client) toggleRoute(command routeCommand) error { } func (c *Client) getRouteManager() (routemanager.Manager, error) { - client := c.connectClient + client := c.getConnectClient() if client == nil { return nil, fmt.Errorf("not connected") } diff --git a/client/android/peer_notifier.go b/client/android/peer_notifier.go index b03947da1..4ec22f3ab 100644 --- a/client/android/peer_notifier.go +++ b/client/android/peer_notifier.go @@ -2,11 +2,20 @@ package android +import "github.com/netbirdio/netbird/client/internal/peer" + +// Connection status constants exported via gomobile. +const ( + ConnStatusIdle = int(peer.StatusIdle) + ConnStatusConnecting = int(peer.StatusConnecting) + ConnStatusConnected = int(peer.StatusConnected) +) + // PeerInfo describe information about the peers. It designed for the UI usage type PeerInfo struct { IP string FQDN string - ConnStatus string // Todo replace to enum + ConnStatus int Routes PeerRoutes } diff --git a/client/android/platform_files.go b/client/android/platform_files.go index f0c369750..3be40c0bd 100644 --- a/client/android/platform_files.go +++ b/client/android/platform_files.go @@ -7,4 +7,5 @@ package android type PlatformFiles interface { ConfigurationFilePath() string StateFilePath() string + CacheDir() string } diff --git a/client/cmd/capture.go b/client/cmd/capture.go new file mode 100644 index 000000000..95caaa5cd --- /dev/null +++ b/client/cmd/capture.go @@ -0,0 +1,196 @@ +package cmd + +import ( + "context" + "fmt" + "io" + "os" + "os/signal" + "path/filepath" + "strings" + "syscall" + + "github.com/hashicorp/go-multierror" + "github.com/spf13/cobra" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/durationpb" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/util/capture" +) + +var captureCmd = &cobra.Command{ + Use: "capture", + Short: "Capture packets on the WireGuard interface", + Long: `Captures decrypted packets flowing through the WireGuard interface. + +Default output is human-readable text. Use --pcap or --output for pcap binary. +Requires --enable-capture to be set at service install or reconfigure time. + +Examples: + netbird debug capture + netbird debug capture host 100.64.0.1 and port 443 + netbird debug capture tcp + netbird debug capture icmp + netbird debug capture src host 10.0.0.1 and dst port 80 + netbird debug capture -o capture.pcap + netbird debug capture --pcap | tshark -r - + netbird debug capture --pcap | tcpdump -r - -n`, + Args: cobra.ArbitraryArgs, + RunE: runCapture, +} + +func init() { + debugCmd.AddCommand(captureCmd) + + captureCmd.Flags().Bool("pcap", false, "Force pcap binary output (default when --output is set)") + captureCmd.Flags().BoolP("verbose", "v", false, "Show seq/ack, TTL, window, total length") + captureCmd.Flags().Bool("ascii", false, "Print payload as ASCII after each packet (useful for HTTP)") + captureCmd.Flags().Uint32("snap-len", 0, "Max bytes per packet (0 = full)") + captureCmd.Flags().DurationP("duration", "d", 0, "Capture duration (0 = until interrupted)") + captureCmd.Flags().StringP("output", "o", "", "Write pcap to file instead of stdout") +} + +func runCapture(cmd *cobra.Command, args []string) error { + conn, err := getClient(cmd) + if err != nil { + return err + } + defer func() { + if err := conn.Close(); err != nil { + cmd.PrintErrf(errCloseConnection, err) + } + }() + + client := proto.NewDaemonServiceClient(conn) + + req, err := buildCaptureRequest(cmd, args) + if err != nil { + return err + } + + ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + stream, err := client.StartCapture(ctx, req) + if err != nil { + return handleCaptureError(err) + } + + // First Recv is the empty acceptance message from the server. If the + // device is unavailable (kernel WG, not connected, capture disabled), + // the server returns an error instead. + if _, err := stream.Recv(); err != nil { + return handleCaptureError(err) + } + + out, cleanup, err := captureOutput(cmd) + if err != nil { + return err + } + + if req.TextOutput { + cmd.PrintErrf("Capturing packets... Press Ctrl+C to stop.\n") + } else { + cmd.PrintErrf("Capturing packets (pcap)... Press Ctrl+C to stop.\n") + } + + streamErr := streamCapture(ctx, cmd, stream, out) + cleanupErr := cleanup() + if streamErr != nil { + return streamErr + } + return cleanupErr +} + +func buildCaptureRequest(cmd *cobra.Command, args []string) (*proto.StartCaptureRequest, error) { + req := &proto.StartCaptureRequest{} + + if len(args) > 0 { + expr := strings.Join(args, " ") + if _, err := capture.ParseFilter(expr); err != nil { + return nil, fmt.Errorf("invalid filter: %w", err) + } + req.FilterExpr = expr + } + + if snap, _ := cmd.Flags().GetUint32("snap-len"); snap > 0 { + req.SnapLen = snap + } + if d, _ := cmd.Flags().GetDuration("duration"); d != 0 { + if d < 0 { + return nil, fmt.Errorf("duration must not be negative") + } + req.Duration = durationpb.New(d) + } + req.Verbose, _ = cmd.Flags().GetBool("verbose") + req.Ascii, _ = cmd.Flags().GetBool("ascii") + + outPath, _ := cmd.Flags().GetString("output") + forcePcap, _ := cmd.Flags().GetBool("pcap") + req.TextOutput = !forcePcap && outPath == "" + + return req, nil +} + +func streamCapture(ctx context.Context, cmd *cobra.Command, stream proto.DaemonService_StartCaptureClient, out io.Writer) error { + for { + pkt, err := stream.Recv() + if err != nil { + if ctx.Err() != nil { + cmd.PrintErrf("\nCapture stopped.\n") + return nil //nolint:nilerr // user interrupted + } + if err == io.EOF { + cmd.PrintErrf("\nCapture finished.\n") + return nil + } + return handleCaptureError(err) + } + if _, err := out.Write(pkt.GetData()); err != nil { + return fmt.Errorf("write output: %w", err) + } + } +} + +// captureOutput returns the writer for capture data and a cleanup function +// that finalizes the file. Errors from the cleanup must be propagated. +func captureOutput(cmd *cobra.Command) (io.Writer, func() error, error) { + outPath, _ := cmd.Flags().GetString("output") + if outPath == "" { + return os.Stdout, func() error { return nil }, nil + } + + f, err := os.CreateTemp(filepath.Dir(outPath), filepath.Base(outPath)+".*.tmp") + if err != nil { + return nil, nil, fmt.Errorf("create output file: %w", err) + } + tmpPath := f.Name() + return f, func() error { + var merr *multierror.Error + if err := f.Close(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("close output file: %w", err)) + } + fi, statErr := os.Stat(tmpPath) + if statErr != nil || fi.Size() == 0 { + if rmErr := os.Remove(tmpPath); rmErr != nil && !os.IsNotExist(rmErr) { + merr = multierror.Append(merr, fmt.Errorf("remove empty output file: %w", rmErr)) + } + return nberrors.FormatErrorOrNil(merr) + } + if err := os.Rename(tmpPath, outPath); err != nil { + merr = multierror.Append(merr, fmt.Errorf("rename output file: %w", err)) + return nberrors.FormatErrorOrNil(merr) + } + cmd.PrintErrf("Wrote %s\n", outPath) + return nberrors.FormatErrorOrNil(merr) + }, nil +} + +func handleCaptureError(err error) error { + if s, ok := status.FromError(err); ok { + return fmt.Errorf("%s", s.Message()) + } + return err +} diff --git a/client/cmd/debug.go b/client/cmd/debug.go index 0e2717756..2a8cdc887 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -9,6 +9,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/durationpb" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/debug" @@ -199,9 +200,11 @@ func runForDuration(cmd *cobra.Command, args []string) error { cmd.Println("Log level set to trace.") } + needsRestoreUp := false if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message()) } else { + needsRestoreUp = !stateWasDown cmd.Println("netbird down") } @@ -217,6 +220,7 @@ func runForDuration(cmd *cobra.Command, args []string) error { if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message()) } else { + needsRestoreUp = false cmd.Println("netbird up") } @@ -236,11 +240,50 @@ func runForDuration(cmd *cobra.Command, args []string) error { }() } + captureStarted := false + if wantCapture, _ := cmd.Flags().GetBool("capture"); wantCapture { + captureTimeout := duration + 30*time.Second + const maxBundleCapture = 10 * time.Minute + if captureTimeout > maxBundleCapture { + captureTimeout = maxBundleCapture + } + _, err := client.StartBundleCapture(cmd.Context(), &proto.StartBundleCaptureRequest{ + Timeout: durationpb.New(captureTimeout), + }) + if err != nil { + cmd.PrintErrf("Failed to start packet capture: %v\n", status.Convert(err).Message()) + } else { + captureStarted = true + cmd.Println("Packet capture started.") + // Safety: always stop on exit, even if the normal stop below runs too. + defer func() { + if captureStarted { + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil { + cmd.PrintErrf("Failed to stop packet capture: %v\n", err) + } + } + }() + } + } + if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil { return waitErr } cmd.Println("\nDuration completed") + if captureStarted { + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil { + cmd.PrintErrf("Failed to stop packet capture: %v\n", err) + } else { + captureStarted = false + cmd.Println("Packet capture stopped.") + } + } + if cpuProfilingStarted { if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil { cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err) @@ -264,6 +307,14 @@ func runForDuration(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) } + if needsRestoreUp { + if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { + cmd.PrintErrf("Failed to restore service up state: %v\n", status.Convert(err).Message()) + } else { + cmd.Println("netbird up (restored)") + } + } + if stateWasDown { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message()) @@ -405,4 +456,5 @@ func init() { forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle") forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server") forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle") + forCmd.Flags().Bool("capture", false, "Capture packets during the debug duration and include in bundle") } diff --git a/client/cmd/expose.go b/client/cmd/expose.go index 1334617d8..c48a6adac 100644 --- a/client/cmd/expose.go +++ b/client/cmd/expose.go @@ -14,7 +14,9 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/client/internal/expose" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/util" ) @@ -200,7 +202,7 @@ func exposeFn(cmd *cobra.Command, args []string) error { stream, err := client.ExposeService(ctx, req) if err != nil { - return fmt.Errorf("expose service: %w", err) + return fmt.Errorf("expose service: %v", status.Convert(err).Message()) } if err := handleExposeReady(cmd, stream, port); err != nil { @@ -211,26 +213,31 @@ func exposeFn(cmd *cobra.Command, args []string) error { } func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) { - switch strings.ToLower(exposeProtocol) { - case "http": + p, err := expose.ParseProtocolType(exposeProtocol) + if err != nil { + return 0, fmt.Errorf("invalid protocol: %w", err) + } + + switch p { + case expose.ProtocolHTTP: return proto.ExposeProtocol_EXPOSE_HTTP, nil - case "https": + case expose.ProtocolHTTPS: return proto.ExposeProtocol_EXPOSE_HTTPS, nil - case "tcp": + case expose.ProtocolTCP: return proto.ExposeProtocol_EXPOSE_TCP, nil - case "udp": + case expose.ProtocolUDP: return proto.ExposeProtocol_EXPOSE_UDP, nil - case "tls": + case expose.ProtocolTLS: return proto.ExposeProtocol_EXPOSE_TLS, nil default: - return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol) + return 0, fmt.Errorf("unhandled protocol type: %d", p) } } func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error { event, err := stream.Recv() if err != nil { - return fmt.Errorf("receive expose event: %w", err) + return fmt.Errorf("receive expose event: %v", status.Convert(err).Message()) } ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready) diff --git a/client/cmd/root.go b/client/cmd/root.go index aa5b98dfd..29d4328a1 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -75,6 +75,8 @@ var ( mtu uint16 profilesDisabled bool updateSettingsDisabled bool + captureEnabled bool + networksDisabled bool rootCmd = &cobra.Command{ Use: "netbird", diff --git a/client/cmd/service.go b/client/cmd/service.go index e55465875..56d8a8726 100644 --- a/client/cmd/service.go +++ b/client/cmd/service.go @@ -41,13 +41,17 @@ func init() { defaultServiceName = "Netbird" } - serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd) + serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd) serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles") serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings") + serviceCmd.PersistentFlags().BoolVar(&captureEnabled, "enable-capture", false, "Enables packet capture via 'netbird debug capture'. To persist, use: netbird service install --enable-capture") + serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks") rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") serviceEnvDesc := `Sets extra environment variables for the service. ` + `You can specify a comma-separated list of KEY=VALUE pairs. ` + + `New keys are merged with previously saved env vars; existing keys are overwritten. ` + + `Use --service-env "" to clear all saved env vars. ` + `E.g. --service-env NB_LOG_LEVEL=debug,CUSTOM_VAR=value` installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc) diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go index 5fe318ddf..88121c067 100644 --- a/client/cmd/service_controller.go +++ b/client/cmd/service_controller.go @@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error { } } - serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled) + serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, captureEnabled, networksDisabled) if err := serverInstance.Start(); err != nil { log.Fatalf("failed to start daemon: %v", err) } diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index f6828d96a..2d45fa063 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -59,6 +59,14 @@ func buildServiceArguments() []string { args = append(args, "--disable-update-settings") } + if captureEnabled { + args = append(args, "--enable-capture") + } + + if networksDisabled { + args = append(args, "--disable-networks") + } + return args } @@ -119,6 +127,10 @@ var installCmd = &cobra.Command{ return err } + if err := loadAndApplyServiceParams(cmd); err != nil { + cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err) + } + svcConfig, err := createServiceConfigForInstall() if err != nil { return err @@ -136,6 +148,10 @@ var installCmd = &cobra.Command{ return fmt.Errorf("install service: %w", err) } + if err := saveServiceParams(currentServiceParams()); err != nil { + cmd.PrintErrf("Warning: failed to save service params: %v\n", err) + } + cmd.Println("NetBird service has been installed") return nil }, @@ -187,6 +203,10 @@ This command will temporarily stop the service, update its configuration, and re return err } + if err := loadAndApplyServiceParams(cmd); err != nil { + cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err) + } + wasRunning, err := isServiceRunning() if err != nil && !errors.Is(err, ErrGetServiceStatus) { return fmt.Errorf("check service status: %w", err) @@ -222,6 +242,10 @@ This command will temporarily stop the service, update its configuration, and re return fmt.Errorf("install service with new config: %w", err) } + if err := saveServiceParams(currentServiceParams()); err != nil { + cmd.PrintErrf("Warning: failed to save service params: %v\n", err) + } + if wasRunning { cmd.Println("Starting NetBird service...") if err := s.Start(); err != nil { diff --git a/client/cmd/service_params.go b/client/cmd/service_params.go new file mode 100644 index 000000000..192e0ac60 --- /dev/null +++ b/client/cmd/service_params.go @@ -0,0 +1,224 @@ +//go:build !ios && !android + +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "maps" + "os" + "path/filepath" + + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/client/configs" + "github.com/netbirdio/netbird/util" +) + +const serviceParamsFile = "service.json" + +// serviceParams holds install-time service parameters that persist across +// uninstall/reinstall cycles. Saved to /service.json. +type serviceParams struct { + LogLevel string `json:"log_level"` + DaemonAddr string `json:"daemon_addr"` + ManagementURL string `json:"management_url,omitempty"` + ConfigPath string `json:"config_path,omitempty"` + LogFiles []string `json:"log_files,omitempty"` + DisableProfiles bool `json:"disable_profiles,omitempty"` + DisableUpdateSettings bool `json:"disable_update_settings,omitempty"` + EnableCapture bool `json:"enable_capture,omitempty"` + DisableNetworks bool `json:"disable_networks,omitempty"` + ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"` +} + +// serviceParamsPath returns the path to the service params file. +func serviceParamsPath() string { + return filepath.Join(configs.StateDir, serviceParamsFile) +} + +// loadServiceParams reads saved service parameters from disk. +// Returns nil with no error if the file does not exist. +func loadServiceParams() (*serviceParams, error) { + path := serviceParamsPath() + + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil //nolint:nilnil + } + return nil, fmt.Errorf("read service params %s: %w", path, err) + } + + var params serviceParams + if err := json.Unmarshal(data, ¶ms); err != nil { + return nil, fmt.Errorf("parse service params %s: %w", path, err) + } + + return ¶ms, nil +} + +// saveServiceParams writes current service parameters to disk atomically +// with restricted permissions. +func saveServiceParams(params *serviceParams) error { + path := serviceParamsPath() + if err := util.WriteJsonWithRestrictedPermission(context.Background(), path, params); err != nil { + return fmt.Errorf("save service params: %w", err) + } + return nil +} + +// currentServiceParams captures the current state of all package-level +// variables into a serviceParams struct. +func currentServiceParams() *serviceParams { + params := &serviceParams{ + LogLevel: logLevel, + DaemonAddr: daemonAddr, + ManagementURL: managementURL, + ConfigPath: configPath, + LogFiles: logFiles, + DisableProfiles: profilesDisabled, + DisableUpdateSettings: updateSettingsDisabled, + EnableCapture: captureEnabled, + DisableNetworks: networksDisabled, + } + + if len(serviceEnvVars) > 0 { + parsed, err := parseServiceEnvVars(serviceEnvVars) + if err == nil { + params.ServiceEnvVars = parsed + } + } + + return params +} + +// loadAndApplyServiceParams loads saved params from disk and applies them +// to any flags that were not explicitly set. +func loadAndApplyServiceParams(cmd *cobra.Command) error { + params, err := loadServiceParams() + if err != nil { + return err + } + applyServiceParams(cmd, params) + return nil +} + +// applyServiceParams merges saved parameters into package-level variables +// for any flag that was not explicitly set by the user (via CLI or env var). +// Flags that were Changed() are left untouched. +func applyServiceParams(cmd *cobra.Command, params *serviceParams) { + if params == nil { + return + } + + // For fields with non-empty defaults (log-level, daemon-addr), keep the + // != "" guard so that an older service.json missing the field doesn't + // clobber the default with an empty string. + if !rootCmd.PersistentFlags().Changed("log-level") && params.LogLevel != "" { + logLevel = params.LogLevel + } + + if !rootCmd.PersistentFlags().Changed("daemon-addr") && params.DaemonAddr != "" { + daemonAddr = params.DaemonAddr + } + + // For optional fields where empty means "use default", always apply so + // that an explicit clear (--management-url "") persists across reinstalls. + if !rootCmd.PersistentFlags().Changed("management-url") { + managementURL = params.ManagementURL + } + + if !rootCmd.PersistentFlags().Changed("config") { + configPath = params.ConfigPath + } + + if !rootCmd.PersistentFlags().Changed("log-file") { + logFiles = params.LogFiles + } + + if !serviceCmd.PersistentFlags().Changed("disable-profiles") { + profilesDisabled = params.DisableProfiles + } + + if !serviceCmd.PersistentFlags().Changed("disable-update-settings") { + updateSettingsDisabled = params.DisableUpdateSettings + } + + if !serviceCmd.PersistentFlags().Changed("enable-capture") { + captureEnabled = params.EnableCapture + } + + if !serviceCmd.PersistentFlags().Changed("disable-networks") { + networksDisabled = params.DisableNetworks + } + + applyServiceEnvParams(cmd, params) +} + +// applyServiceEnvParams merges saved service environment variables. +// If --service-env was explicitly set with values, explicit values win on key +// conflict but saved keys not in the explicit set are carried over. +// If --service-env was explicitly set to empty, all saved env vars are cleared. +// If --service-env was not set, saved env vars are used entirely. +func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) { + if !cmd.Flags().Changed("service-env") { + if len(params.ServiceEnvVars) > 0 { + // No explicit env vars: rebuild serviceEnvVars from saved params. + serviceEnvVars = envMapToSlice(params.ServiceEnvVars) + } + return + } + + // Flag was explicitly set: parse what the user provided. + explicit, err := parseServiceEnvVars(serviceEnvVars) + if err != nil { + cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err) + return + } + + // If the user passed an empty value (e.g. --service-env ""), clear all + // saved env vars rather than merging. + if len(explicit) == 0 { + serviceEnvVars = nil + return + } + + if len(params.ServiceEnvVars) == 0 { + return + } + + // Merge saved values underneath explicit ones. + merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit)) + maps.Copy(merged, params.ServiceEnvVars) + maps.Copy(merged, explicit) // explicit wins on conflict + serviceEnvVars = envMapToSlice(merged) +} + +var resetParamsCmd = &cobra.Command{ + Use: "reset-params", + Short: "Remove saved service install parameters", + Long: "Removes the saved service.json file so the next install uses default parameters.", + RunE: func(cmd *cobra.Command, args []string) error { + path := serviceParamsPath() + if err := os.Remove(path); err != nil { + if os.IsNotExist(err) { + cmd.Println("No saved service parameters found") + return nil + } + return fmt.Errorf("remove service params: %w", err) + } + cmd.Printf("Removed saved service parameters (%s)\n", path) + return nil + }, +} + +// envMapToSlice converts a map of env vars to a KEY=VALUE slice. +func envMapToSlice(m map[string]string) []string { + s := make([]string, 0, len(m)) + for k, v := range m { + s = append(s, k+"="+v) + } + return s +} diff --git a/client/cmd/service_params_test.go b/client/cmd/service_params_test.go new file mode 100644 index 000000000..f338c12f4 --- /dev/null +++ b/client/cmd/service_params_test.go @@ -0,0 +1,560 @@ +//go:build !ios && !android + +package cmd + +import ( + "encoding/json" + "go/ast" + "go/parser" + "go/token" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/spf13/cobra" + "github.com/spf13/pflag" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/configs" +) + +func TestServiceParamsPath(t *testing.T) { + original := configs.StateDir + t.Cleanup(func() { configs.StateDir = original }) + + configs.StateDir = "/var/lib/netbird" + assert.Equal(t, filepath.Join("/var/lib/netbird", "service.json"), serviceParamsPath()) + + configs.StateDir = "/custom/state" + assert.Equal(t, filepath.Join("/custom/state", "service.json"), serviceParamsPath()) +} + +func TestSaveAndLoadServiceParams(t *testing.T) { + tmpDir := t.TempDir() + + original := configs.StateDir + t.Cleanup(func() { configs.StateDir = original }) + configs.StateDir = tmpDir + + params := &serviceParams{ + LogLevel: "debug", + DaemonAddr: "unix:///var/run/netbird.sock", + ManagementURL: "https://my.server.com", + ConfigPath: "/etc/netbird/config.json", + LogFiles: []string{"/var/log/netbird/client.log", "console"}, + DisableProfiles: true, + DisableUpdateSettings: false, + ServiceEnvVars: map[string]string{"NB_LOG_FORMAT": "json", "CUSTOM": "val"}, + } + + err := saveServiceParams(params) + require.NoError(t, err) + + // Verify the file exists and is valid JSON. + data, err := os.ReadFile(filepath.Join(tmpDir, "service.json")) + require.NoError(t, err) + assert.True(t, json.Valid(data)) + + loaded, err := loadServiceParams() + require.NoError(t, err) + require.NotNil(t, loaded) + + assert.Equal(t, params.LogLevel, loaded.LogLevel) + assert.Equal(t, params.DaemonAddr, loaded.DaemonAddr) + assert.Equal(t, params.ManagementURL, loaded.ManagementURL) + assert.Equal(t, params.ConfigPath, loaded.ConfigPath) + assert.Equal(t, params.LogFiles, loaded.LogFiles) + assert.Equal(t, params.DisableProfiles, loaded.DisableProfiles) + assert.Equal(t, params.DisableUpdateSettings, loaded.DisableUpdateSettings) + assert.Equal(t, params.ServiceEnvVars, loaded.ServiceEnvVars) +} + +func TestLoadServiceParams_FileNotExists(t *testing.T) { + tmpDir := t.TempDir() + + original := configs.StateDir + t.Cleanup(func() { configs.StateDir = original }) + configs.StateDir = tmpDir + + params, err := loadServiceParams() + assert.NoError(t, err) + assert.Nil(t, params) +} + +func TestLoadServiceParams_InvalidJSON(t *testing.T) { + tmpDir := t.TempDir() + + original := configs.StateDir + t.Cleanup(func() { configs.StateDir = original }) + configs.StateDir = tmpDir + + err := os.WriteFile(filepath.Join(tmpDir, "service.json"), []byte("not json"), 0600) + require.NoError(t, err) + + params, err := loadServiceParams() + assert.Error(t, err) + assert.Nil(t, params) +} + +func TestCurrentServiceParams(t *testing.T) { + origLogLevel := logLevel + origDaemonAddr := daemonAddr + origManagementURL := managementURL + origConfigPath := configPath + origLogFiles := logFiles + origProfilesDisabled := profilesDisabled + origUpdateSettingsDisabled := updateSettingsDisabled + origServiceEnvVars := serviceEnvVars + t.Cleanup(func() { + logLevel = origLogLevel + daemonAddr = origDaemonAddr + managementURL = origManagementURL + configPath = origConfigPath + logFiles = origLogFiles + profilesDisabled = origProfilesDisabled + updateSettingsDisabled = origUpdateSettingsDisabled + serviceEnvVars = origServiceEnvVars + }) + + logLevel = "trace" + daemonAddr = "tcp://127.0.0.1:9999" + managementURL = "https://mgmt.example.com" + configPath = "/tmp/test-config.json" + logFiles = []string{"/tmp/test.log"} + profilesDisabled = true + updateSettingsDisabled = true + serviceEnvVars = []string{"FOO=bar", "BAZ=qux"} + + params := currentServiceParams() + + assert.Equal(t, "trace", params.LogLevel) + assert.Equal(t, "tcp://127.0.0.1:9999", params.DaemonAddr) + assert.Equal(t, "https://mgmt.example.com", params.ManagementURL) + assert.Equal(t, "/tmp/test-config.json", params.ConfigPath) + assert.Equal(t, []string{"/tmp/test.log"}, params.LogFiles) + assert.True(t, params.DisableProfiles) + assert.True(t, params.DisableUpdateSettings) + assert.Equal(t, map[string]string{"FOO": "bar", "BAZ": "qux"}, params.ServiceEnvVars) +} + +func TestApplyServiceParams_OnlyUnchangedFlags(t *testing.T) { + origLogLevel := logLevel + origDaemonAddr := daemonAddr + origManagementURL := managementURL + origConfigPath := configPath + origLogFiles := logFiles + origProfilesDisabled := profilesDisabled + origUpdateSettingsDisabled := updateSettingsDisabled + origServiceEnvVars := serviceEnvVars + t.Cleanup(func() { + logLevel = origLogLevel + daemonAddr = origDaemonAddr + managementURL = origManagementURL + configPath = origConfigPath + logFiles = origLogFiles + profilesDisabled = origProfilesDisabled + updateSettingsDisabled = origUpdateSettingsDisabled + serviceEnvVars = origServiceEnvVars + }) + + // Reset all flags to defaults. + logLevel = "info" + daemonAddr = "unix:///var/run/netbird.sock" + managementURL = "" + configPath = "/etc/netbird/config.json" + logFiles = []string{"/var/log/netbird/client.log"} + profilesDisabled = false + updateSettingsDisabled = false + serviceEnvVars = nil + + // Reset Changed state on all relevant flags. + rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) { + f.Changed = false + }) + serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) { + f.Changed = false + }) + + // Simulate user explicitly setting --log-level via CLI. + logLevel = "warn" + require.NoError(t, rootCmd.PersistentFlags().Set("log-level", "warn")) + + saved := &serviceParams{ + LogLevel: "debug", + DaemonAddr: "tcp://127.0.0.1:5555", + ManagementURL: "https://saved.example.com", + ConfigPath: "/saved/config.json", + LogFiles: []string{"/saved/client.log"}, + DisableProfiles: true, + DisableUpdateSettings: true, + ServiceEnvVars: map[string]string{"SAVED_KEY": "saved_val"}, + } + + cmd := &cobra.Command{} + cmd.Flags().StringSlice("service-env", nil, "") + applyServiceParams(cmd, saved) + + // log-level was Changed, so it should keep "warn", not use saved "debug". + assert.Equal(t, "warn", logLevel) + + // All other fields were not Changed, so they should use saved values. + assert.Equal(t, "tcp://127.0.0.1:5555", daemonAddr) + assert.Equal(t, "https://saved.example.com", managementURL) + assert.Equal(t, "/saved/config.json", configPath) + assert.Equal(t, []string{"/saved/client.log"}, logFiles) + assert.True(t, profilesDisabled) + assert.True(t, updateSettingsDisabled) + assert.Equal(t, []string{"SAVED_KEY=saved_val"}, serviceEnvVars) +} + +func TestApplyServiceParams_BooleanRevertToFalse(t *testing.T) { + origProfilesDisabled := profilesDisabled + origUpdateSettingsDisabled := updateSettingsDisabled + t.Cleanup(func() { + profilesDisabled = origProfilesDisabled + updateSettingsDisabled = origUpdateSettingsDisabled + }) + + // Simulate current state where booleans are true (e.g. set by previous install). + profilesDisabled = true + updateSettingsDisabled = true + + // Reset Changed state so flags appear unset. + serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) { + f.Changed = false + }) + + // Saved params have both as false. + saved := &serviceParams{ + DisableProfiles: false, + DisableUpdateSettings: false, + } + + cmd := &cobra.Command{} + cmd.Flags().StringSlice("service-env", nil, "") + applyServiceParams(cmd, saved) + + assert.False(t, profilesDisabled, "saved false should override current true") + assert.False(t, updateSettingsDisabled, "saved false should override current true") +} + +func TestApplyServiceParams_ClearManagementURL(t *testing.T) { + origManagementURL := managementURL + t.Cleanup(func() { managementURL = origManagementURL }) + + managementURL = "https://leftover.example.com" + + // Simulate saved params where management URL was explicitly cleared. + saved := &serviceParams{ + LogLevel: "info", + DaemonAddr: "unix:///var/run/netbird.sock", + // ManagementURL intentionally empty: was cleared with --management-url "". + } + + rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) { + f.Changed = false + }) + + cmd := &cobra.Command{} + cmd.Flags().StringSlice("service-env", nil, "") + applyServiceParams(cmd, saved) + + assert.Equal(t, "", managementURL, "saved empty management URL should clear the current value") +} + +func TestApplyServiceParams_NilParams(t *testing.T) { + origLogLevel := logLevel + t.Cleanup(func() { logLevel = origLogLevel }) + + logLevel = "info" + cmd := &cobra.Command{} + cmd.Flags().StringSlice("service-env", nil, "") + + // Should be a no-op. + applyServiceParams(cmd, nil) + assert.Equal(t, "info", logLevel) +} + +func TestApplyServiceEnvParams_MergeExplicitAndSaved(t *testing.T) { + origServiceEnvVars := serviceEnvVars + t.Cleanup(func() { serviceEnvVars = origServiceEnvVars }) + + // Set up a command with --service-env marked as Changed. + cmd := &cobra.Command{} + cmd.Flags().StringSlice("service-env", nil, "") + require.NoError(t, cmd.Flags().Set("service-env", "EXPLICIT=yes,OVERLAP=explicit")) + + serviceEnvVars = []string{"EXPLICIT=yes", "OVERLAP=explicit"} + + saved := &serviceParams{ + ServiceEnvVars: map[string]string{ + "SAVED": "val", + "OVERLAP": "saved", + }, + } + + applyServiceEnvParams(cmd, saved) + + // Parse result for easier assertion. + result, err := parseServiceEnvVars(serviceEnvVars) + require.NoError(t, err) + + assert.Equal(t, "yes", result["EXPLICIT"]) + assert.Equal(t, "val", result["SAVED"]) + // Explicit wins on conflict. + assert.Equal(t, "explicit", result["OVERLAP"]) +} + +func TestApplyServiceEnvParams_NotChanged(t *testing.T) { + origServiceEnvVars := serviceEnvVars + t.Cleanup(func() { serviceEnvVars = origServiceEnvVars }) + + serviceEnvVars = nil + + cmd := &cobra.Command{} + cmd.Flags().StringSlice("service-env", nil, "") + + saved := &serviceParams{ + ServiceEnvVars: map[string]string{"FROM_SAVED": "val"}, + } + + applyServiceEnvParams(cmd, saved) + + result, err := parseServiceEnvVars(serviceEnvVars) + require.NoError(t, err) + assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result) +} + +func TestApplyServiceEnvParams_ExplicitEmptyClears(t *testing.T) { + origServiceEnvVars := serviceEnvVars + t.Cleanup(func() { serviceEnvVars = origServiceEnvVars }) + + // Simulate --service-env "" which produces [""] in the slice. + serviceEnvVars = []string{""} + + cmd := &cobra.Command{} + cmd.Flags().StringSlice("service-env", nil, "") + require.NoError(t, cmd.Flags().Set("service-env", "")) + + saved := &serviceParams{ + ServiceEnvVars: map[string]string{"OLD_VAR": "should_be_cleared"}, + } + + applyServiceEnvParams(cmd, saved) + + assert.Nil(t, serviceEnvVars, "explicit empty --service-env should clear all saved env vars") +} + +func TestCurrentServiceParams_EmptyEnvVarsAfterParse(t *testing.T) { + origServiceEnvVars := serviceEnvVars + t.Cleanup(func() { serviceEnvVars = origServiceEnvVars }) + + // Simulate --service-env "" which produces [""] in the slice. + serviceEnvVars = []string{""} + + params := currentServiceParams() + + // After parsing, the empty string is skipped, resulting in an empty map. + // The map should still be set (not nil) so it overwrites saved values. + assert.NotNil(t, params.ServiceEnvVars, "empty env vars should produce empty map, not nil") + assert.Empty(t, params.ServiceEnvVars, "no valid env vars should be parsed from empty string") +} + +// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are +// referenced in both currentServiceParams() and applyServiceParams(). If a new field is +// added to serviceParams but not wired into these functions, this test fails. +func TestServiceParams_FieldsCoveredInFunctions(t *testing.T) { + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, "service_params.go", nil, 0) + require.NoError(t, err) + + // Collect all JSON field names from the serviceParams struct. + structFields := extractStructJSONFields(t, file, "serviceParams") + require.NotEmpty(t, structFields, "failed to find serviceParams struct fields") + + // Collect field names referenced in currentServiceParams and applyServiceParams. + currentFields := extractFuncFieldRefs(t, file, "currentServiceParams", structFields) + applyFields := extractFuncFieldRefs(t, file, "applyServiceParams", structFields) + // applyServiceEnvParams handles ServiceEnvVars indirectly. + applyEnvFields := extractFuncFieldRefs(t, file, "applyServiceEnvParams", structFields) + for k, v := range applyEnvFields { + applyFields[k] = v + } + + for _, field := range structFields { + assert.Contains(t, currentFields, field, + "serviceParams field %q is not captured in currentServiceParams()", field) + assert.Contains(t, applyFields, field, + "serviceParams field %q is not restored in applyServiceParams()/applyServiceEnvParams()", field) + } +} + +// TestServiceParams_BuildArgsCoversAllFlags ensures that buildServiceArguments references +// all serviceParams fields that should become CLI args. ServiceEnvVars is excluded because +// it flows through newSVCConfig() EnvVars, not CLI args. +func TestServiceParams_BuildArgsCoversAllFlags(t *testing.T) { + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, "service_params.go", nil, 0) + require.NoError(t, err) + + structFields := extractStructJSONFields(t, file, "serviceParams") + require.NotEmpty(t, structFields) + + installerFile, err := parser.ParseFile(fset, "service_installer.go", nil, 0) + require.NoError(t, err) + + // Fields that are handled outside of buildServiceArguments (env vars go through newSVCConfig). + fieldsNotInArgs := map[string]bool{ + "ServiceEnvVars": true, + } + + buildFields := extractFuncGlobalRefs(t, installerFile, "buildServiceArguments") + + // Forward: every struct field must appear in buildServiceArguments. + for _, field := range structFields { + if fieldsNotInArgs[field] { + continue + } + globalVar := fieldToGlobalVar(field) + assert.Contains(t, buildFields, globalVar, + "serviceParams field %q (global %q) is not referenced in buildServiceArguments()", field, globalVar) + } + + // Reverse: every service-related global used in buildServiceArguments must + // have a corresponding serviceParams field. This catches a developer adding + // a new flag to buildServiceArguments without adding it to the struct. + globalToField := make(map[string]string, len(structFields)) + for _, field := range structFields { + globalToField[fieldToGlobalVar(field)] = field + } + // Identifiers in buildServiceArguments that are not service params + // (builtins, boilerplate, loop variables). + nonParamGlobals := map[string]bool{ + "args": true, "append": true, "string": true, "_": true, + "logFile": true, // range variable over logFiles + } + for ref := range buildFields { + if nonParamGlobals[ref] { + continue + } + _, inStruct := globalToField[ref] + assert.True(t, inStruct, + "buildServiceArguments() references global %q which has no corresponding serviceParams field", ref) + } +} + +// extractStructJSONFields returns field names from a named struct type. +func extractStructJSONFields(t *testing.T, file *ast.File, structName string) []string { + t.Helper() + var fields []string + ast.Inspect(file, func(n ast.Node) bool { + ts, ok := n.(*ast.TypeSpec) + if !ok || ts.Name.Name != structName { + return true + } + st, ok := ts.Type.(*ast.StructType) + if !ok { + return false + } + for _, f := range st.Fields.List { + if len(f.Names) > 0 { + fields = append(fields, f.Names[0].Name) + } + } + return false + }) + return fields +} + +// extractFuncFieldRefs returns which of the given field names appear inside the +// named function, either as selector expressions (params.FieldName) or as +// composite literal keys (&serviceParams{FieldName: ...}). +func extractFuncFieldRefs(t *testing.T, file *ast.File, funcName string, fields []string) map[string]bool { + t.Helper() + fieldSet := make(map[string]bool, len(fields)) + for _, f := range fields { + fieldSet[f] = true + } + + found := make(map[string]bool) + fn := findFuncDecl(file, funcName) + require.NotNil(t, fn, "function %s not found", funcName) + + ast.Inspect(fn.Body, func(n ast.Node) bool { + switch v := n.(type) { + case *ast.SelectorExpr: + if fieldSet[v.Sel.Name] { + found[v.Sel.Name] = true + } + case *ast.KeyValueExpr: + if ident, ok := v.Key.(*ast.Ident); ok && fieldSet[ident.Name] { + found[ident.Name] = true + } + } + return true + }) + return found +} + +// extractFuncGlobalRefs returns all identifier names referenced in the named function body. +func extractFuncGlobalRefs(t *testing.T, file *ast.File, funcName string) map[string]bool { + t.Helper() + fn := findFuncDecl(file, funcName) + require.NotNil(t, fn, "function %s not found", funcName) + + refs := make(map[string]bool) + ast.Inspect(fn.Body, func(n ast.Node) bool { + if ident, ok := n.(*ast.Ident); ok { + refs[ident.Name] = true + } + return true + }) + return refs +} + +func findFuncDecl(file *ast.File, name string) *ast.FuncDecl { + for _, decl := range file.Decls { + fn, ok := decl.(*ast.FuncDecl) + if ok && fn.Name.Name == name { + return fn + } + } + return nil +} + +// fieldToGlobalVar maps serviceParams field names to the package-level variable +// names used in buildServiceArguments and applyServiceParams. +func fieldToGlobalVar(field string) string { + m := map[string]string{ + "LogLevel": "logLevel", + "DaemonAddr": "daemonAddr", + "ManagementURL": "managementURL", + "ConfigPath": "configPath", + "LogFiles": "logFiles", + "DisableProfiles": "profilesDisabled", + "DisableUpdateSettings": "updateSettingsDisabled", + "EnableCapture": "captureEnabled", + "DisableNetworks": "networksDisabled", + "ServiceEnvVars": "serviceEnvVars", + } + if v, ok := m[field]; ok { + return v + } + // Default: lowercase first letter. + return strings.ToLower(field[:1]) + field[1:] +} + +func TestEnvMapToSlice(t *testing.T) { + m := map[string]string{"A": "1", "B": "2"} + s := envMapToSlice(m) + assert.Len(t, s, 2) + assert.Contains(t, s, "A=1") + assert.Contains(t, s, "B=2") +} + +func TestEnvMapToSlice_Empty(t *testing.T) { + s := envMapToSlice(map[string]string{}) + assert.Empty(t, s) +} diff --git a/client/cmd/service_test.go b/client/cmd/service_test.go index 6d75ca524..ce6f71550 100644 --- a/client/cmd/service_test.go +++ b/client/cmd/service_test.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "os" + "os/signal" "runtime" + "syscall" "testing" "time" @@ -13,6 +15,22 @@ import ( "github.com/stretchr/testify/require" ) +// TestMain intercepts when this test binary is run as a daemon subprocess. +// On FreeBSD, the rc.d service script runs the binary via daemon(8) -r with +// "service run ..." arguments. Since the test binary can't handle cobra CLI +// args, it exits immediately, causing daemon -r to respawn rapidly until +// hitting the rate limit and exiting. This makes service restart unreliable. +// Blocking here keeps the subprocess alive until the init system sends SIGTERM. +func TestMain(m *testing.M) { + if len(os.Args) > 2 && os.Args[1] == "service" && os.Args[2] == "run" { + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGTERM, os.Interrupt) + <-sig + return + } + os.Exit(m.Run()) +} + const ( serviceStartTimeout = 10 * time.Second serviceStopTimeout = 5 * time.Second @@ -79,6 +97,34 @@ func TestServiceLifecycle(t *testing.T) { logLevel = "info" daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir) + // Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run. + t.Cleanup(func() { + cfg, err := newSVCConfig() + if err != nil { + t.Errorf("cleanup: create service config: %v", err) + return + } + ctxSvc, cancel := context.WithCancel(context.Background()) + defer cancel() + s, err := newSVC(newProgram(ctxSvc, cancel), cfg) + if err != nil { + t.Errorf("cleanup: create service: %v", err) + return + } + + // If the subtests already cleaned up, there's nothing to do. + if _, err := s.Status(); err != nil { + return + } + + if err := s.Stop(); err != nil { + t.Errorf("cleanup: stop service: %v", err) + } + if err := s.Uninstall(); err != nil { + t.Errorf("cleanup: uninstall service: %v", err) + } + }) + ctx := context.Background() t.Run("Install", func(t *testing.T) { diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 4bda33e65..c24965e8d 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -13,6 +13,8 @@ import ( "github.com/netbirdio/management-integrations/integrations" + nbcache "github.com/netbirdio/netbird/management/server/cache" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/modules/peers" @@ -100,9 +102,16 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp jobManager := job.NewJobManager(nil, store, peersmanager) - iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore) + ctx := context.Background() - metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + t.Fatal(err) + } + + iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore) + + metrics, err := telemetry.NewDefaultAppMetrics(ctx) require.NoError(t, err) settingsMockManager := settings.NewMockManager(ctrl) @@ -113,12 +122,11 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp Return(&types.Settings{}, nil). AnyTimes() - ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config) - accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + accountManager, err := mgmt.BuildManager(ctx, config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore) if err != nil { t.Fatal(err) } @@ -127,7 +135,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp if err != nil { t.Fatal(err) } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil) if err != nil { t.Fatal(err) } @@ -152,7 +160,7 @@ func startClientDaemon( s := grpc.NewServer() server := client.New(ctx, - "", "", false, false) + "", "", false, false, false, false) if err := server.Start(); err != nil { t.Fatal(err) } diff --git a/client/embed/capture.go b/client/embed/capture.go new file mode 100644 index 000000000..30f9b496f --- /dev/null +++ b/client/embed/capture.go @@ -0,0 +1,65 @@ +package embed + +import ( + "io" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/util/capture" +) + +// CaptureOptions configures a packet capture session. +type CaptureOptions struct { + // Output receives pcap-formatted data. Nil disables pcap output. + Output io.Writer + // TextOutput receives human-readable packet summaries. Nil disables text output. + TextOutput io.Writer + // Filter is a BPF-like filter expression (e.g. "host 10.0.0.1 and tcp port 443"). + // Empty captures all packets. + Filter string + // Verbose adds seq/ack, TTL, window, and total length to text output. + Verbose bool + // ASCII dumps transport payload as printable ASCII after each packet line. + ASCII bool +} + +// CaptureStats reports capture session counters. +type CaptureStats struct { + Packets int64 + Bytes int64 + Dropped int64 +} + +// CaptureSession represents an active packet capture. Call Stop to end the +// capture and flush buffered packets. +type CaptureSession struct { + sess *capture.Session + engine *internal.Engine +} + +// Stop ends the capture, flushes remaining packets, and detaches from the device. +// Safe to call multiple times. +func (cs *CaptureSession) Stop() { + if cs.engine != nil { + _ = cs.engine.SetCapture(nil) + cs.engine = nil + } + if cs.sess != nil { + cs.sess.Stop() + } +} + +// Stats returns current capture counters. +func (cs *CaptureSession) Stats() CaptureStats { + s := cs.sess.Stats() + return CaptureStats{ + Packets: s.Packets, + Bytes: s.Bytes, + Dropped: s.Dropped, + } +} + +// Done returns a channel that is closed when the capture's writer goroutine +// has fully exited and all buffered packets have been flushed. +func (cs *CaptureSession) Done() <-chan struct{} { + return cs.sess.Done() +} diff --git a/client/embed/embed.go b/client/embed/embed.go index 9fa797f18..baa1d94d6 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -24,6 +24,7 @@ import ( "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/shared/management/domain" mgmProto "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/util/capture" ) var ( @@ -33,14 +34,14 @@ var ( ErrConfigNotInitialized = errors.New("config not initialized") ) -// PeerConnStatus is a peer's connection status. -type PeerConnStatus = peer.ConnStatus - const ( // PeerStatusConnected indicates the peer is in connected state. PeerStatusConnected = peer.StatusConnected ) +// PeerConnStatus is a peer's connection status. +type PeerConnStatus = peer.ConnStatus + // Client manages a netbird embedded client instance. type Client struct { deviceName string @@ -65,7 +66,7 @@ type Options struct { PrivateKey string // ManagementURL overrides the default management server URL ManagementURL string - // PreSharedKey is the pre-shared key for the WireGuard interface + // PreSharedKey is the pre-shared key for the tunnel interface PreSharedKey string // LogOutput is the output destination for logs (defaults to os.Stderr if nil) LogOutput io.Writer @@ -81,9 +82,9 @@ type Options struct { DisableClientRoutes bool // BlockInbound blocks all inbound connections from peers BlockInbound bool - // WireguardPort is the port for the WireGuard interface. Use 0 for a random port. + // WireguardPort is the port for the tunnel interface. Use 0 for a random port. WireguardPort *int - // MTU is the MTU for the WireGuard interface. + // MTU is the MTU for the tunnel interface. // Valid values are in the range 576..8192 bytes. // If non-nil, this value overrides any value stored in the config file. // If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280. @@ -375,6 +376,32 @@ func (c *Client) NewHTTPClient() *http.Client { } } +// Expose exposes a local service via the NetBird reverse proxy, making it accessible through a public URL. +// It returns an ExposeSession. Call Wait on the session to keep it alive. +func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession, error) { + engine, err := c.getEngine() + if err != nil { + return nil, err + } + + mgr := engine.GetExposeManager() + if mgr == nil { + return nil, fmt.Errorf("expose manager not available") + } + + resp, err := mgr.Expose(ctx, req) + if err != nil { + return nil, fmt.Errorf("expose: %w", err) + } + + return &ExposeSession{ + Domain: resp.Domain, + ServiceName: resp.ServiceName, + ServiceURL: resp.ServiceURL, + mgr: mgr, + }, nil +} + // Status returns the current status of the client. func (c *Client) Status() (peer.FullStatus, error) { c.mu.Lock() @@ -443,6 +470,52 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error { return sshcommon.VerifyHostKey(storedKey, key, peerAddress) } +// StartCapture begins capturing packets on this client's tunnel device. +// Only one capture can be active at a time; starting a new one stops the previous. +// Call StopCapture (or CaptureSession.Stop) to end it. +func (c *Client) StartCapture(opts CaptureOptions) (*CaptureSession, error) { + engine, err := c.getEngine() + if err != nil { + return nil, err + } + + var matcher capture.Matcher + if opts.Filter != "" { + m, err := capture.ParseFilter(opts.Filter) + if err != nil { + return nil, fmt.Errorf("parse filter: %w", err) + } + matcher = m + } + + sess, err := capture.NewSession(capture.Options{ + Output: opts.Output, + TextOutput: opts.TextOutput, + Matcher: matcher, + Verbose: opts.Verbose, + ASCII: opts.ASCII, + }) + if err != nil { + return nil, fmt.Errorf("create capture session: %w", err) + } + + if err := engine.SetCapture(sess); err != nil { + sess.Stop() + return nil, fmt.Errorf("set capture: %w", err) + } + + return &CaptureSession{sess: sess, engine: engine}, nil +} + +// StopCapture stops the active capture session if one is running. +func (c *Client) StopCapture() error { + engine, err := c.getEngine() + if err != nil { + return err + } + return engine.SetCapture(nil) +} + // getEngine safely retrieves the engine from the client with proper locking. // Returns ErrClientNotStarted if the client is not started. // Returns ErrEngineNotStarted if the engine is not available. diff --git a/client/embed/expose.go b/client/embed/expose.go new file mode 100644 index 000000000..825bb90ee --- /dev/null +++ b/client/embed/expose.go @@ -0,0 +1,45 @@ +package embed + +import ( + "context" + "errors" + + "github.com/netbirdio/netbird/client/internal/expose" +) + +const ( + // ExposeProtocolHTTP exposes the service as HTTP. + ExposeProtocolHTTP = expose.ProtocolHTTP + // ExposeProtocolHTTPS exposes the service as HTTPS. + ExposeProtocolHTTPS = expose.ProtocolHTTPS + // ExposeProtocolTCP exposes the service as TCP. + ExposeProtocolTCP = expose.ProtocolTCP + // ExposeProtocolUDP exposes the service as UDP. + ExposeProtocolUDP = expose.ProtocolUDP + // ExposeProtocolTLS exposes the service as TLS. + ExposeProtocolTLS = expose.ProtocolTLS +) + +// ExposeRequest is a request to expose a local service via the NetBird reverse proxy. +type ExposeRequest = expose.Request + +// ExposeProtocolType represents the protocol used for exposing a service. +type ExposeProtocolType = expose.ProtocolType + +// ExposeSession represents an active expose session. Use Wait to block until the session ends. +type ExposeSession struct { + Domain string + ServiceName string + ServiceURL string + + mgr *expose.Manager +} + +// Wait blocks while keeping the expose session alive. +// It returns when ctx is cancelled or a keep-alive error occurs, then terminates the session. +func (s *ExposeSession) Wait(ctx context.Context) error { + if s == nil || s.mgr == nil { + return errors.New("expose session is not initialized") + } + return s.mgr.KeepAlive(ctx, s.Domain) +} diff --git a/client/firewall/create_linux.go b/client/firewall/create_linux.go index 12dcaee8a..d916ebad4 100644 --- a/client/firewall/create_linux.go +++ b/client/firewall/create_linux.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "os" + "strconv" "github.com/coreos/go-iptables/iptables" "github.com/google/nftables" @@ -35,20 +36,34 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK" type FWType int func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) { - // on the linux system we try to user nftables or iptables - // in any case, because we need to allow netbird interface traffic - // so we use AllowNetbird traffic from these firewall managers - // for the userspace packet filtering firewall + // We run in userspace mode and force userspace firewall was requested. We don't attempt native firewall. + if iface.IsUserspaceBind() && forceUserspaceFirewall() { + log.Info("forcing userspace firewall") + return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu) + } + + // Use native firewall for either kernel or userspace, the interface appears identical to netfilter fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu) + // Kernel cannot fall back to anything else, need to return error if !iface.IsUserspaceBind() { return fm, err } + // Fall back to the userspace packet filter if native is unavailable if err != nil { log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) + return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu) } - return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu) + + // Native firewall handles packet filtering, but the userspace WireGuard bind + // needs a device filter for DNS interception hooks. Install a minimal + // hooks-only filter that passes all traffic through to the kernel firewall. + if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil { + log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err) + } + + return fm, nil } func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) { @@ -160,3 +175,17 @@ func isIptablesClientAvailable(client *iptables.IPTables) bool { _, err := client.ListChains("filter") return err == nil } + +func forceUserspaceFirewall() bool { + val := os.Getenv(EnvForceUserspaceFirewall) + if val == "" { + return false + } + + force, err := strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", EnvForceUserspaceFirewall, err) + return false + } + return force +} diff --git a/client/firewall/firewalld/firewalld.go b/client/firewall/firewalld/firewalld.go new file mode 100644 index 000000000..188ea61dd --- /dev/null +++ b/client/firewall/firewalld/firewalld.go @@ -0,0 +1,11 @@ +// Package firewalld integrates with the firewalld daemon so NetBird can place +// its wg interface into firewalld's "trusted" zone. This is required because +// firewalld's nftables chains are created with NFT_CHAIN_OWNER on recent +// versions, which returns EPERM to any other process that tries to insert +// rules into them. The workaround mirrors what Tailscale does: let firewalld +// itself add the accept rules to its own chains by trusting the interface. +package firewalld + +// TrustedZone is the firewalld zone name used for interfaces whose traffic +// should bypass firewalld filtering. +const TrustedZone = "trusted" diff --git a/client/firewall/firewalld/firewalld_linux.go b/client/firewall/firewalld/firewalld_linux.go new file mode 100644 index 000000000..924a04b0a --- /dev/null +++ b/client/firewall/firewalld/firewalld_linux.go @@ -0,0 +1,260 @@ +//go:build linux + +package firewalld + +import ( + "context" + "errors" + "fmt" + "os/exec" + "strings" + "sync" + "time" + + "github.com/godbus/dbus/v5" + log "github.com/sirupsen/logrus" +) + +const ( + dbusDest = "org.fedoraproject.FirewallD1" + dbusPath = "/org/fedoraproject/FirewallD1" + dbusRootIface = "org.fedoraproject.FirewallD1" + dbusZoneIface = "org.fedoraproject.FirewallD1.zone" + + errZoneAlreadySet = "ZONE_ALREADY_SET" + errAlreadyEnabled = "ALREADY_ENABLED" + errUnknownIface = "UNKNOWN_INTERFACE" + errNotEnabled = "NOT_ENABLED" + + // callTimeout bounds each individual DBus or firewall-cmd invocation. + // A fresh context is created for each call so a slow DBus probe can't + // exhaust the deadline before the firewall-cmd fallback gets to run. + callTimeout = 3 * time.Second +) + +var ( + errDBusUnavailable = errors.New("firewalld dbus unavailable") + + // trustLogOnce ensures the "added to trusted zone" message is logged at + // Info level only for the first successful add per process; repeat adds + // from other init paths are quieter. + trustLogOnce sync.Once + + parentCtxMu sync.RWMutex + parentCtx context.Context = context.Background() +) + +// SetParentContext installs a parent context whose cancellation aborts any +// in-flight TrustInterface call. It does not affect UntrustInterface, which +// always uses a fresh Background-rooted timeout so cleanup can still run +// during engine shutdown when the engine context is already cancelled. +func SetParentContext(ctx context.Context) { + parentCtxMu.Lock() + parentCtx = ctx + parentCtxMu.Unlock() +} + +func getParentContext() context.Context { + parentCtxMu.RLock() + defer parentCtxMu.RUnlock() + return parentCtx +} + +// TrustInterface places iface into firewalld's trusted zone if firewalld is +// running. It is idempotent and best-effort: errors are returned so callers +// can log, but a non-running firewalld is not an error. Only the first +// successful call per process logs at Info. Respects the parent context set +// via SetParentContext so startup-time cancellation unblocks it. +func TrustInterface(iface string) error { + parent := getParentContext() + if !isRunning(parent) { + return nil + } + if err := addTrusted(parent, iface); err != nil { + return fmt.Errorf("add %s to firewalld trusted zone: %w", iface, err) + } + trustLogOnce.Do(func() { + log.Infof("added %s to firewalld trusted zone", iface) + }) + log.Debugf("firewalld: ensured %s is in trusted zone", iface) + return nil +} + +// UntrustInterface removes iface from firewalld's trusted zone if firewalld +// is running. Idempotent. Uses a Background-rooted timeout so it still runs +// during shutdown after the engine context has been cancelled. +func UntrustInterface(iface string) error { + if !isRunning(context.Background()) { + return nil + } + if err := removeTrusted(context.Background(), iface); err != nil { + return fmt.Errorf("remove %s from firewalld trusted zone: %w", iface, err) + } + return nil +} + +func newCallContext(parent context.Context) (context.Context, context.CancelFunc) { + return context.WithTimeout(parent, callTimeout) +} + +func isRunning(parent context.Context) bool { + ctx, cancel := newCallContext(parent) + ok, err := isRunningDBus(ctx) + cancel() + if err == nil { + return ok + } + if errors.Is(err, errDBusUnavailable) || errors.Is(err, context.DeadlineExceeded) { + ctx, cancel = newCallContext(parent) + defer cancel() + return isRunningCLI(ctx) + } + return false +} + +func addTrusted(parent context.Context, iface string) error { + ctx, cancel := newCallContext(parent) + err := addDBus(ctx, iface) + cancel() + if err == nil { + return nil + } + if !errors.Is(err, errDBusUnavailable) { + log.Debugf("firewalld: dbus add failed, falling back to firewall-cmd: %v", err) + } + ctx, cancel = newCallContext(parent) + defer cancel() + return addCLI(ctx, iface) +} + +func removeTrusted(parent context.Context, iface string) error { + ctx, cancel := newCallContext(parent) + err := removeDBus(ctx, iface) + cancel() + if err == nil { + return nil + } + if !errors.Is(err, errDBusUnavailable) { + log.Debugf("firewalld: dbus remove failed, falling back to firewall-cmd: %v", err) + } + ctx, cancel = newCallContext(parent) + defer cancel() + return removeCLI(ctx, iface) +} + +func isRunningDBus(ctx context.Context) (bool, error) { + conn, err := dbus.SystemBus() + if err != nil { + return false, fmt.Errorf("%w: %v", errDBusUnavailable, err) + } + obj := conn.Object(dbusDest, dbusPath) + + var zone string + if err := obj.CallWithContext(ctx, dbusRootIface+".getDefaultZone", 0).Store(&zone); err != nil { + return false, fmt.Errorf("firewalld getDefaultZone: %w", err) + } + return true, nil +} + +func isRunningCLI(ctx context.Context) bool { + if _, err := exec.LookPath("firewall-cmd"); err != nil { + return false + } + return exec.CommandContext(ctx, "firewall-cmd", "--state").Run() == nil +} + +func addDBus(ctx context.Context, iface string) error { + conn, err := dbus.SystemBus() + if err != nil { + return fmt.Errorf("%w: %v", errDBusUnavailable, err) + } + obj := conn.Object(dbusDest, dbusPath) + + call := obj.CallWithContext(ctx, dbusZoneIface+".addInterface", 0, TrustedZone, iface) + if call.Err == nil { + return nil + } + + if dbusErrContains(call.Err, errAlreadyEnabled) { + return nil + } + + if dbusErrContains(call.Err, errZoneAlreadySet) { + move := obj.CallWithContext(ctx, dbusZoneIface+".changeZoneOfInterface", 0, TrustedZone, iface) + if move.Err != nil { + return fmt.Errorf("firewalld changeZoneOfInterface: %w", move.Err) + } + return nil + } + + return fmt.Errorf("firewalld addInterface: %w", call.Err) +} + +func removeDBus(ctx context.Context, iface string) error { + conn, err := dbus.SystemBus() + if err != nil { + return fmt.Errorf("%w: %v", errDBusUnavailable, err) + } + obj := conn.Object(dbusDest, dbusPath) + + call := obj.CallWithContext(ctx, dbusZoneIface+".removeInterface", 0, TrustedZone, iface) + if call.Err == nil { + return nil + } + + if dbusErrContains(call.Err, errUnknownIface) || dbusErrContains(call.Err, errNotEnabled) { + return nil + } + + return fmt.Errorf("firewalld removeInterface: %w", call.Err) +} + +func addCLI(ctx context.Context, iface string) error { + if _, err := exec.LookPath("firewall-cmd"); err != nil { + return fmt.Errorf("firewall-cmd not available: %w", err) + } + + // --change-interface (no --permanent) binds the interface for the + // current runtime only; we do not want membership to persist across + // reboots because netbird re-asserts it on every startup. + out, err := exec.CommandContext(ctx, + "firewall-cmd", "--zone="+TrustedZone, "--change-interface="+iface, + ).CombinedOutput() + if err != nil { + return fmt.Errorf("firewall-cmd change-interface: %w: %s", err, strings.TrimSpace(string(out))) + } + return nil +} + +func removeCLI(ctx context.Context, iface string) error { + if _, err := exec.LookPath("firewall-cmd"); err != nil { + return fmt.Errorf("firewall-cmd not available: %w", err) + } + + out, err := exec.CommandContext(ctx, + "firewall-cmd", "--zone="+TrustedZone, "--remove-interface="+iface, + ).CombinedOutput() + if err != nil { + msg := strings.TrimSpace(string(out)) + if strings.Contains(msg, errUnknownIface) || strings.Contains(msg, errNotEnabled) { + return nil + } + return fmt.Errorf("firewall-cmd remove-interface: %w: %s", err, msg) + } + return nil +} + +func dbusErrContains(err error, code string) bool { + if err == nil { + return false + } + var de dbus.Error + if errors.As(err, &de) { + for _, b := range de.Body { + if s, ok := b.(string); ok && strings.Contains(s, code) { + return true + } + } + } + return strings.Contains(err.Error(), code) +} diff --git a/client/firewall/firewalld/firewalld_linux_test.go b/client/firewall/firewalld/firewalld_linux_test.go new file mode 100644 index 000000000..d812745fc --- /dev/null +++ b/client/firewall/firewalld/firewalld_linux_test.go @@ -0,0 +1,49 @@ +//go:build linux + +package firewalld + +import ( + "errors" + "testing" + + "github.com/godbus/dbus/v5" +) + +func TestDBusErrContains(t *testing.T) { + tests := []struct { + name string + err error + code string + want bool + }{ + {"nil error", nil, errZoneAlreadySet, false}, + {"plain error match", errors.New("ZONE_ALREADY_SET: wt0"), errZoneAlreadySet, true}, + {"plain error miss", errors.New("something else"), errZoneAlreadySet, false}, + { + "dbus.Error body match", + dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"ZONE_ALREADY_SET: wt0"}}, + errZoneAlreadySet, + true, + }, + { + "dbus.Error body miss", + dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"INVALID_INTERFACE"}}, + errAlreadyEnabled, + false, + }, + { + "dbus.Error non-string body falls back to Error()", + dbus.Error{Name: "x", Body: []any{123}}, + "x", + true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := dbusErrContains(tc.err, tc.code) + if got != tc.want { + t.Fatalf("dbusErrContains(%v, %q) = %v; want %v", tc.err, tc.code, got, tc.want) + } + }) + } +} diff --git a/client/firewall/firewalld/firewalld_other.go b/client/firewall/firewalld/firewalld_other.go new file mode 100644 index 000000000..cfa28221d --- /dev/null +++ b/client/firewall/firewalld/firewalld_other.go @@ -0,0 +1,25 @@ +//go:build !linux + +package firewalld + +import "context" + +// SetParentContext is a no-op on non-Linux platforms because firewalld only +// runs on Linux. +func SetParentContext(context.Context) { + // intentionally empty: firewalld is a Linux-only daemon +} + +// TrustInterface is a no-op on non-Linux platforms because firewalld only +// runs on Linux. +func TrustInterface(string) error { + // intentionally empty: firewalld is a Linux-only daemon + return nil +} + +// UntrustInterface is a no-op on non-Linux platforms because firewalld only +// runs on Linux. +func UntrustInterface(string) error { + // intentionally empty: firewalld is a Linux-only daemon + return nil +} diff --git a/client/firewall/iface.go b/client/firewall/iface.go index b83c5f912..491f03269 100644 --- a/client/firewall/iface.go +++ b/client/firewall/iface.go @@ -7,6 +7,12 @@ import ( "github.com/netbirdio/netbird/client/iface/wgaddr" ) +// EnvForceUserspaceFirewall forces the use of the userspace packet filter even when +// native iptables/nftables is available. This only applies when the WireGuard interface +// runs in userspace mode. When set, peer ACLs are handled by USPFilter instead of +// kernel netfilter rules. +const EnvForceUserspaceFirewall = "NB_FORCE_USERSPACE_FIREWALL" + // IFaceMapper defines subset methods of interface required for manager type IFaceMapper interface { Name() string diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index d83798f09..e629f7881 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -21,6 +21,10 @@ const ( // rules chains contains the effective ACL rules chainNameInputRules = "NETBIRD-ACL-INPUT" + + // mangleFwdKey is the entries map key for mangle FORWARD guard rules that prevent + // external DNAT from bypassing ACL rules. + mangleFwdKey = "MANGLE-FORWARD" ) type aclEntries map[string][][]string @@ -274,6 +278,12 @@ func (m *aclManager) cleanChains() error { } } + for _, rule := range m.entries[mangleFwdKey] { + if err := m.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil { + log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err) + } + } + for _, ipsetName := range m.ipsetStore.ipsetNames() { if err := m.flushIPSet(ipsetName); err != nil { if errors.Is(err, ipset.ErrSetNotExist) { @@ -303,6 +313,10 @@ func (m *aclManager) createDefaultChains() error { } for chainName, rules := range m.entries { + // mangle FORWARD guard rules are handled separately below + if chainName == mangleFwdKey { + continue + } for _, rule := range rules { if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil { log.Debugf("failed to create input chain jump rule: %s", err) @@ -322,6 +336,13 @@ func (m *aclManager) createDefaultChains() error { } clear(m.optionalEntries) + // Insert mangle FORWARD guard rules to prevent external DNAT bypass. + for _, rule := range m.entries[mangleFwdKey] { + if err := m.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil { + log.Errorf("failed to add mangle FORWARD guard rule: %v", err) + } + } + return nil } @@ -343,6 +364,22 @@ func (m *aclManager) seedInitialEntries() { m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT}) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN}) + + // Mangle FORWARD guard: when external DNAT redirects traffic from the wg interface, it + // traverses FORWARD instead of INPUT, bypassing ACL rules. ACCEPT rules in filter FORWARD + // can be inserted above ours. Mangle runs before filter, so these guard rules enforce the + // ACL mark check where it cannot be overridden. + m.appendToEntries(mangleFwdKey, []string{ + "-i", m.wgIface.Name(), + "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", + "-j", "ACCEPT", + }) + m.appendToEntries(mangleFwdKey, []string{ + "-i", m.wgIface.Name(), + "-m", "conntrack", "--ctstate", "DNAT", + "-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), + "-j", "DROP", + }) } func (m *aclManager) seedInitialOptionalEntries() { diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 04c338375..7d8cd7f8c 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/firewall/firewalld" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -33,7 +34,6 @@ type Manager struct { type iFaceMapper interface { Name() string Address() wgaddr.Address - IsUserspaceBind() bool } // Create iptables firewall manager @@ -64,10 +64,9 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { func (m *Manager) Init(stateManager *statemanager.Manager) error { state := &ShutdownState{ InterfaceState: &InterfaceState{ - NameStr: m.wgIface.Name(), - WGAddress: m.wgIface.Address(), - UserspaceBind: m.wgIface.IsUserspaceBind(), - MTU: m.router.mtu, + NameStr: m.wgIface.Name(), + WGAddress: m.wgIface.Address(), + MTU: m.router.mtu, }, } stateManager.RegisterState(state) @@ -88,6 +87,12 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { log.Warnf("raw table not available, notrack rules will be disabled: %v", err) } + // Trust after all fatal init steps so a later failure doesn't leave the + // interface in firewalld's trusted zone without a corresponding Close. + if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil { + log.Warnf("failed to trust interface in firewalld: %v", err) + } + // persist early to ensure cleanup of chains go func() { if err := stateManager.PersistState(context.Background()); err != nil { @@ -193,6 +198,12 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err)) } + // Appending to merr intentionally blocks DeleteState below so ShutdownState + // stays persisted and the crash-recovery path retries firewalld cleanup. + if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil { + merr = multierror.Append(merr, err) + } + // attempt to delete state only if all other operations succeeded if merr == nil { if err := stateManager.DeleteState(&ShutdownState{}); err != nil { @@ -203,12 +214,10 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { return nberrors.FormatErrorOrNil(merr) } -// AllowNetbird allows netbird interface traffic +// AllowNetbird allows netbird interface traffic. +// This is called when USPFilter wraps the native firewall, adding blanket accept +// rules so that packet filtering is handled in userspace instead of by netfilter. func (m *Manager) AllowNetbird() error { - if !m.wgIface.IsUserspaceBind() { - return nil - } - _, err := m.AddPeerFiltering( nil, net.IP{0, 0, 0, 0}, @@ -221,6 +230,11 @@ func (m *Manager) AllowNetbird() error { if err != nil { return fmt.Errorf("allow netbird interface traffic: %w", err) } + + if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil { + log.Warnf("failed to trust interface in firewalld: %v", err) + } + return nil } @@ -286,6 +300,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) } +// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. +func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort) +} + +// RemoveOutputDNAT removes an OUTPUT chain DNAT rule. +func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort) +} + const ( chainNameRaw = "NETBIRD-RAW" chainOUTPUT = "OUTPUT" diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index ee47a27c0..cc4bda0e0 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -47,8 +47,6 @@ func (i *iFaceMock) Address() wgaddr.Address { panic("AddressFunc is not set") } -func (i *iFaceMock) IsUserspaceBind() bool { return false } - func TestIptablesManager(t *testing.T) { ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err) diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 1fe4c149f..a7c4f67dd 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -36,6 +36,7 @@ const ( chainRTFWDOUT = "NETBIRD-RT-FWD-OUT" chainRTPRE = "NETBIRD-RT-PRE" chainRTRDR = "NETBIRD-RT-RDR" + chainNATOutput = "NETBIRD-NAT-OUTPUT" chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP" routingFinalForwardJump = "ACCEPT" routingFinalNatJump = "MASQUERADE" @@ -43,6 +44,7 @@ const ( jumpManglePre = "jump-mangle-pre" jumpNatPre = "jump-nat-pre" jumpNatPost = "jump-nat-post" + jumpNatOutput = "jump-nat-output" jumpMSSClamp = "jump-mss-clamp" markManglePre = "mark-mangle-pre" markManglePost = "mark-mangle-post" @@ -387,6 +389,14 @@ func (r *router) cleanUpDefaultForwardRules() error { } log.Debug("flushing routing related tables") + + // Remove jump rules from built-in chains before deleting custom chains, + // otherwise the chain deletion fails with "device or resource busy". + jumpRule := []string{"-j", chainNATOutput} + if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil { + log.Debugf("clean OUTPUT jump rule: %v", err) + } + for _, chainInfo := range []struct { chain string table string @@ -396,6 +406,7 @@ func (r *router) cleanUpDefaultForwardRules() error { {chainRTPRE, tableMangle}, {chainRTNAT, tableNat}, {chainRTRDR, tableNat}, + {chainNATOutput, tableNat}, {chainRTMSSCLAMP, tableMangle}, } { ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain) @@ -970,6 +981,81 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto return nil } +// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use. +func (r *router) ensureNATOutputChain() error { + if _, exists := r.rules[jumpNatOutput]; exists { + return nil + } + + chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput) + if err != nil { + return fmt.Errorf("check chain %s: %w", chainNATOutput, err) + } + if !chainExists { + if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil { + return fmt.Errorf("create chain %s: %w", chainNATOutput, err) + } + } + + jumpRule := []string{"-j", chainNATOutput} + if err := r.iptablesClient.Insert(tableNat, "OUTPUT", 1, jumpRule...); err != nil { + if !chainExists { + if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil { + log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr) + } + } + return fmt.Errorf("add OUTPUT jump rule: %w", err) + } + r.rules[jumpNatOutput] = jumpRule + + r.updateState() + return nil +} + +// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. +func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if _, exists := r.rules[ruleID]; exists { + return nil + } + + if err := r.ensureNATOutputChain(); err != nil { + return err + } + + dnatRule := []string{ + "-p", strings.ToLower(string(protocol)), + "--dport", strconv.Itoa(int(sourcePort)), + "-d", localAddr.String(), + "-j", "DNAT", + "--to-destination", ":" + strconv.Itoa(int(targetPort)), + } + + if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil { + return fmt.Errorf("add output DNAT rule: %w", err) + } + r.rules[ruleID] = dnatRule + + r.updateState() + return nil +} + +// RemoveOutputDNAT removes an OUTPUT chain DNAT rule. +func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if dnatRule, exists := r.rules[ruleID]; exists { + if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil { + return fmt.Errorf("delete output DNAT rule: %w", err) + } + delete(r.rules, ruleID) + } + + r.updateState() + return nil +} + func applyPort(flag string, port *firewall.Port) []string { if port == nil { return nil diff --git a/client/firewall/iptables/state_linux.go b/client/firewall/iptables/state_linux.go index c88774c1f..121c755e9 100644 --- a/client/firewall/iptables/state_linux.go +++ b/client/firewall/iptables/state_linux.go @@ -9,10 +9,9 @@ import ( ) type InterfaceState struct { - NameStr string `json:"name"` - WGAddress wgaddr.Address `json:"wg_address"` - UserspaceBind bool `json:"userspace_bind"` - MTU uint16 `json:"mtu"` + NameStr string `json:"name"` + WGAddress wgaddr.Address `json:"wg_address"` + MTU uint16 `json:"mtu"` } func (i *InterfaceState) Name() string { @@ -23,10 +22,6 @@ func (i *InterfaceState) Address() wgaddr.Address { return i.WGAddress } -func (i *InterfaceState) IsUserspaceBind() bool { - return i.UserspaceBind -} - type ShutdownState struct { sync.Mutex diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 3511a5463..d65d717b3 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -169,6 +169,14 @@ type Manager interface { // RemoveInboundDNAT removes inbound DNAT rule RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + // AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. + // localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only. + AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + + // RemoveOutputDNAT removes an OUTPUT chain DNAT rule. + // localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only. + RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + // SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic. // This prevents conntrack from interfering with WireGuard proxy communication. SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index f57b28abc..8cd5cc6b3 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -14,6 +14,7 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" + "github.com/netbirdio/netbird/client/firewall/firewalld" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -40,7 +41,6 @@ func getTableName() string { type iFaceMapper interface { Name() string Address() wgaddr.Address - IsUserspaceBind() bool } // Manager of iptables firewall @@ -106,10 +106,9 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { // cleanup using Close() without needing to store specific rules. if err := stateManager.UpdateState(&ShutdownState{ InterfaceState: &InterfaceState{ - NameStr: m.wgIface.Name(), - WGAddress: m.wgIface.Address(), - UserspaceBind: m.wgIface.IsUserspaceBind(), - MTU: m.router.mtu, + NameStr: m.wgIface.Name(), + WGAddress: m.wgIface.Address(), + MTU: m.router.mtu, }, }); err != nil { log.Errorf("failed to update state: %v", err) @@ -205,12 +204,10 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { return m.router.RemoveNatRule(pair) } -// AllowNetbird allows netbird interface traffic +// AllowNetbird allows netbird interface traffic. +// This is called when USPFilter wraps the native firewall, adding blanket accept +// rules so that packet filtering is handled in userspace instead of by netfilter. func (m *Manager) AllowNetbird() error { - if !m.wgIface.IsUserspaceBind() { - return nil - } - m.mutex.Lock() defer m.mutex.Unlock() @@ -221,6 +218,10 @@ func (m *Manager) AllowNetbird() error { return fmt.Errorf("flush allow input netbird rules: %w", err) } + if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil { + log.Warnf("failed to trust interface in firewalld: %v", err) + } + return nil } @@ -346,6 +347,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) } +// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. +func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort) +} + +// RemoveOutputDNAT removes an OUTPUT chain DNAT rule. +func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort) +} + const ( chainNameRawOutput = "netbird-raw-out" chainNameRawPrerouting = "netbird-raw-pre" diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 75b1e2b6c..d48e4ba88 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -52,8 +52,6 @@ func (i *iFaceMock) Address() wgaddr.Address { panic("AddressFunc is not set") } -func (i *iFaceMock) IsUserspaceBind() bool { return false } - func TestNftablesManager(t *testing.T) { // just check on the local interface diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index fde654c20..8cc0d2792 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -19,6 +19,7 @@ import ( "golang.org/x/sys/unix" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/firewall/firewalld" firewall "github.com/netbirdio/netbird/client/firewall/manager" nbid "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" @@ -36,9 +37,12 @@ const ( chainNameRoutingFw = "netbird-rt-fwd" chainNameRoutingNat = "netbird-rt-postrouting" chainNameRoutingRdr = "netbird-rt-redirect" + chainNameNATOutput = "netbird-nat-output" chainNameForward = "FORWARD" chainNameMangleForward = "netbird-mangle-forward" + firewalldTableName = "firewalld" + userDataAcceptForwardRuleIif = "frwacceptiif" userDataAcceptForwardRuleOif = "frwacceptoif" userDataAcceptInputRule = "inputaccept" @@ -132,6 +136,10 @@ func (r *router) Reset() error { merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err)) } + if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil { + merr = multierror.Append(merr, err) + } + if err := r.removeNatPreroutingRules(); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err)) } @@ -279,6 +287,10 @@ func (r *router) createContainers() error { log.Errorf("failed to add accept rules for the forward chain: %s", err) } + if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil { + log.Warnf("failed to trust interface in firewalld: %v", err) + } + if err := r.refreshRulesMap(); err != nil { log.Errorf("failed to refresh rules: %s", err) } @@ -1318,6 +1330,13 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool { return false } + // Skip firewalld-owned chains. Firewalld creates its chains with the + // NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM. + // We delegate acceptance to firewalld by trusting the interface instead. + if chain.Table.Name == firewalldTableName { + return false + } + // Skip all iptables-managed tables in the ip family if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) { return false @@ -1853,6 +1872,130 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto return nil } +// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use. +func (r *router) ensureNATOutputChain() error { + if _, exists := r.chains[chainNameNATOutput]; exists { + return nil + } + + r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{ + Name: chainNameNATOutput, + Table: r.workTable, + Hooknum: nftables.ChainHookOutput, + Priority: nftables.ChainPriorityNATDest, + Type: nftables.ChainTypeNAT, + }) + + if err := r.conn.Flush(); err != nil { + delete(r.chains, chainNameNATOutput) + return fmt.Errorf("create NAT output chain: %w", err) + } + return nil +} + +// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. +func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if _, exists := r.rules[ruleID]; exists { + return nil + } + + if err := r.ensureNATOutputChain(); err != nil { + return err + } + + protoNum, err := protoToInt(protocol) + if err != nil { + return fmt.Errorf("convert protocol to number: %w", err) + } + + exprs := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{protoNum}, + }, + &expr.Payload{ + DestRegister: 2, + Base: expr.PayloadBaseTransportHeader, + Offset: 2, + Len: 2, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 2, + Data: binaryutil.BigEndian.PutUint16(sourcePort), + }, + } + + exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...) + + exprs = append(exprs, + &expr.Immediate{ + Register: 1, + Data: localAddr.AsSlice(), + }, + &expr.Immediate{ + Register: 2, + Data: binaryutil.BigEndian.PutUint16(targetPort), + }, + &expr.NAT{ + Type: expr.NATTypeDestNAT, + Family: uint32(nftables.TableFamilyIPv4), + RegAddrMin: 1, + RegProtoMin: 2, + }, + ) + + dnatRule := &nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameNATOutput], + Exprs: exprs, + UserData: []byte(ruleID), + } + r.conn.AddRule(dnatRule) + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("add output DNAT rule: %w", err) + } + + r.rules[ruleID] = dnatRule + + return nil +} + +// RemoveOutputDNAT removes an OUTPUT chain DNAT rule. +func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + rule, exists := r.rules[ruleID] + if !exists { + return nil + } + + if rule.Handle == 0 { + log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID) + delete(r.rules, ruleID) + return nil + } + + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err) + } + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("flush delete output DNAT rule: %w", err) + } + delete(r.rules, ruleID) + + return nil +} + // applyNetwork generates nftables expressions for networks (CIDR) or sets func (r *router) applyNetwork( network firewall.Network, diff --git a/client/firewall/nftables/state_linux.go b/client/firewall/nftables/state_linux.go index 48b7b3741..462ad2556 100644 --- a/client/firewall/nftables/state_linux.go +++ b/client/firewall/nftables/state_linux.go @@ -8,10 +8,9 @@ import ( ) type InterfaceState struct { - NameStr string `json:"name"` - WGAddress wgaddr.Address `json:"wg_address"` - UserspaceBind bool `json:"userspace_bind"` - MTU uint16 `json:"mtu"` + NameStr string `json:"name"` + WGAddress wgaddr.Address `json:"wg_address"` + MTU uint16 `json:"mtu"` } func (i *InterfaceState) Name() string { @@ -22,10 +21,6 @@ func (i *InterfaceState) Address() wgaddr.Address { return i.WGAddress } -func (i *InterfaceState) IsUserspaceBind() bool { - return i.UserspaceBind -} - type ShutdownState struct { InterfaceState *InterfaceState `json:"interface_state,omitempty"` } diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index 6a6533344..b120cdf12 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -3,6 +3,9 @@ package uspfilter import ( + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/firewall/firewalld" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -16,6 +19,9 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { if m.nativeFirewall != nil { return m.nativeFirewall.Close(stateManager) } + if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil { + log.Warnf("failed to untrust interface in firewalld: %v", err) + } return nil } @@ -24,5 +30,8 @@ func (m *Manager) AllowNetbird() error { if m.nativeFirewall != nil { return m.nativeFirewall.AllowNetbird() } + if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil { + log.Warnf("failed to trust interface in firewalld: %v", err) + } return nil } diff --git a/client/firewall/uspfilter/common/hooks.go b/client/firewall/uspfilter/common/hooks.go new file mode 100644 index 000000000..dadd800dd --- /dev/null +++ b/client/firewall/uspfilter/common/hooks.go @@ -0,0 +1,37 @@ +package common + +import ( + "net/netip" + "sync/atomic" +) + +// PacketHook stores a registered hook for a specific IP:port. +type PacketHook struct { + IP netip.Addr + Port uint16 + Fn func([]byte) bool +} + +// HookMatches checks if a packet's destination matches the hook and invokes it. +func HookMatches(h *PacketHook, dstIP netip.Addr, dport uint16, packetData []byte) bool { + if h == nil { + return false + } + if h.IP == dstIP && h.Port == dport { + return h.Fn(packetData) + } + return false +} + +// SetHook atomically stores a hook, handling nil removal. +func SetHook(ptr *atomic.Pointer[PacketHook], ip netip.Addr, dPort uint16, hook func([]byte) bool) { + if hook == nil { + ptr.Store(nil) + return + } + ptr.Store(&PacketHook{ + IP: ip, + Port: dPort, + Fn: hook, + }) +} diff --git a/client/firewall/uspfilter/common/iface.go b/client/firewall/uspfilter/common/iface.go index 7296953db..9c06eb3f7 100644 --- a/client/firewall/uspfilter/common/iface.go +++ b/client/firewall/uspfilter/common/iface.go @@ -9,6 +9,7 @@ import ( // IFaceMapper defines subset methods of interface required for manager type IFaceMapper interface { + Name() string SetFilter(device.PacketFilter) error Address() wgaddr.Address GetWGDevice() *wgdevice.Device diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index df2e274eb..3787e63a8 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -115,12 +115,13 @@ type Manager struct { localipmanager *localIPManager - udpTracker *conntrack.UDPTracker - icmpTracker *conntrack.ICMPTracker - tcpTracker *conntrack.TCPTracker - forwarder atomic.Pointer[forwarder.Forwarder] - logger *nblog.Logger - flowLogger nftypes.FlowLogger + udpTracker *conntrack.UDPTracker + icmpTracker *conntrack.ICMPTracker + tcpTracker *conntrack.TCPTracker + forwarder atomic.Pointer[forwarder.Forwarder] + pendingCapture atomic.Pointer[forwarder.PacketCapture] + logger *nblog.Logger + flowLogger nftypes.FlowLogger blockRule firewall.Rule @@ -140,6 +141,10 @@ type Manager struct { mtu uint16 mssClampValue uint16 mssClampEnabled bool + + // Only one hook per protocol is supported. Outbound direction only. + udpHookOut atomic.Pointer[common.PacketHook] + tcpHookOut atomic.Pointer[common.PacketHook] } // decoder for packages @@ -347,6 +352,19 @@ func (m *Manager) determineRouting() error { return nil } +// SetPacketCapture sets or clears packet capture on the forwarder endpoint. +// This captures outbound response packets that bypass the FilteredDevice in netstack mode. +func (m *Manager) SetPacketCapture(pc forwarder.PacketCapture) { + if pc == nil { + m.pendingCapture.Store(nil) + } else { + m.pendingCapture.Store(&pc) + } + if fwder := m.forwarder.Load(); fwder != nil { + fwder.SetCapture(pc) + } +} + // initForwarder initializes the forwarder, it disables routing on errors func (m *Manager) initForwarder() error { if m.forwarder.Load() != nil { @@ -368,6 +386,11 @@ func (m *Manager) initForwarder() error { m.forwarder.Store(forwarder) + // Re-load after store: a concurrent SetPacketCapture may have seen forwarder as nil and only updated pendingCapture. + if pc := m.pendingCapture.Load(); pc != nil { + forwarder.SetCapture(*pc) + } + log.Debug("forwarder initialized") return nil @@ -594,6 +617,8 @@ func (m *Manager) resetState() { maps.Clear(m.incomingRules) maps.Clear(m.routeRulesMap) m.routeRules = m.routeRules[:0] + m.udpHookOut.Store(nil) + m.tcpHookOut.Store(nil) if m.udpTracker != nil { m.udpTracker.Close() @@ -608,6 +633,7 @@ func (m *Manager) resetState() { } if fwder := m.forwarder.Load(); fwder != nil { + fwder.SetCapture(nil) fwder.Stop() } @@ -713,6 +739,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool { return true } case layers.LayerTypeTCP: + if m.tcpHooksDrop(uint16(d.tcp.DstPort), dstIP, packetData) { + return true + } // Clamp MSS on all TCP SYN packets, including those from local IPs. // SNATed routed traffic may appear as local IP but still requires clamping. if m.mssClampEnabled { @@ -895,39 +924,12 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt d.dnatOrigPort = 0 } -// udpHooksDrop checks if any UDP hooks should drop the packet func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool { - m.mutex.RLock() - defer m.mutex.RUnlock() + return common.HookMatches(m.udpHookOut.Load(), dstIP, dport, packetData) +} - // Check specific destination IP first - if rules, exists := m.outgoingRules[dstIP]; exists { - for _, rule := range rules { - if rule.udpHook != nil && portsMatch(rule.dPort, dport) { - return rule.udpHook(packetData) - } - } - } - - // Check IPv4 unspecified address - if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists { - for _, rule := range rules { - if rule.udpHook != nil && portsMatch(rule.dPort, dport) { - return rule.udpHook(packetData) - } - } - } - - // Check IPv6 unspecified address - if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists { - for _, rule := range rules { - if rule.udpHook != nil && portsMatch(rule.dPort, dport) { - return rule.udpHook(packetData) - } - } - } - - return false +func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool { + return common.HookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData) } // filterInbound implements filtering logic for incoming packets. @@ -1278,12 +1280,6 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d return rule.mgmtId, rule.drop, true } case layers.LayerTypeUDP: - // if rule has UDP hook (and if we are here we match this rule) - // we ignore rule.drop and call this hook - if rule.udpHook != nil { - return rule.mgmtId, rule.udpHook(packetData), true - } - if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) { return rule.mgmtId, rule.drop, true } @@ -1342,65 +1338,14 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot return sourceMatched } -// AddUDPPacketHook calls hook when UDP packet from given direction matched -// -// Hook function returns flag which indicates should be the matched package dropped or not -func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string { - r := PeerRule{ - id: uuid.New().String(), - ip: ip, - protoLayer: layers.LayerTypeUDP, - dPort: &firewall.Port{Values: []uint16{dPort}}, - ipLayer: layers.LayerTypeIPv6, - udpHook: hook, - } - - if ip.Is4() { - r.ipLayer = layers.LayerTypeIPv4 - } - - m.mutex.Lock() - if in { - // Incoming UDP hooks are stored in allow rules map - if _, ok := m.incomingRules[r.ip]; !ok { - m.incomingRules[r.ip] = make(map[string]PeerRule) - } - m.incomingRules[r.ip][r.id] = r - } else { - if _, ok := m.outgoingRules[r.ip]; !ok { - m.outgoingRules[r.ip] = make(map[string]PeerRule) - } - m.outgoingRules[r.ip][r.id] = r - } - m.mutex.Unlock() - - return r.id +// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove. +func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) { + common.SetHook(&m.udpHookOut, ip, dPort, hook) } -// RemovePacketHook removes packet hook by given ID -func (m *Manager) RemovePacketHook(hookID string) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - // Check incoming hooks (stored in allow rules) - for _, arr := range m.incomingRules { - for _, r := range arr { - if r.id == hookID { - delete(arr, r.id) - return nil - } - } - } - // Check outgoing hooks - for _, arr := range m.outgoingRules { - for _, r := range arr { - if r.id == hookID { - delete(arr, r.id) - return nil - } - } - } - return fmt.Errorf("hook with given id not found") +// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove. +func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) { + common.SetHook(&m.tcpHookOut, ip, dPort, hook) } // SetLogLevel sets the log level for the firewall manager diff --git a/client/firewall/uspfilter/filter_test.go b/client/firewall/uspfilter/filter_test.go index 55a8e723c..5fb9fef0e 100644 --- a/client/firewall/uspfilter/filter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -12,6 +12,7 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" wgdevice "golang.zx2c4.com/wireguard/device" @@ -30,12 +31,20 @@ var logger = log.NewFromLogrus(logrus.StandardLogger()) var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger() type IFaceMock struct { + NameFunc func() string SetFilterFunc func(device.PacketFilter) error AddressFunc func() wgaddr.Address GetWGDeviceFunc func() *wgdevice.Device GetDeviceFunc func() *device.FilteredDevice } +func (i *IFaceMock) Name() string { + if i.NameFunc == nil { + return "wgtest" + } + return i.NameFunc() +} + func (i *IFaceMock) GetWGDevice() *wgdevice.Device { if i.GetWGDeviceFunc == nil { return nil @@ -186,81 +195,52 @@ func TestManagerDeleteRule(t *testing.T) { } } -func TestAddUDPPacketHook(t *testing.T) { - tests := []struct { - name string - in bool - expDir fw.RuleDirection - ip netip.Addr - dPort uint16 - hook func([]byte) bool - expectedID string - }{ - { - name: "Test Outgoing UDP Packet Hook", - in: false, - expDir: fw.RuleDirectionOUT, - ip: netip.MustParseAddr("10.168.0.1"), - dPort: 8000, - hook: func([]byte) bool { return true }, - }, - { - name: "Test Incoming UDP Packet Hook", - in: true, - expDir: fw.RuleDirectionIN, - ip: netip.MustParseAddr("::1"), - dPort: 9000, - hook: func([]byte) bool { return false }, - }, - } +func TestSetUDPPacketHook(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, nbiface.DefaultMTU) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, manager.Close(nil)) }) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - manager, err := Create(&IFaceMock{ - SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger, nbiface.DefaultMTU) - require.NoError(t, err) + var called bool + manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, func([]byte) bool { + called = true + return true + }) - manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) + h := manager.udpHookOut.Load() + require.NotNil(t, h) + assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP) + assert.Equal(t, uint16(8000), h.Port) + assert.True(t, h.Fn(nil)) + assert.True(t, called) - var addedRule PeerRule - if tt.in { - // Incoming UDP hooks are stored in allow rules map - if len(manager.incomingRules[tt.ip]) != 1 { - t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules[tt.ip])) - return - } - for _, rule := range manager.incomingRules[tt.ip] { - addedRule = rule - } - } else { - if len(manager.outgoingRules[tt.ip]) != 1 { - t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules[tt.ip])) - return - } - for _, rule := range manager.outgoingRules[tt.ip] { - addedRule = rule - } - } + manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil) + assert.Nil(t, manager.udpHookOut.Load()) +} - if tt.ip.Compare(addedRule.ip) != 0 { - t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip) - return - } - if tt.dPort != addedRule.dPort.Values[0] { - t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0]) - return - } - if layers.LayerTypeUDP != addedRule.protoLayer { - t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer) - return - } - if addedRule.udpHook == nil { - t.Errorf("expected udpHook to be set") - return - } - }) - } +func TestSetTCPPacketHook(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, nbiface.DefaultMTU) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, manager.Close(nil)) }) + + var called bool + manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, func([]byte) bool { + called = true + return true + }) + + h := manager.tcpHookOut.Load() + require.NotNil(t, h) + assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP) + assert.Equal(t, uint16(53), h.Port) + assert.True(t, h.Fn(nil)) + assert.True(t, called) + + manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil) + assert.Nil(t, manager.tcpHookOut.Load()) } // TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added @@ -530,39 +510,12 @@ func TestRemovePacketHook(t *testing.T) { require.NoError(t, manager.Close(nil)) }() - // Add a UDP packet hook - hookFunc := func(data []byte) bool { return true } - hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc) + manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, func([]byte) bool { return true }) - // Assert the hook is added by finding it in the manager's outgoing rules - found := false - for _, arr := range manager.outgoingRules { - for _, rule := range arr { - if rule.id == hookID { - found = true - break - } - } - } + require.NotNil(t, manager.udpHookOut.Load(), "hook should be registered") - if !found { - t.Fatalf("The hook was not added properly.") - } - - // Now remove the packet hook - err = manager.RemovePacketHook(hookID) - if err != nil { - t.Fatalf("Failed to remove hook: %s", err) - } - - // Assert the hook is removed by checking it in the manager's outgoing rules - for _, arr := range manager.outgoingRules { - for _, rule := range arr { - if rule.id == hookID { - t.Fatalf("The hook was not removed properly.") - } - } - } + manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, nil) + assert.Nil(t, manager.udpHookOut.Load(), "hook should be removed") } func TestProcessOutgoingHooks(t *testing.T) { @@ -592,8 +545,7 @@ func TestProcessOutgoingHooks(t *testing.T) { } hookCalled := false - hookID := manager.AddUDPPacketHook( - false, + manager.SetUDPPacketHook( netip.MustParseAddr("100.10.0.100"), 53, func([]byte) bool { @@ -601,7 +553,6 @@ func TestProcessOutgoingHooks(t *testing.T) { return true }, ) - require.NotEmpty(t, hookID) // Create test UDP packet ipv4 := &layers.IPv4{ diff --git a/client/firewall/uspfilter/forwarder/endpoint.go b/client/firewall/uspfilter/forwarder/endpoint.go index 692a24140..96ab89af8 100644 --- a/client/firewall/uspfilter/forwarder/endpoint.go +++ b/client/firewall/uspfilter/forwarder/endpoint.go @@ -12,12 +12,19 @@ import ( nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" ) +// PacketCapture captures raw packets for debugging. Implementations must be +// safe for concurrent use and must not block. +type PacketCapture interface { + Offer(data []byte, outbound bool) +} + // endpoint implements stack.LinkEndpoint and handles integration with the wireguard device type endpoint struct { logger *nblog.Logger dispatcher stack.NetworkDispatcher device *wgdevice.Device mtu atomic.Uint32 + capture atomic.Pointer[PacketCapture] } func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { @@ -54,13 +61,17 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) continue } - // Send the packet through WireGuard + pktBytes := data.AsSlice() + address := netHeader.DestinationAddress() - err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice()) - if err != nil { + if err := e.device.CreateOutboundPacket(pktBytes, address.AsSlice()); err != nil { e.logger.Error1("CreateOutboundPacket: %v", err) continue } + + if pc := e.capture.Load(); pc != nil { + (*pc).Offer(pktBytes, true) + } written++ } diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index d17c3cd5c..925273f24 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -139,6 +139,16 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow return f, nil } +// SetCapture sets or clears the packet capture on the forwarder endpoint. +// This captures outbound packets that bypass the FilteredDevice (netstack forwarding). +func (f *Forwarder) SetCapture(pc PacketCapture) { + if pc == nil { + f.endpoint.capture.Store(nil) + return + } + f.endpoint.capture.Store(&pc) +} + func (f *Forwarder) InjectIncomingPacket(payload []byte) error { if len(payload) < header.IPv4MinimumSize { return fmt.Errorf("packet too small: %d bytes", len(payload)) diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go index cb3db325d..217423901 100644 --- a/client/firewall/uspfilter/forwarder/icmp.go +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -270,5 +270,9 @@ func (f *Forwarder) injectICMPReply(id stack.TransportEndpointID, icmpPayload [] return 0 } + if pc := f.endpoint.capture.Load(); pc != nil { + (*pc).Offer(fullPacket, true) + } + return len(fullPacket) } diff --git a/client/firewall/uspfilter/hooks_filter.go b/client/firewall/uspfilter/hooks_filter.go new file mode 100644 index 000000000..8d3cc0f5c --- /dev/null +++ b/client/firewall/uspfilter/hooks_filter.go @@ -0,0 +1,90 @@ +package uspfilter + +import ( + "encoding/binary" + "net/netip" + "sync/atomic" + + "github.com/netbirdio/netbird/client/firewall/uspfilter/common" + "github.com/netbirdio/netbird/client/iface/device" +) + +const ( + ipv4HeaderMinLen = 20 + ipv4ProtoOffset = 9 + ipv4FlagsOffset = 6 + ipv4DstOffset = 16 + ipProtoUDP = 17 + ipProtoTCP = 6 + ipv4FragOffMask = 0x1fff + // dstPortOffset is the offset of the destination port within a UDP or TCP header. + dstPortOffset = 2 +) + +// HooksFilter is a minimal packet filter that only handles outbound DNS hooks. +// It is installed on the WireGuard interface when the userspace bind is active +// but a full firewall filter (Manager) is not needed because a native kernel +// firewall (nftables/iptables) handles packet filtering. +type HooksFilter struct { + udpHook atomic.Pointer[common.PacketHook] + tcpHook atomic.Pointer[common.PacketHook] +} + +var _ device.PacketFilter = (*HooksFilter)(nil) + +// FilterOutbound checks outbound packets for DNS hook matches. +// Only IPv4 packets matching the registered hook IP:port are intercepted. +// IPv6 and non-IP packets pass through unconditionally. +func (f *HooksFilter) FilterOutbound(packetData []byte, _ int) bool { + if len(packetData) < ipv4HeaderMinLen { + return false + } + + // Only process IPv4 packets, let everything else pass through. + if packetData[0]>>4 != 4 { + return false + } + + ihl := int(packetData[0]&0x0f) * 4 + if ihl < ipv4HeaderMinLen || len(packetData) < ihl+4 { + return false + } + + // Skip non-first fragments: they don't carry L4 headers. + flagsAndOffset := binary.BigEndian.Uint16(packetData[ipv4FlagsOffset : ipv4FlagsOffset+2]) + if flagsAndOffset&ipv4FragOffMask != 0 { + return false + } + + dstIP, ok := netip.AddrFromSlice(packetData[ipv4DstOffset : ipv4DstOffset+4]) + if !ok { + return false + } + + proto := packetData[ipv4ProtoOffset] + dstPort := binary.BigEndian.Uint16(packetData[ihl+dstPortOffset : ihl+dstPortOffset+2]) + + switch proto { + case ipProtoUDP: + return common.HookMatches(f.udpHook.Load(), dstIP, dstPort, packetData) + case ipProtoTCP: + return common.HookMatches(f.tcpHook.Load(), dstIP, dstPort, packetData) + default: + return false + } +} + +// FilterInbound allows all inbound packets (native firewall handles filtering). +func (f *HooksFilter) FilterInbound([]byte, int) bool { + return false +} + +// SetUDPPacketHook registers the UDP packet hook. +func (f *HooksFilter) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) { + common.SetHook(&f.udpHook, ip, dPort, hook) +} + +// SetTCPPacketHook registers the TCP packet hook. +func (f *HooksFilter) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) { + common.SetHook(&f.tcpHook, ip, dPort, hook) +} diff --git a/client/firewall/uspfilter/localip.go b/client/firewall/uspfilter/localip.go index ffc807f46..f63fe3e45 100644 --- a/client/firewall/uspfilter/localip.go +++ b/client/firewall/uspfilter/localip.go @@ -144,6 +144,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) { if err != nil { log.Warnf("failed to get interfaces: %v", err) } else { + // TODO: filter out down interfaces (net.FlagUp). Also handle the reverse + // case where an interface comes up between refreshes. for _, intf := range interfaces { m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses) } diff --git a/client/firewall/uspfilter/nat.go b/client/firewall/uspfilter/nat.go index 597f892cf..8ed32eb5e 100644 --- a/client/firewall/uspfilter/nat.go +++ b/client/firewall/uspfilter/nat.go @@ -421,6 +421,7 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye } // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +// TODO: also delegate to nativeFirewall when available for kernel WG mode func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { var layerType gopacket.LayerType switch protocol { @@ -466,6 +467,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort) } +// AddOutputDNAT delegates to the native firewall if available. +func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + if m.nativeFirewall == nil { + return fmt.Errorf("output DNAT not supported without native firewall") + } + return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort) +} + +// RemoveOutputDNAT delegates to the native firewall if available. +func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + if m.nativeFirewall == nil { + return nil + } + return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort) +} + // translateInboundPortDNAT applies port-specific DNAT translation to inbound packets. func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool { if !m.portDNATEnabled.Load() { diff --git a/client/firewall/uspfilter/rule.go b/client/firewall/uspfilter/rule.go index dbe3a7858..08d68a78e 100644 --- a/client/firewall/uspfilter/rule.go +++ b/client/firewall/uspfilter/rule.go @@ -18,9 +18,7 @@ type PeerRule struct { protoLayer gopacket.LayerType sPort *firewall.Port dPort *firewall.Port - drop bool - - udpHook func([]byte) bool + drop bool } // ID returns the rule id diff --git a/client/firewall/uspfilter/tracer_test.go b/client/firewall/uspfilter/tracer_test.go index d9f9f1aa8..657f96fc0 100644 --- a/client/firewall/uspfilter/tracer_test.go +++ b/client/firewall/uspfilter/tracer_test.go @@ -399,21 +399,17 @@ func TestTracePacket(t *testing.T) { { name: "UDPTraffic_WithHook", setup: func(m *Manager) { - hookFunc := func([]byte) bool { - return true - } - m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc) + m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool { + return true // drop (intercepted by hook) + }) }, packetBuilder: func() *PacketBuilder { - return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN) + return createPacketBuilder("100.10.0.100", "100.10.255.254", "udp", 12345, 53, fw.RuleDirectionOUT) }, expectedStages: []PacketStage{ StageReceived, - StageInboundPortDNAT, - StageInbound1to1NAT, - StageConntrack, - StageRouting, - StagePeerACL, + StageOutbound1to1NAT, + StageOutboundPortReverse, StageCompleted, }, expectedAllow: false, diff --git a/client/iface/bind/ice_bind_test.go b/client/iface/bind/ice_bind_test.go index 1fdd955c9..f49e68508 100644 --- a/client/iface/bind/ice_bind_test.go +++ b/client/iface/bind/ice_bind_test.go @@ -239,8 +239,12 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) { ipv6Count++ } - assert.Equal(t, packetsPerFamily, ipv4Count) - assert.Equal(t, packetsPerFamily, ipv6Count) + // Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The + // routing-correctness checks above are the real assertions; the counts + // are a sanity bound to catch a totally silent path. + minDelivered := packetsPerFamily * 80 / 100 + assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold") + assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold") } func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) { diff --git a/client/iface/device/device_filter.go b/client/iface/device/device_filter.go index 708f38d26..fc1c65efa 100644 --- a/client/iface/device/device_filter.go +++ b/client/iface/device/device_filter.go @@ -3,6 +3,7 @@ package device import ( "net/netip" "sync" + "sync/atomic" "golang.zx2c4.com/wireguard/tun" ) @@ -15,14 +16,25 @@ type PacketFilter interface { // FilterInbound filter incoming packets from external sources to host FilterInbound(packetData []byte, size int) bool - // AddUDPPacketHook calls hook when UDP packet from given direction matched - // - // Hook function returns flag which indicates should be the matched package dropped or not. - // Hook function receives raw network packet data as argument. - AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string + // SetUDPPacketHook registers a hook for outbound UDP packets matching the given IP and port. + // Hook function returns true if the packet should be dropped. + // Only one UDP hook is supported; calling again replaces the previous hook. + // Pass nil hook to remove. + SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) - // RemovePacketHook removes hook by ID - RemovePacketHook(hookID string) error + // SetTCPPacketHook registers a hook for outbound TCP packets matching the given IP and port. + // Hook function returns true if the packet should be dropped. + // Only one TCP hook is supported; calling again replaces the previous hook. + // Pass nil hook to remove. + SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) +} + +// PacketCapture captures raw packets for debugging. Implementations must be +// safe for concurrent use and must not block. +type PacketCapture interface { + // Offer submits a packet for capture. outbound is true for packets + // leaving the host (Read path), false for packets arriving (Write path). + Offer(data []byte, outbound bool) } // FilteredDevice to override Read or Write of packets @@ -30,6 +42,7 @@ type FilteredDevice struct { tun.Device filter PacketFilter + capture atomic.Pointer[PacketCapture] mutex sync.RWMutex closeOnce sync.Once } @@ -60,20 +73,25 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er if n, err = d.Device.Read(bufs, sizes, offset); err != nil { return 0, err } + d.mutex.RLock() filter := d.filter d.mutex.RUnlock() - if filter == nil { - return + if filter != nil { + for i := 0; i < n; i++ { + if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) { + bufs = append(bufs[:i], bufs[i+1:]...) + sizes = append(sizes[:i], sizes[i+1:]...) + n-- + i-- + } + } } - for i := 0; i < n; i++ { - if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) { - bufs = append(bufs[:i], bufs[i+1:]...) - sizes = append(sizes[:i], sizes[i+1:]...) - n-- - i-- + if pc := d.capture.Load(); pc != nil { + for i := 0; i < n; i++ { + (*pc).Offer(bufs[i][offset:offset+sizes[i]], true) } } @@ -82,6 +100,13 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er // Write wraps write method with filtering feature func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) { + // Capture before filtering so dropped packets are still visible in captures. + if pc := d.capture.Load(); pc != nil { + for _, buf := range bufs { + (*pc).Offer(buf[offset:], false) + } + } + d.mutex.RLock() filter := d.filter d.mutex.RUnlock() @@ -93,9 +118,10 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) { filteredBufs := make([][]byte, 0, len(bufs)) dropped := 0 for _, buf := range bufs { - if !filter.FilterInbound(buf[offset:], len(buf)) { - filteredBufs = append(filteredBufs, buf) + if filter.FilterInbound(buf[offset:], len(buf)) { dropped++ + } else { + filteredBufs = append(filteredBufs, buf) } } @@ -110,3 +136,14 @@ func (d *FilteredDevice) SetFilter(filter PacketFilter) { d.filter = filter d.mutex.Unlock() } + +// SetCapture sets or clears the packet capture sink. Pass nil to disable. +// Uses atomic store so the hot path (Read/Write) is a single pointer load +// with no locking overhead when capture is off. +func (d *FilteredDevice) SetCapture(pc PacketCapture) { + if pc == nil { + d.capture.Store(nil) + return + } + d.capture.Store(&pc) +} diff --git a/client/iface/device/device_filter_test.go b/client/iface/device/device_filter_test.go index eef783542..8fb16ca8d 100644 --- a/client/iface/device/device_filter_test.go +++ b/client/iface/device/device_filter_test.go @@ -158,7 +158,7 @@ func TestDeviceWrapperRead(t *testing.T) { t.Errorf("unexpected error: %v", err) return } - if n != 0 { + if n != 1 { t.Errorf("expected n=1, got %d", n) return } diff --git a/client/iface/iface.go b/client/iface/iface.go index 9b331d68c..655dd1682 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -217,7 +217,6 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error // Close closes the tunnel interface func (w *WGIface) Close() error { w.mu.Lock() - defer w.mu.Unlock() var result *multierror.Error @@ -225,7 +224,15 @@ func (w *WGIface) Close() error { result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err)) } - if err := w.tun.Close(); err != nil { + // Release w.mu before calling w.tun.Close(): the underlying + // wireguard-go device.Close() waits for its send/receive goroutines + // to drain. Some of those goroutines re-enter WGIface methods that + // take w.mu (e.g. the packet filter DNS hook calls GetDevice()), so + // holding the mutex here would deadlock the shutdown path. + tun := w.tun + w.mu.Unlock() + + if err := tun.Close(); err != nil { result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err)) } diff --git a/client/iface/iface_close_test.go b/client/iface/iface_close_test.go new file mode 100644 index 000000000..171e15d0a --- /dev/null +++ b/client/iface/iface_close_test.go @@ -0,0 +1,113 @@ +//go:build !android + +package iface + +import ( + "errors" + "sync" + "testing" + "time" + + wgdevice "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" + + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +// fakeTunDevice implements WGTunDevice and lets the test control when +// Close() returns. It mimics the wireguard-go shutdown path, which blocks +// until its goroutines drain. Some of those goroutines (e.g. the packet +// filter DNS hook in client/internal/dns) call back into WGIface, so if +// WGIface.Close() held w.mu across tun.Close() the shutdown would +// deadlock. +type fakeTunDevice struct { + closeStarted chan struct{} + unblockClose chan struct{} +} + +func (f *fakeTunDevice) Create() (device.WGConfigurer, error) { + return nil, errors.New("not implemented") +} +func (f *fakeTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { + return nil, errors.New("not implemented") +} +func (f *fakeTunDevice) UpdateAddr(wgaddr.Address) error { return nil } +func (f *fakeTunDevice) WgAddress() wgaddr.Address { return wgaddr.Address{} } +func (f *fakeTunDevice) MTU() uint16 { return DefaultMTU } +func (f *fakeTunDevice) DeviceName() string { return "nb-close-test" } +func (f *fakeTunDevice) FilteredDevice() *device.FilteredDevice { return nil } +func (f *fakeTunDevice) Device() *wgdevice.Device { return nil } +func (f *fakeTunDevice) GetNet() *netstack.Net { return nil } +func (f *fakeTunDevice) GetICEBind() device.EndpointManager { return nil } + +func (f *fakeTunDevice) Close() error { + close(f.closeStarted) + <-f.unblockClose + return nil +} + +type fakeProxyFactory struct{} + +func (fakeProxyFactory) GetProxy() wgproxy.Proxy { return nil } +func (fakeProxyFactory) GetProxyPort() uint16 { return 0 } +func (fakeProxyFactory) Free() error { return nil } + +// TestWGIface_CloseReleasesMutexBeforeTunClose guards against a deadlock +// that surfaces as a macOS test-timeout in +// TestDNSPermanent_updateUpstream: WGIface.Close() used to hold w.mu +// while waiting for the wireguard-go device goroutines to finish, and +// one of those goroutines (the DNS filter hook) calls back into +// WGIface.GetDevice() which needs the same mutex. The fix is to drop +// the lock before tun.Close() returns control. +func TestWGIface_CloseReleasesMutexBeforeTunClose(t *testing.T) { + tun := &fakeTunDevice{ + closeStarted: make(chan struct{}), + unblockClose: make(chan struct{}), + } + w := &WGIface{ + tun: tun, + wgProxyFactory: fakeProxyFactory{}, + } + + closeDone := make(chan error, 1) + go func() { + closeDone <- w.Close() + }() + + select { + case <-tun.closeStarted: + case <-time.After(2 * time.Second): + close(tun.unblockClose) + t.Fatal("tun.Close() was never invoked") + } + + // Simulate the WireGuard read goroutine calling back into WGIface + // via the packet filter's DNS hook. If Close() still held w.mu + // during tun.Close(), this would block until the test timeout. + getDeviceDone := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + _ = w.GetDevice() + close(getDeviceDone) + }() + + select { + case <-getDeviceDone: + case <-time.After(2 * time.Second): + close(tun.unblockClose) + wg.Wait() + t.Fatal("GetDevice() deadlocked while WGIface.Close was closing the tun") + } + + close(tun.unblockClose) + select { + case <-closeDone: + case <-time.After(2 * time.Second): + t.Fatal("WGIface.Close() never returned after the tun was unblocked") + } +} diff --git a/client/iface/mocks/filter.go b/client/iface/mocks/filter.go index 566068aa5..5ae98039c 100644 --- a/client/iface/mocks/filter.go +++ b/client/iface/mocks/filter.go @@ -34,18 +34,28 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder { return m.recorder } -// AddUDPPacketHook mocks base method. -func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string { +// SetUDPPacketHook mocks base method. +func (m *MockPacketFilter) SetUDPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(string) - return ret0 + m.ctrl.Call(m, "SetUDPPacketHook", arg0, arg1, arg2) } -// AddUDPPacketHook indicates an expected call of AddUDPPacketHook. -func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +// SetUDPPacketHook indicates an expected call of SetUDPPacketHook. +func (mr *MockPacketFilterMockRecorder) SetUDPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetUDPPacketHook), arg0, arg1, arg2) +} + +// SetTCPPacketHook mocks base method. +func (m *MockPacketFilter) SetTCPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetTCPPacketHook", arg0, arg1, arg2) +} + +// SetTCPPacketHook indicates an expected call of SetTCPPacketHook. +func (mr *MockPacketFilterMockRecorder) SetTCPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTCPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetTCPPacketHook), arg0, arg1, arg2) } // FilterInbound mocks base method. @@ -75,17 +85,3 @@ func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 an mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1) } - -// RemovePacketHook mocks base method. -func (m *MockPacketFilter) RemovePacketHook(arg0 string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RemovePacketHook", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// RemovePacketHook indicates an expected call of RemovePacketHook. -func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0) -} diff --git a/client/iface/mocks/iface/mocks/filter.go b/client/iface/mocks/iface/mocks/filter.go deleted file mode 100644 index 291ab9ab5..000000000 --- a/client/iface/mocks/iface/mocks/filter.go +++ /dev/null @@ -1,87 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter) - -// Package mocks is a generated GoMock package. -package mocks - -import ( - net "net" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" -) - -// MockPacketFilter is a mock of PacketFilter interface. -type MockPacketFilter struct { - ctrl *gomock.Controller - recorder *MockPacketFilterMockRecorder -} - -// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter. -type MockPacketFilterMockRecorder struct { - mock *MockPacketFilter -} - -// NewMockPacketFilter creates a new mock instance. -func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter { - mock := &MockPacketFilter{ctrl: ctrl} - mock.recorder = &MockPacketFilterMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder { - return m.recorder -} - -// AddUDPPacketHook mocks base method. -func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func(*net.UDPAddr, []byte) bool) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3) -} - -// AddUDPPacketHook indicates an expected call of AddUDPPacketHook. -func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3) -} - -// FilterInbound mocks base method. -func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "FilterInbound", arg0) - ret0, _ := ret[0].(bool) - return ret0 -} - -// FilterInbound indicates an expected call of FilterInbound. -func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0) -} - -// FilterOutbound mocks base method. -func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "FilterOutbound", arg0) - ret0, _ := ret[0].(bool) - return ret0 -} - -// FilterOutbound indicates an expected call of FilterOutbound. -func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0) -} - -// SetNetwork mocks base method. -func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetNetwork", arg0) -} - -// SetNetwork indicates an expected call of SetNetwork. -func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0) -} diff --git a/client/iface/udpmux/universal.go b/client/iface/udpmux/universal.go index 43bfedaaa..89a7eefb9 100644 --- a/client/iface/udpmux/universal.go +++ b/client/iface/udpmux/universal.go @@ -171,7 +171,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error { } if u.address.Network.Contains(a) { - log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address) + log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address) return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address) } @@ -181,7 +181,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error { u.addrCache.Store(addr.String(), isRouted) if isRouted { // Extra log, as the error only shows up with ICE logging enabled - log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix) + log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix) return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix) } } diff --git a/client/installer.nsis b/client/installer.nsis index 96d60a785..63bff1c5b 100644 --- a/client/installer.nsis +++ b/client/installer.nsis @@ -201,7 +201,18 @@ Pop $0 Function .onInit StrCpy $INSTDIR "${INSTALL_DIR}" +; Default autostart to enabled so silent installs (/S) match the interactive default +StrCpy $AutostartEnabled "1" + +; Pre-0.70.1 installers ran without SetRegView, so their uninstall keys live +; in the 32-bit view. Fall back to it so upgrades still find them. +SetRegView 64 ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString" +${If} $R0 == "" + SetRegView 32 + ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString" + SetRegView 64 +${EndIf} ${If} $R0 != "" # if silent install jump to uninstall step IfSilent uninstall @@ -214,6 +225,10 @@ ${If} $R0 != "" ${EndIf} FunctionEnd + +Function un.onInit +SetRegView 64 +FunctionEnd ###################################################################### Section -MainProgram ${INSTALL_TYPE} @@ -228,6 +243,7 @@ Section -MainProgram !else File /r "..\\dist\\netbird_windows_amd64\\" !endif + File "..\\client\\ui\\assets\\netbird.png" SectionEnd ###################################################################### @@ -247,9 +263,11 @@ WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}" ; Create autostart registry entry based on checkbox DetailPrint "Autostart enabled: $AutostartEnabled" ${If} $AutostartEnabled == "1" - WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe" + WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"' DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe" ${Else} + DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" + ; Legacy: pre-HKLM installs wrote to HKCU; clean that up too. DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" DetailPrint "Autostart not enabled by user" ${EndIf} @@ -283,6 +301,8 @@ ExecWait `taskkill /im ${UI_APP_EXE}.exe /f` ; Remove autostart registry entry DetailPrint "Removing autostart registry entry if exists..." +DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" +; Legacy: pre-HKLM installs wrote to HKCU; clean that up too. DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" ; Handle data deletion based on checkbox @@ -321,6 +341,7 @@ DetailPrint "Removing registry keys..." DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}" DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}" DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}" +DeleteRegKey HKCU "Software\Classes\AppUserModelId\${APP_NAME}" DetailPrint "Removing application directory from PATH..." EnVar::SetHKLM diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index bd7adfaef..408ed992f 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -19,6 +19,9 @@ import ( var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger() func TestDefaultManager(t *testing.T) { + t.Setenv("NB_WG_KERNEL_DISABLED", "true") + t.Setenv(firewall.EnvForceUserspaceFirewall, "true") + networkMap := &mgmProto.NetworkMap{ FirewallRules: []*mgmProto.FirewallRule{ { @@ -135,6 +138,7 @@ func TestDefaultManager(t *testing.T) { func TestDefaultManagerStateless(t *testing.T) { // stateless currently only in userspace, so we have to disable kernel t.Setenv("NB_WG_KERNEL_DISABLED", "true") + t.Setenv(firewall.EnvForceUserspaceFirewall, "true") t.Setenv("NB_DISABLE_CONNTRACK", "true") networkMap := &mgmProto.NetworkMap{ @@ -194,6 +198,7 @@ func TestDefaultManagerStateless(t *testing.T) { // This tests the full ACL manager -> uspfilter integration. func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) { t.Setenv("NB_WG_KERNEL_DISABLED", "true") + t.Setenv(firewall.EnvForceUserspaceFirewall, "true") networkMap := &mgmProto.NetworkMap{ FirewallRules: []*mgmProto.FirewallRule{ @@ -258,6 +263,7 @@ func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) { // up when they're removed from the network map in a subsequent update. func TestDenyRulesCleanedUpOnRemoval(t *testing.T) { t.Setenv("NB_WG_KERNEL_DISABLED", "true") + t.Setenv(firewall.EnvForceUserspaceFirewall, "true") ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -339,6 +345,7 @@ func TestDenyRulesCleanedUpOnRemoval(t *testing.T) { // one added without leaking. func TestRuleUpdateChangingAction(t *testing.T) { t.Setenv("NB_WG_KERNEL_DISABLED", "true") + t.Setenv(firewall.EnvForceUserspaceFirewall, "true") ctrl := gomock.NewController(t) defer ctrl.Finish() diff --git a/client/internal/auth/auth.go b/client/internal/auth/auth.go index 44e98bede..bdfd07430 100644 --- a/client/internal/auth/auth.go +++ b/client/internal/auth/auth.go @@ -155,7 +155,7 @@ func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) { var needsLogin bool err = a.withRetry(ctx, func(client *mgm.GrpcClient) error { - _, _, err := a.doMgmLogin(client, ctx, pubSSHKey) + err := a.doMgmLogin(client, ctx, pubSSHKey) if isLoginNeeded(err) { needsLogin = true return nil @@ -179,8 +179,8 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err var isAuthError bool err = a.withRetry(ctx, func(client *mgm.GrpcClient) error { - serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey) - if serverKey != nil && isRegistrationNeeded(err) { + err := a.doMgmLogin(client, ctx, pubSSHKey) + if isRegistrationNeeded(err) { log.Debugf("peer registration required") _, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey) if err != nil { @@ -201,13 +201,7 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err // getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) { - serverKey, err := client.GetServerPublicKey() - if err != nil { - log.Errorf("failed while getting Management Service public key: %v", err) - return nil, err - } - - protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey) + protoFlow, err := client.GetPKCEAuthorizationFlow() if err != nil { if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { log.Warnf("server couldn't find pkce flow, contact admin: %v", err) @@ -221,7 +215,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro config := &PKCEAuthProviderConfig{ Audience: protoConfig.GetAudience(), ClientID: protoConfig.GetClientID(), - ClientSecret: protoConfig.GetClientSecret(), + ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck TokenEndpoint: protoConfig.GetTokenEndpoint(), AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(), Scope: protoConfig.GetScope(), @@ -246,13 +240,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro // getDeviceFlow retrieves device authorization flow configuration and creates a flow instance func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) { - serverKey, err := client.GetServerPublicKey() - if err != nil { - log.Errorf("failed while getting Management Service public key: %v", err) - return nil, err - } - - protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey) + protoFlow, err := client.GetDeviceAuthorizationFlow() if err != nil { if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { log.Warnf("server couldn't find device flow, contact admin: %v", err) @@ -266,7 +254,7 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, config := &DeviceAuthProviderConfig{ Audience: protoConfig.GetAudience(), ClientID: protoConfig.GetClientID(), - ClientSecret: protoConfig.GetClientSecret(), + ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck Domain: protoConfig.Domain, TokenEndpoint: protoConfig.GetTokenEndpoint(), DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(), @@ -292,28 +280,16 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, } // doMgmLogin performs the actual login operation with the management service -func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) { - serverKey, err := client.GetServerPublicKey() - if err != nil { - log.Errorf("failed while getting Management Service public key: %v", err) - return nil, nil, err - } - +func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) error { sysInfo := system.GetInfo(ctx) a.setSystemInfoFlags(sysInfo) - loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels) - return serverKey, loginResp, err + _, err := client.Login(sysInfo, pubSSHKey, a.config.DNSLabels) + return err } // registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key. // Otherwise tries to register with the provided setupKey via command line. func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) { - serverPublicKey, err := client.GetServerPublicKey() - if err != nil { - log.Errorf("failed while getting Management Service public key: %v", err) - return nil, err - } - validSetupKey, err := uuid.Parse(setupKey) if err != nil && jwtToken == "" { return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err) @@ -322,7 +298,7 @@ func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKe log.Debugf("sending peer registration request to Management Service") info := system.GetInfo(ctx) a.setSystemInfoFlags(info) - loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels) + loginResp, err := client.Register(validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels) if err != nil { log.Errorf("failed registering peer %v", err) return nil, err diff --git a/client/internal/connect.go b/client/internal/connect.go index 242b25b44..72e096a80 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -44,6 +44,10 @@ import ( "github.com/netbirdio/netbird/version" ) +// androidRunOverride is set on Android to inject mobile dependencies +// when using embed.Client (which calls Run() with empty MobileDependency). +var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath string) error + type ConnectClient struct { ctx context.Context config *profilemanager.Config @@ -76,6 +80,9 @@ func (c *ConnectClient) SetUpdateManager(um *updater.Manager) { // Run with main logic. func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error { + if androidRunOverride != nil { + return androidRunOverride(c, runningChan, logPath) + } return c.run(MobileDependency{}, runningChan, logPath) } @@ -87,6 +94,7 @@ func (c *ConnectClient) RunOnAndroid( dnsAddresses []netip.AddrPort, dnsReadyListener dns.ReadyListener, stateFilePath string, + cacheDir string, ) error { // in case of non Android os these variables will be nil mobileDependency := MobileDependency{ @@ -96,6 +104,7 @@ func (c *ConnectClient) RunOnAndroid( HostDNSAddresses: dnsAddresses, DnsReadyListener: dnsReadyListener, StateFilePath: stateFilePath, + TempDir: cacheDir, } return c.run(mobileDependency, nil, "") } @@ -104,6 +113,7 @@ func (c *ConnectClient) RunOniOS( fileDescriptor int32, networkChangeListener listener.NetworkChangeListener, dnsManager dns.IosDnsManager, + dnsAddresses []netip.AddrPort, stateFilePath string, ) error { // Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension. @@ -113,6 +123,7 @@ func (c *ConnectClient) RunOniOS( FileDescriptor: fileDescriptor, NetworkChangeListener: networkChangeListener, DnsManager: dnsManager, + HostDNSAddresses: dnsAddresses, StateFilePath: stateFilePath, } return c.run(mobileDependency, nil, "") @@ -322,6 +333,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan c.statusRecorder.MarkSignalConnected() relayURLs, token := parseRelayInfo(loginResp) + if override, ok := peer.OverrideRelayURLs(); ok { + log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override) + relayURLs = override + } peerConfig := loginResp.GetPeerConfig() engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath) @@ -329,6 +344,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan log.Error(err) return wrapErr(err) } + engineConfig.TempDir = mobileDependency.TempDir relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU) c.statusRecorder.SetRelayMgr(relayManager) @@ -610,12 +626,6 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP // loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc) func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) { - - serverPublicKey, err := client.GetServerPublicKey() - if err != nil { - return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err) - } - sysInfo := system.GetInfo(ctx) sysInfo.SetFlags( config.RosenpassEnabled, @@ -634,12 +644,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config.EnableSSHRemotePortForwarding, config.DisableSSHAuth, ) - loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels) - if err != nil { - return nil, err - } - - return loginResp, nil + return client.Login(sysInfo, pubSSHKey, config.DNSLabels) } func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier { diff --git a/client/internal/connect_android_default.go b/client/internal/connect_android_default.go new file mode 100644 index 000000000..190341c4a --- /dev/null +++ b/client/internal/connect_android_default.go @@ -0,0 +1,73 @@ +//go:build android + +package internal + +import ( + "net/netip" + + "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/listener" + "github.com/netbirdio/netbird/client/internal/stdnet" +) + +// noopIFaceDiscover is a stub ExternalIFaceDiscover for embed.Client on Android. +// It returns an empty interface list, which means ICE P2P candidates won't be +// discovered — connections will fall back to relay. Applications that need P2P +// should provide a real implementation via runOnAndroidEmbed that uses +// Android's ConnectivityManager to enumerate network interfaces. +type noopIFaceDiscover struct{} + +func (noopIFaceDiscover) IFaces() (string, error) { + // Return empty JSON array — no local interfaces advertised for ICE. + // This is intentional: without Android's ConnectivityManager, we cannot + // reliably enumerate interfaces (netlink is restricted on Android 11+). + // Relay connections still work; only P2P hole-punching is disabled. + return "[]", nil +} + +// noopNetworkChangeListener is a stub for embed.Client on Android. +// Network change events are ignored since the embed client manages its own +// reconnection logic via the engine's built-in retry mechanism. +type noopNetworkChangeListener struct{} + +func (noopNetworkChangeListener) OnNetworkChanged(string) { + // No-op: embed.Client relies on the engine's internal reconnection + // logic rather than OS-level network change notifications. +} + +func (noopNetworkChangeListener) SetInterfaceIP(string) { + // No-op: in netstack mode, the overlay IP is managed by the userspace + // network stack, not by OS-level interface configuration. +} + +// noopDnsReadyListener is a stub for embed.Client on Android. +// DNS readiness notifications are not needed in netstack/embed mode +// since system DNS is disabled and DNS resolution happens externally. +type noopDnsReadyListener struct{} + +func (noopDnsReadyListener) OnReady() { + // No-op: embed.Client does not need DNS readiness notifications. + // System DNS is disabled in netstack mode. +} + +var _ stdnet.ExternalIFaceDiscover = noopIFaceDiscover{} +var _ listener.NetworkChangeListener = noopNetworkChangeListener{} +var _ dns.ReadyListener = noopDnsReadyListener{} + +func init() { + // Wire up the default override so embed.Client.Start() works on Android + // with netstack mode. Provides complete no-op stubs for all mobile + // dependencies so the engine's existing Android code paths work unchanged. + // Applications that need P2P ICE or real DNS should replace this by + // setting androidRunOverride before calling Start(). + androidRunOverride = func(c *ConnectClient, runningChan chan struct{}, logPath string) error { + return c.runOnAndroidEmbed( + noopIFaceDiscover{}, + noopNetworkChangeListener{}, + []netip.AddrPort{}, + noopDnsReadyListener{}, + runningChan, + logPath, + ) + } +} diff --git a/client/internal/connect_android_embed.go b/client/internal/connect_android_embed.go new file mode 100644 index 000000000..18f72e841 --- /dev/null +++ b/client/internal/connect_android_embed.go @@ -0,0 +1,32 @@ +//go:build android + +package internal + +import ( + "net/netip" + + "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/listener" + "github.com/netbirdio/netbird/client/internal/stdnet" +) + +// runOnAndroidEmbed is like RunOnAndroid but accepts a runningChan +// so embed.Client.Start() can detect when the engine is ready. +// It provides complete MobileDependency so the engine's existing +// Android code paths work unchanged. +func (c *ConnectClient) runOnAndroidEmbed( + iFaceDiscover stdnet.ExternalIFaceDiscover, + networkChangeListener listener.NetworkChangeListener, + dnsAddresses []netip.AddrPort, + dnsReadyListener dns.ReadyListener, + runningChan chan struct{}, + logPath string, +) error { + mobileDependency := MobileDependency{ + IFaceDiscover: iFaceDiscover, + NetworkChangeListener: networkChangeListener, + HostDNSAddresses: dnsAddresses, + DnsReadyListener: dnsReadyListener, + } + return c.run(mobileDependency, runningChan, logPath) +} diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index c9ebf25e5..90560d028 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -16,7 +16,6 @@ import ( "path/filepath" "runtime" "runtime/pprof" - "slices" "sort" "strings" "time" @@ -25,12 +24,12 @@ import ( "google.golang.org/protobuf/encoding/protojson" "github.com/netbirdio/netbird/client/anonymize" + "github.com/netbirdio/netbird/client/configs" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/updater/installer" nbstatus "github.com/netbirdio/netbird/client/status" mgmProto "github.com/netbirdio/netbird/shared/management/proto" - "github.com/netbirdio/netbird/util" ) const readmeContent = `Netbird debug bundle @@ -52,6 +51,7 @@ resolved_domains.txt: Anonymized resolved domain IP addresses from the status re config.txt: Anonymized configuration information of the NetBird client. network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules. state.json: Anonymized client state dump containing netbird states for the active profile. +service_params.json: Sanitized service install parameters (service.json). Sensitive environment variable values are masked. Only present when service.json exists. metrics.txt: Buffered client metrics in InfluxDB line protocol format. Only present when metrics collection is enabled. Peer identifiers are anonymized. mutex.prof: Mutex profiling information. goroutine.prof: Goroutine profiling information. @@ -61,6 +61,7 @@ allocs.prof: Allocations profiling information. threadcreate.prof: Thread creation profiling information. cpu.prof: CPU profiling information. stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation. +capture.pcap: Packet capture in pcap format. Only present when capture was running during bundle collection. Omitted from anonymized bundles because it contains raw decrypted packet data. Anonymization Process @@ -232,7 +233,9 @@ type BundleGenerator struct { statusRecorder *peer.Status syncResponse *mgmProto.SyncResponse logPath string + tempDir string cpuProfile []byte + capturePath string refreshStatus func() // Optional callback to refresh status before bundle generation clientMetrics MetricsExporter @@ -254,8 +257,10 @@ type GeneratorDependencies struct { StatusRecorder *peer.Status SyncResponse *mgmProto.SyncResponse LogPath string + TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used. CPUProfile []byte - RefreshStatus func() // Optional callback to refresh status before bundle generation + CapturePath string + RefreshStatus func() ClientMetrics MetricsExporter } @@ -273,7 +278,9 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen statusRecorder: deps.StatusRecorder, syncResponse: deps.SyncResponse, logPath: deps.LogPath, + tempDir: deps.TempDir, cpuProfile: deps.CPUProfile, + capturePath: deps.CapturePath, refreshStatus: deps.RefreshStatus, clientMetrics: deps.ClientMetrics, @@ -285,7 +292,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen // Generate creates a debug bundle and returns the location. func (g *BundleGenerator) Generate() (resp string, err error) { - bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip") + bundlePath, err := os.CreateTemp(g.tempDir, "netbird.debug.*.zip") if err != nil { return "", fmt.Errorf("create zip file: %w", err) } @@ -343,6 +350,10 @@ func (g *BundleGenerator) createArchive() error { log.Errorf("failed to add CPU profile to debug bundle: %v", err) } + if err := g.addCaptureFile(); err != nil { + log.Errorf("failed to add capture file to debug bundle: %v", err) + } + if err := g.addStackTrace(); err != nil { log.Errorf("failed to add stack trace to debug bundle: %v", err) } @@ -359,6 +370,10 @@ func (g *BundleGenerator) createArchive() error { log.Errorf("failed to add corrupted state files to debug bundle: %v", err) } + if err := g.addServiceParams(); err != nil { + log.Errorf("failed to add service params to debug bundle: %v", err) + } + if err := g.addMetrics(); err != nil { log.Errorf("failed to add metrics to debug bundle: %v", err) } @@ -367,15 +382,8 @@ func (g *BundleGenerator) createArchive() error { log.Errorf("failed to add wg show output: %v", err) } - if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) { - if err := g.addLogfile(); err != nil { - log.Errorf("failed to add log file to debug bundle: %v", err) - if err := g.trySystemdLogFallback(); err != nil { - log.Errorf("failed to add systemd logs as fallback: %v", err) - } - } - } else if err := g.trySystemdLogFallback(); err != nil { - log.Errorf("failed to add systemd logs: %v", err) + if err := g.addPlatformLog(); err != nil { + log.Errorf("failed to add logs to debug bundle: %v", err) } if err := g.addUpdateLogs(); err != nil { @@ -488,6 +496,90 @@ func (g *BundleGenerator) addConfig() error { return nil } +const ( + serviceParamsFile = "service.json" + serviceParamsBundle = "service_params.json" + maskedValue = "***" + envVarPrefix = "NB_" + jsonKeyManagementURL = "management_url" + jsonKeyServiceEnv = "service_env_vars" +) + +var sensitiveEnvSubstrings = []string{"key", "token", "secret", "password", "credential"} + +// addServiceParams reads the service.json file and adds a sanitized version to the bundle. +// Non-NB_ env vars and vars with sensitive names are masked. Other NB_ values are anonymized. +func (g *BundleGenerator) addServiceParams() error { + path := filepath.Join(configs.StateDir, serviceParamsFile) + + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("read service params: %w", err) + } + + var params map[string]any + if err := json.Unmarshal(data, ¶ms); err != nil { + return fmt.Errorf("parse service params: %w", err) + } + + if g.anonymize { + if mgmtURL, ok := params[jsonKeyManagementURL].(string); ok && mgmtURL != "" { + params[jsonKeyManagementURL] = g.anonymizer.AnonymizeURI(mgmtURL) + } + } + + g.sanitizeServiceEnvVars(params) + + sanitizedData, err := json.MarshalIndent(params, "", " ") + if err != nil { + return fmt.Errorf("marshal sanitized service params: %w", err) + } + + if err := g.addFileToZip(bytes.NewReader(sanitizedData), serviceParamsBundle); err != nil { + return fmt.Errorf("add service params to zip: %w", err) + } + + return nil +} + +// sanitizeServiceEnvVars masks or anonymizes env var values in service params. +// Non-NB_ vars and vars with sensitive names (key, token, etc.) are fully masked. +// Other NB_ var values are passed through the anonymizer when anonymization is enabled. +func (g *BundleGenerator) sanitizeServiceEnvVars(params map[string]any) { + envVars, ok := params[jsonKeyServiceEnv].(map[string]any) + if !ok { + return + } + + sanitized := make(map[string]any, len(envVars)) + for k, v := range envVars { + val, _ := v.(string) + switch { + case !strings.HasPrefix(k, envVarPrefix) || isSensitiveEnvVar(k): + sanitized[k] = maskedValue + case g.anonymize: + sanitized[k] = g.anonymizer.AnonymizeString(val) + default: + sanitized[k] = val + } + } + params[jsonKeyServiceEnv] = sanitized +} + +// isSensitiveEnvVar returns true for env var names that may contain secrets. +func isSensitiveEnvVar(key string) bool { + lower := strings.ToLower(key) + for _, s := range sensitiveEnvSubstrings { + if strings.Contains(lower, s) { + return true + } + } + return false +} + func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) { configContent.WriteString("NetBird Client Configuration:\n\n") @@ -585,6 +677,29 @@ func (g *BundleGenerator) addCPUProfile() error { return nil } +func (g *BundleGenerator) addCaptureFile() error { + if g.capturePath == "" { + return nil + } + + if g.anonymize { + log.Info("skipping capture file in anonymized bundle (contains raw packet data)") + return nil + } + + f, err := os.Open(g.capturePath) + if err != nil { + return fmt.Errorf("open capture file: %w", err) + } + defer f.Close() + + if err := g.addFileToZip(f, "capture.pcap"); err != nil { + return fmt.Errorf("add capture file to zip: %w", err) + } + + return nil +} + func (g *BundleGenerator) addStackTrace() error { buf := make([]byte, 5242880) // 5 MB buffer n := runtime.Stack(buf, true) diff --git a/client/internal/debug/debug_android.go b/client/internal/debug/debug_android.go new file mode 100644 index 000000000..a4e2b3e98 --- /dev/null +++ b/client/internal/debug/debug_android.go @@ -0,0 +1,41 @@ +//go:build android + +package debug + +import ( + "fmt" + "io" + "os/exec" + + log "github.com/sirupsen/logrus" +) + +func (g *BundleGenerator) addPlatformLog() error { + cmd := exec.Command("/system/bin/logcat", "-d") + stdout, err := cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("logcat stdout pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + return fmt.Errorf("start logcat: %w", err) + } + + var logReader io.Reader = stdout + if g.anonymize { + var pw *io.PipeWriter + logReader, pw = io.Pipe() + go anonymizeLog(stdout, pw, g.anonymizer) + } + + if err := g.addFileToZip(logReader, "logcat.txt"); err != nil { + return fmt.Errorf("add logcat to zip: %w", err) + } + + if err := cmd.Wait(); err != nil { + return fmt.Errorf("wait logcat: %w", err) + } + + log.Debug("added logcat output to debug bundle") + return nil +} diff --git a/client/internal/debug/debug_nonandroid.go b/client/internal/debug/debug_nonandroid.go new file mode 100644 index 000000000..117238dec --- /dev/null +++ b/client/internal/debug/debug_nonandroid.go @@ -0,0 +1,25 @@ +//go:build !android + +package debug + +import ( + "slices" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/util" +) + +func (g *BundleGenerator) addPlatformLog() error { + if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) { + if err := g.addLogfile(); err != nil { + log.Errorf("failed to add log file to debug bundle: %v", err) + if err := g.trySystemdLogFallback(); err != nil { + return err + } + } + } else if err := g.trySystemdLogFallback(); err != nil { + return err + } + return nil +} diff --git a/client/internal/debug/debug_test.go b/client/internal/debug/debug_test.go index 59837c328..6b5bb911c 100644 --- a/client/internal/debug/debug_test.go +++ b/client/internal/debug/debug_test.go @@ -1,8 +1,12 @@ package debug import ( + "archive/zip" + "bytes" "encoding/json" "net" + "os" + "path/filepath" "strings" "testing" @@ -10,6 +14,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/client/anonymize" + "github.com/netbirdio/netbird/client/configs" mgmProto "github.com/netbirdio/netbird/shared/management/proto" ) @@ -420,6 +425,226 @@ func TestAnonymizeNetworkMap(t *testing.T) { } } +func TestIsSensitiveEnvVar(t *testing.T) { + tests := []struct { + key string + sensitive bool + }{ + {"NB_SETUP_KEY", true}, + {"NB_API_TOKEN", true}, + {"NB_CLIENT_SECRET", true}, + {"NB_PASSWORD", true}, + {"NB_CREDENTIAL", true}, + {"NB_LOG_LEVEL", false}, + {"NB_MANAGEMENT_URL", false}, + {"NB_HOSTNAME", false}, + {"HOME", false}, + {"PATH", false}, + } + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + assert.Equal(t, tt.sensitive, isSensitiveEnvVar(tt.key)) + }) + } +} + +func TestSanitizeServiceEnvVars(t *testing.T) { + tests := []struct { + name string + anonymize bool + input map[string]any + check func(t *testing.T, params map[string]any) + }{ + { + name: "no env vars key", + anonymize: false, + input: map[string]any{"management_url": "https://mgmt.example.com"}, + check: func(t *testing.T, params map[string]any) { + t.Helper() + assert.Equal(t, "https://mgmt.example.com", params["management_url"], "non-env fields should be untouched") + _, ok := params[jsonKeyServiceEnv] + assert.False(t, ok, "service_env_vars should not be added") + }, + }, + { + name: "non-NB vars are masked", + anonymize: false, + input: map[string]any{ + jsonKeyServiceEnv: map[string]any{ + "HOME": "/root", + "PATH": "/usr/bin", + "NB_LOG_LEVEL": "debug", + }, + }, + check: func(t *testing.T, params map[string]any) { + t.Helper() + env := params[jsonKeyServiceEnv].(map[string]any) + assert.Equal(t, maskedValue, env["HOME"], "non-NB_ var should be masked") + assert.Equal(t, maskedValue, env["PATH"], "non-NB_ var should be masked") + assert.Equal(t, "debug", env["NB_LOG_LEVEL"], "safe NB_ var should pass through") + }, + }, + { + name: "sensitive NB vars are masked", + anonymize: false, + input: map[string]any{ + jsonKeyServiceEnv: map[string]any{ + "NB_SETUP_KEY": "abc123", + "NB_API_TOKEN": "tok_xyz", + "NB_LOG_LEVEL": "info", + }, + }, + check: func(t *testing.T, params map[string]any) { + t.Helper() + env := params[jsonKeyServiceEnv].(map[string]any) + assert.Equal(t, maskedValue, env["NB_SETUP_KEY"], "sensitive NB_ var should be masked") + assert.Equal(t, maskedValue, env["NB_API_TOKEN"], "sensitive NB_ var should be masked") + assert.Equal(t, "info", env["NB_LOG_LEVEL"], "safe NB_ var should pass through") + }, + }, + { + name: "safe NB vars anonymized when anonymize is true", + anonymize: true, + input: map[string]any{ + jsonKeyServiceEnv: map[string]any{ + "NB_MANAGEMENT_URL": "https://mgmt.example.com:443", + "NB_LOG_LEVEL": "debug", + "NB_SETUP_KEY": "secret", + "SOME_OTHER": "val", + }, + }, + check: func(t *testing.T, params map[string]any) { + t.Helper() + env := params[jsonKeyServiceEnv].(map[string]any) + // Safe NB_ values should be anonymized (not the original, not masked) + mgmtVal := env["NB_MANAGEMENT_URL"].(string) + assert.NotEqual(t, "https://mgmt.example.com:443", mgmtVal, "should be anonymized") + assert.NotEqual(t, maskedValue, mgmtVal, "should not be masked") + + logVal := env["NB_LOG_LEVEL"].(string) + assert.NotEqual(t, maskedValue, logVal, "safe NB_ var should not be masked") + + // Sensitive and non-NB_ still masked + assert.Equal(t, maskedValue, env["NB_SETUP_KEY"]) + assert.Equal(t, maskedValue, env["SOME_OTHER"]) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses()) + g := &BundleGenerator{ + anonymize: tt.anonymize, + anonymizer: anonymizer, + } + g.sanitizeServiceEnvVars(tt.input) + tt.check(t, tt.input) + }) + } +} + +func TestAddServiceParams(t *testing.T) { + t.Run("missing service.json returns nil", func(t *testing.T) { + g := &BundleGenerator{ + anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()), + } + + origStateDir := configs.StateDir + configs.StateDir = t.TempDir() + t.Cleanup(func() { configs.StateDir = origStateDir }) + + err := g.addServiceParams() + assert.NoError(t, err) + }) + + t.Run("management_url anonymized when anonymize is true", func(t *testing.T) { + dir := t.TempDir() + origStateDir := configs.StateDir + configs.StateDir = dir + t.Cleanup(func() { configs.StateDir = origStateDir }) + + input := map[string]any{ + jsonKeyManagementURL: "https://api.example.com:443", + jsonKeyServiceEnv: map[string]any{ + "NB_LOG_LEVEL": "trace", + }, + } + data, err := json.Marshal(input) + require.NoError(t, err) + require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600)) + + var buf bytes.Buffer + zw := zip.NewWriter(&buf) + + g := &BundleGenerator{ + anonymize: true, + anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()), + archive: zw, + } + + require.NoError(t, g.addServiceParams()) + require.NoError(t, zw.Close()) + + zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len())) + require.NoError(t, err) + require.Len(t, zr.File, 1) + assert.Equal(t, serviceParamsBundle, zr.File[0].Name) + + rc, err := zr.File[0].Open() + require.NoError(t, err) + defer rc.Close() + + var result map[string]any + require.NoError(t, json.NewDecoder(rc).Decode(&result)) + + mgmt := result[jsonKeyManagementURL].(string) + assert.NotEqual(t, "https://api.example.com:443", mgmt, "management_url should be anonymized") + assert.NotEmpty(t, mgmt) + + env := result[jsonKeyServiceEnv].(map[string]any) + assert.NotEqual(t, maskedValue, env["NB_LOG_LEVEL"], "safe NB_ var should not be masked") + }) + + t.Run("management_url preserved when anonymize is false", func(t *testing.T) { + dir := t.TempDir() + origStateDir := configs.StateDir + configs.StateDir = dir + t.Cleanup(func() { configs.StateDir = origStateDir }) + + input := map[string]any{ + jsonKeyManagementURL: "https://api.example.com:443", + } + data, err := json.Marshal(input) + require.NoError(t, err) + require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600)) + + var buf bytes.Buffer + zw := zip.NewWriter(&buf) + + g := &BundleGenerator{ + anonymize: false, + anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()), + archive: zw, + } + + require.NoError(t, g.addServiceParams()) + require.NoError(t, zw.Close()) + + zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len())) + require.NoError(t, err) + + rc, err := zr.File[0].Open() + require.NoError(t, err) + defer rc.Close() + + var result map[string]any + require.NoError(t, json.NewDecoder(rc).Decode(&result)) + + assert.Equal(t, "https://api.example.com:443", result[jsonKeyManagementURL], "management_url should be preserved") + }) +} + // Helper function to check if IP is in CGNAT range func isInCGNATRange(ip net.IP) bool { cgnat := net.IPNet{ diff --git a/client/internal/debug/upload_test.go b/client/internal/debug/upload_test.go index e833c196d..f224b8d3f 100644 --- a/client/internal/debug/upload_test.go +++ b/client/internal/debug/upload_test.go @@ -3,10 +3,12 @@ package debug import ( "context" "errors" + "net" "net/http" "os" "path/filepath" "testing" + "time" "github.com/stretchr/testify/require" @@ -19,8 +21,10 @@ func TestUpload(t *testing.T) { t.Skip("Skipping upload test on docker ci") } testDir := t.TempDir() - testURL := "http://localhost:8080" + addr := reserveLoopbackPort(t) + testURL := "http://" + addr t.Setenv("SERVER_URL", testURL) + t.Setenv("SERVER_ADDRESS", addr) t.Setenv("STORE_DIR", testDir) srv := server.NewServer() go func() { @@ -33,6 +37,7 @@ func TestUpload(t *testing.T) { t.Errorf("Failed to stop server: %v", err) } }) + waitForServer(t, addr) file := filepath.Join(t.TempDir(), "tmpfile") fileContent := []byte("test file content") @@ -47,3 +52,30 @@ func TestUpload(t *testing.T) { require.NoError(t, err) require.Equal(t, fileContent, createdFileContent) } + +// reserveLoopbackPort binds an ephemeral port on loopback to learn a free +// address, then releases it so the server under test can rebind. The close/ +// rebind window is racy in theory; on loopback with a kernel-assigned port +// it's essentially never contended in practice. +func reserveLoopbackPort(t *testing.T) string { + t.Helper() + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + addr := l.Addr().String() + require.NoError(t, l.Close()) + return addr +} + +func waitForServer(t *testing.T, addr string) { + t.Helper() + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond) + if err == nil { + _ = c.Close() + return + } + time.Sleep(20 * time.Millisecond) + } + t.Fatalf("server did not start listening on %s in time", addr) +} diff --git a/client/internal/dns/file_parser_unix.go b/client/internal/dns/file_parser_unix.go index 8dacb4e51..50ba74c0c 100644 --- a/client/internal/dns/file_parser_unix.go +++ b/client/internal/dns/file_parser_unix.go @@ -13,6 +13,7 @@ import ( const ( defaultResolvConfPath = "/etc/resolv.conf" + nsswitchConfPath = "/etc/nsswitch.conf" ) type resolvConf struct { diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 06a2056b1..57e7722d4 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -1,7 +1,10 @@ package dns import ( + "context" "fmt" + "math" + "net" "slices" "strconv" "strings" @@ -73,6 +76,9 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error { return nil } w.response = m + if m.MsgHdr.Truncated { + w.SetMeta("truncated", "true") + } return w.ResponseWriter.WriteMsg(m) } @@ -189,16 +195,26 @@ func (c *HandlerChain) logHandlers() { } func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + c.dispatch(w, r, math.MaxInt) +} + +// dispatch routes a DNS request through the chain, skipping handlers with +// priority > maxPriority. Shared by ServeDNS and ResolveInternal. +func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) { if len(r.Question) == 0 { return } startTime := time.Now() requestID := resutil.GenerateRequestID() - logger := log.WithFields(log.Fields{ + fields := log.Fields{ "request_id": requestID, "dns_id": fmt.Sprintf("%04x", r.Id), - }) + } + if addr := w.RemoteAddr(); addr != nil { + fields["client"] = addr.String() + } + logger := log.WithFields(fields) question := r.Question[0] qname := strings.ToLower(question.Name) @@ -209,6 +225,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // Try handlers in priority order for _, entry := range handlers { + if entry.Priority > maxPriority { + continue + } if !c.isHandlerMatch(qname, entry) { continue } @@ -261,9 +280,58 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q meta += " " + k + "=" + v } - logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s", + logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB%s took=%s", qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer), - meta, time.Since(startTime)) + cw.response.Len(), meta, time.Since(startTime)) +} + +// ResolveInternal runs an in-process DNS query against the chain, skipping any +// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt +// cache refresher) that must bypass themselves to avoid loops. Honors ctx +// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own +// (bounded by the invoked handler's internal timeout). +func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) { + if len(r.Question) == 0 { + return nil, fmt.Errorf("empty question") + } + + base := &internalResponseWriter{} + done := make(chan struct{}) + go func() { + c.dispatch(base, r, maxPriority) + close(done) + }() + + select { + case <-done: + case <-ctx.Done(): + // Prefer a completed response if dispatch finished concurrently with cancellation. + select { + case <-done: + default: + return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err()) + } + } + + if base.response == nil || base.response.Rcode == dns.RcodeRefused { + return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d", + strings.ToLower(r.Question[0].Name), maxPriority) + } + return base.response, nil +} + +// HasRootHandlerAtOrBelow reports whether any "." handler is registered at +// priority ≤ maxPriority. +func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool { + c.mu.RLock() + defer c.mu.RUnlock() + + for _, h := range c.handlers { + if h.Pattern == "." && h.Priority <= maxPriority { + return true + } + } + return false } func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool { @@ -284,3 +352,36 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool { } } } + +// internalResponseWriter captures a dns.Msg for in-process chain queries. +type internalResponseWriter struct { + response *dns.Msg +} + +func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil } +func (w *internalResponseWriter) LocalAddr() net.Addr { return nil } +func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil } + +// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg +// still surface their answer to ResolveInternal. +func (w *internalResponseWriter) Write(p []byte) (int, error) { + msg := new(dns.Msg) + if err := msg.Unpack(p); err != nil { + return 0, err + } + w.response = msg + return len(p), nil +} + +func (w *internalResponseWriter) Close() error { return nil } +func (w *internalResponseWriter) TsigStatus() error { return nil } + +// TsigTimersOnly is part of dns.ResponseWriter. +func (w *internalResponseWriter) TsigTimersOnly(bool) { + // no-op: in-process queries carry no TSIG state. +} + +// Hijack is part of dns.ResponseWriter. +func (w *internalResponseWriter) Hijack() { + // no-op: in-process queries have no underlying connection to hand off. +} diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index fa9525069..034a760dc 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -1,11 +1,15 @@ package dns_test import ( + "context" + "net" "testing" + "time" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns/test" @@ -1042,3 +1046,163 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) { }) } } + +// answeringHandler writes a fixed A record to ack the query. Used to verify +// which handler ResolveInternal dispatches to. +type answeringHandler struct { + name string + ip string +} + +func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + resp := &dns.Msg{} + resp.SetReply(r) + resp.Answer = []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP(h.ip).To4(), + }} + _ = w.WriteMsg(resp) +} + +func (h *answeringHandler) String() string { return h.name } + +func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) { + chain := nbdns.NewHandlerChain() + + high := &answeringHandler{name: "high", ip: "10.0.0.1"} + low := &answeringHandler{name: "low", ip: "10.0.0.2"} + + chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache) + chain.AddHandler("example.com.", low, nbdns.PriorityUpstream) + + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, 1, len(resp.Answer)) + a, ok := resp.Answer[0].(*dns.A) + assert.True(t, ok) + assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream") +} + +func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) { + chain := nbdns.NewHandlerChain() + high := &answeringHandler{name: "high", ip: "10.0.0.1"} + chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache) + + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + _, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream) + assert.Error(t, err, "no handler at or below maxPriority should error") +} + +// rawWriteHandler packs a response and calls ResponseWriter.Write directly +// (instead of WriteMsg), exercising the internalResponseWriter.Write path. +type rawWriteHandler struct { + ip string +} + +func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + resp := &dns.Msg{} + resp.SetReply(r) + resp.Answer = []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP(h.ip).To4(), + }} + packed, err := resp.Pack() + if err != nil { + return + } + _, _ = w.Write(packed) +} + +func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) { + chain := nbdns.NewHandlerChain() + chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream) + + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream) + assert.NoError(t, err) + require.NotNil(t, resp) + require.Len(t, resp.Answer, 1) + a, ok := resp.Answer[0].(*dns.A) + require.True(t, ok) + assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer") +} + +func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) { + chain := nbdns.NewHandlerChain() + _, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream) + assert.Error(t, err) +} + +// hangingHandler blocks indefinitely until closed, simulating a wedged upstream. +type hangingHandler struct { + block chan struct{} +} + +func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + <-h.block + resp := &dns.Msg{} + resp.SetReply(r) + _ = w.WriteMsg(resp) +} + +func (h *hangingHandler) String() string { return "hangingHandler" } + +func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) { + chain := nbdns.NewHandlerChain() + h := &hangingHandler{block: make(chan struct{})} + defer close(h.block) + + chain.AddHandler("example.com.", h, nbdns.PriorityUpstream) + + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + start := time.Now() + _, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream) + elapsed := time.Since(start) + + assert.Error(t, err) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline") +} + +func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) { + chain := nbdns.NewHandlerChain() + h := &answeringHandler{name: "h", ip: "10.0.0.1"} + + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain") + + chain.AddHandler("example.com.", h, nbdns.PriorityUpstream) + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count") + + chain.AddHandler(".", h, nbdns.PriorityMgmtCache) + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded") + + chain.AddHandler(".", h, nbdns.PriorityDefault) + assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included") + + chain.RemoveHandler(".", nbdns.PriorityDefault) + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream)) + + // Primary nsgroup case: root handler lands at PriorityUpstream. + chain.AddHandler(".", h, nbdns.PriorityUpstream) + assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included") + chain.RemoveHandler(".", nbdns.PriorityUpstream) + + // Fallback case: original /etc/resolv.conf entries land at PriorityFallback. + chain.AddHandler(".", h, nbdns.PriorityFallback) + assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included") + chain.RemoveHandler(".", nbdns.PriorityFallback) + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream)) +} diff --git a/client/internal/dns/host_unix.go b/client/internal/dns/host_unix.go index 422fed4e5..d7301d725 100644 --- a/client/internal/dns/host_unix.go +++ b/client/internal/dns/host_unix.go @@ -46,12 +46,12 @@ type restoreHostManager interface { } func newHostManager(wgInterface string) (hostManager, error) { - osManager, err := getOSDNSManagerType() + osManager, reason, err := getOSDNSManagerType() if err != nil { return nil, fmt.Errorf("get os dns manager type: %w", err) } - log.Infof("System DNS manager discovered: %s", osManager) + log.Infof("System DNS manager discovered: %s (%s)", osManager, reason) mgr, err := newHostManagerFromType(wgInterface, osManager) // need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value if err != nil { @@ -74,17 +74,49 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor } } -func getOSDNSManagerType() (osManagerType, error) { +func getOSDNSManagerType() (osManagerType, string, error) { + resolved := isSystemdResolvedRunning() + nss := isLibnssResolveUsed() + stub := checkStub() + + // Prefer systemd-resolved whenever it owns libc resolution, regardless of + // who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups + // that go through nss-resolve, and in foreign mode they can loop back + // through resolved as an upstream. + if resolved && (nss || stub) { + return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil + } + + mgr, reason, rejected, err := scanResolvConfHeader() + if err != nil { + return 0, "", err + } + if reason != "" { + return mgr, reason, nil + } + + fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub) + if len(rejected) > 0 { + fallback += "; rejected: " + strings.Join(rejected, ", ") + } + return fileManager, fallback, nil +} + +// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the +// matching manager. If reason is empty the caller should pick file mode and +// use rejected for diagnostics. +func scanResolvConfHeader() (osManagerType, string, []string, error) { file, err := os.Open(defaultResolvConfPath) if err != nil { - return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err) + return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err) } defer func() { - if err := file.Close(); err != nil { - log.Errorf("close file %s: %s", defaultResolvConfPath, err) + if cerr := file.Close(); cerr != nil { + log.Errorf("close file %s: %s", defaultResolvConfPath, cerr) } }() + var rejected []string scanner := bufio.NewScanner(file) for scanner.Scan() { text := scanner.Text() @@ -92,41 +124,48 @@ func getOSDNSManagerType() (osManagerType, error) { continue } if text[0] != '#' { - return fileManager, nil + break } - if strings.Contains(text, fileGeneratedResolvConfContentHeader) { - return netbirdManager, nil - } - if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() { - return networkManager, nil - } - if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() { - if checkStub() { - return systemdManager, nil - } else { - return fileManager, nil - } - } - if strings.Contains(text, "resolvconf") { - if isSystemdResolveConfMode() { - return systemdManager, nil - } - - return resolvConfManager, nil + if mgr, reason, rej := matchResolvConfHeader(text); reason != "" { + return mgr, reason, nil, nil + } else if rej != "" { + rejected = append(rejected, rej) } } if err := scanner.Err(); err != nil && err != io.EOF { - return 0, fmt.Errorf("scan: %w", err) + return 0, "", nil, fmt.Errorf("scan: %w", err) } - - return fileManager, nil + return 0, "", rejected, nil } -// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager. +// matchResolvConfHeader inspects a single comment line. Returns either a +// definitive (manager, reason) or a non-empty rejected diagnostic. +func matchResolvConfHeader(text string) (osManagerType, string, string) { + if strings.Contains(text, fileGeneratedResolvConfContentHeader) { + return netbirdManager, "netbird-managed resolv.conf header detected", "" + } + if strings.Contains(text, "NetworkManager") { + if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() { + return networkManager, "NetworkManager header + supported version on dbus", "" + } + return 0, "", "NetworkManager header (no dbus or unsupported version)" + } + if strings.Contains(text, "resolvconf") { + if isSystemdResolveConfMode() { + return systemdManager, "resolvconf header in systemd-resolved compatibility mode", "" + } + return resolvConfManager, "resolvconf header detected", "" + } + return 0, "", "" +} + +// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed +// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping +// into file mode while resolved is active. func checkStub() bool { rConf, err := parseDefaultResolvConf() if err != nil { - log.Warnf("failed to parse resolv conf: %s", err) + log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err) return true } @@ -139,3 +178,36 @@ func checkStub() bool { return false } + +// isLibnssResolveUsed reports whether nss-resolve is listed before dns on +// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are +// delegated to systemd-resolved regardless of /etc/resolv.conf. +func isLibnssResolveUsed() bool { + bs, err := os.ReadFile(nsswitchConfPath) + if err != nil { + log.Debugf("read %s: %v", nsswitchConfPath, err) + return false + } + return parseNsswitchResolveAhead(bs) +} + +func parseNsswitchResolveAhead(data []byte) bool { + for _, line := range strings.Split(string(data), "\n") { + if i := strings.IndexByte(line, '#'); i >= 0 { + line = line[:i] + } + fields := strings.Fields(line) + if len(fields) < 2 || fields[0] != "hosts:" { + continue + } + for _, module := range fields[1:] { + switch module { + case "dns": + return false + case "resolve": + return true + } + } + } + return false +} diff --git a/client/internal/dns/host_unix_test.go b/client/internal/dns/host_unix_test.go new file mode 100644 index 000000000..e936281d3 --- /dev/null +++ b/client/internal/dns/host_unix_test.go @@ -0,0 +1,76 @@ +//go:build (linux && !android) || freebsd + +package dns + +import "testing" + +func TestParseNsswitchResolveAhead(t *testing.T) { + tests := []struct { + name string + in string + want bool + }{ + { + name: "resolve before dns with action token", + in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n", + want: true, + }, + { + name: "dns before resolve", + in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n", + want: false, + }, + { + name: "debian default with only dns", + in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n", + want: false, + }, + { + name: "neither resolve nor dns", + in: "hosts: files myhostname\n", + want: false, + }, + { + name: "no hosts line", + in: "passwd: files systemd\ngroup: files systemd\n", + want: false, + }, + { + name: "empty", + in: "", + want: false, + }, + { + name: "comments and blank lines ignored", + in: "# comment\n\n# another\nhosts: resolve dns\n", + want: true, + }, + { + name: "trailing inline comment", + in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n", + want: true, + }, + { + name: "hosts token must be the first field", + in: " hosts: resolve dns\n", + want: true, + }, + { + name: "other db line mentioning resolve is ignored", + in: "networks: resolve\nhosts: dns\n", + want: false, + }, + { + name: "only resolve, no dns", + in: "hosts: files resolve\n", + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want { + t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/client/internal/dns/local/local_test.go b/client/internal/dns/local/local_test.go index 73f70035f..2c6b7dbc3 100644 --- a/client/internal/dns/local/local_test.go +++ b/client/internal/dns/local/local_test.go @@ -1263,9 +1263,9 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) { }) } -// TestLocalResolver_Stop tests cleanup on Stop +// TestLocalResolver_Stop tests cleanup on GracefullyStop func TestLocalResolver_Stop(t *testing.T) { - t.Run("Stop clears all state", func(t *testing.T) { + t.Run("GracefullyStop clears all state", func(t *testing.T) { resolver := NewResolver() resolver.Update([]nbdns.CustomZone{{ Domain: "example.com.", @@ -1285,7 +1285,7 @@ func TestLocalResolver_Stop(t *testing.T) { assert.False(t, resolver.isInManagedZone("host.example.com.")) }) - t.Run("Stop is safe to call multiple times", func(t *testing.T) { + t.Run("GracefullyStop is safe to call multiple times", func(t *testing.T) { resolver := NewResolver() resolver.Update([]nbdns.CustomZone{{ Domain: "example.com.", @@ -1299,7 +1299,7 @@ func TestLocalResolver_Stop(t *testing.T) { resolver.Stop() }) - t.Run("Stop cancels in-flight external resolution", func(t *testing.T) { + t.Run("GracefullyStop cancels in-flight external resolution", func(t *testing.T) { resolver := NewResolver() lookupStarted := make(chan struct{}) diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go index 314af51d9..988e427fb 100644 --- a/client/internal/dns/mgmt/mgmt.go +++ b/client/internal/dns/mgmt/mgmt.go @@ -2,40 +2,83 @@ package mgmt import ( "context" + "errors" "fmt" "net" - "net/netip" "net/url" + "os" + "slices" "strings" "sync" + "sync/atomic" "time" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "golang.org/x/sync/singleflight" dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" + "github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/shared/management/domain" ) -const dnsTimeout = 5 * time.Second +const ( + dnsTimeout = 5 * time.Second + defaultTTL = 300 * time.Second + refreshBackoff = 30 * time.Second -// Resolver caches critical NetBird infrastructure domains + // envMgmtCacheTTL overrides defaultTTL for integration/dev testing. + envMgmtCacheTTL = "NB_MGMT_CACHE_TTL" +) + +// ChainResolver lets the cache refresh stale entries through the DNS handler +// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the +// system resolver. +type ChainResolver interface { + ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) + HasRootHandlerAtOrBelow(maxPriority int) bool +} + +// cachedRecord holds DNS records plus timestamps used for TTL refresh. +// records and cachedAt are set at construction and treated as immutable; +// lastFailedRefresh and consecFailures are mutable and must be accessed under +// Resolver.mutex. +type cachedRecord struct { + records []dns.RR + cachedAt time.Time + lastFailedRefresh time.Time + consecFailures int +} + +// Resolver caches critical NetBird infrastructure domains. +// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex. type Resolver struct { - records map[dns.Question][]dns.RR + records map[dns.Question]*cachedRecord mgmtDomain *domain.Domain serverDomains *dnsconfig.ServerDomains mutex sync.RWMutex -} -type ipsResponse struct { - ips []netip.Addr - err error + chain ChainResolver + chainMaxPriority int + refreshGroup singleflight.Group + + // refreshing tracks questions whose refresh is running via the OS + // fallback path. A ServeDNS hit for a question in this map indicates + // the OS resolver routed the recursive query back to us (loop). Only + // the OS path arms this so chain-path refreshes don't produce false + // positives. The atomic bool is CAS-flipped once per refresh to + // throttle the warning log. + refreshing map[dns.Question]*atomic.Bool + + cacheTTL time.Duration } // NewResolver creates a new management domains cache resolver. func NewResolver() *Resolver { return &Resolver{ - records: make(map[dns.Question][]dns.RR), + records: make(map[dns.Question]*cachedRecord), + refreshing: make(map[dns.Question]*atomic.Bool), + cacheTTL: resolveCacheTTL(), } } @@ -44,7 +87,19 @@ func (m *Resolver) String() string { return "MgmtCacheResolver" } -// ServeDNS implements dns.Handler interface. +// SetChainResolver wires the handler chain used to refresh stale cache entries. +// maxPriority caps which handlers may answer refresh queries (typically +// PriorityUpstream, so upstream/default/fallback handlers are consulted and +// mgmt/route/local handlers are skipped). +func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) { + m.mutex.Lock() + m.chain = chain + m.chainMaxPriority = maxPriority + m.mutex.Unlock() +} + +// ServeDNS serves cached A/AAAA records. Stale entries are returned +// immediately and refreshed asynchronously (stale-while-revalidate). func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) == 0 { m.continueToNext(w, r) @@ -60,7 +115,14 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } m.mutex.RLock() - records, found := m.records[question] + cached, found := m.records[question] + inflight := m.refreshing[question] + var shouldRefresh bool + if found { + stale := time.Since(cached.cachedAt) > m.cacheTTL + inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff + shouldRefresh = stale && !inBackoff + } m.mutex.RUnlock() if !found { @@ -68,12 +130,23 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } + if inflight != nil && inflight.CompareAndSwap(false, true) { + log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)", + question.Name) + } + + // Skip scheduling a refresh goroutine if one is already inflight for + // this question; singleflight would dedup anyway but skipping avoids + // a parked goroutine per stale hit under bursty load. + if shouldRefresh && inflight == nil { + m.scheduleRefresh(question, cached) + } + resp := &dns.Msg{} resp.SetReply(r) resp.Authoritative = false resp.RecursionAvailable = true - - resp.Answer = append(resp.Answer, records...) + resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt)) log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name) @@ -98,101 +171,260 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) { } } -// AddDomain manually adds a domain to cache by resolving it. +// AddDomain resolves a domain and stores its A/AAAA records in the cache. +// A family that resolves NODATA (nil err, zero records) evicts any stale +// entry for that qtype. func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString())) ctx, cancel := context.WithTimeout(ctx, dnsTimeout) defer cancel() - ips, err := lookupIPWithExtraTimeout(ctx, d) - if err != nil { - return err + aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName) + + if errA != nil && errAAAA != nil { + return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA)) } - var aRecords, aaaaRecords []dns.RR - for _, ip := range ips { - if ip.Is4() { - rr := &dns.A{ - Hdr: dns.RR_Header{ - Name: dnsName, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 300, - }, - A: ip.AsSlice(), - } - aRecords = append(aRecords, rr) - } else if ip.Is6() { - rr := &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: dnsName, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: 300, - }, - AAAA: ip.AsSlice(), - } - aaaaRecords = append(aaaaRecords, rr) + if len(aRecords) == 0 && len(aaaaRecords) == 0 { + if err := errors.Join(errA, errAAAA); err != nil { + return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err) } + return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString()) } + now := time.Now() m.mutex.Lock() + defer m.mutex.Unlock() - if len(aRecords) > 0 { - aQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - } - m.records[aQuestion] = aRecords - } + m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now) + m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now) - if len(aaaaRecords) > 0 { - aaaaQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeAAAA, - Qclass: dns.ClassINET, - } - m.records[aaaaQuestion] = aaaaRecords - } - - m.mutex.Unlock() - - log.Debugf("added domain=%s with %d A records and %d AAAA records", + log.Debugf("added/updated domain=%s with %d A records and %d AAAA records", d.SafeString(), len(aRecords), len(aaaaRecords)) return nil } -func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) { - log.Infof("looking up IP for mgmt domain=%s", d.SafeString()) - defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString()) - resultChan := make(chan *ipsResponse, 1) +// applyFamilyRecords writes records, evicts on NODATA, leaves the cache +// untouched on error. Caller holds m.mutex. +func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) { + q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET} + switch { + case len(records) > 0: + m.records[q] = &cachedRecord{records: records, cachedAt: now} + case err == nil: + delete(m.records, q) + } +} - go func() { - ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString()) - resultChan <- &ipsResponse{ - err: err, - ips: ips, +// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per +// unique in-flight key; bursty stale hits share its channel. expected is the +// cachedRecord pointer observed by the caller; the refresh only mutates the +// cache if that pointer is still the one stored, so a stale in-flight refresh +// can't clobber a newer entry written by AddDomain or a competing refresh. +func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) { + key := question.Name + "|" + dns.TypeToString[question.Qtype] + _ = m.refreshGroup.DoChan(key, func() (any, error) { + return nil, m.refreshQuestion(question, expected) + }) +} + +// refreshQuestion replaces the cached records on success, or marks the entry +// failed (arming the backoff) on failure. While this runs, ServeDNS can detect +// a resolver loop by spotting a query for this same question arriving on us. +// expected pins the cache entry observed at schedule time; mutations only apply +// if m.records[question] still points at it. +func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error { + ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) + defer cancel() + + d, err := domain.FromString(strings.TrimSuffix(question.Name, ".")) + if err != nil { + m.markRefreshFailed(question, expected) + return fmt.Errorf("parse domain: %w", err) + } + + records, err := m.lookupRecords(ctx, d, question) + if err != nil { + fails := m.markRefreshFailed(question, expected) + logf := log.Warnf + if fails == 0 || fails > 1 { + logf = log.Debugf } - }() - - var resp *ipsResponse - - select { - case <-time.After(dnsTimeout + time.Millisecond*500): - log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString()) - return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString()) - case <-ctx.Done(): - return nil, ctx.Err() - case resp = <-resultChan: + logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)", + d.SafeString(), dns.TypeToString[question.Qtype], err, fails) + return err } - if resp.err != nil { - return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err) + // NOERROR/NODATA: family gone upstream, evict so we stop serving stale. + if len(records) == 0 { + m.mutex.Lock() + if m.records[question] == expected { + delete(m.records, question) + m.mutex.Unlock() + log.Infof("removed mgmt cache domain=%s type=%s: no records returned", + d.SafeString(), dns.TypeToString[question.Qtype]) + return nil + } + m.mutex.Unlock() + log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh", + d.SafeString(), dns.TypeToString[question.Qtype]) + return nil } - return resp.ips, nil + + now := time.Now() + m.mutex.Lock() + if m.records[question] != expected { + m.mutex.Unlock() + log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh", + d.SafeString(), dns.TypeToString[question.Qtype]) + return nil + } + m.records[question] = &cachedRecord{records: records, cachedAt: now} + m.mutex.Unlock() + + log.Infof("refreshed mgmt cache domain=%s type=%s", + d.SafeString(), dns.TypeToString[question.Qtype]) + return nil +} + +func (m *Resolver) markRefreshing(question dns.Question) { + m.mutex.Lock() + m.refreshing[question] = &atomic.Bool{} + m.mutex.Unlock() +} + +func (m *Resolver) clearRefreshing(question dns.Question) { + m.mutex.Lock() + delete(m.refreshing, question) + m.mutex.Unlock() +} + +// markRefreshFailed arms the backoff and returns the new consecutive-failure +// count so callers can downgrade subsequent failure logs to debug. +func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int { + m.mutex.Lock() + defer m.mutex.Unlock() + c, ok := m.records[question] + if !ok || c != expected { + return 0 + } + c.lastFailedRefresh = time.Now() + c.consecFailures++ + return c.consecFailures +} + +// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let +// callers tell records, NODATA (nil err, no records), and failure apart. +func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) { + m.mutex.RLock() + chain := m.chain + maxPriority := m.chainMaxPriority + m.mutex.RUnlock() + + if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) { + aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA) + aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA) + return + } + + // TODO: drop once every supported OS registers a fallback resolver. Safe + // today: no root handler at priority ≤ PriorityUpstream means NetBird is + // not the system resolver, so net.DefaultResolver will not loop back. + aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA) + aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA) + return +} + +// lookupRecords resolves a single record type via chain or OS. The OS branch +// arms the loop detector for the duration of its call so that ServeDNS can +// spot the OS resolver routing the recursive query back to us. +func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) { + m.mutex.RLock() + chain := m.chain + maxPriority := m.chainMaxPriority + m.mutex.RUnlock() + + if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) { + return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype) + } + + // TODO: drop once every supported OS registers a fallback resolver. + m.markRefreshing(q) + defer m.clearRefreshing(q) + + return m.osLookup(ctx, d, q.Name, q.Qtype) +} + +// lookupViaChain resolves via the handler chain and rewrites each RR to use +// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache +// target-owned records or upstream TTLs. NODATA returns (nil, nil). +func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) { + msg := &dns.Msg{} + msg.SetQuestion(dnsName, qtype) + msg.RecursionDesired = true + + resp, err := chain.ResolveInternal(ctx, msg, maxPriority) + if err != nil { + return nil, fmt.Errorf("chain resolve: %w", err) + } + if resp == nil { + return nil, fmt.Errorf("chain resolve returned nil response") + } + if resp.Rcode != dns.RcodeSuccess { + return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode]) + } + + ttl := uint32(m.cacheTTL.Seconds()) + owners := cnameOwners(dnsName, resp.Answer) + var filtered []dns.RR + for _, rr := range resp.Answer { + h := rr.Header() + if h.Class != dns.ClassINET || h.Rrtype != qtype { + continue + } + if !owners[strings.ToLower(dns.Fqdn(h.Name))] { + continue + } + if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil { + filtered = append(filtered, cp) + } + } + return filtered, nil +} + +// osLookup resolves a single family via net.DefaultResolver using resutil, +// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA +// returns (nil, nil). +func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) { + network := resutil.NetworkForQtype(qtype) + if network == "" { + return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype]) + } + + log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype]) + defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype]) + + result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype) + if result.Rcode == dns.RcodeSuccess { + return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil + } + + if result.Err != nil { + return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err) + } + return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode]) +} + +// responseTTL returns the remaining cache lifetime in seconds (rounded up), +// so downstream resolvers don't cache an answer for longer than we will. +func (m *Resolver) responseTTL(cachedAt time.Time) uint32 { + remaining := m.cacheTTL - time.Since(cachedAt) + if remaining <= 0 { + return 0 + } + return uint32((remaining + time.Second - 1) / time.Second) } // PopulateFromConfig extracts and caches domains from the client configuration. @@ -224,19 +456,12 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error { m.mutex.Lock() defer m.mutex.Unlock() - aQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - } - delete(m.records, aQuestion) - - aaaaQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeAAAA, - Qclass: dns.ClassINET, - } - delete(m.records, aaaaQuestion) + qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET} + qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET} + delete(m.records, qA) + delete(m.records, qAAAA) + delete(m.refreshing, qA) + delete(m.refreshing, qAAAA) log.Debugf("removed domain=%s from cache", d.SafeString()) return nil @@ -394,3 +619,73 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve return domains } + +// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non +// A/AAAA records return nil. +func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR { + switch r := rr.(type) { + case *dns.A: + cp := *r + cp.Hdr.Name = owner + cp.Hdr.Ttl = ttl + cp.A = slices.Clone(r.A) + return &cp + case *dns.AAAA: + cp := *r + cp.Hdr.Name = owner + cp.Hdr.Ttl = ttl + cp.AAAA = slices.Clone(r.AAAA) + return &cp + } + return nil +} + +// cloneRecordsWithTTL clones A/AAAA records preserving their owner and +// stamping ttl so the response shares no memory with the cached slice. +func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR { + out := make([]dns.RR, 0, len(records)) + for _, rr := range records { + if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil { + out = append(out, cp) + } + } + return out +} + +// cnameOwners returns dnsName plus every target reachable by following CNAMEs +// in answer, iterating until fixed point so out-of-order chains resolve. +func cnameOwners(dnsName string, answer []dns.RR) map[string]bool { + owners := map[string]bool{dnsName: true} + for { + added := false + for _, rr := range answer { + cname, ok := rr.(*dns.CNAME) + if !ok { + continue + } + name := strings.ToLower(dns.Fqdn(cname.Hdr.Name)) + if !owners[name] { + continue + } + target := strings.ToLower(dns.Fqdn(cname.Target)) + if !owners[target] { + owners[target] = true + added = true + } + } + if !added { + return owners + } + } +} + +// resolveCacheTTL reads the cache TTL override env var; invalid or empty +// values fall back to defaultTTL. Called once per Resolver from NewResolver. +func resolveCacheTTL() time.Duration { + if v := os.Getenv(envMgmtCacheTTL); v != "" { + if d, err := time.ParseDuration(v); err == nil && d > 0 { + return d + } + } + return defaultTTL +} diff --git a/client/internal/dns/mgmt/mgmt_refresh_test.go b/client/internal/dns/mgmt/mgmt_refresh_test.go new file mode 100644 index 000000000..9faa5a0b8 --- /dev/null +++ b/client/internal/dns/mgmt/mgmt_refresh_test.go @@ -0,0 +1,408 @@ +package mgmt + +import ( + "context" + "errors" + "net" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/dns/test" + "github.com/netbirdio/netbird/shared/management/domain" +) + +type fakeChain struct { + mu sync.Mutex + calls map[string]int + answers map[string][]dns.RR + err error + hasRoot bool + onLookup func() +} + +func newFakeChain() *fakeChain { + return &fakeChain{ + calls: map[string]int{}, + answers: map[string][]dns.RR{}, + hasRoot: true, + } +} + +func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool { + f.mu.Lock() + defer f.mu.Unlock() + return f.hasRoot +} + +func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) { + f.mu.Lock() + q := msg.Question[0] + key := q.Name + "|" + dns.TypeToString[q.Qtype] + f.calls[key]++ + answers := f.answers[key] + err := f.err + onLookup := f.onLookup + f.mu.Unlock() + + if onLookup != nil { + onLookup() + } + if err != nil { + return nil, err + } + resp := &dns.Msg{} + resp.SetReply(msg) + resp.Answer = answers + return resp, nil +} + +func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) { + f.mu.Lock() + defer f.mu.Unlock() + key := name + "|" + dns.TypeToString[qtype] + hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60} + switch qtype { + case dns.TypeA: + f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}} + case dns.TypeAAAA: + f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}} + } +} + +func (f *fakeChain) callCount(name string, qtype uint16) int { + f.mu.Lock() + defer f.mu.Unlock() + return f.calls[name+"|"+dns.TypeToString[qtype]] +} + +// waitFor polls the predicate until it returns true or the deadline passes. +func waitFor(t *testing.T, d time.Duration, fn func() bool) { + t.Helper() + deadline := time.Now().Add(d) + for time.Now().Before(deadline) { + if fn() { + return + } + time.Sleep(5 * time.Millisecond) + } + t.Fatalf("condition not met within %s", d) +} + +func queryA(t *testing.T, r *Resolver, name string) *dns.Msg { + t.Helper() + msg := new(dns.Msg) + msg.SetQuestion(name, dns.TypeA) + w := &test.MockResponseWriter{} + r.ServeDNS(w, msg) + return w.GetLastResponse() +} + +func firstA(t *testing.T, resp *dns.Msg) string { + t.Helper() + require.NotNil(t, resp) + require.Greater(t, len(resp.Answer), 0, "expected at least one answer") + a, ok := resp.Answer[0].(*dns.A) + require.True(t, ok, "expected A record") + return a.A.String() +} + +func TestResolver_CacheTTLGatesRefresh(t *testing.T) { + // Same cached entry age, different cacheTTL values: the shorter TTL must + // trigger a background refresh, the longer one must not. Proves that the + // per-Resolver cacheTTL field actually drives the stale decision. + cachedAt := time.Now().Add(-100 * time.Millisecond) + + newRec := func() *cachedRecord { + return &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: cachedAt, + } + } + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + + t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) { + r := NewResolver() + r.cacheTTL = 10 * time.Millisecond + chain := newFakeChain() + chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + r.records[q] = newRec() + + resp := queryA(t, r, q.Name) + assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs") + + waitFor(t, time.Second, func() bool { + return chain.callCount(q.Name, dns.TypeA) >= 1 + }) + }) + + t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) { + r := NewResolver() + r.cacheTTL = time.Hour + chain := newFakeChain() + chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + r.records[q] = newRec() + + resp := queryA(t, r, q.Name) + assert.Equal(t, "10.0.0.1", firstA(t, resp)) + + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh") + }) +} + +func TestResolver_ServeFresh_NoRefresh(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + + r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now(), // fresh + } + + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.1", firstA(t, resp)) + + time.Sleep(20 * time.Millisecond) + assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh") +} + +func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now().Add(-2 * defaultTTL), // stale + } + + // First query: serves stale immediately. + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs") + + waitFor(t, time.Second, func() bool { + return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1 + }) + + // Next query should now return the refreshed IP. + waitFor(t, time.Second, func() bool { + resp := queryA(t, r, "mgmt.example.com.") + return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2" + }) +} + +func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + + var inflight atomic.Int32 + var maxInflight atomic.Int32 + chain.onLookup = func() { + cur := inflight.Add(1) + defer inflight.Add(-1) + for { + prev := maxInflight.Load() + if cur <= prev || maxInflight.CompareAndSwap(prev, cur) { + break + } + } + time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide + } + + r.SetChainResolver(chain, 50) + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now().Add(-2 * defaultTTL), + } + + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + queryA(t, r, "mgmt.example.com.") + }() + } + wg.Wait() + + waitFor(t, 2*time.Second, func() bool { + return inflight.Load() == 0 + }) + + calls := chain.callCount("mgmt.example.com.", dns.TypeA) + assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls) + assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently") +} + +func TestResolver_RefreshFailureArmsBackoff(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.err = errors.New("boom") + r.SetChainResolver(chain, 50) + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now().Add(-2 * defaultTTL), + } + + // First stale hit triggers a refresh attempt that fails. + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails") + + waitFor(t, time.Second, func() bool { + return chain.callCount("mgmt.example.com.", dns.TypeA) == 1 + }) + waitFor(t, time.Second, func() bool { + r.mutex.RLock() + defer r.mutex.RUnlock() + c, ok := r.records[q] + return ok && !c.lastFailedRefresh.IsZero() + }) + + // Subsequent stale hits within backoff window should not schedule more refreshes. + for i := 0; i < 10; i++ { + queryA(t, r, "mgmt.example.com.") + } + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes") +} + +func TestResolver_NoRootHandler_SkipsChain(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.hasRoot = false + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + + // With hasRoot=false the chain must not be consulted. Use a short + // deadline so the OS fallback returns quickly without waiting on a + // real network call in CI. + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + _, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.") + + assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), + "chain must not be used when no root handler is registered at the bound priority") +} + +func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) { + // ServeDNS being invoked for a question while a refresh for that question + // is inflight indicates a resolver loop (OS resolver sent the recursive + // query back to us). The inflightRefresh.loopLoggedOnce flag must be set. + r := NewResolver() + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now(), + } + + // Simulate an inflight refresh. + r.markRefreshing(q) + defer r.clearRefreshing(q) + + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries") + + r.mutex.RLock() + inflight := r.refreshing[q] + r.mutex.RUnlock() + require.NotNil(t, inflight) + assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed") +} + +func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) { + r := NewResolver() + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now(), + } + + r.markRefreshing(q) + defer r.clearRefreshing(q) + + // Multiple ServeDNS calls during the same refresh must not re-set the flag + // (CompareAndSwap from false -> true returns true only on the first call). + for range 5 { + queryA(t, r, "mgmt.example.com.") + } + + r.mutex.RLock() + inflight := r.refreshing[q] + r.mutex.RUnlock() + assert.True(t, inflight.Load()) +} + +func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) { + r := NewResolver() + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now(), + } + + queryA(t, r, "mgmt.example.com.") + + r.mutex.RLock() + _, ok := r.refreshing[q] + r.mutex.RUnlock() + assert.False(t, ok, "no refresh inflight means no loop tracking") +} + +func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2") + r.SetChainResolver(chain, 50) + + require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com"))) + + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.2", firstA(t, resp)) + assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA)) + assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA)) +} diff --git a/client/internal/dns/mgmt/mgmt_test.go b/client/internal/dns/mgmt/mgmt_test.go index 9e8a746f3..276cbba0a 100644 --- a/client/internal/dns/mgmt/mgmt_test.go +++ b/client/internal/dns/mgmt/mgmt_test.go @@ -6,6 +6,7 @@ import ( "net/url" "strings" "testing" + "time" "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -23,6 +24,60 @@ func TestResolver_NewResolver(t *testing.T) { assert.False(t, resolver.MatchSubdomains()) } +func TestResolveCacheTTL(t *testing.T) { + tests := []struct { + name string + value string + want time.Duration + }{ + {"unset falls back to default", "", defaultTTL}, + {"valid duration", "45s", 45 * time.Second}, + {"valid minutes", "2m", 2 * time.Minute}, + {"malformed falls back to default", "not-a-duration", defaultTTL}, + {"zero falls back to default", "0s", defaultTTL}, + {"negative falls back to default", "-5s", defaultTTL}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Setenv(envMgmtCacheTTL, tc.value) + got := resolveCacheTTL() + assert.Equal(t, tc.want, got, "parsed TTL should match") + }) + } +} + +func TestNewResolver_CacheTTLFromEnv(t *testing.T) { + t.Setenv(envMgmtCacheTTL, "7s") + r := NewResolver() + assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env") +} + +func TestResolver_ResponseTTL(t *testing.T) { + now := time.Now() + tests := []struct { + name string + cacheTTL time.Duration + cachedAt time.Time + wantMin uint32 + wantMax uint32 + }{ + {"fresh entry returns full TTL", 60 * time.Second, now, 59, 60}, + {"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31}, + {"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0}, + {"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + r := &Resolver{cacheTTL: tc.cacheTTL} + got := r.responseTTL(tc.cachedAt) + assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin") + assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax") + }) + } +} + func TestResolver_ExtractDomainFromURL(t *testing.T) { tests := []struct { name string diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index 1df57d1db..548b1f54f 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -90,6 +90,11 @@ func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) { // Mock implementation - no-op } +// SetFirewall mock implementation of SetFirewall from Server interface +func (m *MockServer) SetFirewall(Firewall) { + // Mock implementation - no-op +} + // BeginBatch mock implementation of BeginBatch from Server interface func (m *MockServer) BeginBatch() { // Mock implementation - no-op diff --git a/client/internal/dns/response_writer.go b/client/internal/dns/response_writer.go index edc65a5d9..287cf28b0 100644 --- a/client/internal/dns/response_writer.go +++ b/client/internal/dns/response_writer.go @@ -104,3 +104,23 @@ func (r *responseWriter) TsigTimersOnly(bool) { // After a call to Hijack(), the DNS package will not do anything with the connection. func (r *responseWriter) Hijack() { } + +// remoteAddrFromPacket extracts the source IP:port from a decoded packet for logging. +func remoteAddrFromPacket(packet gopacket.Packet) *net.UDPAddr { + var srcIP net.IP + if ipv4 := packet.Layer(layers.LayerTypeIPv4); ipv4 != nil { + srcIP = ipv4.(*layers.IPv4).SrcIP + } else if ipv6 := packet.Layer(layers.LayerTypeIPv6); ipv6 != nil { + srcIP = ipv6.(*layers.IPv6).SrcIP + } + + var srcPort int + if udp := packet.Layer(layers.LayerTypeUDP); udp != nil { + srcPort = int(udp.(*layers.UDP).SrcPort) + } + + if srcIP == nil { + return nil + } + return &net.UDPAddr{IP: srcIP, Port: srcPort} +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 3c47f4ee6..d4f54dec5 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -58,6 +58,7 @@ type Server interface { UpdateServerConfig(domains dnsconfig.ServerDomains) error PopulateManagementDomain(mgmtURL *url.URL) error SetRouteChecker(func(netip.Addr) bool) + SetFirewall(Firewall) } type nsGroupsByDomain struct { @@ -151,7 +152,7 @@ func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*Default if config.WgInterface.IsUserspaceBind() { dnsService = NewServiceViaMemory(config.WgInterface) } else { - dnsService = newServiceViaListener(config.WgInterface, addrPort) + dnsService = newServiceViaListener(config.WgInterface, addrPort, nil) } server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys) @@ -186,11 +187,16 @@ func NewDefaultServerIos( ctx context.Context, wgInterface WGIface, iosDnsManager IosDnsManager, + hostsDnsList []netip.AddrPort, statusRecorder *peer.Status, disableSys bool, ) *DefaultServer { + log.Debugf("iOS host dns address list is: %v", hostsDnsList) ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys) ds.iosDnsManager = iosDnsManager + ds.hostsDNSHolder.set(hostsDnsList) + ds.permanent = true + ds.addHostRootZone() return ds } @@ -206,6 +212,7 @@ func newDefaultServer( ctx, stop := context.WithCancel(ctx) mgmtCacheResolver := mgmt.NewResolver() + mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream) defaultServer := &DefaultServer{ ctx: ctx, @@ -374,6 +381,17 @@ func (s *DefaultServer) DnsIP() netip.Addr { return s.service.RuntimeIP() } +// SetFirewall sets the firewall used for DNS port DNAT rules. +// This must be called before Initialize when using the listener-based service, +// because the firewall is typically not available at construction time. +func (s *DefaultServer) SetFirewall(fw Firewall) { + if svc, ok := s.service.(*serviceViaListener); ok { + svc.listenerFlagLock.Lock() + svc.firewall = fw + svc.listenerFlagLock.Unlock() + } +} + // Stop stops the server func (s *DefaultServer) Stop() { s.probeMu.Lock() @@ -395,8 +413,12 @@ func (s *DefaultServer) Stop() { maps.Clear(s.extraDomains) } -func (s *DefaultServer) disableDNS() error { - defer s.service.Stop() +func (s *DefaultServer) disableDNS() (retErr error) { + defer func() { + if err := s.service.Stop(); err != nil { + retErr = errors.Join(retErr, fmt.Errorf("stop DNS service: %w", err)) + } + }() if s.isUsingNoopHostManager() { return nil diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index d3b0c250d..f77f6e898 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -476,8 +476,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { packetfilter := pfmock.NewMockPacketFilter(ctrl) packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes() - packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - packetfilter.EXPECT().RemovePacketHook(gomock.Any()) + packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() if err := wgIface.SetFilter(packetfilter); err != nil { t.Errorf("set packet filter: %v", err) @@ -1071,7 +1071,7 @@ func (m *mockHandler) ID() types.HandlerID { return types.Hand type mockService struct{} func (m *mockService) Listen() error { return nil } -func (m *mockService) Stop() {} +func (m *mockService) Stop() error { return nil } func (m *mockService) RuntimeIP() netip.Addr { return netip.MustParseAddr("127.0.0.1") } func (m *mockService) RuntimePort() int { return 53 } func (m *mockService) RegisterMux(string, dns.Handler) {} diff --git a/client/internal/dns/service.go b/client/internal/dns/service.go index 6a76c53e3..1c6ce7849 100644 --- a/client/internal/dns/service.go +++ b/client/internal/dns/service.go @@ -4,15 +4,25 @@ import ( "net/netip" "github.com/miekg/dns" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" ) const ( DefaultPort = 53 ) +// Firewall provides DNAT capabilities for DNS port redirection. +// This is used when the DNS server cannot bind port 53 directly +// and needs firewall rules to redirect traffic. +type Firewall interface { + AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error + RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error +} + type service interface { Listen() error - Stop() + Stop() error RegisterMux(domain string, handler dns.Handler) DeregisterMux(key string) RuntimePort() int diff --git a/client/internal/dns/service_listener.go b/client/internal/dns/service_listener.go index f7ddfd40f..4e09f1b7f 100644 --- a/client/internal/dns/service_listener.go +++ b/client/internal/dns/service_listener.go @@ -10,9 +10,13 @@ import ( "sync" "time" + "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/ebpf" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" ) @@ -31,25 +35,33 @@ type serviceViaListener struct { dnsMux *dns.ServeMux customAddr *netip.AddrPort server *dns.Server + tcpServer *dns.Server listenIP netip.Addr listenPort uint16 listenerIsRunning bool listenerFlagLock sync.Mutex ebpfService ebpfMgr.Manager + firewall Firewall + tcpDNATConfigured bool } -func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener { +func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort, fw Firewall) *serviceViaListener { mux := dns.NewServeMux() s := &serviceViaListener{ wgInterface: wgIface, dnsMux: mux, customAddr: customAddr, + firewall: fw, server: &dns.Server{ Net: "udp", Handler: mux, UDPSize: 65535, }, + tcpServer: &dns.Server{ + Net: "tcp", + Handler: mux, + }, } return s @@ -70,43 +82,86 @@ func (s *serviceViaListener) Listen() error { return fmt.Errorf("eval listen address: %w", err) } s.listenIP = s.listenIP.Unmap() - s.server.Addr = net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort))) - log.Debugf("starting dns on %s", s.server.Addr) - go func() { - s.setListenerStatus(true) - defer s.setListenerStatus(false) + addr := net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort))) + s.server.Addr = addr + s.tcpServer.Addr = addr - err := s.server.ListenAndServe() - if err != nil { - log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err) + log.Debugf("starting dns on %s (UDP + TCP)", addr) + s.listenerIsRunning = true + + go func() { + if err := s.server.ListenAndServe(); err != nil { + log.Errorf("failed to run DNS UDP server on port %d: %v", s.listenPort, err) + } + + s.listenerFlagLock.Lock() + unexpected := s.listenerIsRunning + s.listenerIsRunning = false + s.listenerFlagLock.Unlock() + + if unexpected { + if err := s.tcpServer.Shutdown(); err != nil { + log.Debugf("failed to shutdown DNS TCP server: %v", err) + } } }() + go func() { + if err := s.tcpServer.ListenAndServe(); err != nil { + log.Errorf("failed to run DNS TCP server on port %d: %v", s.listenPort, err) + } + }() + + // When eBPF redirects UDP port 53 to our listen port, TCP still needs + // a DNAT rule because eBPF only handles UDP. + if s.ebpfService != nil && s.firewall != nil && s.listenPort != DefaultPort { + if err := s.firewall.AddOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil { + log.Warnf("failed to add DNS TCP DNAT rule, TCP DNS on port 53 will not work: %v", err) + } else { + s.tcpDNATConfigured = true + log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", s.listenIP, DefaultPort, s.listenIP, s.listenPort) + } + } + return nil } -func (s *serviceViaListener) Stop() { +func (s *serviceViaListener) Stop() error { s.listenerFlagLock.Lock() defer s.listenerFlagLock.Unlock() if !s.listenerIsRunning { - return + return nil } + s.listenerIsRunning = false ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - err := s.server.ShutdownContext(ctx) - if err != nil { - log.Errorf("stopping dns server listener returned an error: %v", err) + var merr *multierror.Error + + if err := s.server.ShutdownContext(ctx); err != nil { + merr = multierror.Append(merr, fmt.Errorf("stop DNS UDP server: %w", err)) + } + + if err := s.tcpServer.ShutdownContext(ctx); err != nil { + merr = multierror.Append(merr, fmt.Errorf("stop DNS TCP server: %w", err)) + } + + if s.tcpDNATConfigured && s.firewall != nil { + if err := s.firewall.RemoveOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err)) + } + s.tcpDNATConfigured = false } if s.ebpfService != nil { - err = s.ebpfService.FreeDNSFwd() - if err != nil { - log.Errorf("stopping traffic forwarder returned an error: %v", err) + if err := s.ebpfService.FreeDNSFwd(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("stop traffic forwarder: %w", err)) } } + + return nberrors.FormatErrorOrNil(merr) } func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) { @@ -133,12 +188,6 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr { return s.listenIP } -func (s *serviceViaListener) setListenerStatus(running bool) { - s.listenerFlagLock.Lock() - defer s.listenerFlagLock.Unlock() - - s.listenerIsRunning = running -} // evalListenAddress figure out the listen address for the DNS server // first check the 53 port availability on WG interface or lo, if not success @@ -187,18 +236,28 @@ func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) { } func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool { - addrString := net.JoinHostPort(ip.String(), strconv.Itoa(port)) - udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString)) - probeListener, err := net.ListenUDP("udp", udpAddr) + addrPort := netip.AddrPortFrom(ip, uint16(port)) + + udpAddr := net.UDPAddrFromAddrPort(addrPort) + udpLn, err := net.ListenUDP("udp", udpAddr) if err != nil { - log.Warnf("binding dns on %s is not available, error: %s", addrString, err) + log.Warnf("binding dns UDP on %s is not available: %s", addrPort, err) return false } - - err = probeListener.Close() - if err != nil { - log.Errorf("got an error closing the probe listener, error: %s", err) + if err := udpLn.Close(); err != nil { + log.Debugf("close UDP probe listener: %s", err) } + + tcpAddr := net.TCPAddrFromAddrPort(addrPort) + tcpLn, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + log.Warnf("binding dns TCP on %s is not available: %s", addrPort, err) + return false + } + if err := tcpLn.Close(); err != nil { + log.Debugf("close TCP probe listener: %s", err) + } + return true } diff --git a/client/internal/dns/service_listener_test.go b/client/internal/dns/service_listener_test.go new file mode 100644 index 000000000..90ef71d19 --- /dev/null +++ b/client/internal/dns/service_listener_test.go @@ -0,0 +1,86 @@ +package dns + +import ( + "fmt" + "net" + "net/netip" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestServiceViaListener_TCPAndUDP(t *testing.T) { + handler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("192.0.2.1"), + }) + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + // Create a service using a custom address to avoid needing root + svc := newServiceViaListener(nil, nil, nil) + svc.dnsMux.Handle(".", handler) + + // Bind both transports up front to avoid TOCTOU races. + udpAddr := net.UDPAddrFromAddrPort(netip.AddrPortFrom(customIP, 0)) + udpConn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + t.Skip("cannot bind to 127.0.0.153, skipping") + } + port := uint16(udpConn.LocalAddr().(*net.UDPAddr).Port) + + tcpAddr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(customIP, port)) + tcpLn, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + udpConn.Close() + t.Skip("cannot bind TCP on same port, skipping") + } + + addr := fmt.Sprintf("%s:%d", customIP, port) + svc.server.PacketConn = udpConn + svc.tcpServer.Listener = tcpLn + svc.listenIP = customIP + svc.listenPort = port + + go func() { + if err := svc.server.ActivateAndServe(); err != nil { + t.Logf("udp server: %v", err) + } + }() + go func() { + if err := svc.tcpServer.ActivateAndServe(); err != nil { + t.Logf("tcp server: %v", err) + } + }() + svc.listenerIsRunning = true + + defer func() { + require.NoError(t, svc.Stop()) + }() + + q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + + // Test UDP query + udpClient := &dns.Client{Net: "udp", Timeout: 2 * time.Second} + udpResp, _, err := udpClient.Exchange(q, addr) + require.NoError(t, err, "UDP query should succeed") + require.NotNil(t, udpResp) + require.NotEmpty(t, udpResp.Answer) + assert.Contains(t, udpResp.Answer[0].String(), "192.0.2.1", "UDP response should contain expected IP") + + // Test TCP query + tcpClient := &dns.Client{Net: "tcp", Timeout: 2 * time.Second} + tcpResp, _, err := tcpClient.Exchange(q, addr) + require.NoError(t, err, "TCP query should succeed") + require.NotNil(t, tcpResp) + require.NotEmpty(t, tcpResp.Answer) + assert.Contains(t, tcpResp.Answer[0].String(), "192.0.2.1", "TCP response should contain expected IP") +} diff --git a/client/internal/dns/service_memory.go b/client/internal/dns/service_memory.go index 6ef0ab526..e8c036076 100644 --- a/client/internal/dns/service_memory.go +++ b/client/internal/dns/service_memory.go @@ -1,6 +1,7 @@ package dns import ( + "errors" "fmt" "net/netip" "sync" @@ -10,6 +11,7 @@ import ( "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface" nbnet "github.com/netbirdio/netbird/client/net" ) @@ -18,7 +20,8 @@ type ServiceViaMemory struct { dnsMux *dns.ServeMux runtimeIP netip.Addr runtimePort int - udpFilterHookID string + tcpDNS *tcpDNSServer + tcpHookSet bool listenerIsRunning bool listenerFlagLock sync.Mutex } @@ -28,14 +31,13 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory { if err != nil { log.Errorf("get last ip from network: %v", err) } - s := &ServiceViaMemory{ + + return &ServiceViaMemory{ wgInterface: wgIface, dnsMux: dns.NewServeMux(), - runtimeIP: lastIP, runtimePort: DefaultPort, } - return s } func (s *ServiceViaMemory) Listen() error { @@ -46,10 +48,8 @@ func (s *ServiceViaMemory) Listen() error { return nil } - var err error - s.udpFilterHookID, err = s.filterDNSTraffic() - if err != nil { - return fmt.Errorf("filter dns traffice: %w", err) + if err := s.filterDNSTraffic(); err != nil { + return fmt.Errorf("filter dns traffic: %w", err) } s.listenerIsRunning = true @@ -57,19 +57,29 @@ func (s *ServiceViaMemory) Listen() error { return nil } -func (s *ServiceViaMemory) Stop() { +func (s *ServiceViaMemory) Stop() error { s.listenerFlagLock.Lock() defer s.listenerFlagLock.Unlock() if !s.listenerIsRunning { - return + return nil } - if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil { - log.Errorf("unable to remove DNS packet hook: %s", err) + filter := s.wgInterface.GetFilter() + if filter != nil { + filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil) + if s.tcpHookSet { + filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil) + } + } + + if s.tcpDNS != nil { + s.tcpDNS.Stop() } s.listenerIsRunning = false + + return nil } func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) { @@ -88,10 +98,18 @@ func (s *ServiceViaMemory) RuntimeIP() netip.Addr { return s.runtimeIP } -func (s *ServiceViaMemory) filterDNSTraffic() (string, error) { +func (s *ServiceViaMemory) filterDNSTraffic() error { filter := s.wgInterface.GetFilter() if filter == nil { - return "", fmt.Errorf("can't set DNS filter, filter not initialized") + return errors.New("DNS filter not initialized") + } + + // Create TCP DNS server lazily here since the device may not exist at construction time. + if s.tcpDNS == nil { + if dev := s.wgInterface.GetDevice(); dev != nil { + // MTU only affects TCP segment sizing; DNS messages are small so this has no practical impact. + s.tcpDNS = newTCPDNSServer(s.dnsMux, dev.Device, s.runtimeIP, uint16(s.runtimePort), iface.DefaultMTU) + } } firstLayerDecoder := layers.LayerTypeIPv4 @@ -100,12 +118,16 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) { } hook := func(packetData []byte) bool { - // Decode the packet packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default) - // Get the UDP layer udpLayer := packet.Layer(layers.LayerTypeUDP) - udp := udpLayer.(*layers.UDP) + if udpLayer == nil { + return true + } + udp, ok := udpLayer.(*layers.UDP) + if !ok { + return true + } msg := new(dns.Msg) if err := msg.Unpack(udp.Payload); err != nil { @@ -113,13 +135,30 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) { return true } - writer := responseWriter{ - packet: packet, - device: s.wgInterface.GetDevice().Device, + dev := s.wgInterface.GetDevice() + if dev == nil { + return true } - go s.dnsMux.ServeDNS(&writer, msg) + + writer := &responseWriter{ + remote: remoteAddrFromPacket(packet), + packet: packet, + device: dev.Device, + } + go s.dnsMux.ServeDNS(writer, msg) return true } - return filter.AddUDPPacketHook(false, s.runtimeIP, uint16(s.runtimePort), hook), nil + filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), hook) + + if s.tcpDNS != nil { + tcpHook := func(packetData []byte) bool { + s.tcpDNS.InjectPacket(packetData) + return true + } + filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), tcpHook) + s.tcpHookSet = true + } + + return nil } diff --git a/client/internal/dns/tcpstack.go b/client/internal/dns/tcpstack.go new file mode 100644 index 000000000..88e72e767 --- /dev/null +++ b/client/internal/dns/tcpstack.go @@ -0,0 +1,444 @@ +package dns + +import ( + "errors" + "fmt" + "io" + "net" + "net/netip" + "sync" + "sync/atomic" + "time" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + dnsTCPReceiveWindow = 8192 + dnsTCPMaxInFlight = 16 + dnsTCPIdleTimeout = 30 * time.Second + dnsTCPReadTimeout = 5 * time.Second +) + +// tcpDNSServer is an on-demand TCP DNS server backed by a minimal gvisor stack. +// It is started lazily when a truncated DNS response is detected and shuts down +// after a period of inactivity to conserve resources. +type tcpDNSServer struct { + mu sync.Mutex + s *stack.Stack + ep *dnsEndpoint + mux *dns.ServeMux + tunDev tun.Device + ip netip.Addr + port uint16 + mtu uint16 + + running bool + closed bool + timerID uint64 + timer *time.Timer +} + +func newTCPDNSServer(mux *dns.ServeMux, tunDev tun.Device, ip netip.Addr, port uint16, mtu uint16) *tcpDNSServer { + return &tcpDNSServer{ + mux: mux, + tunDev: tunDev, + ip: ip, + port: port, + mtu: mtu, + } +} + +// InjectPacket ensures the stack is running and delivers a raw IP packet into +// the gvisor stack for TCP processing. Combining both operations under a single +// lock prevents a race where the idle timer could stop the stack between +// start and delivery. +func (t *tcpDNSServer) InjectPacket(payload []byte) { + t.mu.Lock() + defer t.mu.Unlock() + + if t.closed { + return + } + + if !t.running { + if err := t.startLocked(); err != nil { + log.Errorf("failed to start TCP DNS stack: %v", err) + return + } + t.running = true + log.Debugf("TCP DNS stack started on %s:%d (triggered by %s)", t.ip, t.port, srcAddrFromPacket(payload)) + } + t.resetTimerLocked() + + ep := t.ep + if ep == nil || ep.dispatcher == nil { + return + } + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(payload), + }) + // DeliverNetworkPacket takes ownership of the packet buffer; do not DecRef. + ep.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt) +} + +// Stop tears down the gvisor stack and releases resources permanently. +// After Stop, InjectPacket becomes a no-op. +func (t *tcpDNSServer) Stop() { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopLocked() + t.closed = true +} + +func (t *tcpDNSServer) startLocked() error { + // TODO: add ipv6.NewProtocol when IPv6 overlay support lands. + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + HandleLocal: false, + }) + + nicID := tcpip.NICID(1) + ep := &dnsEndpoint{ + tunDev: t.tunDev, + } + ep.mtu.Store(uint32(t.mtu)) + + if err := s.CreateNIC(nicID, ep); err != nil { + s.Close() + s.Wait() + return fmt.Errorf("create NIC: %v", err) + } + + protoAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.AddrFromSlice(t.ip.AsSlice()), + PrefixLen: 32, + }, + } + if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { + s.Close() + s.Wait() + return fmt.Errorf("add protocol address: %s", err) + } + + if err := s.SetPromiscuousMode(nicID, true); err != nil { + s.Close() + s.Wait() + return fmt.Errorf("set promiscuous mode: %s", err) + } + if err := s.SetSpoofing(nicID, true); err != nil { + s.Close() + s.Wait() + return fmt.Errorf("set spoofing: %s", err) + } + + defaultSubnet, err := tcpip.NewSubnet( + tcpip.AddrFrom4([4]byte{0, 0, 0, 0}), + tcpip.MaskFromBytes([]byte{0, 0, 0, 0}), + ) + if err != nil { + s.Close() + s.Wait() + return fmt.Errorf("create default subnet: %w", err) + } + + s.SetRouteTable([]tcpip.Route{ + {Destination: defaultSubnet, NIC: nicID}, + }) + + tcpFwd := tcp.NewForwarder(s, dnsTCPReceiveWindow, dnsTCPMaxInFlight, func(r *tcp.ForwarderRequest) { + t.handleTCPDNS(r) + }) + s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket) + + t.s = s + t.ep = ep + return nil +} + +func (t *tcpDNSServer) stopLocked() { + if !t.running { + return + } + + if t.timer != nil { + t.timer.Stop() + t.timer = nil + } + + if t.s != nil { + t.s.Close() + t.s.Wait() + t.s = nil + } + t.ep = nil + t.running = false + + log.Debugf("TCP DNS stack stopped") +} + +func (t *tcpDNSServer) resetTimerLocked() { + if t.timer != nil { + t.timer.Stop() + } + t.timerID++ + id := t.timerID + t.timer = time.AfterFunc(dnsTCPIdleTimeout, func() { + t.mu.Lock() + defer t.mu.Unlock() + + // Only stop if this timer is still the active one. + // A racing InjectPacket may have replaced it. + if t.timerID != id { + return + } + t.stopLocked() + }) +} + +func (t *tcpDNSServer) handleTCPDNS(r *tcp.ForwarderRequest) { + id := r.ID() + + wq := waiter.Queue{} + ep, epErr := r.CreateEndpoint(&wq) + if epErr != nil { + log.Debugf("TCP DNS: failed to create endpoint: %v", epErr) + r.Complete(true) + return + } + r.Complete(false) + + conn := gonet.NewTCPConn(&wq, ep) + defer func() { + if err := conn.Close(); err != nil { + log.Tracef("TCP DNS: close conn: %v", err) + } + }() + + // Reset idle timer on activity + t.mu.Lock() + t.resetTimerLocked() + t.mu.Unlock() + + localAddr := &net.TCPAddr{ + IP: id.LocalAddress.AsSlice(), + Port: int(id.LocalPort), + } + remoteAddr := &net.TCPAddr{ + IP: id.RemoteAddress.AsSlice(), + Port: int(id.RemotePort), + } + + for { + if err := conn.SetReadDeadline(time.Now().Add(dnsTCPReadTimeout)); err != nil { + log.Debugf("TCP DNS: set deadline for %s: %v", remoteAddr, err) + break + } + + msg, err := readTCPDNSMessage(conn) + if err != nil { + if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) { + log.Debugf("TCP DNS: read from %s: %v", remoteAddr, err) + } + break + } + + writer := &tcpResponseWriter{ + conn: conn, + localAddr: localAddr, + remoteAddr: remoteAddr, + } + t.mux.ServeDNS(writer, msg) + } +} + +// dnsEndpoint implements stack.LinkEndpoint for writing packets back via the tun device. +type dnsEndpoint struct { + dispatcher stack.NetworkDispatcher + tunDev tun.Device + mtu atomic.Uint32 +} + +func (e *dnsEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher } +func (e *dnsEndpoint) IsAttached() bool { return e.dispatcher != nil } +func (e *dnsEndpoint) MTU() uint32 { return e.mtu.Load() } +func (e *dnsEndpoint) Capabilities() stack.LinkEndpointCapabilities { return stack.CapabilityNone } +func (e *dnsEndpoint) MaxHeaderLength() uint16 { return 0 } +func (e *dnsEndpoint) LinkAddress() tcpip.LinkAddress { return "" } +func (e *dnsEndpoint) Wait() { /* no async work */ } +func (e *dnsEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone } +func (e *dnsEndpoint) AddHeader(*stack.PacketBuffer) { /* IP-level endpoint, no link header */ } +func (e *dnsEndpoint) ParseHeader(*stack.PacketBuffer) bool { return true } +func (e *dnsEndpoint) Close() { /* lifecycle managed by tcpDNSServer */ } +func (e *dnsEndpoint) SetLinkAddress(tcpip.LinkAddress) { /* no link address for tun */ } +func (e *dnsEndpoint) SetMTU(mtu uint32) { e.mtu.Store(mtu) } +func (e *dnsEndpoint) SetOnCloseAction(func()) { /* not needed */ } + +const tunPacketOffset = 40 + +func (e *dnsEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { + var written int + for _, pkt := range pkts.AsSlice() { + data := stack.PayloadSince(pkt.NetworkHeader()) + if data == nil { + continue + } + + raw := data.AsSlice() + buf := make([]byte, tunPacketOffset, tunPacketOffset+len(raw)) + buf = append(buf, raw...) + data.Release() + + if _, err := e.tunDev.Write([][]byte{buf}, tunPacketOffset); err != nil { + log.Tracef("TCP DNS endpoint: failed to write packet: %v", err) + continue + } + written++ + } + return written, nil +} + +// tcpResponseWriter implements dns.ResponseWriter for TCP DNS connections. +type tcpResponseWriter struct { + conn *gonet.TCPConn + localAddr net.Addr + remoteAddr net.Addr +} + +func (w *tcpResponseWriter) LocalAddr() net.Addr { + return w.localAddr +} + +func (w *tcpResponseWriter) RemoteAddr() net.Addr { + return w.remoteAddr +} + +func (w *tcpResponseWriter) WriteMsg(msg *dns.Msg) error { + data, err := msg.Pack() + if err != nil { + return fmt.Errorf("pack: %w", err) + } + + // DNS TCP: 2-byte length prefix + message + buf := make([]byte, 2+len(data)) + buf[0] = byte(len(data) >> 8) + buf[1] = byte(len(data)) + copy(buf[2:], data) + + if _, err = w.conn.Write(buf); err != nil { + return err + } + return nil +} + +func (w *tcpResponseWriter) Write(data []byte) (int, error) { + buf := make([]byte, 2+len(data)) + buf[0] = byte(len(data) >> 8) + buf[1] = byte(len(data)) + copy(buf[2:], data) + if _, err := w.conn.Write(buf); err != nil { + return 0, err + } + return len(data), nil +} + +func (w *tcpResponseWriter) Close() error { + return w.conn.Close() +} + +func (w *tcpResponseWriter) TsigStatus() error { return nil } +func (w *tcpResponseWriter) TsigTimersOnly(bool) { /* TSIG not supported */ } +func (w *tcpResponseWriter) Hijack() { /* not supported */ } + +// readTCPDNSMessage reads a single DNS message from a TCP connection (length-prefixed). +func readTCPDNSMessage(conn *gonet.TCPConn) (*dns.Msg, error) { + // DNS over TCP uses a 2-byte length prefix + lenBuf := make([]byte, 2) + if _, err := io.ReadFull(conn, lenBuf); err != nil { + return nil, fmt.Errorf("read length: %w", err) + } + + msgLen := int(lenBuf[0])<<8 | int(lenBuf[1]) + if msgLen == 0 || msgLen > 65535 { + return nil, fmt.Errorf("invalid message length: %d", msgLen) + } + + msgBuf := make([]byte, msgLen) + if _, err := io.ReadFull(conn, msgBuf); err != nil { + return nil, fmt.Errorf("read message: %w", err) + } + + msg := new(dns.Msg) + if err := msg.Unpack(msgBuf); err != nil { + return nil, fmt.Errorf("unpack: %w", err) + } + return msg, nil +} + +// srcAddrFromPacket extracts the source IP:port from a raw IP+TCP packet for logging. +// Supports both IPv4 and IPv6. +func srcAddrFromPacket(pkt []byte) netip.AddrPort { + if len(pkt) == 0 { + return netip.AddrPort{} + } + + srcIP, transportOffset := srcIPFromPacket(pkt) + if !srcIP.IsValid() || len(pkt) < transportOffset+2 { + return netip.AddrPort{} + } + + srcPort := uint16(pkt[transportOffset])<<8 | uint16(pkt[transportOffset+1]) + return netip.AddrPortFrom(srcIP.Unmap(), srcPort) +} + +func srcIPFromPacket(pkt []byte) (netip.Addr, int) { + switch header.IPVersion(pkt) { + case 4: + return srcIPv4(pkt) + case 6: + return srcIPv6(pkt) + default: + return netip.Addr{}, 0 + } +} + +func srcIPv4(pkt []byte) (netip.Addr, int) { + if len(pkt) < header.IPv4MinimumSize { + return netip.Addr{}, 0 + } + hdr := header.IPv4(pkt) + src := hdr.SourceAddress() + ip, ok := netip.AddrFromSlice(src.AsSlice()) + if !ok { + return netip.Addr{}, 0 + } + return ip, int(hdr.HeaderLength()) +} + +func srcIPv6(pkt []byte) (netip.Addr, int) { + if len(pkt) < header.IPv6MinimumSize { + return netip.Addr{}, 0 + } + hdr := header.IPv6(pkt) + src := hdr.SourceAddress() + ip, ok := netip.AddrFromSlice(src.AsSlice()) + if !ok { + return netip.Addr{}, 0 + } + return ip, header.IPv6MinimumSize +} diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 5b8135132..746b73ca7 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -41,10 +41,61 @@ const ( reactivatePeriod = 30 * time.Second probeTimeout = 2 * time.Second + + // ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP + // payload from the tunnel MTU. + ipUDPHeaderSize = 60 + 8 ) const testRecord = "com." +const ( + protoUDP = "udp" + protoTCP = "tcp" +) + +type dnsProtocolKey struct{} + +// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context. +func contextWithDNSProtocol(ctx context.Context, network string) context.Context { + return context.WithValue(ctx, dnsProtocolKey{}, network) +} + +// dnsProtocolFromContext retrieves the inbound DNS protocol from context. +func dnsProtocolFromContext(ctx context.Context) string { + if ctx == nil { + return "" + } + if v, ok := ctx.Value(dnsProtocolKey{}).(string); ok { + return v + } + return "" +} + +type upstreamProtocolKey struct{} + +// upstreamProtocolResult holds the protocol used for the upstream exchange. +// Stored as a pointer in context so the exchange function can set it. +type upstreamProtocolResult struct { + protocol string +} + +// contextWithupstreamProtocolResult stores a mutable result holder in the context. +func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) { + r := &upstreamProtocolResult{} + return context.WithValue(ctx, upstreamProtocolKey{}, r), r +} + +// setUpstreamProtocol sets the upstream protocol on the result holder in context, if present. +func setUpstreamProtocol(ctx context.Context, protocol string) { + if ctx == nil { + return + } + if r, ok := ctx.Value(upstreamProtocolKey{}).(*upstreamProtocolResult); ok && r != nil { + r.protocol = protocol + } +} + type upstreamClient interface { exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) } @@ -138,7 +189,16 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } - ok, failures := u.tryUpstreamServers(w, r, logger) + // Propagate inbound protocol so upstream exchange can use TCP directly + // when the request came in over TCP. + ctx := u.ctx + if addr := w.RemoteAddr(); addr != nil { + network := addr.Network() + ctx = contextWithDNSProtocol(ctx, network) + resutil.SetMeta(w, "protocol", network) + } + + ok, failures := u.tryUpstreamServers(ctx, w, r, logger) if len(failures) > 0 { u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger) } @@ -153,7 +213,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) { } } -func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) { +func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) { timeout := u.upstreamTimeout if len(u.upstreamServers) > 1 { maxTotal := 5 * time.Second @@ -168,7 +228,7 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M var failures []upstreamFailure for _, upstream := range u.upstreamServers { - if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil { + if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil { failures = append(failures, *failure) } else { return true, failures @@ -178,15 +238,17 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M } // queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream. -func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure { +func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure { var rm *dns.Msg var t time.Duration var err error var startTime time.Time + var upstreamProto *upstreamProtocolResult func() { - ctx, cancel := context.WithTimeout(u.ctx, timeout) + ctx, cancel := context.WithTimeout(parentCtx, timeout) defer cancel() + ctx, upstreamProto = contextWithupstreamProtocolResult(ctx) startTime = time.Now() rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r) }() @@ -203,7 +265,7 @@ func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, u return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]} } - u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger) + u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger) return nil } @@ -220,10 +282,13 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add return &upstreamFailure{upstream: upstream, reason: reason} } -func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool { +func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool { u.successCount.Add(1) resutil.SetMeta(w, "upstream", upstream.String()) + if upstreamProto != nil && upstreamProto.protocol != "" { + resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol) + } // Clear Zero bit from external responses to prevent upstream servers from // manipulating our internal fallthrough signaling mechanism @@ -428,13 +493,42 @@ func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalC return err } +// clientUDPMaxSize returns the maximum UDP response size the client accepts. +func clientUDPMaxSize(r *dns.Msg) int { + if opt := r.IsEdns0(); opt != nil { + return int(opt.UDPSize()) + } + return dns.MinMsgSize +} + // ExchangeWithFallback exchanges a DNS message with the upstream server. // It first tries to use UDP, and if it is truncated, it falls back to TCP. +// If the inbound request came over TCP (via context), it skips the UDP attempt. // If the passed context is nil, this will use Exchange instead of ExchangeContext. func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) { - // MTU - ip + udp headers - // Note: this could be sent out on an interface that is not ours, but higher MTU settings could break truncation handling. - client.UDPSize = uint16(currentMTU - (60 + 8)) + // If the request came in over TCP, go straight to TCP upstream. + if dnsProtocolFromContext(ctx) == protoTCP { + tcpClient := *client + tcpClient.Net = protoTCP + rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream) + if err != nil { + return nil, t, fmt.Errorf("with tcp: %w", err) + } + setUpstreamProtocol(ctx, protoTCP) + return rm, t, nil + } + + clientMaxSize := clientUDPMaxSize(r) + + // Cap EDNS0 to our tunnel MTU so the upstream doesn't send a + // response larger than our read buffer. + // Note: the query could be sent out on an interface that is not ours, + // but higher MTU settings could break truncation handling. + maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize) + client.UDPSize = maxUDPPayload + if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload { + opt.SetUDPSize(maxUDPPayload) + } var ( rm *dns.Msg @@ -453,25 +547,32 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u } if rm == nil || !rm.MsgHdr.Truncated { + setUpstreamProtocol(ctx, protoUDP) return rm, t, nil } - log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.", - r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + // TODO: if the upstream's truncated UDP response already contains more + // data than the client's buffer, we could truncate locally and skip + // the TCP retry. - client.Net = "tcp" + tcpClient := *client + tcpClient.Net = protoTCP if ctx == nil { - rm, t, err = client.Exchange(r, upstream) + rm, t, err = tcpClient.Exchange(r, upstream) } else { - rm, t, err = client.ExchangeContext(ctx, r, upstream) + rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream) } if err != nil { return nil, t, fmt.Errorf("with tcp: %w", err) } - // TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP + setUpstreamProtocol(ctx, protoTCP) + + if rm.Len() > clientMaxSize { + rm.Truncate(clientMaxSize) + } return rm, t, nil } @@ -479,18 +580,46 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u // ExchangeWithNetstack performs a DNS exchange using netstack for dialing. // This is needed when netstack is enabled to reach peer IPs through the tunnel. func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) { - reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp") + // If request came in over TCP, go straight to TCP upstream + if dnsProtocolFromContext(ctx) == protoTCP { + rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP) + if err != nil { + return nil, err + } + setUpstreamProtocol(ctx, protoTCP) + return rm, nil + } + + clientMaxSize := clientUDPMaxSize(r) + + // Cap EDNS0 to our tunnel MTU so the upstream doesn't send a + // response larger than what we can read over UDP. + maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize) + if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload { + opt.SetUDPSize(maxUDPPayload) + } + + reply, err := netstackExchange(ctx, nsNet, r, upstream, protoUDP) if err != nil { return nil, err } - // If response is truncated, retry with TCP if reply != nil && reply.MsgHdr.Truncated { - log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP", - r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) - return netstackExchange(ctx, nsNet, r, upstream, "tcp") + rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP) + if err != nil { + return nil, err + } + + setUpstreamProtocol(ctx, protoTCP) + if rm.Len() > clientMaxSize { + rm.Truncate(clientMaxSize) + } + + return rm, nil } + setUpstreamProtocol(ctx, protoUDP) + return reply, nil } @@ -511,7 +640,7 @@ func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upst } } - dnsConn := &dns.Conn{Conn: conn} + dnsConn := &dns.Conn{Conn: conn, UDPSize: uint16(currentMTU - ipUDPHeaderSize)} if err := dnsConn.WriteMsg(r); err != nil { return nil, fmt.Errorf("write %s message: %w", network, err) diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index d7cff377b..ee1ca42fe 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -51,7 +51,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin upstreamExchangeClient := &dns.Client{ Timeout: ClientTimeout, } - return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) + return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream) } // exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN @@ -76,7 +76,7 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri Timeout: timeout, } - return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) + return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream) } func (u *upstreamResolver) isLocalResolver(upstream string) bool { diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index ab164c30b..1797fdad8 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -475,3 +475,298 @@ func TestFormatFailures(t *testing.T) { }) } } + +func TestDNSProtocolContext(t *testing.T) { + t.Run("roundtrip udp", func(t *testing.T) { + ctx := contextWithDNSProtocol(context.Background(), protoUDP) + assert.Equal(t, protoUDP, dnsProtocolFromContext(ctx)) + }) + + t.Run("roundtrip tcp", func(t *testing.T) { + ctx := contextWithDNSProtocol(context.Background(), protoTCP) + assert.Equal(t, protoTCP, dnsProtocolFromContext(ctx)) + }) + + t.Run("missing returns empty", func(t *testing.T) { + assert.Equal(t, "", dnsProtocolFromContext(context.Background())) + }) +} + +func TestExchangeWithFallback_TCPContext(t *testing.T) { + // Start a local DNS server that responds on TCP only + tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1"), + }) + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + tcpServer := &dns.Server{ + Addr: "127.0.0.1:0", + Net: "tcp", + Handler: tcpHandler, + } + + tcpLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + tcpServer.Listener = tcpLn + + go func() { + if err := tcpServer.ActivateAndServe(); err != nil { + t.Logf("tcp server: %v", err) + } + }() + defer func() { + _ = tcpServer.Shutdown() + }() + + upstream := tcpLn.Addr().String() + + // With TCP context, should connect directly via TCP without trying UDP + ctx := contextWithDNSProtocol(context.Background(), protoTCP) + client := &dns.Client{Timeout: 2 * time.Second} + r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + + rm, _, err := ExchangeWithFallback(ctx, client, r, upstream) + require.NoError(t, err) + require.NotNil(t, rm) + require.NotEmpty(t, rm.Answer) + assert.Contains(t, rm.Answer[0].String(), "10.0.0.1") +} + +func TestExchangeWithFallback_UDPFallbackToTCP(t *testing.T) { + // UDP handler returns a truncated response to trigger TCP retry. + udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + m.Truncated = true + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + // TCP handler returns the full answer. + tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.3"), + }) + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + udpPC, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + addr := udpPC.LocalAddr().String() + + udpServer := &dns.Server{ + PacketConn: udpPC, + Net: "udp", + Handler: udpHandler, + } + + tcpLn, err := net.Listen("tcp", addr) + require.NoError(t, err) + + tcpServer := &dns.Server{ + Listener: tcpLn, + Net: "tcp", + Handler: tcpHandler, + } + + go func() { + if err := udpServer.ActivateAndServe(); err != nil { + t.Logf("udp server: %v", err) + } + }() + go func() { + if err := tcpServer.ActivateAndServe(); err != nil { + t.Logf("tcp server: %v", err) + } + }() + defer func() { + _ = udpServer.Shutdown() + _ = tcpServer.Shutdown() + }() + + ctx := context.Background() + client := &dns.Client{Timeout: 2 * time.Second} + r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + + rm, _, err := ExchangeWithFallback(ctx, client, r, addr) + require.NoError(t, err, "should fall back to TCP after truncated UDP response") + require.NotNil(t, rm) + require.NotEmpty(t, rm.Answer, "TCP response should contain the full answer") + assert.Contains(t, rm.Answer[0].String(), "10.0.0.3") + assert.False(t, rm.Truncated, "TCP response should not be truncated") +} + +func TestExchangeWithFallback_TCPContextSkipsUDP(t *testing.T) { + // Start only a TCP server (no UDP). With TCP context it should succeed. + tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.2"), + }) + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + tcpLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + tcpServer := &dns.Server{ + Listener: tcpLn, + Net: "tcp", + Handler: tcpHandler, + } + + go func() { + if err := tcpServer.ActivateAndServe(); err != nil { + t.Logf("tcp server: %v", err) + } + }() + defer func() { + _ = tcpServer.Shutdown() + }() + + upstream := tcpLn.Addr().String() + + // TCP context: should skip UDP entirely and go directly to TCP + ctx := contextWithDNSProtocol(context.Background(), protoTCP) + client := &dns.Client{Timeout: 2 * time.Second} + r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + + rm, _, err := ExchangeWithFallback(ctx, client, r, upstream) + require.NoError(t, err) + require.NotNil(t, rm) + require.NotEmpty(t, rm.Answer) + assert.Contains(t, rm.Answer[0].String(), "10.0.0.2") + + // Without TCP context, trying to reach a TCP-only server via UDP should fail + ctx2 := context.Background() + client2 := &dns.Client{Timeout: 500 * time.Millisecond} + _, _, err = ExchangeWithFallback(ctx2, client2, r, upstream) + assert.Error(t, err, "should fail when no UDP server and no TCP context") +} + +func TestExchangeWithFallback_EDNS0Capped(t *testing.T) { + // Verify that a client EDNS0 larger than our MTU-derived limit gets + // capped in the outgoing request so the upstream doesn't send a + // response larger than our read buffer. + var receivedUDPSize uint16 + udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + if opt := r.IsEdns0(); opt != nil { + receivedUDPSize = opt.UDPSize() + } + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1"), + }) + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + udpPC, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + addr := udpPC.LocalAddr().String() + + udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler} + go func() { _ = udpServer.ActivateAndServe() }() + t.Cleanup(func() { _ = udpServer.Shutdown() }) + + ctx := context.Background() + client := &dns.Client{Timeout: 2 * time.Second} + r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + r.SetEdns0(4096, false) + + rm, _, err := ExchangeWithFallback(ctx, client, r, addr) + require.NoError(t, err) + require.NotNil(t, rm) + + expectedMax := uint16(currentMTU - ipUDPHeaderSize) + assert.Equal(t, expectedMax, receivedUDPSize, + "upstream should see capped EDNS0, not the client's 4096") +} + +func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) { + // When the client advertises a large EDNS0 (4096) and the upstream + // truncates, the TCP response should NOT be truncated since the full + // answer fits within the client's original buffer. + udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + m.Truncated = true + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + // Add enough records to exceed MTU but fit within 4096 + for i := range 20 { + m.Answer = append(m.Answer, &dns.TXT{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 60}, + Txt: []string{fmt.Sprintf("record-%d-padding-data-to-make-it-longer", i)}, + }) + } + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + udpPC, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + addr := udpPC.LocalAddr().String() + + udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler} + tcpLn, err := net.Listen("tcp", addr) + require.NoError(t, err) + tcpServer := &dns.Server{Listener: tcpLn, Net: "tcp", Handler: tcpHandler} + + go func() { _ = udpServer.ActivateAndServe() }() + go func() { _ = tcpServer.ActivateAndServe() }() + t.Cleanup(func() { + _ = udpServer.Shutdown() + _ = tcpServer.Shutdown() + }) + + ctx := context.Background() + client := &dns.Client{Timeout: 2 * time.Second} + + // Client with large buffer: should get all records without truncation + r := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT) + r.SetEdns0(4096, false) + + rm, _, err := ExchangeWithFallback(ctx, client, r, addr) + require.NoError(t, err) + require.NotNil(t, rm) + assert.Len(t, rm.Answer, 20, "large EDNS0 client should get all records") + assert.False(t, rm.Truncated, "response should not be truncated for large buffer client") + + // Client with small buffer: should get truncated response + r2 := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT) + r2.SetEdns0(512, false) + + rm2, _, err := ExchangeWithFallback(ctx, &dns.Client{Timeout: 2 * time.Second}, r2, addr) + require.NoError(t, err) + require.NotNil(t, rm2) + assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records") + assert.True(t, rm2.Truncated, "response should be truncated for small buffer client") +} diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 5c7cb31fc..2e8ef84ab 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -237,8 +237,8 @@ func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, re return } - logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s", - qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime)) + logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB took=%s", + qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), resp.Len(), time.Since(startTime)) } // udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation. @@ -263,20 +263,28 @@ func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error { func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { startTime := time.Now() - logger := log.WithFields(log.Fields{ + fields := log.Fields{ "request_id": resutil.GenerateRequestID(), "dns_id": fmt.Sprintf("%04x", query.Id), - }) + } + if addr := w.RemoteAddr(); addr != nil { + fields["client"] = addr.String() + } + logger := log.WithFields(fields) f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime) } func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) { startTime := time.Now() - logger := log.WithFields(log.Fields{ + fields := log.Fields{ "request_id": resutil.GenerateRequestID(), "dns_id": fmt.Sprintf("%04x", query.Id), - }) + } + if addr := w.RemoteAddr(); addr != nil { + fields["client"] = addr.String() + } + logger := log.WithFields(fields) f.handleDNSQuery(logger, w, query, startTime) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 7b100bd0c..7f19e2d28 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -26,7 +26,9 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/firewall" + "github.com/netbirdio/netbird/client/firewall/firewalld" firewallManager "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" @@ -46,6 +48,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer/guard" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/peerstore" + "github.com/netbirdio/netbird/client/internal/portforward" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/rosenpass" @@ -66,6 +69,7 @@ import ( signal "github.com/netbirdio/netbird/shared/signal/client" sProto "github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/util/capture" ) // PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer. @@ -139,6 +143,7 @@ type EngineConfig struct { ProfileConfig *profilemanager.Config LogPath string + TempDir string } // EngineServices holds the external service dependencies required by the Engine. @@ -210,9 +215,12 @@ type Engine struct { // checks are the client-applied posture checks that need to be evaluated on the client checks []*mgmProto.Checks - relayManager *relayClient.Manager - stateManager *statemanager.Manager - srWatcher *guard.SRWatcher + relayManager *relayClient.Manager + stateManager *statemanager.Manager + portForwardManager *portforward.Manager + srWatcher *guard.SRWatcher + + afpacketCapture *capture.AFPacketCapture // Sync response persistence (protected by syncRespMux) syncRespMux sync.RWMutex @@ -259,26 +267,27 @@ func NewEngine( mobileDep MobileDependency, ) *Engine { engine := &Engine{ - clientCtx: clientCtx, - clientCancel: clientCancel, - signal: services.SignalClient, - signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey), - mgmClient: services.MgmClient, - relayManager: services.RelayManager, - peerStore: peerstore.NewConnStore(), - syncMsgMux: &sync.Mutex{}, - config: config, - mobileDep: mobileDep, - STUNs: []*stun.URI{}, - TURNs: []*stun.URI{}, - networkSerial: 0, - statusRecorder: services.StatusRecorder, - stateManager: services.StateManager, - checks: services.Checks, - probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), - jobExecutor: jobexec.NewExecutor(), - clientMetrics: services.ClientMetrics, - updateManager: services.UpdateManager, + clientCtx: clientCtx, + clientCancel: clientCancel, + signal: services.SignalClient, + signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey), + mgmClient: services.MgmClient, + relayManager: services.RelayManager, + peerStore: peerstore.NewConnStore(), + syncMsgMux: &sync.Mutex{}, + config: config, + mobileDep: mobileDep, + STUNs: []*stun.URI{}, + TURNs: []*stun.URI{}, + networkSerial: 0, + statusRecorder: services.StatusRecorder, + stateManager: services.StateManager, + portForwardManager: portforward.NewManager(), + checks: services.Checks, + probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), + jobExecutor: jobexec.NewExecutor(), + clientMetrics: services.ClientMetrics, + updateManager: services.UpdateManager, } log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String()) @@ -500,7 +509,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool { - for _, routes := range e.routeManager.GetClientRoutes() { + for _, routes := range e.routeManager.GetSelectedClientRoutes() { for _, r := range routes { if r.Network.Contains(ip) { return true @@ -521,6 +530,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) return err } + // Inject firewall into DNS server now that it's available. + // The DNS server is created before the firewall because the route manager + // depends on the DNS server, and the firewall depends on the wg interface. + e.dnsServer.SetFirewall(e.firewall) + e.udpMux, err = e.wgInterface.Up() if err != nil { log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error()) @@ -532,6 +546,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) // conntrack entries from being created before the rules are in place e.setupWGProxyNoTrack() + // Start after interface is up since port may have been resolved from 0 or changed if occupied + e.shutdownWg.Add(1) + go func() { + defer e.shutdownWg.Done() + e.portForwardManager.Start(e.ctx, uint16(e.config.WgPort)) + }() + // Set the WireGuard interface for rosenpass after interface is up if e.rpManager != nil { e.rpManager.SetInterface(e.wgInterface) @@ -554,7 +575,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) e.connMgr.Start(e.ctx) e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg) - e.srWatcher.Start() + e.srWatcher.Start(peer.IsForceRelayed()) e.receiveSignalEvents() e.receiveManagementEvents() @@ -588,6 +609,8 @@ func (e *Engine) createFirewall() error { return nil } + firewalld.SetParentContext(e.ctx) + var err error e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU) if err != nil { @@ -925,7 +948,12 @@ func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error { return fmt.Errorf("update relay token: %w", err) } - e.relayManager.UpdateServerURLs(update.Urls) + urls := update.Urls + if override, ok := peer.OverrideRelayURLs(); ok { + log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override) + urls = override + } + e.relayManager.UpdateServerURLs(urls) // Just in case the agent started with an MGM server where the relay was disabled but was later enabled. // We can ignore all errors because the guard will manage the reconnection retries. @@ -1080,6 +1108,7 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR StatusRecorder: e.statusRecorder, SyncResponse: syncResponse, LogPath: e.config.LogPath, + TempDir: e.config.TempDir, ClientMetrics: e.clientMetrics, RefreshStatus: func() { e.RunHealthProbes(true) @@ -1535,12 +1564,13 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV } serviceDependencies := peer.ServiceDependencies{ - StatusRecorder: e.statusRecorder, - Signaler: e.signaler, - IFaceDiscover: e.mobileDep.IFaceDiscover, - RelayManager: e.relayManager, - SrWatcher: e.srWatcher, - MetricsRecorder: e.clientMetrics, + StatusRecorder: e.statusRecorder, + Signaler: e.signaler, + IFaceDiscover: e.mobileDep.IFaceDiscover, + RelayManager: e.relayManager, + SrWatcher: e.srWatcher, + PortForwardManager: e.portForwardManager, + MetricsRecorder: e.clientMetrics, } peerConn, err := peer.NewConn(config, serviceDependencies) if err != nil { @@ -1677,6 +1707,11 @@ func (e *Engine) parseNATExternalIPMappings() []string { } func (e *Engine) close() { + if e.afpacketCapture != nil { + e.afpacketCapture.Stop() + e.afpacketCapture = nil + } + log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) if e.wgInterface != nil { @@ -1697,6 +1732,12 @@ func (e *Engine) close() { if e.rpManager != nil { _ = e.rpManager.Close() } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := e.portForwardManager.GracefullyStop(ctx); err != nil { + log.Warnf("failed to gracefully stop port forwarding manager: %s", err) + } } func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) { @@ -1800,7 +1841,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) { return dnsServer, nil case "ios": - dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS) + dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS) return dnsServer, nil default: @@ -1837,6 +1878,11 @@ func (e *Engine) GetExposeManager() *expose.Manager { return e.exposeManager } +// IsBlockInbound returns whether inbound connections are blocked. +func (e *Engine) IsBlockInbound() bool { + return e.config.BlockInbound +} + // GetClientMetrics returns the client metrics func (e *Engine) GetClientMetrics() *metrics.ClientMetrics { return e.clientMetrics @@ -2131,6 +2177,62 @@ func (e *Engine) Address() (netip.Addr, error) { return e.wgInterface.Address().IP, nil } +// SetCapture sets or clears packet capture on the WireGuard device. +// On userspace WireGuard, it taps the FilteredDevice directly. +// On kernel WireGuard (Linux), it falls back to AF_PACKET raw socket capture. +// Pass nil to disable capture. +func (e *Engine) SetCapture(pc device.PacketCapture) error { + e.syncMsgMux.Lock() + defer e.syncMsgMux.Unlock() + + intf := e.wgInterface + if intf == nil { + return errors.New("wireguard interface not initialized") + } + + if e.afpacketCapture != nil { + e.afpacketCapture.Stop() + e.afpacketCapture = nil + } + + dev := intf.GetDevice() + if dev != nil { + dev.SetCapture(pc) + e.setForwarderCapture(pc) + return nil + } + + // Kernel mode: no FilteredDevice. Use AF_PACKET on Linux. + if pc == nil { + return nil + } + sess, ok := pc.(*capture.Session) + if !ok { + return errors.New("filtered device not available and AF_PACKET requires *capture.Session") + } + + afc := capture.NewAFPacketCapture(intf.Name(), sess) + if err := afc.Start(); err != nil { + return fmt.Errorf("start AF_PACKET capture on %s: %w", intf.Name(), err) + } + e.afpacketCapture = afc + return nil +} + +// setForwarderCapture propagates capture to the USP filter's forwarder endpoint. +// This captures outbound response packets that bypass the FilteredDevice in netstack mode. +func (e *Engine) setForwarderCapture(pc device.PacketCapture) { + if e.firewall == nil { + return + } + type forwarderCapturer interface { + SetPacketCapture(pc forwarder.PacketCapture) + } + if fc, ok := e.firewall.(forwarderCapturer); ok { + fc.SetPacketCapture(pc) + } +} + func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) { if e.firewall == nil { log.Warn("firewall is disabled, not updating forwarding rules") @@ -2352,6 +2454,8 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) { } } + relayIP := decodeRelayIP(msg.GetBody().GetRelayServerIP()) + offerAnswer := peer.OfferAnswer{ IceCredentials: peer.IceCredentials{ UFrag: remoteCred.UFrag, @@ -2362,7 +2466,23 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) { RosenpassPubKey: rosenpassPubKey, RosenpassAddr: rosenpassAddr, RelaySrvAddress: msg.GetBody().GetRelayServerAddress(), + RelaySrvIP: relayIP, SessionID: sessionID, } return &offerAnswer, nil } + +// decodeRelayIP decodes the proto relayServerIP bytes (4 or 16) into a +// netip.Addr. Returns the zero value for empty input and logs a warning +// for malformed payloads. +func decodeRelayIP(b []byte) netip.Addr { + if len(b) == 0 { + return netip.Addr{} + } + ip, ok := netip.AddrFromSlice(b) + if !ok { + log.Warnf("invalid relayServerIP in signal message (%d bytes), ignoring", len(b)) + return netip.Addr{} + } + return ip.Unmap() +} diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 77fe9049b..f4c5be70a 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -55,6 +55,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" @@ -828,7 +829,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, EngineServices{ + }, EngineServices{ SignalClient: &signal.MockClient{}, MgmClient: &mgmt.MockClient{}, RelayManager: relayMgr, @@ -1035,7 +1036,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, EngineServices{ + }, EngineServices{ SignalClient: &signal.MockClient{}, MgmClient: &mgmt.MockClient{}, RelayManager: relayMgr, @@ -1538,13 +1539,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin return nil, err } - publicKey, err := mgmtClient.GetServerPublicKey() - if err != nil { - return nil, err - } - info := system.GetInfo(ctx) - resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil, nil) + resp, err := mgmtClient.Register(setupKey, "", info, nil, nil) if err != nil { return nil, err } @@ -1566,7 +1562,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin } relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) -e, err := NewEngine(ctx, cancel, conf, EngineServices{ + e, err := NewEngine(ctx, cancel, conf, EngineServices{ SignalClient: signalClient, MgmClient: mgmtClient, RelayManager: relayMgr, @@ -1639,7 +1635,12 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri peersManager := peers.NewManager(store, permissionsManager) jobManager := job.NewJobManager(nil, store, peersManager) - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore) + cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, "", err + } + + ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) @@ -1661,7 +1662,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config) - accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) if err != nil { return nil, "", err } @@ -1670,7 +1671,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri if err != nil { return nil, "", err } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil) if err != nil { return nil, "", err } diff --git a/client/internal/expose/manager.go b/client/internal/expose/manager.go index c59a1a7bd..076f92043 100644 --- a/client/internal/expose/manager.go +++ b/client/internal/expose/manager.go @@ -4,11 +4,14 @@ import ( "context" "time" - mgm "github.com/netbirdio/netbird/shared/management/client" log "github.com/sirupsen/logrus" + + mgm "github.com/netbirdio/netbird/shared/management/client" ) -const renewTimeout = 10 * time.Second +const ( + renewTimeout = 10 * time.Second +) // Response holds the response from exposing a service. type Response struct { @@ -18,11 +21,13 @@ type Response struct { PortAutoAssigned bool } +// Request holds the parameters for exposing a local service via the management server. +// It is part of the embed API surface and exposed via a type alias. type Request struct { NamePrefix string Domain string Port uint16 - Protocol int + Protocol ProtocolType Pin string Password string UserGroups []string @@ -59,6 +64,8 @@ func (m *Manager) Expose(ctx context.Context, req Request) (*Response, error) { return fromClientExposeResponse(resp), nil } +// KeepAlive periodically renews the expose session for the given domain until the context is canceled or an error occurs. +// It is part of the embed API surface and exposed via a type alias. func (m *Manager) KeepAlive(ctx context.Context, domain string) error { ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() diff --git a/client/internal/expose/manager_test.go b/client/internal/expose/manager_test.go index 87d43cdb0..7d76c9838 100644 --- a/client/internal/expose/manager_test.go +++ b/client/internal/expose/manager_test.go @@ -86,7 +86,7 @@ func TestNewRequest(t *testing.T) { exposeReq := NewRequest(req) assert.Equal(t, uint16(8080), exposeReq.Port, "port should match") - assert.Equal(t, int(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match") + assert.Equal(t, ProtocolType(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match") assert.Equal(t, "123456", exposeReq.Pin, "pin should match") assert.Equal(t, "secret", exposeReq.Password, "password should match") assert.Equal(t, []string{"group1", "group2"}, exposeReq.UserGroups, "user groups should match") diff --git a/client/internal/expose/protocol.go b/client/internal/expose/protocol.go new file mode 100644 index 000000000..d5026d51e --- /dev/null +++ b/client/internal/expose/protocol.go @@ -0,0 +1,40 @@ +package expose + +import ( + "fmt" + "strings" +) + +// ProtocolType represents the protocol used for exposing a service. +type ProtocolType int + +const ( + // ProtocolHTTP exposes the service as HTTP. + ProtocolHTTP ProtocolType = 0 + // ProtocolHTTPS exposes the service as HTTPS. + ProtocolHTTPS ProtocolType = 1 + // ProtocolTCP exposes the service as TCP. + ProtocolTCP ProtocolType = 2 + // ProtocolUDP exposes the service as UDP. + ProtocolUDP ProtocolType = 3 + // ProtocolTLS exposes the service as TLS. + ProtocolTLS ProtocolType = 4 +) + +// ParseProtocolType parses a protocol string into a ProtocolType. +func ParseProtocolType(s string) (ProtocolType, error) { + switch strings.ToLower(s) { + case "http": + return ProtocolHTTP, nil + case "https": + return ProtocolHTTPS, nil + case "tcp": + return ProtocolTCP, nil + case "udp": + return ProtocolUDP, nil + case "tls": + return ProtocolTLS, nil + default: + return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", s) + } +} diff --git a/client/internal/expose/request.go b/client/internal/expose/request.go index bff4f2ce7..ec75bb276 100644 --- a/client/internal/expose/request.go +++ b/client/internal/expose/request.go @@ -9,7 +9,7 @@ import ( func NewRequest(req *daemonProto.ExposeServiceRequest) *Request { return &Request{ Port: uint16(req.Port), - Protocol: int(req.Protocol), + Protocol: ProtocolType(req.Protocol), Pin: req.Pin, Password: req.Password, UserGroups: req.UserGroups, @@ -24,7 +24,7 @@ func toClientExposeRequest(req Request) mgm.ExposeRequest { NamePrefix: req.NamePrefix, Domain: req.Domain, Port: req.Port, - Protocol: req.Protocol, + Protocol: int(req.Protocol), Pin: req.Pin, Password: req.Password, UserGroups: req.UserGroups, diff --git a/client/internal/lazyconn/activity/listener_bind_test.go b/client/internal/lazyconn/activity/listener_bind_test.go index f86dd3877..1baaae6be 100644 --- a/client/internal/lazyconn/activity/listener_bind_test.go +++ b/client/internal/lazyconn/activity/listener_bind_test.go @@ -3,7 +3,6 @@ package activity import ( "net" "net/netip" - "runtime" "testing" "time" @@ -18,10 +17,6 @@ import ( peerid "github.com/netbirdio/netbird/client/internal/peer/id" ) -func isBindListenerPlatform() bool { - return runtime.GOOS == "windows" || runtime.GOOS == "js" -} - // mockEndpointManager implements device.EndpointManager for testing type mockEndpointManager struct { endpoints map[netip.Addr]net.Conn @@ -181,10 +176,6 @@ func TestBindListener_Close(t *testing.T) { } func TestManager_BindMode(t *testing.T) { - if !isBindListenerPlatform() { - t.Skip("BindListener only used on Windows/JS platforms") - } - mockEndpointMgr := newMockEndpointManager() mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} @@ -226,10 +217,6 @@ func TestManager_BindMode(t *testing.T) { } func TestManager_BindMode_MultiplePeers(t *testing.T) { - if !isBindListenerPlatform() { - t.Skip("BindListener only used on Windows/JS platforms") - } - mockEndpointMgr := newMockEndpointManager() mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} diff --git a/client/internal/lazyconn/activity/manager.go b/client/internal/lazyconn/activity/manager.go index 1c11378c8..cccc0669f 100644 --- a/client/internal/lazyconn/activity/manager.go +++ b/client/internal/lazyconn/activity/manager.go @@ -4,14 +4,12 @@ import ( "errors" "net" "net/netip" - "runtime" "sync" "time" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/lazyconn" peerid "github.com/netbirdio/netbird/client/internal/peer/id" @@ -75,16 +73,6 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error) return NewUDPListener(m.wgIface, peerCfg) } - // BindListener is used on Windows, JS, and netstack platforms: - // - JS: Cannot listen to UDP sockets - // - Windows: IP_UNICAST_IF socket option forces packets out the interface the default - // gateway points to, preventing them from reaching the loopback interface. - // - Netstack: Allows multiple instances on the same host without port conflicts. - // BindListener bypasses these issues by passing data directly through the bind. - if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() { - return NewUDPListener(m.wgIface, peerCfg) - } - provider, ok := m.wgIface.(bindProvider) if !ok { return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider") diff --git a/client/internal/lazyconn/manager/manager.go b/client/internal/lazyconn/manager/manager.go index b6b3c6091..fc47bda39 100644 --- a/client/internal/lazyconn/manager/manager.go +++ b/client/internal/lazyconn/manager/manager.go @@ -6,7 +6,6 @@ import ( "time" log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/internal/lazyconn" "github.com/netbirdio/netbird/client/internal/lazyconn/activity" @@ -91,8 +90,8 @@ func (m *Manager) UpdateRouteHAMap(haMap route.HAMap) { m.routesMu.Lock() defer m.routesMu.Unlock() - maps.Clear(m.peerToHAGroups) - maps.Clear(m.haGroupToPeers) + clear(m.peerToHAGroups) + clear(m.haGroupToPeers) for haUniqueID, routes := range haMap { var peers []string diff --git a/client/internal/mobile_dependency.go b/client/internal/mobile_dependency.go index 7c95e2b99..310d61a25 100644 --- a/client/internal/mobile_dependency.go +++ b/client/internal/mobile_dependency.go @@ -22,4 +22,8 @@ type MobileDependency struct { DnsManager dns.IosDnsManager FileDescriptor int32 StateFilePath string + + // TempDir is a writable directory for temporary files (e.g., debug bundle zip). + // On Android, this should be set to the app's cache directory. + TempDir string } diff --git a/client/internal/netflow/conntrack/conntrack.go b/client/internal/netflow/conntrack/conntrack.go index a4ffa3a25..2420b1fdf 100644 --- a/client/internal/netflow/conntrack/conntrack.go +++ b/client/internal/netflow/conntrack/conntrack.go @@ -7,7 +7,9 @@ import ( "fmt" "net/netip" "sync" + "time" + "github.com/cenkalti/backoff/v4" "github.com/google/uuid" log "github.com/sirupsen/logrus" nfct "github.com/ti-mo/conntrack" @@ -17,31 +19,64 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) -const defaultChannelSize = 100 +const ( + defaultChannelSize = 100 + reconnectInitInterval = 5 * time.Second + reconnectMaxInterval = 5 * time.Minute + reconnectRandomization = 0.5 +) + +// listener abstracts a netlink conntrack connection for testability. +type listener interface { + Listen(evChan chan<- nfct.Event, numWorkers uint8, groups []netfilter.NetlinkGroup) (chan error, error) + Close() error +} // ConnTrack manages kernel-based conntrack events type ConnTrack struct { flowLogger nftypes.FlowLogger iface nftypes.IFaceMapper - conn *nfct.Conn + conn listener mux sync.Mutex + dial func() (listener, error) instanceID uuid.UUID started bool done chan struct{} sysctlModified bool } +// DialFunc is a constructor for netlink conntrack connections. +type DialFunc func() (listener, error) + +// Option configures a ConnTrack instance. +type Option func(*ConnTrack) + +// WithDialer overrides the default netlink dialer, primarily for testing. +func WithDialer(dial DialFunc) Option { + return func(c *ConnTrack) { + c.dial = dial + } +} + +func defaultDial() (listener, error) { + return nfct.Dial(nil) +} + // New creates a new connection tracker that interfaces with the kernel's conntrack system -func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack { - return &ConnTrack{ +func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper, opts ...Option) *ConnTrack { + ct := &ConnTrack{ flowLogger: flowLogger, iface: iface, instanceID: uuid.New(), - started: false, + dial: defaultDial, done: make(chan struct{}, 1), } + for _, opt := range opts { + opt(ct) + } + return ct } // Start begins tracking connections by listening for conntrack events. This method is idempotent. @@ -59,8 +94,9 @@ func (c *ConnTrack) Start(enableCounters bool) error { c.EnableAccounting() } - conn, err := nfct.Dial(nil) + conn, err := c.dial() if err != nil { + c.RestoreAccounting() return fmt.Errorf("dial conntrack: %w", err) } c.conn = conn @@ -76,9 +112,16 @@ func (c *ConnTrack) Start(enableCounters bool) error { log.Errorf("Error closing conntrack connection: %v", err) } c.conn = nil + c.RestoreAccounting() return fmt.Errorf("start conntrack listener: %w", err) } + // Drain any stale stop signal from a previous cycle. + select { + case <-c.done: + default: + } + c.started = true go c.receiverRoutine(events, errChan) @@ -92,17 +135,98 @@ func (c *ConnTrack) receiverRoutine(events chan nfct.Event, errChan chan error) case event := <-events: c.handleEvent(event) case err := <-errChan: - log.Errorf("Error from conntrack event listener: %v", err) - if err := c.conn.Close(); err != nil { - log.Errorf("Error closing conntrack connection: %v", err) + if events, errChan = c.handleListenerError(err); events == nil { + return } - return case <-c.done: return } } } +// handleListenerError closes the failed connection and attempts to reconnect. +// Returns new channels on success, or nil if shutdown was requested. +func (c *ConnTrack) handleListenerError(err error) (chan nfct.Event, chan error) { + log.Warnf("conntrack event listener failed: %v", err) + c.closeConn() + return c.reconnect() +} + +func (c *ConnTrack) closeConn() { + c.mux.Lock() + defer c.mux.Unlock() + + if c.conn != nil { + if err := c.conn.Close(); err != nil { + log.Debugf("close conntrack connection: %v", err) + } + c.conn = nil + } +} + +// reconnect attempts to re-establish the conntrack netlink listener with exponential backoff. +// Returns new channels on success, or nil if shutdown was requested. +func (c *ConnTrack) reconnect() (chan nfct.Event, chan error) { + bo := &backoff.ExponentialBackOff{ + InitialInterval: reconnectInitInterval, + RandomizationFactor: reconnectRandomization, + Multiplier: backoff.DefaultMultiplier, + MaxInterval: reconnectMaxInterval, + MaxElapsedTime: 0, // retry indefinitely + Clock: backoff.SystemClock, + } + bo.Reset() + + for { + delay := bo.NextBackOff() + log.Infof("reconnecting conntrack listener in %s", delay) + + select { + case <-c.done: + c.mux.Lock() + c.started = false + c.mux.Unlock() + return nil, nil + case <-time.After(delay): + } + + conn, err := c.dial() + if err != nil { + log.Warnf("reconnect conntrack dial: %v", err) + continue + } + + events := make(chan nfct.Event, defaultChannelSize) + errChan, err := conn.Listen(events, 1, []netfilter.NetlinkGroup{ + netfilter.GroupCTNew, + netfilter.GroupCTDestroy, + }) + if err != nil { + log.Warnf("reconnect conntrack listen: %v", err) + if closeErr := conn.Close(); closeErr != nil { + log.Debugf("close conntrack connection: %v", closeErr) + } + continue + } + + c.mux.Lock() + if !c.started { + // Stop() ran while we were reconnecting. + c.mux.Unlock() + if closeErr := conn.Close(); closeErr != nil { + log.Debugf("close conntrack connection: %v", closeErr) + } + return nil, nil + } + c.conn = conn + c.mux.Unlock() + + log.Infof("conntrack listener reconnected successfully") + + return events, errChan + } +} + // Stop stops the connection tracking. This method is idempotent. func (c *ConnTrack) Stop() { c.mux.Lock() @@ -136,23 +260,27 @@ func (c *ConnTrack) Close() error { c.mux.Lock() defer c.mux.Unlock() - if c.started { - select { - case c.done <- struct{}{}: - default: - } + if !c.started { + return nil } + select { + case c.done <- struct{}{}: + default: + } + + c.started = false + + var closeErr error if c.conn != nil { - err := c.conn.Close() + closeErr = c.conn.Close() c.conn = nil - c.started = false + } - c.RestoreAccounting() + c.RestoreAccounting() - if err != nil { - return fmt.Errorf("close conntrack: %w", err) - } + if closeErr != nil { + return fmt.Errorf("close conntrack: %w", closeErr) } return nil diff --git a/client/internal/netflow/conntrack/conntrack_test.go b/client/internal/netflow/conntrack/conntrack_test.go new file mode 100644 index 000000000..35ceec90d --- /dev/null +++ b/client/internal/netflow/conntrack/conntrack_test.go @@ -0,0 +1,224 @@ +//go:build linux && !android + +package conntrack + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + nfct "github.com/ti-mo/conntrack" + "github.com/ti-mo/netfilter" +) + +type mockListener struct { + errChan chan error + closed atomic.Bool + closedCh chan struct{} +} + +func newMockListener() *mockListener { + return &mockListener{ + errChan: make(chan error, 1), + closedCh: make(chan struct{}), + } +} + +func (m *mockListener) Listen(evChan chan<- nfct.Event, _ uint8, _ []netfilter.NetlinkGroup) (chan error, error) { + return m.errChan, nil +} + +func (m *mockListener) Close() error { + if m.closed.CompareAndSwap(false, true) { + close(m.closedCh) + } + return nil +} + +func TestReconnectAfterError(t *testing.T) { + first := newMockListener() + second := newMockListener() + third := newMockListener() + listeners := []*mockListener{first, second, third} + callCount := atomic.Int32{} + + ct := New(nil, nil, WithDialer(func() (listener, error) { + n := int(callCount.Add(1)) - 1 + return listeners[n], nil + })) + + err := ct.Start(false) + require.NoError(t, err) + + // Inject an error on the first listener. + first.errChan <- assert.AnError + + // Wait for reconnect to complete. + require.Eventually(t, func() bool { + return callCount.Load() >= 2 + }, 15*time.Second, 100*time.Millisecond, "reconnect should dial a new connection") + + // The first connection must have been closed. + select { + case <-first.closedCh: + case <-time.After(2 * time.Second): + t.Fatal("first connection was not closed") + } + + // Verify the receiver is still running by injecting and handling a second error. + second.errChan <- assert.AnError + + require.Eventually(t, func() bool { + return callCount.Load() >= 3 + }, 15*time.Second, 100*time.Millisecond, "second reconnect should succeed") + + ct.Stop() +} + +func TestStopDuringReconnectBackoff(t *testing.T) { + mock := newMockListener() + + ct := New(nil, nil, WithDialer(func() (listener, error) { + return mock, nil + })) + + err := ct.Start(false) + require.NoError(t, err) + + // Trigger an error so the receiver enters reconnect. + mock.errChan <- assert.AnError + + // Wait for the error handler to close the old listener before calling Stop. + select { + case <-mock.closedCh: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for reconnect to start") + } + + // Stop while reconnecting. + ct.Stop() + + ct.mux.Lock() + assert.False(t, ct.started, "started should be false after Stop") + assert.Nil(t, ct.conn, "conn should be nil after Stop") + ct.mux.Unlock() +} + +func TestStopRaceWithReconnectDial(t *testing.T) { + first := newMockListener() + dialStarted := make(chan struct{}) + dialProceed := make(chan struct{}) + second := newMockListener() + callCount := atomic.Int32{} + + ct := New(nil, nil, WithDialer(func() (listener, error) { + n := callCount.Add(1) + if n == 1 { + return first, nil + } + // Second dial: signal that we're in progress, wait for test to call Stop. + close(dialStarted) + <-dialProceed + return second, nil + })) + + err := ct.Start(false) + require.NoError(t, err) + + // Trigger error to enter reconnect. + first.errChan <- assert.AnError + + // Wait for reconnect's second dial to begin. + select { + case <-dialStarted: + case <-time.After(15 * time.Second): + t.Fatal("timed out waiting for reconnect dial") + } + + // Stop while dial is in progress (conn is nil at this point). + ct.Stop() + + // Let the dial complete. reconnect should detect started==false and close the new conn. + close(dialProceed) + + // The second connection should be closed (not leaked). + select { + case <-second.closedCh: + case <-time.After(2 * time.Second): + t.Fatal("second connection was leaked after Stop") + } + + ct.mux.Lock() + assert.False(t, ct.started) + assert.Nil(t, ct.conn) + ct.mux.Unlock() +} + +func TestCloseRaceWithReconnectDial(t *testing.T) { + first := newMockListener() + dialStarted := make(chan struct{}) + dialProceed := make(chan struct{}) + second := newMockListener() + callCount := atomic.Int32{} + + ct := New(nil, nil, WithDialer(func() (listener, error) { + n := callCount.Add(1) + if n == 1 { + return first, nil + } + close(dialStarted) + <-dialProceed + return second, nil + })) + + err := ct.Start(false) + require.NoError(t, err) + + first.errChan <- assert.AnError + + select { + case <-dialStarted: + case <-time.After(15 * time.Second): + t.Fatal("timed out waiting for reconnect dial") + } + + // Close while dial is in progress (conn is nil). + require.NoError(t, ct.Close()) + + close(dialProceed) + + // The second connection should be closed (not leaked). + select { + case <-second.closedCh: + case <-time.After(2 * time.Second): + t.Fatal("second connection was leaked after Close") + } + + ct.mux.Lock() + assert.False(t, ct.started) + assert.Nil(t, ct.conn) + ct.mux.Unlock() +} + +func TestStartIsIdempotent(t *testing.T) { + mock := newMockListener() + callCount := atomic.Int32{} + + ct := New(nil, nil, WithDialer(func() (listener, error) { + callCount.Add(1) + return mock, nil + })) + + err := ct.Start(false) + require.NoError(t, err) + + // Second Start should be a no-op. + err = ct.Start(false) + require.NoError(t, err) + + assert.Equal(t, int32(1), callCount.Load(), "dial should only be called once") + + ct.Stop() +} diff --git a/client/internal/netflow/store/memory.go b/client/internal/netflow/store/memory.go index b695a0a12..a44505e96 100644 --- a/client/internal/netflow/store/memory.go +++ b/client/internal/netflow/store/memory.go @@ -3,8 +3,6 @@ package store import ( "sync" - "golang.org/x/exp/maps" - "github.com/google/uuid" "github.com/netbirdio/netbird/client/internal/netflow/types" @@ -30,7 +28,7 @@ func (m *Memory) StoreEvent(event *types.Event) { func (m *Memory) Close() { m.mux.Lock() defer m.mux.Unlock() - maps.Clear(m.events) + clear(m.events) } func (m *Memory) GetEvents() []*types.Event { diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index bea0725f2..1e416bfe7 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -22,6 +22,7 @@ import ( icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/peer/id" "github.com/netbirdio/netbird/client/internal/peer/worker" + "github.com/netbirdio/netbird/client/internal/portforward" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/route" relayClient "github.com/netbirdio/netbird/shared/relay/client" @@ -45,6 +46,7 @@ type ServiceDependencies struct { RelayManager *relayClient.Manager SrWatcher *guard.SRWatcher PeerConnDispatcher *dispatcher.ConnectionDispatcher + PortForwardManager *portforward.Manager MetricsRecorder MetricsRecorder } @@ -87,16 +89,17 @@ type ConnConfig struct { } type Conn struct { - Log *log.Entry - mu sync.Mutex - ctx context.Context - ctxCancel context.CancelFunc - config ConnConfig - statusRecorder *Status - signaler *Signaler - iFaceDiscover stdnet.ExternalIFaceDiscover - relayManager *relayClient.Manager - srWatcher *guard.SRWatcher + Log *log.Entry + mu sync.Mutex + ctx context.Context + ctxCancel context.CancelFunc + config ConnConfig + statusRecorder *Status + signaler *Signaler + iFaceDiscover stdnet.ExternalIFaceDiscover + relayManager *relayClient.Manager + srWatcher *guard.SRWatcher + portForwardManager *portforward.Manager onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) onDisconnected func(remotePeer string) @@ -145,19 +148,20 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) { dumpState := newStateDump(config.Key, connLog, services.StatusRecorder) var conn = &Conn{ - Log: connLog, - config: config, - statusRecorder: services.StatusRecorder, - signaler: services.Signaler, - iFaceDiscover: services.IFaceDiscover, - relayManager: services.RelayManager, - srWatcher: services.SrWatcher, - statusRelay: worker.NewAtomicStatus(), - statusICE: worker.NewAtomicStatus(), - dumpState: dumpState, - endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)), - wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState), - metricsRecorder: services.MetricsRecorder, + Log: connLog, + config: config, + statusRecorder: services.StatusRecorder, + signaler: services.Signaler, + iFaceDiscover: services.IFaceDiscover, + relayManager: services.RelayManager, + srWatcher: services.SrWatcher, + portForwardManager: services.PortForwardManager, + statusRelay: worker.NewAtomicStatus(), + statusICE: worker.NewAtomicStatus(), + dumpState: dumpState, + endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)), + wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState), + metricsRecorder: services.MetricsRecorder, } return conn, nil @@ -181,17 +185,20 @@ func (conn *Conn) Open(engineCtx context.Context) error { conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager) - relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() - workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) - if err != nil { - return err + forceRelay := IsForceRelayed() + if !forceRelay { + relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() + workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) + if err != nil { + return err + } + conn.workerICE = workerICE } - conn.workerICE = workerICE conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages) conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer) - if !isForceRelayed() { + if !forceRelay { conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer) } @@ -247,7 +254,9 @@ func (conn *Conn) Close(signalToRemote bool) { conn.wgWatcherCancel() } conn.workerRelay.CloseConn() - conn.workerICE.Close() + if conn.workerICE != nil { + conn.workerICE.Close() + } if conn.wgProxyRelay != nil { err := conn.wgProxyRelay.CloseConn() @@ -290,7 +299,9 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) { // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) { conn.dumpState.RemoteCandidate() - conn.workerICE.OnRemoteCandidate(candidate, haRoutes) + if conn.workerICE != nil { + conn.workerICE.OnRemoteCandidate(candidate, haRoutes) + } } // SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established @@ -708,33 +719,35 @@ func (conn *Conn) evalStatus() ConnStatus { return StatusConnecting } -func (conn *Conn) isConnectedOnAllWay() (connected bool) { - // would be better to protect this with a mutex, but it could cause deadlock with Close function - +// isConnectedOnAllWay evaluates the overall connection status based on ICE and Relay transports. +// +// The result is a tri-state: +// - ConnStatusConnected: all available transports are up +// - ConnStatusPartiallyConnected: relay is up but ICE is still pending/reconnecting +// - ConnStatusDisconnected: no working transport +func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) { defer func() { - if !connected { + if status == guard.ConnStatusDisconnected { conn.logTraceConnState() } }() - // For JS platform: only relay connection is supported - if runtime.GOOS == "js" { - return conn.statusRelay.Get() == worker.StatusConnected + iceWorkerCreated := conn.workerICE != nil + + var iceInProgress bool + if iceWorkerCreated { + iceInProgress = conn.workerICE.InProgress() } - // For non-JS platforms: check ICE connection status - if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() { - return false - } - - // If relay is supported with peer, it must also be connected - if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { - if conn.statusRelay.Get() == worker.StatusDisconnected { - return false - } - } - - return true + return evalConnStatus(connStatusInputs{ + forceRelay: IsForceRelayed(), + peerUsesRelay: conn.workerRelay.IsRelayConnectionSupportedWithPeer(), + relayConnected: conn.statusRelay.Get() == worker.StatusConnected, + remoteSupportsICE: conn.handshaker.RemoteICESupported(), + iceWorkerCreated: iceWorkerCreated, + iceStatusConnecting: conn.statusICE.Get() != worker.StatusDisconnected, + iceInProgress: iceInProgress, + }) } func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) { @@ -922,3 +935,43 @@ func isController(config ConnConfig) bool { func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { return remoteRosenpassPubKey != nil } + +func evalConnStatus(in connStatusInputs) guard.ConnStatus { + // "Relay up and needed" — the peer uses relay and the transport is connected. + relayUsedAndUp := in.peerUsesRelay && in.relayConnected + + // Force-relay mode: ICE never runs. Relay is the only transport and must be up. + if in.forceRelay { + return boolToConnStatus(relayUsedAndUp) + } + + // Remote peer doesn't support ICE, or we haven't created the worker yet: + // relay is the only possible transport. + if !in.remoteSupportsICE || !in.iceWorkerCreated { + return boolToConnStatus(relayUsedAndUp) + } + + // ICE counts as "up" when the status is anything other than Disconnected, OR + // when a negotiation is currently in progress (so we don't spam offers while one is in flight). + iceUp := in.iceStatusConnecting || in.iceInProgress + + // Relay side is acceptable if the peer doesn't rely on relay, or relay is connected. + relayOK := !in.peerUsesRelay || in.relayConnected + + switch { + case iceUp && relayOK: + return guard.ConnStatusConnected + case relayUsedAndUp: + // Relay is up but ICE is down — partially connected. + return guard.ConnStatusPartiallyConnected + default: + return guard.ConnStatusDisconnected + } +} + +func boolToConnStatus(connected bool) guard.ConnStatus { + if connected { + return guard.ConnStatusConnected + } + return guard.ConnStatusDisconnected +} diff --git a/client/internal/peer/conn_status.go b/client/internal/peer/conn_status.go index 73acc5ef5..b43e245f3 100644 --- a/client/internal/peer/conn_status.go +++ b/client/internal/peer/conn_status.go @@ -13,6 +13,20 @@ const ( StatusConnected ) +// connStatusInputs is the primitive-valued snapshot of the state that drives the +// tri-state connection classification. Extracted so the decision logic can be unit-tested +// without constructing full Worker/Handshaker objects. +type connStatusInputs struct { + forceRelay bool // NB_FORCE_RELAY or JS/WASM + peerUsesRelay bool // remote peer advertises relay support AND local has relay + relayConnected bool // statusRelay reports Connected (independent of whether peer uses relay) + remoteSupportsICE bool // remote peer sent ICE credentials + iceWorkerCreated bool // local WorkerICE exists (false in force-relay mode) + iceStatusConnecting bool // statusICE is anything other than Disconnected + iceInProgress bool // a negotiation is currently in flight +} + + // ConnStatus describe the status of a peer's connection type ConnStatus int32 diff --git a/client/internal/peer/conn_status_eval_test.go b/client/internal/peer/conn_status_eval_test.go new file mode 100644 index 000000000..66393cafe --- /dev/null +++ b/client/internal/peer/conn_status_eval_test.go @@ -0,0 +1,201 @@ +package peer + +import ( + "testing" + + "github.com/netbirdio/netbird/client/internal/peer/guard" +) + +func TestEvalConnStatus_ForceRelay(t *testing.T) { + tests := []struct { + name string + in connStatusInputs + want guard.ConnStatus + }{ + { + name: "force relay, peer uses relay, relay up", + in: connStatusInputs{ + forceRelay: true, + peerUsesRelay: true, + relayConnected: true, + }, + want: guard.ConnStatusConnected, + }, + { + name: "force relay, peer uses relay, relay down", + in: connStatusInputs{ + forceRelay: true, + peerUsesRelay: true, + relayConnected: false, + }, + want: guard.ConnStatusDisconnected, + }, + { + name: "force relay, peer does NOT use relay - disconnected forever", + in: connStatusInputs{ + forceRelay: true, + peerUsesRelay: false, + relayConnected: true, + }, + want: guard.ConnStatusDisconnected, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := evalConnStatus(tc.in); got != tc.want { + t.Fatalf("evalConnStatus = %v, want %v", got, tc.want) + } + }) + } +} + +func TestEvalConnStatus_ICEUnavailable(t *testing.T) { + tests := []struct { + name string + in connStatusInputs + want guard.ConnStatus + }{ + { + name: "remote does not support ICE, peer uses relay, relay up", + in: connStatusInputs{ + peerUsesRelay: true, + relayConnected: true, + remoteSupportsICE: false, + iceWorkerCreated: true, + }, + want: guard.ConnStatusConnected, + }, + { + name: "remote does not support ICE, peer uses relay, relay down", + in: connStatusInputs{ + peerUsesRelay: true, + relayConnected: false, + remoteSupportsICE: false, + iceWorkerCreated: true, + }, + want: guard.ConnStatusDisconnected, + }, + { + name: "ICE worker not yet created, relay up", + in: connStatusInputs{ + peerUsesRelay: true, + relayConnected: true, + remoteSupportsICE: true, + iceWorkerCreated: false, + }, + want: guard.ConnStatusConnected, + }, + { + name: "remote does not support ICE, peer does not use relay", + in: connStatusInputs{ + peerUsesRelay: false, + relayConnected: false, + remoteSupportsICE: false, + iceWorkerCreated: true, + }, + want: guard.ConnStatusDisconnected, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := evalConnStatus(tc.in); got != tc.want { + t.Fatalf("evalConnStatus = %v, want %v", got, tc.want) + } + }) + } +} + +func TestEvalConnStatus_FullyAvailable(t *testing.T) { + base := connStatusInputs{ + remoteSupportsICE: true, + iceWorkerCreated: true, + } + + tests := []struct { + name string + mutator func(*connStatusInputs) + want guard.ConnStatus + }{ + { + name: "ICE connected, relay connected, peer uses relay", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = true + in.relayConnected = true + in.iceStatusConnecting = true + }, + want: guard.ConnStatusConnected, + }, + { + name: "ICE connected, peer does NOT use relay", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = false + in.relayConnected = false + in.iceStatusConnecting = true + }, + want: guard.ConnStatusConnected, + }, + { + name: "ICE InProgress only, peer does NOT use relay", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = false + in.iceStatusConnecting = false + in.iceInProgress = true + }, + want: guard.ConnStatusConnected, + }, + { + name: "ICE down, relay up, peer uses relay -> partial", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = true + in.relayConnected = true + in.iceStatusConnecting = false + in.iceInProgress = false + }, + want: guard.ConnStatusPartiallyConnected, + }, + { + name: "ICE down, peer does NOT use relay -> disconnected", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = false + in.relayConnected = false + in.iceStatusConnecting = false + in.iceInProgress = false + }, + want: guard.ConnStatusDisconnected, + }, + { + name: "ICE up, peer uses relay but relay down -> partial (relay required, ICE ignored)", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = true + in.relayConnected = false + in.iceStatusConnecting = true + }, + // relayOK = false (peer uses relay but it's down), iceUp = true + // first switch arm fails (relayOK false), relayUsedAndUp = false (relay down), + // falls into default: Disconnected. + want: guard.ConnStatusDisconnected, + }, + { + name: "ICE down, relay up but peer does not use relay -> disconnected", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = false + in.relayConnected = true // not actually used since peer doesn't rely on it + in.iceStatusConnecting = false + in.iceInProgress = false + }, + want: guard.ConnStatusDisconnected, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + in := base + tc.mutator(&in) + if got := evalConnStatus(in); got != tc.want { + t.Fatalf("evalConnStatus = %v, want %v (inputs: %+v)", got, tc.want, in) + } + }) + } +} diff --git a/client/internal/peer/env.go b/client/internal/peer/env.go index 7f500c410..ed6a3af53 100644 --- a/client/internal/peer/env.go +++ b/client/internal/peer/env.go @@ -7,12 +7,38 @@ import ( ) const ( - EnvKeyNBForceRelay = "NB_FORCE_RELAY" + EnvKeyNBForceRelay = "NB_FORCE_RELAY" + EnvKeyNBHomeRelayServers = "NB_HOME_RELAY_SERVERS" ) -func isForceRelayed() bool { +func IsForceRelayed() bool { if runtime.GOOS == "js" { return true } return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true") } + +// OverrideRelayURLs returns the relay server URL list set in +// NB_HOME_RELAY_SERVERS (comma-separated) and a boolean indicating whether +// the override is active. When the env var is unset, the boolean is false +// and the caller should keep the list received from the management server. +// Intended for lab/debug scenarios where a peer must pin to a specific home +// relay regardless of what management offers. +func OverrideRelayURLs() ([]string, bool) { + raw := os.Getenv(EnvKeyNBHomeRelayServers) + if raw == "" { + return nil, false + } + parts := strings.Split(raw, ",") + urls := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + urls = append(urls, p) + } + } + if len(urls) == 0 { + return nil, false + } + return urls, true +} diff --git a/client/internal/peer/guard/guard.go b/client/internal/peer/guard/guard.go index d93403730..2e5efbcc5 100644 --- a/client/internal/peer/guard/guard.go +++ b/client/internal/peer/guard/guard.go @@ -8,7 +8,19 @@ import ( log "github.com/sirupsen/logrus" ) -type isConnectedFunc func() bool +// ConnStatus represents the connection state as seen by the guard. +type ConnStatus int + +const ( + // ConnStatusDisconnected means neither ICE nor Relay is connected. + ConnStatusDisconnected ConnStatus = iota + // ConnStatusPartiallyConnected means Relay is connected but ICE is not. + ConnStatusPartiallyConnected + // ConnStatusConnected means all required connections are established. + ConnStatusConnected +) + +type connStatusFunc func() ConnStatus // Guard is responsible for the reconnection logic. // It will trigger to send an offer to the peer then has connection issues. @@ -20,14 +32,14 @@ type isConnectedFunc func() bool // - ICE candidate changes type Guard struct { log *log.Entry - isConnectedOnAllWay isConnectedFunc + isConnectedOnAllWay connStatusFunc timeout time.Duration srWatcher *SRWatcher relayedConnDisconnected chan struct{} iCEConnDisconnected chan struct{} } -func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { +func NewGuard(log *log.Entry, isConnectedFn connStatusFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { return &Guard{ log: log, isConnectedOnAllWay: isConnectedFn, @@ -57,8 +69,17 @@ func (g *Guard) SetICEConnDisconnected() { } } -// reconnectLoopWithRetry periodically check the connection status. -// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported +// reconnectLoopWithRetry periodically checks the connection status and sends offers to re-establish connectivity. +// +// Behavior depends on the connection state reported by isConnectedOnAllWay: +// - Connected: no action, the peer is fully reachable. +// - Disconnected (neither ICE nor Relay): retries aggressively with exponential backoff (800ms doubling +// up to timeout), never gives up. This ensures rapid recovery when the peer has no connectivity at all. +// - PartiallyConnected (Relay up, ICE not): retries up to 3 times with exponential backoff, then switches +// to one attempt per hour. This limits signaling traffic when relay already provides connectivity. +// +// External events (relay/ICE disconnect, signal/relay reconnect, candidate changes) reset the retry +// counter and backoff ticker, giving ICE a fresh chance after network conditions change. func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) { srReconnectedChan := g.srWatcher.NewListener() defer g.srWatcher.RemoveListener(srReconnectedChan) @@ -68,36 +89,47 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) { tickerChannel := ticker.C + iceState := &iceRetryState{log: g.log} + defer iceState.reset() + for { select { - case t := <-tickerChannel: - if t.IsZero() { - g.log.Infof("retry timed out, stop periodic offer sending") - // after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop - tickerChannel = make(<-chan time.Time) - continue + case <-tickerChannel: + switch g.isConnectedOnAllWay() { + case ConnStatusConnected: + // all good, nothing to do + case ConnStatusDisconnected: + callback() + case ConnStatusPartiallyConnected: + if iceState.shouldRetry() { + callback() + } else { + iceState.enterHourlyMode() + ticker.Stop() + tickerChannel = iceState.hourlyC() + } } - if !g.isConnectedOnAllWay() { - callback() - } case <-g.relayedConnDisconnected: g.log.Debugf("Relay connection changed, reset reconnection ticker") ticker.Stop() - ticker = g.prepareExponentTicker(ctx) + ticker = g.newReconnectTicker(ctx) tickerChannel = ticker.C + iceState.reset() case <-g.iCEConnDisconnected: g.log.Debugf("ICE connection changed, reset reconnection ticker") ticker.Stop() - ticker = g.prepareExponentTicker(ctx) + ticker = g.newReconnectTicker(ctx) tickerChannel = ticker.C + iceState.reset() case <-srReconnectedChan: g.log.Debugf("has network changes, reset reconnection ticker") ticker.Stop() - ticker = g.prepareExponentTicker(ctx) + ticker = g.newReconnectTicker(ctx) tickerChannel = ticker.C + iceState.reset() case <-ctx.Done(): g.log.Debugf("context is done, stop reconnect loop") @@ -120,7 +152,7 @@ func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker { return backoff.NewTicker(bo) } -func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker { +func (g *Guard) newReconnectTicker(ctx context.Context) *backoff.Ticker { bo := backoff.WithContext(&backoff.ExponentialBackOff{ InitialInterval: 800 * time.Millisecond, RandomizationFactor: 0.1, diff --git a/client/internal/peer/guard/ice_retry_state.go b/client/internal/peer/guard/ice_retry_state.go new file mode 100644 index 000000000..01dc1bf2d --- /dev/null +++ b/client/internal/peer/guard/ice_retry_state.go @@ -0,0 +1,61 @@ +package guard + +import ( + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + // maxICERetries is the maximum number of ICE offer attempts when relay is connected + maxICERetries = 3 + // iceRetryInterval is the periodic retry interval after ICE retries are exhausted + iceRetryInterval = 1 * time.Hour +) + +// iceRetryState tracks the limited ICE retry attempts when relay is already connected. +// After maxICERetries attempts it switches to a periodic hourly retry. +type iceRetryState struct { + log *log.Entry + retries int + hourly *time.Ticker +} + +func (s *iceRetryState) reset() { + s.retries = 0 + if s.hourly != nil { + s.hourly.Stop() + s.hourly = nil + } +} + +// shouldRetry reports whether the caller should send another ICE offer on this tick. +// Returns false when the per-cycle retry budget is exhausted and the caller must switch +// to the hourly ticker via enterHourlyMode + hourlyC. +func (s *iceRetryState) shouldRetry() bool { + if s.hourly != nil { + s.log.Debugf("hourly ICE retry attempt") + return true + } + + s.retries++ + if s.retries <= maxICERetries { + s.log.Debugf("ICE retry attempt %d/%d", s.retries, maxICERetries) + return true + } + + return false +} + +// enterHourlyMode starts the hourly retry ticker. Must be called after shouldRetry returns false. +func (s *iceRetryState) enterHourlyMode() { + s.log.Infof("ICE retries exhausted (%d/%d), switching to hourly retry", maxICERetries, maxICERetries) + s.hourly = time.NewTicker(iceRetryInterval) +} + +func (s *iceRetryState) hourlyC() <-chan time.Time { + if s.hourly == nil { + return nil + } + return s.hourly.C +} diff --git a/client/internal/peer/guard/ice_retry_state_test.go b/client/internal/peer/guard/ice_retry_state_test.go new file mode 100644 index 000000000..6a5b5a76f --- /dev/null +++ b/client/internal/peer/guard/ice_retry_state_test.go @@ -0,0 +1,103 @@ +package guard + +import ( + "testing" + + log "github.com/sirupsen/logrus" +) + +func newTestRetryState() *iceRetryState { + return &iceRetryState{log: log.NewEntry(log.StandardLogger())} +} + +func TestICERetryState_AllowsInitialBudget(t *testing.T) { + s := newTestRetryState() + + for i := 1; i <= maxICERetries; i++ { + if !s.shouldRetry() { + t.Fatalf("shouldRetry returned false on attempt %d, want true (budget = %d)", i, maxICERetries) + } + } +} + +func TestICERetryState_ExhaustsAfterBudget(t *testing.T) { + s := newTestRetryState() + + for i := 0; i < maxICERetries; i++ { + _ = s.shouldRetry() + } + + if s.shouldRetry() { + t.Fatalf("shouldRetry returned true after budget exhausted, want false") + } +} + +func TestICERetryState_HourlyCNilBeforeEnterHourlyMode(t *testing.T) { + s := newTestRetryState() + + if s.hourlyC() != nil { + t.Fatalf("hourlyC returned non-nil channel before enterHourlyMode") + } +} + +func TestICERetryState_EnterHourlyModeArmsTicker(t *testing.T) { + s := newTestRetryState() + for i := 0; i < maxICERetries+1; i++ { + _ = s.shouldRetry() + } + + s.enterHourlyMode() + defer s.reset() + + if s.hourlyC() == nil { + t.Fatalf("hourlyC returned nil after enterHourlyMode") + } +} + +func TestICERetryState_ShouldRetryTrueInHourlyMode(t *testing.T) { + s := newTestRetryState() + s.enterHourlyMode() + defer s.reset() + + if !s.shouldRetry() { + t.Fatalf("shouldRetry returned false in hourly mode, want true") + } + + // Subsequent calls also return true — we keep retrying on each hourly tick. + if !s.shouldRetry() { + t.Fatalf("second shouldRetry returned false in hourly mode, want true") + } +} + +func TestICERetryState_ResetRestoresBudget(t *testing.T) { + s := newTestRetryState() + for i := 0; i < maxICERetries+1; i++ { + _ = s.shouldRetry() + } + s.enterHourlyMode() + + s.reset() + + if s.hourlyC() != nil { + t.Fatalf("hourlyC returned non-nil channel after reset") + } + if s.retries != 0 { + t.Fatalf("retries = %d after reset, want 0", s.retries) + } + + for i := 1; i <= maxICERetries; i++ { + if !s.shouldRetry() { + t.Fatalf("shouldRetry returned false on attempt %d after reset, want true", i) + } + } +} + +func TestICERetryState_ResetIsIdempotent(t *testing.T) { + s := newTestRetryState() + s.reset() + s.reset() // second call must not panic or re-stop a nil ticker + + if s.hourlyC() != nil { + t.Fatalf("hourlyC non-nil after double reset") + } +} diff --git a/client/internal/peer/guard/sr_watcher.go b/client/internal/peer/guard/sr_watcher.go index 6f4f5ad4f..0befd7438 100644 --- a/client/internal/peer/guard/sr_watcher.go +++ b/client/internal/peer/guard/sr_watcher.go @@ -39,7 +39,7 @@ func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscove return srw } -func (w *SRWatcher) Start() { +func (w *SRWatcher) Start(disableICEMonitor bool) { w.mu.Lock() defer w.mu.Unlock() @@ -50,8 +50,10 @@ func (w *SRWatcher) Start() { ctx, cancel := context.WithCancel(context.Background()) w.cancelIceMonitor = cancel - iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod()) - go iceMonitor.Start(ctx, w.onICEChanged) + if !disableICEMonitor { + iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod()) + go iceMonitor.Start(ctx, w.onICEChanged) + } w.signalClient.SetOnReconnectedListener(w.onReconnected) w.relayManager.SetOnReconnectedListener(w.onReconnected) diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go index 9b50cecd1..1d44096b6 100644 --- a/client/internal/peer/handshaker.go +++ b/client/internal/peer/handshaker.go @@ -3,7 +3,9 @@ package peer import ( "context" "errors" + "net/netip" "sync" + "sync/atomic" log "github.com/sirupsen/logrus" @@ -39,10 +41,18 @@ type OfferAnswer struct { // relay server address RelaySrvAddress string + // RelaySrvIP is the IP the remote peer is connected to on its + // relay server. Used as a dial target if DNS for RelaySrvAddress + // fails. Zero value if the peer did not advertise an IP. + RelaySrvIP netip.Addr // SessionID is the unique identifier of the session, used to discard old messages SessionID *ICESessionID } +func (o *OfferAnswer) hasICECredentials() bool { + return o.IceCredentials.UFrag != "" && o.IceCredentials.Pwd != "" +} + type Handshaker struct { mu sync.Mutex log *log.Entry @@ -59,6 +69,10 @@ type Handshaker struct { relayListener *AsyncOfferListener iceListener func(remoteOfferAnswer *OfferAnswer) + // remoteICESupported tracks whether the remote peer includes ICE credentials in its offers/answers. + // When false, the local side skips ICE listener dispatch and suppresses ICE credentials in responses. + remoteICESupported atomic.Bool + // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection remoteOffersCh chan OfferAnswer // remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection @@ -66,7 +80,7 @@ type Handshaker struct { } func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker { - return &Handshaker{ + h := &Handshaker{ log: log, config: config, signaler: signaler, @@ -76,6 +90,13 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W remoteOffersCh: make(chan OfferAnswer), remoteAnswerCh: make(chan OfferAnswer), } + // assume remote supports ICE until we learn otherwise from received offers + h.remoteICESupported.Store(ice != nil) + return h +} + +func (h *Handshaker) RemoteICESupported() bool { + return h.remoteICESupported.Load() } func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) { @@ -90,18 +111,20 @@ func (h *Handshaker) Listen(ctx context.Context) { for { select { case remoteOfferAnswer := <-h.remoteOffersCh: - h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) + h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials()) // Record signaling received for reconnection attempts if h.metricsStages != nil { h.metricsStages.RecordSignalingReceived() } + h.updateRemoteICEState(&remoteOfferAnswer) + if h.relayListener != nil { h.relayListener.Notify(&remoteOfferAnswer) } - if h.iceListener != nil { + if h.iceListener != nil && h.RemoteICESupported() { h.iceListener(&remoteOfferAnswer) } @@ -110,18 +133,20 @@ func (h *Handshaker) Listen(ctx context.Context) { continue } case remoteOfferAnswer := <-h.remoteAnswerCh: - h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) + h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials()) // Record signaling received for reconnection attempts if h.metricsStages != nil { h.metricsStages.RecordSignalingReceived() } + h.updateRemoteICEState(&remoteOfferAnswer) + if h.relayListener != nil { h.relayListener.Notify(&remoteOfferAnswer) } - if h.iceListener != nil { + if h.iceListener != nil && h.RemoteICESupported() { h.iceListener(&remoteOfferAnswer) } case <-ctx.Done(): @@ -183,20 +208,39 @@ func (h *Handshaker) sendAnswer() error { } func (h *Handshaker) buildOfferAnswer() OfferAnswer { - uFrag, pwd := h.ice.GetLocalUserCredentials() - sid := h.ice.SessionID() answer := OfferAnswer{ - IceCredentials: IceCredentials{uFrag, pwd}, WgListenPort: h.config.LocalWgPort, Version: version.NetbirdVersion(), RosenpassPubKey: h.config.RosenpassConfig.PubKey, RosenpassAddr: h.config.RosenpassConfig.Addr, - SessionID: &sid, } - if addr, err := h.relay.RelayInstanceAddress(); err == nil { + if h.ice != nil && h.RemoteICESupported() { + uFrag, pwd := h.ice.GetLocalUserCredentials() + sid := h.ice.SessionID() + answer.IceCredentials = IceCredentials{uFrag, pwd} + answer.SessionID = &sid + } + + if addr, ip, err := h.relay.RelayInstanceAddress(); err == nil { answer.RelaySrvAddress = addr + answer.RelaySrvIP = ip } return answer } + +func (h *Handshaker) updateRemoteICEState(offer *OfferAnswer) { + hasICE := offer.hasICECredentials() + prev := h.remoteICESupported.Swap(hasICE) + if prev != hasICE { + if hasICE { + h.log.Infof("remote peer started sending ICE credentials") + } else { + h.log.Infof("remote peer stopped sending ICE credentials") + if h.ice != nil { + h.ice.Close() + } + } + } +} diff --git a/client/internal/peer/notifier_test.go b/client/internal/peer/notifier_test.go index bbdc00e13..0b7722b0c 100644 --- a/client/internal/peer/notifier_test.go +++ b/client/internal/peer/notifier_test.go @@ -8,6 +8,7 @@ import ( type mocListener struct { lastState int wg sync.WaitGroup + peersWg sync.WaitGroup peers int } @@ -33,6 +34,7 @@ func (l *mocListener) OnAddressChanged(host, addr string) { } func (l *mocListener) OnPeersListChanged(size int) { l.peers = size + l.peersWg.Done() } func (l *mocListener) setWaiter() { @@ -43,6 +45,14 @@ func (l *mocListener) wait() { l.wg.Wait() } +func (l *mocListener) setPeersWaiter() { + l.peersWg.Add(1) +} + +func (l *mocListener) waitPeers() { + l.peersWg.Wait() +} + func Test_notifier_serverState(t *testing.T) { type scenario struct { @@ -72,11 +82,13 @@ func Test_notifier_serverState(t *testing.T) { func Test_notifier_SetListener(t *testing.T) { listener := &mocListener{} listener.setWaiter() + listener.setPeersWaiter() n := newNotifier() n.lastNotification = stateConnecting n.setListener(listener) listener.wait() + listener.waitPeers() if listener.lastState != n.lastNotification { t.Errorf("invalid state: %d, expected: %d", listener.lastState, n.lastNotification) } @@ -85,9 +97,14 @@ func Test_notifier_SetListener(t *testing.T) { func Test_notifier_RemoveListener(t *testing.T) { listener := &mocListener{} listener.setWaiter() + listener.setPeersWaiter() n := newNotifier() n.lastNotification = stateConnecting n.setListener(listener) + // setListener replays cached state on a goroutine; wait for both the state + // and peers callbacks to finish so we don't race on listener.peers. + listener.wait() + listener.waitPeers() n.removeListener() n.peerListChanged(1) diff --git a/client/internal/peer/signaler.go b/client/internal/peer/signaler.go index b28906625..5e437d96b 100644 --- a/client/internal/peer/signaler.go +++ b/client/internal/peer/signaler.go @@ -46,23 +46,27 @@ func (s *Signaler) Ready() bool { // SignalOfferAnswer signals either an offer or an answer to remote peer func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error { - sessionIDBytes, err := offerAnswer.SessionID.Bytes() - if err != nil { - log.Warnf("failed to get session ID bytes: %v", err) + var sessionIDBytes []byte + if offerAnswer.SessionID != nil { + var err error + sessionIDBytes, err = offerAnswer.SessionID.Bytes() + if err != nil { + log.Warnf("failed to get session ID bytes: %v", err) + } } - msg, err := signal.MarshalCredential( - s.wgPrivateKey, - offerAnswer.WgListenPort, - remoteKey, - &signal.Credential{ + msg, err := signal.MarshalCredential(s.wgPrivateKey, remoteKey, signal.CredentialPayload{ + Type: bodyType, + WgListenPort: offerAnswer.WgListenPort, + Credential: &signal.Credential{ UFrag: offerAnswer.IceCredentials.UFrag, Pwd: offerAnswer.IceCredentials.Pwd, }, - bodyType, - offerAnswer.RosenpassPubKey, - offerAnswer.RosenpassAddr, - offerAnswer.RelaySrvAddress, - sessionIDBytes) + RosenpassPubKey: offerAnswer.RosenpassPubKey, + RosenpassAddr: offerAnswer.RosenpassAddr, + RelaySrvAddress: offerAnswer.RelaySrvAddress, + RelaySrvIP: offerAnswer.RelaySrvIP, + SessionID: sessionIDBytes, + }) if err != nil { return err } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index abedc208e..e8e61f660 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -320,10 +320,10 @@ func (d *Status) RemovePeer(peerPubKey string) error { // UpdatePeerState updates peer status func (d *Status) UpdatePeerState(receivedState State) error { d.mux.Lock() - defer d.mux.Unlock() peerState, ok := d.peers[receivedState.PubKey] if !ok { + d.mux.Unlock() return errors.New("peer doesn't exist") } @@ -343,23 +343,29 @@ func (d *Status) UpdatePeerState(receivedState State) error { d.peers[receivedState.PubKey] = peerState - if hasConnStatusChanged(oldState, receivedState.ConnStatus) { - d.notifyPeerListChanged() - } - + notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus) // when we close the connection we will not notify the router manager - if receivedState.ConnStatus == StatusIdle { - d.notifyPeerStateChangeListeners(receivedState.PubKey) + notifyRouter := receivedState.ConnStatus == StatusIdle + routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter) + numPeers := d.numOfPeers() + + d.mux.Unlock() + + if notifyList { + d.notifier.peerListChanged(numPeers) + } + if notifyRouter { + d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot) } return nil } func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.ResID) error { d.mux.Lock() - defer d.mux.Unlock() peerState, ok := d.peers[peer] if !ok { + d.mux.Unlock() return errors.New("peer doesn't exist") } @@ -371,17 +377,20 @@ func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.R d.routeIDLookup.AddRemoteRouteID(resourceId, pref) } + numPeers := d.numOfPeers() + d.mux.Unlock() + // todo: consider to make sense of this notification or not - d.notifyPeerListChanged() + d.notifier.peerListChanged(numPeers) return nil } func (d *Status) RemovePeerStateRoute(peer string, route string) error { d.mux.Lock() - defer d.mux.Unlock() peerState, ok := d.peers[peer] if !ok { + d.mux.Unlock() return errors.New("peer doesn't exist") } @@ -393,8 +402,11 @@ func (d *Status) RemovePeerStateRoute(peer string, route string) error { d.routeIDLookup.RemoveRemoteRouteID(pref) } + numPeers := d.numOfPeers() + d.mux.Unlock() + // todo: consider to make sense of this notification or not - d.notifyPeerListChanged() + d.notifier.peerListChanged(numPeers) return nil } @@ -410,10 +422,10 @@ func (d *Status) CheckRoutes(ip netip.Addr) ([]byte, bool) { func (d *Status) UpdatePeerICEState(receivedState State) error { d.mux.Lock() - defer d.mux.Unlock() peerState, ok := d.peers[receivedState.PubKey] if !ok { + d.mux.Unlock() return errors.New("peer doesn't exist") } @@ -431,22 +443,28 @@ func (d *Status) UpdatePeerICEState(receivedState State) error { d.peers[receivedState.PubKey] = peerState - if hasConnStatusChanged(oldState, receivedState.ConnStatus) { - d.notifyPeerListChanged() - } + notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus) + notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) + routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter) + numPeers := d.numOfPeers() - if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) { - d.notifyPeerStateChangeListeners(receivedState.PubKey) + d.mux.Unlock() + + if notifyList { + d.notifier.peerListChanged(numPeers) + } + if notifyRouter { + d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot) } return nil } func (d *Status) UpdatePeerRelayedState(receivedState State) error { d.mux.Lock() - defer d.mux.Unlock() peerState, ok := d.peers[receivedState.PubKey] if !ok { + d.mux.Unlock() return errors.New("peer doesn't exist") } @@ -461,22 +479,28 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error { d.peers[receivedState.PubKey] = peerState - if hasConnStatusChanged(oldState, receivedState.ConnStatus) { - d.notifyPeerListChanged() - } + notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus) + notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) + routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter) + numPeers := d.numOfPeers() - if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) { - d.notifyPeerStateChangeListeners(receivedState.PubKey) + d.mux.Unlock() + + if notifyList { + d.notifier.peerListChanged(numPeers) + } + if notifyRouter { + d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot) } return nil } func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error { d.mux.Lock() - defer d.mux.Unlock() peerState, ok := d.peers[receivedState.PubKey] if !ok { + d.mux.Unlock() return errors.New("peer doesn't exist") } @@ -490,22 +514,28 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error d.peers[receivedState.PubKey] = peerState - if hasConnStatusChanged(oldState, receivedState.ConnStatus) { - d.notifyPeerListChanged() - } + notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus) + notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) + routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter) + numPeers := d.numOfPeers() - if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) { - d.notifyPeerStateChangeListeners(receivedState.PubKey) + d.mux.Unlock() + + if notifyList { + d.notifier.peerListChanged(numPeers) + } + if notifyRouter { + d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot) } return nil } func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error { d.mux.Lock() - defer d.mux.Unlock() peerState, ok := d.peers[receivedState.PubKey] if !ok { + d.mux.Unlock() return errors.New("peer doesn't exist") } @@ -522,12 +552,18 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error { d.peers[receivedState.PubKey] = peerState - if hasConnStatusChanged(oldState, receivedState.ConnStatus) { - d.notifyPeerListChanged() - } + notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus) + notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) + routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter) + numPeers := d.numOfPeers() - if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) { - d.notifyPeerStateChangeListeners(receivedState.PubKey) + d.mux.Unlock() + + if notifyList { + d.notifier.peerListChanged(numPeers) + } + if notifyRouter { + d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot) } return nil } @@ -594,17 +630,33 @@ func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) erro // FinishPeerListModifications this event invoke the notification func (d *Status) FinishPeerListModifications() { d.mux.Lock() - defer d.mux.Unlock() if !d.peerListChangedForNotification { + d.mux.Unlock() return } d.peerListChangedForNotification = false - d.notifyPeerListChanged() + numPeers := d.numOfPeers() + // snapshot per-peer router state to deliver after the lock is released + type routerDispatch struct { + peerID string + snapshot map[string]RouterState + } + dispatches := make([]routerDispatch, 0, len(d.peers)) for key := range d.peers { - d.notifyPeerStateChangeListeners(key) + snapshot := d.snapshotRouterPeersLocked(key, true) + if snapshot != nil { + dispatches = append(dispatches, routerDispatch{peerID: key, snapshot: snapshot}) + } + } + + d.mux.Unlock() + + d.notifier.peerListChanged(numPeers) + for _, rd := range dispatches { + d.dispatchRouterPeers(rd.peerID, rd.snapshot) } } @@ -655,10 +707,12 @@ func (d *Status) GetLocalPeerState() LocalPeerState { // UpdateLocalPeerState updates local peer status func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) { d.mux.Lock() - defer d.mux.Unlock() - d.localPeer = localPeerState - d.notifyAddressChanged() + fqdn := d.localPeer.FQDN + ip := d.localPeer.IP + d.mux.Unlock() + + d.notifier.localAddressChanged(fqdn, ip) } // AddLocalPeerStateRoute adds a route to the local peer state @@ -721,30 +775,36 @@ func (d *Status) CleanLocalPeerStateRoutes() { // CleanLocalPeerState cleans local peer status func (d *Status) CleanLocalPeerState() { d.mux.Lock() - defer d.mux.Unlock() - d.localPeer = LocalPeerState{} - d.notifyAddressChanged() + fqdn := d.localPeer.FQDN + ip := d.localPeer.IP + d.mux.Unlock() + + d.notifier.localAddressChanged(fqdn, ip) } // MarkManagementDisconnected sets ManagementState to disconnected func (d *Status) MarkManagementDisconnected(err error) { d.mux.Lock() - defer d.mux.Unlock() - defer d.onConnectionChanged() - d.managementState = false d.managementError = err + mgm := d.managementState + sig := d.signalState + d.mux.Unlock() + + d.notifier.updateServerStates(mgm, sig) } // MarkManagementConnected sets ManagementState to connected func (d *Status) MarkManagementConnected() { d.mux.Lock() - defer d.mux.Unlock() - defer d.onConnectionChanged() - d.managementState = true d.managementError = nil + mgm := d.managementState + sig := d.signalState + d.mux.Unlock() + + d.notifier.updateServerStates(mgm, sig) } // UpdateSignalAddress update the address of the signal server @@ -778,21 +838,25 @@ func (d *Status) UpdateLazyConnection(enabled bool) { // MarkSignalDisconnected sets SignalState to disconnected func (d *Status) MarkSignalDisconnected(err error) { d.mux.Lock() - defer d.mux.Unlock() - defer d.onConnectionChanged() - d.signalState = false d.signalError = err + mgm := d.managementState + sig := d.signalState + d.mux.Unlock() + + d.notifier.updateServerStates(mgm, sig) } // MarkSignalConnected sets SignalState to connected func (d *Status) MarkSignalConnected() { d.mux.Lock() - defer d.mux.Unlock() - defer d.onConnectionChanged() - d.signalState = true d.signalError = nil + mgm := d.managementState + sig := d.signalState + d.mux.Unlock() + + d.notifier.updateServerStates(mgm, sig) } func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) { @@ -919,7 +983,7 @@ func (d *Status) GetRelayStates() []relay.ProbeResult { // if the server connection is not established then we will use the general address // in case of connection we will use the instance specific address - instanceAddr, err := d.relayMgr.RelayInstanceAddress() + instanceAddr, _, err := d.relayMgr.RelayInstanceAddress() if err != nil { // TODO add their status for _, r := range d.relayMgr.ServerURLs() { @@ -1012,18 +1076,17 @@ func (d *Status) RemoveConnectionListener() { d.notifier.removeListener() } -func (d *Status) onConnectionChanged() { - d.notifier.updateServerStates(d.managementState, d.signalState) -} - -// notifyPeerStateChangeListeners notifies route manager about the change in peer state -func (d *Status) notifyPeerStateChangeListeners(peerID string) { - subs, ok := d.changeNotify[peerID] - if !ok { - return +// snapshotRouterPeersLocked builds the RouterState map for a peer's subscribers. +// Caller MUST hold d.mux. Returns nil when there are no subscribers for peerID +// or when notify is false. The snapshot is consumed later by dispatchRouterPeers +// outside the lock so the channel send cannot stall any d.mux holder. +func (d *Status) snapshotRouterPeersLocked(peerID string, notify bool) map[string]RouterState { + if !notify { + return nil + } + if _, ok := d.changeNotify[peerID]; !ok { + return nil } - - // collect the relevant data for router peers routerPeers := make(map[string]RouterState, len(d.changeNotify)) for pid := range d.changeNotify { s, ok := d.peers[pid] @@ -1031,13 +1094,35 @@ func (d *Status) notifyPeerStateChangeListeners(peerID string) { log.Warnf("router peer not found in peers list: %s", pid) continue } - routerPeers[pid] = RouterState{ Status: s.ConnStatus, Relayed: s.Relayed, Latency: s.Latency, } } + return routerPeers +} + +// dispatchRouterPeers delivers a previously snapshotted router-state map to +// the peer's subscribers. Caller MUST NOT hold d.mux. The method takes a +// fresh, short read of d.changeNotify under the lock to grab subscriber +// channels, then sends outside the lock so a slow consumer cannot block other +// d.mux holders. The send itself stays blocking (only short-circuited by the +// subscriber's context) so peer state transitions are not silently dropped. +func (d *Status) dispatchRouterPeers(peerID string, routerPeers map[string]RouterState) { + if routerPeers == nil { + return + } + + d.mux.Lock() + subsMap, ok := d.changeNotify[peerID] + subs := make([]*StatusChangeSubscription, 0, len(subsMap)) + if ok { + for _, sub := range subsMap { + subs = append(subs, sub) + } + } + d.mux.Unlock() for _, sub := range subs { select { @@ -1047,14 +1132,6 @@ func (d *Status) notifyPeerStateChangeListeners(peerID string) { } } -func (d *Status) notifyPeerListChanged() { - d.notifier.peerListChanged(d.numOfPeers()) -} - -func (d *Status) notifyAddressChanged() { - d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP) -} - func (d *Status) numOfPeers() int { return len(d.peers) + len(d.offlinePeers) } diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index edd70fb20..29bf5aaaa 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/internal/peer/conntype" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" + "github.com/netbirdio/netbird/client/internal/portforward" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/route" ) @@ -61,6 +62,9 @@ type WorkerICE struct { // we record the last known state of the ICE agent to avoid duplicate on disconnected events lastKnownState ice.ConnectionState + + // portForwardAttempted tracks if we've already tried port forwarding this session + portForwardAttempted bool } func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) { @@ -214,6 +218,8 @@ func (w *WorkerICE) Close() { } func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) { + w.portForwardAttempted = false + agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd) if err != nil { return nil, fmt.Errorf("create agent: %w", err) @@ -370,6 +376,93 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) { w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err) } }() + + if candidate.Type() == ice.CandidateTypeServerReflexive { + w.injectPortForwardedCandidate(candidate) + } +} + +// injectPortForwardedCandidate signals an additional candidate using the pre-created port mapping. +func (w *WorkerICE) injectPortForwardedCandidate(srflxCandidate ice.Candidate) { + pfManager := w.conn.portForwardManager + if pfManager == nil { + return + } + + mapping := pfManager.GetMapping() + if mapping == nil { + return + } + + w.muxAgent.Lock() + if w.portForwardAttempted { + w.muxAgent.Unlock() + return + } + w.portForwardAttempted = true + w.muxAgent.Unlock() + + forwardedCandidate, err := w.createForwardedCandidate(srflxCandidate, mapping) + if err != nil { + w.log.Warnf("create forwarded candidate: %v", err) + return + } + + w.log.Debugf("injecting port-forwarded candidate: %s (mapping: %d -> %d via %s, priority: %d)", + forwardedCandidate.String(), mapping.InternalPort, mapping.ExternalPort, mapping.NATType, forwardedCandidate.Priority()) + + go func() { + if err := w.signaler.SignalICECandidate(forwardedCandidate, w.config.Key); err != nil { + w.log.Errorf("signal port-forwarded candidate: %v", err) + } + }() +} + +// createForwardedCandidate creates a new server reflexive candidate with the forwarded port. +// It uses the NAT gateway's external IP with the forwarded port. +func (w *WorkerICE) createForwardedCandidate(srflxCandidate ice.Candidate, mapping *portforward.Mapping) (ice.Candidate, error) { + var externalIP string + if mapping.ExternalIP != nil && !mapping.ExternalIP.IsUnspecified() { + externalIP = mapping.ExternalIP.String() + } else { + // Fallback to STUN-discovered address if NAT didn't provide external IP + externalIP = srflxCandidate.Address() + } + + // Per RFC 8445, the related address for srflx is the base (host candidate address). + // If the original srflx has unspecified related address, use its own address as base. + relAddr := srflxCandidate.RelatedAddress().Address + if relAddr == "" || relAddr == "0.0.0.0" || relAddr == "::" { + relAddr = srflxCandidate.Address() + } + + // Arbitrary +1000 boost on top of RFC 8445 priority to favor port-forwarded candidates + // over regular srflx during ICE connectivity checks. + priority := srflxCandidate.Priority() + 1000 + + candidate, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ + Network: srflxCandidate.NetworkType().String(), + Address: externalIP, + Port: int(mapping.ExternalPort), + Component: srflxCandidate.Component(), + Priority: priority, + RelAddr: relAddr, + RelPort: int(mapping.InternalPort), + }) + if err != nil { + return nil, fmt.Errorf("create candidate: %w", err) + } + + for _, e := range srflxCandidate.Extensions() { + if e.Key == ice.ExtensionKeyCandidateID { + e.Value = srflxCandidate.ID() + } + if err := candidate.AddExtension(e); err != nil { + return nil, fmt.Errorf("add extension: %w", err) + } + } + + return candidate, nil } func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) { @@ -411,10 +504,10 @@ func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) { if !lok || !rok { continue } - w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms", + w.log.Debugf("successful ICE path %s: [%s %s %s:%d] <-> [%s %s %s:%d] rtt=%.3fms", sessionID, - local.NetworkType(), local.Type(), local.Address(), - remote.NetworkType(), remote.Type(), remote.Address(), + local.NetworkType(), local.Type(), local.Address(), local.Port(), + remote.NetworkType(), remote.Type(), remote.Address(), remote.Port(), stat.CurrentRoundTripTime*1000) } } diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go index 06309fbaf..0402992c9 100644 --- a/client/internal/peer/worker_relay.go +++ b/client/internal/peer/worker_relay.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net" + "net/netip" "sync" "sync/atomic" @@ -53,15 +54,19 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { w.relaySupportedOnRemotePeer.Store(true) // the relayManager will return with error in case if the connection has lost with relay server - currentRelayAddress, err := w.relayManager.RelayInstanceAddress() + currentRelayAddress, _, err := w.relayManager.RelayInstanceAddress() if err != nil { w.log.Errorf("failed to handle new offer: %s", err) return } srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress) + var serverIP netip.Addr + if srv == remoteOfferAnswer.RelaySrvAddress { + serverIP = remoteOfferAnswer.RelaySrvIP + } - relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key) + relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key, serverIP) if err != nil { if errors.Is(err, relayClient.ErrConnAlreadyExists) { w.log.Debugf("handled offer by reusing existing relay connection") @@ -90,7 +95,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { }) } -func (w *WorkerRelay) RelayInstanceAddress() (string, error) { +func (w *WorkerRelay) RelayInstanceAddress() (string, netip.Addr, error) { return w.relayManager.RelayInstanceAddress() } diff --git a/client/internal/portforward/env.go b/client/internal/portforward/env.go new file mode 100644 index 000000000..ba83c79bf --- /dev/null +++ b/client/internal/portforward/env.go @@ -0,0 +1,35 @@ +package portforward + +import ( + "os" + "strconv" + + log "github.com/sirupsen/logrus" +) + +const ( + envDisableNATMapper = "NB_DISABLE_NAT_MAPPER" + envDisablePCPHealthCheck = "NB_DISABLE_PCP_HEALTH_CHECK" +) + +func isDisabledByEnv() bool { + return parseBoolEnv(envDisableNATMapper) +} + +func isHealthCheckDisabled() bool { + return parseBoolEnv(envDisablePCPHealthCheck) +} + +func parseBoolEnv(key string) bool { + val := os.Getenv(key) + if val == "" { + return false + } + + disabled, err := strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", key, err) + return false + } + return disabled +} diff --git a/client/internal/portforward/manager.go b/client/internal/portforward/manager.go new file mode 100644 index 000000000..b0680160c --- /dev/null +++ b/client/internal/portforward/manager.go @@ -0,0 +1,342 @@ +//go:build !js + +package portforward + +import ( + "context" + "fmt" + "net" + "regexp" + "sync" + "time" + + "github.com/libp2p/go-nat" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/portforward/pcp" +) + +const ( + defaultMappingTTL = 2 * time.Hour + healthCheckInterval = 1 * time.Minute + discoveryTimeout = 10 * time.Second + mappingDescription = "NetBird" +) + +// upnpErrPermanentLeaseOnly matches UPnP error 725 in SOAP fault XML, +// allowing for whitespace/newlines between tags from different router firmware. +var upnpErrPermanentLeaseOnly = regexp.MustCompile(`\s*725\s*`) + +// Mapping represents an active NAT port mapping. +type Mapping struct { + Protocol string + InternalPort uint16 + ExternalPort uint16 + ExternalIP net.IP + NATType string + // TTL is the lease duration. Zero means a permanent lease that never expires. + TTL time.Duration +} + +// TODO: persist mapping state for crash recovery cleanup of permanent leases. +// Currently not done because State.Cleanup requires NAT gateway re-discovery, +// which blocks startup for ~10s when no gateway is present (affects all clients). + +type Manager struct { + cancel context.CancelFunc + + mapping *Mapping + mappingLock sync.Mutex + + wgPort uint16 + + done chan struct{} + stopCtx chan context.Context + + // protect exported functions + mu sync.Mutex +} + +// NewManager creates a new port forwarding manager. +func NewManager() *Manager { + return &Manager{ + stopCtx: make(chan context.Context, 1), + } +} + +func (m *Manager) Start(ctx context.Context, wgPort uint16) { + m.mu.Lock() + if m.cancel != nil { + m.mu.Unlock() + return + } + + if isDisabledByEnv() { + log.Infof("NAT port mapper disabled via %s", envDisableNATMapper) + m.mu.Unlock() + return + } + + if wgPort == 0 { + log.Warnf("invalid WireGuard port 0; NAT mapping disabled") + m.mu.Unlock() + return + } + m.wgPort = wgPort + + m.done = make(chan struct{}) + defer close(m.done) + + ctx, m.cancel = context.WithCancel(ctx) + m.mu.Unlock() + + gateway, mapping, err := m.setup(ctx) + if err != nil { + log.Infof("port forwarding setup: %v", err) + return + } + + m.mappingLock.Lock() + m.mapping = mapping + m.mappingLock.Unlock() + + m.renewLoop(ctx, gateway, mapping.TTL) + + select { + case cleanupCtx := <-m.stopCtx: + // block the Start while cleaned up gracefully + m.cleanup(cleanupCtx, gateway) + default: + // return Start immediately and cleanup in background + cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 10*time.Second) + go func() { + defer cleanupCancel() + m.cleanup(cleanupCtx, gateway) + }() + } +} + +// GetMapping returns the current mapping if ready, nil otherwise +func (m *Manager) GetMapping() *Mapping { + m.mappingLock.Lock() + defer m.mappingLock.Unlock() + + if m.mapping == nil { + return nil + } + + mapping := *m.mapping + return &mapping +} + +// GracefullyStop cancels the manager and attempts to delete the port mapping. +// After GracefullyStop returns, the manager cannot be restarted. +func (m *Manager) GracefullyStop(ctx context.Context) error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.cancel == nil { + return nil + } + + // Send cleanup context before cancelling, so Start picks it up after renewLoop exits. + m.startTearDown(ctx) + + m.cancel() + m.cancel = nil + + select { + case <-ctx.Done(): + return ctx.Err() + case <-m.done: + return nil + } +} + +func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) { + discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout) + defer discoverCancel() + + gateway, err := discoverGateway(discoverCtx) + if err != nil { + return nil, nil, fmt.Errorf("discover gateway: %w", err) + } + + log.Infof("discovered NAT gateway: %s", gateway.Type()) + + mapping, err := m.createMapping(ctx, gateway) + if err != nil { + return nil, nil, fmt.Errorf("create port mapping: %w", err) + } + return gateway, mapping, nil +} + +func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + ttl := defaultMappingTTL + externalPort, err := gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl) + if err != nil { + if !isPermanentLeaseRequired(err) { + return nil, err + } + log.Infof("gateway only supports permanent leases, retrying with indefinite duration") + ttl = 0 + externalPort, err = gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl) + if err != nil { + return nil, err + } + } + + externalIP, err := gateway.GetExternalAddress() + if err != nil { + log.Debugf("failed to get external address: %v", err) + } + + mapping := &Mapping{ + Protocol: "udp", + InternalPort: m.wgPort, + ExternalPort: uint16(externalPort), + ExternalIP: externalIP, + NATType: gateway.Type(), + TTL: ttl, + } + + log.Infof("created port mapping: %d -> %d via %s (external IP: %s)", + m.wgPort, externalPort, gateway.Type(), externalIP) + return mapping, nil +} + +func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT, ttl time.Duration) { + if ttl == 0 { + // Permanent mappings don't expire, just wait for cancellation + // but still run health checks for PCP gateways. + m.permanentLeaseLoop(ctx, gateway) + return + } + + renewTicker := time.NewTicker(ttl / 2) + healthTicker := time.NewTicker(healthCheckInterval) + defer renewTicker.Stop() + defer healthTicker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-renewTicker.C: + if err := m.renewMapping(ctx, gateway); err != nil { + log.Warnf("failed to renew port mapping: %v", err) + continue + } + case <-healthTicker.C: + if m.checkHealthAndRecreate(ctx, gateway) { + renewTicker.Reset(ttl / 2) + } + } + } +} + +func (m *Manager) permanentLeaseLoop(ctx context.Context, gateway nat.NAT) { + healthTicker := time.NewTicker(healthCheckInterval) + defer healthTicker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-healthTicker.C: + m.checkHealthAndRecreate(ctx, gateway) + } + } +} + +func (m *Manager) checkHealthAndRecreate(ctx context.Context, gateway nat.NAT) bool { + if isHealthCheckDisabled() { + return false + } + + m.mappingLock.Lock() + hasMapping := m.mapping != nil + m.mappingLock.Unlock() + + if !hasMapping { + return false + } + + pcpNAT, ok := gateway.(*pcp.NAT) + if !ok { + return false + } + + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + epoch, serverRestarted, err := pcpNAT.CheckServerHealth(ctx) + if err != nil { + log.Debugf("PCP health check failed: %v", err) + return false + } + + if serverRestarted { + log.Warnf("PCP server restart detected (epoch=%d), recreating port mapping", epoch) + if err := m.renewMapping(ctx, gateway); err != nil { + log.Errorf("failed to recreate port mapping after server restart: %v", err) + return false + } + return true + } + + return false +} + +func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + externalPort, err := gateway.AddPortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort), mappingDescription, m.mapping.TTL) + if err != nil { + return fmt.Errorf("add port mapping: %w", err) + } + + if uint16(externalPort) != m.mapping.ExternalPort { + log.Warnf("external port changed on renewal: %d -> %d (candidate may be stale)", m.mapping.ExternalPort, externalPort) + m.mappingLock.Lock() + m.mapping.ExternalPort = uint16(externalPort) + m.mappingLock.Unlock() + } + + log.Debugf("renewed port mapping: %d -> %d", m.mapping.InternalPort, m.mapping.ExternalPort) + return nil +} + +func (m *Manager) cleanup(ctx context.Context, gateway nat.NAT) { + m.mappingLock.Lock() + mapping := m.mapping + m.mapping = nil + m.mappingLock.Unlock() + + if mapping == nil { + return + } + + if err := gateway.DeletePortMapping(ctx, mapping.Protocol, int(mapping.InternalPort)); err != nil { + log.Warnf("delete port mapping on stop: %v", err) + return + } + + log.Infof("deleted port mapping for port %d", mapping.InternalPort) +} + +func (m *Manager) startTearDown(ctx context.Context) { + select { + case m.stopCtx <- ctx: + default: + } +} + +// isPermanentLeaseRequired checks if a UPnP error indicates the gateway only supports permanent leases (error 725). +func isPermanentLeaseRequired(err error) bool { + return err != nil && upnpErrPermanentLeaseOnly.MatchString(err.Error()) +} diff --git a/client/internal/portforward/manager_js.go b/client/internal/portforward/manager_js.go new file mode 100644 index 000000000..36c55063b --- /dev/null +++ b/client/internal/portforward/manager_js.go @@ -0,0 +1,39 @@ +package portforward + +import ( + "context" + "net" + "time" +) + +// Mapping represents an active NAT port mapping. +type Mapping struct { + Protocol string + InternalPort uint16 + ExternalPort uint16 + ExternalIP net.IP + NATType string + // TTL is the lease duration. Zero means a permanent lease that never expires. + TTL time.Duration +} + +// Manager is a stub for js/wasm builds where NAT-PMP/UPnP is not supported. +type Manager struct{} + +// NewManager returns a stub manager for js/wasm builds. +func NewManager() *Manager { + return &Manager{} +} + +// Start is a no-op on js/wasm: NAT-PMP/UPnP is not available in browser environments. +func (m *Manager) Start(context.Context, uint16) { + // no NAT traversal in wasm +} + +// GracefullyStop is a no-op on js/wasm. +func (m *Manager) GracefullyStop(context.Context) error { return nil } + +// GetMapping always returns nil on js/wasm. +func (m *Manager) GetMapping() *Mapping { + return nil +} diff --git a/client/internal/portforward/manager_test.go b/client/internal/portforward/manager_test.go new file mode 100644 index 000000000..1f66f9ccd --- /dev/null +++ b/client/internal/portforward/manager_test.go @@ -0,0 +1,201 @@ +//go:build !js + +package portforward + +import ( + "context" + "fmt" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockNAT struct { + natType string + deviceAddr net.IP + externalAddr net.IP + internalAddr net.IP + mappings map[int]int + addMappingErr error + deleteMappingErr error + onlyPermanentLeases bool + lastTimeout time.Duration +} + +func newMockNAT() *mockNAT { + return &mockNAT{ + natType: "Mock-NAT", + deviceAddr: net.ParseIP("192.168.1.1"), + externalAddr: net.ParseIP("203.0.113.50"), + internalAddr: net.ParseIP("192.168.1.100"), + mappings: make(map[int]int), + } +} + +func (m *mockNAT) Type() string { + return m.natType +} + +func (m *mockNAT) GetDeviceAddress() (net.IP, error) { + return m.deviceAddr, nil +} + +func (m *mockNAT) GetExternalAddress() (net.IP, error) { + return m.externalAddr, nil +} + +func (m *mockNAT) GetInternalAddress() (net.IP, error) { + return m.internalAddr, nil +} + +func (m *mockNAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, description string, timeout time.Duration) (int, error) { + if m.addMappingErr != nil { + return 0, m.addMappingErr + } + if m.onlyPermanentLeases && timeout != 0 { + return 0, fmt.Errorf("SOAP fault. Code: | Explanation: | Detail: 725OnlyPermanentLeasesSupported") + } + externalPort := internalPort + m.mappings[internalPort] = externalPort + m.lastTimeout = timeout + return externalPort, nil +} + +func (m *mockNAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error { + if m.deleteMappingErr != nil { + return m.deleteMappingErr + } + delete(m.mappings, internalPort) + return nil +} + +func TestManager_CreateMapping(t *testing.T) { + m := NewManager() + m.wgPort = 51820 + + gateway := newMockNAT() + mapping, err := m.createMapping(context.Background(), gateway) + require.NoError(t, err) + require.NotNil(t, mapping) + + assert.Equal(t, "udp", mapping.Protocol) + assert.Equal(t, uint16(51820), mapping.InternalPort) + assert.Equal(t, uint16(51820), mapping.ExternalPort) + assert.Equal(t, "Mock-NAT", mapping.NATType) + assert.Equal(t, net.ParseIP("203.0.113.50").To4(), mapping.ExternalIP.To4()) + assert.Equal(t, defaultMappingTTL, mapping.TTL) +} + +func TestManager_GetMapping_ReturnsNilWhenNotReady(t *testing.T) { + m := NewManager() + assert.Nil(t, m.GetMapping()) +} + +func TestManager_GetMapping_ReturnsCopy(t *testing.T) { + m := NewManager() + m.mapping = &Mapping{ + Protocol: "udp", + InternalPort: 51820, + ExternalPort: 51820, + } + + mapping := m.GetMapping() + require.NotNil(t, mapping) + assert.Equal(t, uint16(51820), mapping.InternalPort) + + // Mutating the returned copy should not affect the manager's mapping. + mapping.ExternalPort = 9999 + assert.Equal(t, uint16(51820), m.GetMapping().ExternalPort) +} + +func TestManager_Cleanup_DeletesMapping(t *testing.T) { + m := NewManager() + m.mapping = &Mapping{ + Protocol: "udp", + InternalPort: 51820, + ExternalPort: 51820, + } + + gateway := newMockNAT() + // Seed the mock so we can verify deletion. + gateway.mappings[51820] = 51820 + + m.cleanup(context.Background(), gateway) + + _, exists := gateway.mappings[51820] + assert.False(t, exists, "mapping should be deleted from gateway") + assert.Nil(t, m.GetMapping(), "in-memory mapping should be cleared") +} + +func TestManager_Cleanup_NilMapping(t *testing.T) { + m := NewManager() + gateway := newMockNAT() + + // Should not panic or call gateway. + m.cleanup(context.Background(), gateway) +} + + +func TestManager_CreateMapping_PermanentLeaseFallback(t *testing.T) { + m := NewManager() + m.wgPort = 51820 + + gateway := newMockNAT() + gateway.onlyPermanentLeases = true + + mapping, err := m.createMapping(context.Background(), gateway) + require.NoError(t, err) + require.NotNil(t, mapping) + + assert.Equal(t, uint16(51820), mapping.InternalPort) + assert.Equal(t, time.Duration(0), mapping.TTL, "should return zero TTL for permanent lease") + assert.Equal(t, time.Duration(0), gateway.lastTimeout, "should have retried with zero duration") +} + +func TestIsPermanentLeaseRequired(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "nil error", + err: nil, + expected: false, + }, + { + name: "UPnP error 725", + err: fmt.Errorf("SOAP fault. Code: | Detail: 725OnlyPermanentLeasesSupported"), + expected: true, + }, + { + name: "wrapped error with 725", + err: fmt.Errorf("add port mapping: %w", fmt.Errorf("Detail: 725")), + expected: true, + }, + { + name: "error 725 with newlines in XML", + err: fmt.Errorf("\n 725\n"), + expected: true, + }, + { + name: "bare 725 without XML tag", + err: fmt.Errorf("error code 725"), + expected: false, + }, + { + name: "unrelated error", + err: fmt.Errorf("connection refused"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, isPermanentLeaseRequired(tt.err)) + }) + } +} diff --git a/client/internal/portforward/pcp/client.go b/client/internal/portforward/pcp/client.go new file mode 100644 index 000000000..f6d243ef9 --- /dev/null +++ b/client/internal/portforward/pcp/client.go @@ -0,0 +1,408 @@ +package pcp + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "net" + "net/netip" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + defaultTimeout = 3 * time.Second + responseBufferSize = 128 + + // RFC 6887 Section 8.1.1 retry timing + initialRetryDelay = 3 * time.Second + maxRetryDelay = 1024 * time.Second + maxRetries = 4 // 3s + 6s + 12s + 24s = 45s total worst case +) + +// Client is a PCP protocol client. +// All methods are safe for concurrent use. +type Client struct { + gateway netip.Addr + timeout time.Duration + + mu sync.Mutex + // localIP caches the resolved local IP address. + localIP netip.Addr + // lastEpoch is the last observed server epoch value. + lastEpoch uint32 + // epochTime tracks when lastEpoch was received for state loss detection. + epochTime time.Time + // externalIP caches the external IP from the last successful MAP response. + externalIP netip.Addr + // epochStateLost is set when epoch indicates server restart. + epochStateLost bool +} + +// NewClient creates a new PCP client for the gateway at the given IP. +func NewClient(gateway net.IP) *Client { + addr, ok := netip.AddrFromSlice(gateway) + if !ok { + log.Debugf("invalid gateway IP: %v", gateway) + } + return &Client{ + gateway: addr.Unmap(), + timeout: defaultTimeout, + } +} + +// NewClientWithTimeout creates a new PCP client with a custom timeout. +func NewClientWithTimeout(gateway net.IP, timeout time.Duration) *Client { + addr, ok := netip.AddrFromSlice(gateway) + if !ok { + log.Debugf("invalid gateway IP: %v", gateway) + } + return &Client{ + gateway: addr.Unmap(), + timeout: timeout, + } +} + +// SetLocalIP sets the local IP address to use in PCP requests. +func (c *Client) SetLocalIP(ip net.IP) { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + log.Debugf("invalid local IP: %v", ip) + } + c.mu.Lock() + c.localIP = addr.Unmap() + c.mu.Unlock() +} + +// Gateway returns the gateway IP address. +func (c *Client) Gateway() net.IP { + return c.gateway.AsSlice() +} + +// Announce sends a PCP ANNOUNCE request to discover PCP support. +// Returns the server's epoch time on success. +func (c *Client) Announce(ctx context.Context) (epoch uint32, err error) { + localIP, err := c.getLocalIP() + if err != nil { + return 0, fmt.Errorf("get local IP: %w", err) + } + + req := buildAnnounceRequest(localIP) + resp, err := c.sendRequest(ctx, req) + if err != nil { + return 0, fmt.Errorf("send announce: %w", err) + } + + parsed, err := parseResponse(resp) + if err != nil { + return 0, fmt.Errorf("parse announce response: %w", err) + } + + if parsed.ResultCode != ResultSuccess { + return 0, fmt.Errorf("PCP ANNOUNCE failed: %s", ResultCodeString(parsed.ResultCode)) + } + + c.mu.Lock() + if c.updateEpochLocked(parsed.Epoch) { + log.Warnf("PCP server epoch indicates state loss - mappings may need refresh") + } + c.mu.Unlock() + return parsed.Epoch, nil +} + +// AddPortMapping requests a port mapping from the PCP server. +func (c *Client) AddPortMapping(ctx context.Context, protocol string, internalPort int, lifetime time.Duration) (*MapResponse, error) { + return c.addPortMappingWithHint(ctx, protocol, internalPort, internalPort, netip.Addr{}, lifetime) +} + +// AddPortMappingWithHint requests a port mapping with suggested external port and IP. +// Use lifetime <= 0 to delete a mapping. +func (c *Client) AddPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP net.IP, lifetime time.Duration) (*MapResponse, error) { + var extIP netip.Addr + if suggestedExtIP != nil { + var ok bool + extIP, ok = netip.AddrFromSlice(suggestedExtIP) + if !ok { + log.Debugf("invalid suggested external IP: %v", suggestedExtIP) + } + extIP = extIP.Unmap() + } + return c.addPortMappingWithHint(ctx, protocol, internalPort, suggestedExtPort, extIP, lifetime) +} + +func (c *Client) addPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP netip.Addr, lifetime time.Duration) (*MapResponse, error) { + localIP, err := c.getLocalIP() + if err != nil { + return nil, fmt.Errorf("get local IP: %w", err) + } + + proto, err := protocolNumber(protocol) + if err != nil { + return nil, fmt.Errorf("parse protocol: %w", err) + } + + var nonce [12]byte + if _, err := rand.Read(nonce[:]); err != nil { + return nil, fmt.Errorf("generate nonce: %w", err) + } + + // Convert lifetime to seconds. Lifetime 0 means delete, so only apply + // default for positive durations that round to 0 seconds. + var lifetimeSec uint32 + if lifetime > 0 { + lifetimeSec = uint32(lifetime.Seconds()) + if lifetimeSec == 0 { + lifetimeSec = DefaultLifetime + } + } + + req := buildMapRequest(localIP, nonce, proto, uint16(internalPort), uint16(suggestedExtPort), suggestedExtIP, lifetimeSec) + + resp, err := c.sendRequest(ctx, req) + if err != nil { + return nil, fmt.Errorf("send map request: %w", err) + } + + mapResp, err := parseMapResponse(resp) + if err != nil { + return nil, fmt.Errorf("parse map response: %w", err) + } + + if mapResp.Nonce != nonce { + return nil, fmt.Errorf("nonce mismatch in response") + } + + if mapResp.Protocol != proto { + return nil, fmt.Errorf("protocol mismatch: requested %d, got %d", proto, mapResp.Protocol) + } + if mapResp.InternalPort != uint16(internalPort) { + return nil, fmt.Errorf("internal port mismatch: requested %d, got %d", internalPort, mapResp.InternalPort) + } + + if mapResp.ResultCode != ResultSuccess { + return nil, &Error{ + Code: mapResp.ResultCode, + Message: ResultCodeString(mapResp.ResultCode), + } + } + + c.mu.Lock() + if c.updateEpochLocked(mapResp.Epoch) { + log.Warnf("PCP server epoch indicates state loss - mappings may need refresh") + } + c.cacheExternalIPLocked(mapResp.ExternalIP) + c.mu.Unlock() + return mapResp, nil +} + +// DeletePortMapping removes a port mapping by requesting zero lifetime. +func (c *Client) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error { + if _, err := c.addPortMappingWithHint(ctx, protocol, internalPort, 0, netip.Addr{}, 0); err != nil { + var pcpErr *Error + if errors.As(err, &pcpErr) && pcpErr.Code == ResultNotAuthorized { + return nil + } + return fmt.Errorf("delete mapping: %w", err) + } + return nil +} + +// GetExternalAddress returns the external IP address. +// First checks for a cached value from previous MAP responses. +// If not cached, creates a short-lived mapping to discover the external IP. +func (c *Client) GetExternalAddress(ctx context.Context) (net.IP, error) { + c.mu.Lock() + if c.externalIP.IsValid() { + ip := c.externalIP.AsSlice() + c.mu.Unlock() + return ip, nil + } + c.mu.Unlock() + + // Use an ephemeral port in the dynamic range (49152-65535). + // Port 0 is not valid with UDP/TCP protocols per RFC 6887. + ephemeralPort := 49152 + int(uint16(time.Now().UnixNano()))%(65535-49152) + + // Use minimal lifetime (1 second) for discovery. + resp, err := c.AddPortMapping(ctx, "udp", ephemeralPort, time.Second) + if err != nil { + return nil, fmt.Errorf("create temporary mapping: %w", err) + } + + if err := c.DeletePortMapping(ctx, "udp", ephemeralPort); err != nil { + log.Debugf("cleanup temporary PCP mapping: %v", err) + } + + return resp.ExternalIP.AsSlice(), nil +} + +// LastEpoch returns the last observed server epoch value. +// A decrease in epoch indicates the server may have restarted and mappings may be lost. +func (c *Client) LastEpoch() uint32 { + c.mu.Lock() + defer c.mu.Unlock() + return c.lastEpoch +} + +// EpochStateLost returns true if epoch state loss was detected and clears the flag. +func (c *Client) EpochStateLost() bool { + c.mu.Lock() + defer c.mu.Unlock() + lost := c.epochStateLost + c.epochStateLost = false + return lost +} + +// updateEpoch updates the epoch tracking and detects potential state loss. +// Returns true if state loss was detected (server likely restarted). +// Caller must hold c.mu. +func (c *Client) updateEpochLocked(newEpoch uint32) bool { + now := time.Now() + stateLost := false + + // RFC 6887 Section 8.5: Detect invalid epoch indicating server state loss. + // client_delta = time since last response + // server_delta = epoch change since last response + // Invalid if: client_delta+2 < server_delta - server_delta/16 + // OR: server_delta+2 < client_delta - client_delta/16 + // The +2 handles quantization, /16 (6.25%) handles clock drift. + if !c.epochTime.IsZero() && c.lastEpoch > 0 { + clientDelta := uint32(now.Sub(c.epochTime).Seconds()) + serverDelta := newEpoch - c.lastEpoch + + // Check for epoch going backwards or jumping unexpectedly. + // Subtraction is safe: serverDelta/16 is always <= serverDelta. + if clientDelta+2 < serverDelta-(serverDelta/16) || + serverDelta+2 < clientDelta-(clientDelta/16) { + stateLost = true + c.epochStateLost = true + } + } + + c.lastEpoch = newEpoch + c.epochTime = now + return stateLost +} + +// cacheExternalIP stores the external IP from a successful MAP response. +// Caller must hold c.mu. +func (c *Client) cacheExternalIPLocked(ip netip.Addr) { + if ip.IsValid() && !ip.IsUnspecified() { + c.externalIP = ip + } +} + +// sendRequest sends a PCP request with retries per RFC 6887 Section 8.1.1. +func (c *Client) sendRequest(ctx context.Context, req []byte) ([]byte, error) { + addr := &net.UDPAddr{IP: c.gateway.AsSlice(), Port: Port} + + var lastErr error + delay := initialRetryDelay + + for range maxRetries { + resp, err := c.sendOnce(ctx, addr, req) + if err == nil { + return resp, nil + } + lastErr = err + + if ctx.Err() != nil { + return nil, ctx.Err() + } + + // RFC 6887 Section 8.1.1: RT = (1 + RAND) * MIN(2 * RTprev, MRT) + // RAND is random between -0.1 and +0.1 + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(retryDelayWithJitter(delay)): + } + delay = min(delay*2, maxRetryDelay) + } + + return nil, fmt.Errorf("PCP request failed after %d retries: %w", maxRetries, lastErr) +} + +// retryDelayWithJitter applies RFC 6887 jitter: multiply by (1 + RAND) where RAND is [-0.1, +0.1]. +func retryDelayWithJitter(d time.Duration) time.Duration { + var b [1]byte + _, _ = rand.Read(b[:]) + // Convert byte to range [-0.1, +0.1]: (b/255 * 0.2) - 0.1 + jitter := (float64(b[0])/255.0)*0.2 - 0.1 + return time.Duration(float64(d) * (1 + jitter)) +} + +func (c *Client) sendOnce(ctx context.Context, addr *net.UDPAddr, req []byte) ([]byte, error) { + // Use ListenUDP instead of DialUDP to validate response source address per RFC 6887 §8.3. + conn, err := net.ListenUDP("udp", nil) + if err != nil { + return nil, fmt.Errorf("listen: %w", err) + } + defer func() { + if err := conn.Close(); err != nil { + log.Debugf("close UDP connection: %v", err) + } + }() + + timeout := c.timeout + if deadline, ok := ctx.Deadline(); ok { + if remaining := time.Until(deadline); remaining < timeout { + timeout = remaining + } + } + + if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { + return nil, fmt.Errorf("set deadline: %w", err) + } + + if _, err := conn.WriteToUDP(req, addr); err != nil { + return nil, fmt.Errorf("write: %w", err) + } + + resp := make([]byte, responseBufferSize) + n, from, err := conn.ReadFromUDP(resp) + if err != nil { + return nil, fmt.Errorf("read: %w", err) + } + + // RFC 6887 §8.3: Validate response came from expected PCP server. + if !from.IP.Equal(addr.IP) { + return nil, fmt.Errorf("response from unexpected source %s (expected %s)", from.IP, addr.IP) + } + + return resp[:n], nil +} + +func (c *Client) getLocalIP() (netip.Addr, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if !c.localIP.IsValid() { + return netip.Addr{}, fmt.Errorf("local IP not set for gateway %s", c.gateway) + } + return c.localIP, nil +} + +func protocolNumber(protocol string) (uint8, error) { + switch protocol { + case "udp", "UDP": + return ProtoUDP, nil + case "tcp", "TCP": + return ProtoTCP, nil + default: + return 0, fmt.Errorf("unsupported protocol: %s", protocol) + } +} + +// Error represents a PCP error response. +type Error struct { + Code uint8 + Message string +} + +func (e *Error) Error() string { + return fmt.Sprintf("PCP error: %s (%d)", e.Message, e.Code) +} diff --git a/client/internal/portforward/pcp/client_test.go b/client/internal/portforward/pcp/client_test.go new file mode 100644 index 000000000..79f44a426 --- /dev/null +++ b/client/internal/portforward/pcp/client_test.go @@ -0,0 +1,187 @@ +package pcp + +import ( + "context" + "net" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAddrConversion(t *testing.T) { + tests := []struct { + name string + addr netip.Addr + }{ + {"IPv4", netip.MustParseAddr("192.168.1.100")}, + {"IPv4 loopback", netip.MustParseAddr("127.0.0.1")}, + {"IPv6", netip.MustParseAddr("2001:db8::1")}, + {"IPv6 loopback", netip.MustParseAddr("::1")}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b16 := addrTo16(tt.addr) + + recovered := addrFrom16(b16) + assert.Equal(t, tt.addr, recovered, "address should round-trip") + }) + } +} + +func TestBuildAnnounceRequest(t *testing.T) { + clientIP := netip.MustParseAddr("192.168.1.100") + req := buildAnnounceRequest(clientIP) + + require.Len(t, req, headerSize) + assert.Equal(t, byte(Version), req[0], "version") + assert.Equal(t, byte(OpAnnounce), req[1], "opcode") + + // Check client IP is properly encoded as IPv4-mapped IPv6 + assert.Equal(t, byte(0xff), req[18], "IPv4-mapped prefix byte 10") + assert.Equal(t, byte(0xff), req[19], "IPv4-mapped prefix byte 11") + assert.Equal(t, byte(192), req[20], "IP octet 1") + assert.Equal(t, byte(168), req[21], "IP octet 2") + assert.Equal(t, byte(1), req[22], "IP octet 3") + assert.Equal(t, byte(100), req[23], "IP octet 4") +} + +func TestBuildMapRequest(t *testing.T) { + clientIP := netip.MustParseAddr("192.168.1.100") + nonce := [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} + req := buildMapRequest(clientIP, nonce, ProtoUDP, 51820, 51820, netip.Addr{}, 3600) + + require.Len(t, req, mapRequestSize) + assert.Equal(t, byte(Version), req[0], "version") + assert.Equal(t, byte(OpMap), req[1], "opcode") + + // Lifetime at bytes 4-7 + assert.Equal(t, uint32(3600), (uint32(req[4])<<24)|(uint32(req[5])<<16)|(uint32(req[6])<<8)|uint32(req[7]), "lifetime") + + // Nonce at bytes 24-35 + assert.Equal(t, nonce[:], req[24:36], "nonce") + + // Protocol at byte 36 + assert.Equal(t, byte(ProtoUDP), req[36], "protocol") + + // Internal port at bytes 40-41 + assert.Equal(t, uint16(51820), (uint16(req[40])<<8)|uint16(req[41]), "internal port") + + // External port at bytes 42-43 + assert.Equal(t, uint16(51820), (uint16(req[42])<<8)|uint16(req[43]), "external port") +} + +func TestParseResponse(t *testing.T) { + // Construct a valid ANNOUNCE response + resp := make([]byte, headerSize) + resp[0] = Version + resp[1] = OpAnnounce | OpReply + // Result code = 0 (success) + // Lifetime = 0 + // Epoch = 12345 + resp[8] = 0 + resp[9] = 0 + resp[10] = 0x30 + resp[11] = 0x39 + + parsed, err := parseResponse(resp) + require.NoError(t, err) + assert.Equal(t, uint8(Version), parsed.Version) + assert.Equal(t, uint8(OpAnnounce|OpReply), parsed.Opcode) + assert.Equal(t, uint8(ResultSuccess), parsed.ResultCode) + assert.Equal(t, uint32(12345), parsed.Epoch) +} + +func TestParseResponseErrors(t *testing.T) { + t.Run("too short", func(t *testing.T) { + _, err := parseResponse([]byte{1, 2, 3}) + assert.Error(t, err) + }) + + t.Run("wrong version", func(t *testing.T) { + resp := make([]byte, headerSize) + resp[0] = 1 // Wrong version + resp[1] = OpReply + _, err := parseResponse(resp) + assert.Error(t, err) + }) + + t.Run("missing reply bit", func(t *testing.T) { + resp := make([]byte, headerSize) + resp[0] = Version + resp[1] = OpAnnounce // Missing OpReply bit + _, err := parseResponse(resp) + assert.Error(t, err) + }) +} + +func TestResultCodeString(t *testing.T) { + assert.Equal(t, "SUCCESS", ResultCodeString(ResultSuccess)) + assert.Equal(t, "NOT_AUTHORIZED", ResultCodeString(ResultNotAuthorized)) + assert.Equal(t, "ADDRESS_MISMATCH", ResultCodeString(ResultAddressMismatch)) + assert.Contains(t, ResultCodeString(255), "UNKNOWN") +} + +func TestProtocolNumber(t *testing.T) { + proto, err := protocolNumber("udp") + require.NoError(t, err) + assert.Equal(t, uint8(ProtoUDP), proto) + + proto, err = protocolNumber("tcp") + require.NoError(t, err) + assert.Equal(t, uint8(ProtoTCP), proto) + + proto, err = protocolNumber("UDP") + require.NoError(t, err) + assert.Equal(t, uint8(ProtoUDP), proto) + + _, err = protocolNumber("icmp") + assert.Error(t, err) +} + +func TestClientCreation(t *testing.T) { + gateway := netip.MustParseAddr("192.168.1.1").AsSlice() + + client := NewClient(gateway) + assert.Equal(t, net.IP(gateway), client.Gateway()) + assert.Equal(t, defaultTimeout, client.timeout) + + clientWithTimeout := NewClientWithTimeout(gateway, 5*time.Second) + assert.Equal(t, 5*time.Second, clientWithTimeout.timeout) +} + +func TestNATType(t *testing.T) { + n := NewNAT(netip.MustParseAddr("192.168.1.1").AsSlice(), netip.MustParseAddr("192.168.1.100").AsSlice()) + assert.Equal(t, "PCP", n.Type()) +} + +// Integration test - skipped unless PCP_TEST_GATEWAY env is set +func TestClientIntegration(t *testing.T) { + t.Skip("Integration test - run manually with PCP_TEST_GATEWAY=") + + gateway := netip.MustParseAddr("10.0.1.1").AsSlice() // Change to your test gateway + localIP := netip.MustParseAddr("10.0.1.100").AsSlice() // Change to your local IP + + client := NewClient(gateway) + client.SetLocalIP(localIP) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Test ANNOUNCE + epoch, err := client.Announce(ctx) + require.NoError(t, err) + t.Logf("Server epoch: %d", epoch) + + // Test MAP + resp, err := client.AddPortMapping(ctx, "udp", 51820, 1*time.Hour) + require.NoError(t, err) + t.Logf("Mapping: internal=%d external=%d externalIP=%s", + resp.InternalPort, resp.ExternalPort, resp.ExternalIP) + + // Cleanup + err = client.DeletePortMapping(ctx, "udp", 51820) + require.NoError(t, err) +} diff --git a/client/internal/portforward/pcp/nat.go b/client/internal/portforward/pcp/nat.go new file mode 100644 index 000000000..1dc24274b --- /dev/null +++ b/client/internal/portforward/pcp/nat.go @@ -0,0 +1,209 @@ +package pcp + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/libp2p/go-nat" + "github.com/libp2p/go-netroute" +) + +var _ nat.NAT = (*NAT)(nil) + +// NAT implements the go-nat NAT interface using PCP. +// Supports dual-stack (IPv4 and IPv6) when available. +// All methods are safe for concurrent use. +// +// TODO: IPv6 pinholes use the local IPv6 address. If the address changes +// (e.g., due to SLAAC rotation or network change), the pinhole becomes stale +// and needs to be recreated with the new address. +type NAT struct { + client *Client + + mu sync.RWMutex + // client6 is the IPv6 PCP client, nil if IPv6 is unavailable. + client6 *Client + // localIP6 caches the local IPv6 address used for PCP requests. + localIP6 netip.Addr +} + +// NewNAT creates a new NAT instance backed by PCP. +func NewNAT(gateway, localIP net.IP) *NAT { + client := NewClient(gateway) + client.SetLocalIP(localIP) + return &NAT{ + client: client, + } +} + +// Type returns "PCP" as the NAT type. +func (n *NAT) Type() string { + return "PCP" +} + +// GetDeviceAddress returns the gateway IP address. +func (n *NAT) GetDeviceAddress() (net.IP, error) { + return n.client.Gateway(), nil +} + +// GetExternalAddress returns the external IP address. +func (n *NAT) GetExternalAddress() (net.IP, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + return n.client.GetExternalAddress(ctx) +} + +// GetInternalAddress returns the local IP address used to communicate with the gateway. +func (n *NAT) GetInternalAddress() (net.IP, error) { + addr, err := n.client.getLocalIP() + if err != nil { + return nil, err + } + return addr.AsSlice(), nil +} + +// AddPortMapping creates a port mapping on both IPv4 and IPv6 (if available). +func (n *NAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, _ string, timeout time.Duration) (int, error) { + resp, err := n.client.AddPortMapping(ctx, protocol, internalPort, timeout) + if err != nil { + return 0, fmt.Errorf("add mapping: %w", err) + } + + n.mu.RLock() + client6 := n.client6 + localIP6 := n.localIP6 + n.mu.RUnlock() + + if client6 == nil { + return int(resp.ExternalPort), nil + } + + if _, err := client6.AddPortMapping(ctx, protocol, internalPort, timeout); err != nil { + log.Warnf("IPv6 PCP mapping failed (continuing with IPv4): %v", err) + return int(resp.ExternalPort), nil + } + + log.Infof("created IPv6 PCP pinhole: %s:%d", localIP6, internalPort) + return int(resp.ExternalPort), nil +} + +// DeletePortMapping removes a port mapping from both IPv4 and IPv6. +func (n *NAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error { + err := n.client.DeletePortMapping(ctx, protocol, internalPort) + + n.mu.RLock() + client6 := n.client6 + n.mu.RUnlock() + + if client6 != nil { + if err6 := client6.DeletePortMapping(ctx, protocol, internalPort); err6 != nil { + log.Warnf("IPv6 PCP delete mapping failed: %v", err6) + } + } + + if err != nil { + return fmt.Errorf("delete mapping: %w", err) + } + return nil +} + +// CheckServerHealth sends an ANNOUNCE to verify the server is still responsive. +// Returns the current epoch and whether the server may have restarted (epoch state loss detected). +func (n *NAT) CheckServerHealth(ctx context.Context) (epoch uint32, serverRestarted bool, err error) { + epoch, err = n.client.Announce(ctx) + if err != nil { + return 0, false, fmt.Errorf("announce: %w", err) + } + return epoch, n.client.EpochStateLost(), nil +} + +// DiscoverPCP attempts to discover a PCP-capable gateway. +// Returns a NAT interface if PCP is supported, or an error otherwise. +// Discovers both IPv4 and IPv6 gateways when available. +func DiscoverPCP(ctx context.Context) (nat.NAT, error) { + gateway, localIP, err := getDefaultGateway() + if err != nil { + return nil, fmt.Errorf("get default gateway: %w", err) + } + + client := NewClient(gateway) + client.SetLocalIP(localIP) + if _, err := client.Announce(ctx); err != nil { + return nil, fmt.Errorf("PCP announce: %w", err) + } + + result := &NAT{client: client} + discoverIPv6(ctx, result) + + return result, nil +} + +func discoverIPv6(ctx context.Context, result *NAT) { + gateway6, localIP6, err := getDefaultGateway6() + if err != nil { + log.Debugf("IPv6 gateway discovery failed: %v", err) + return + } + + client6 := NewClient(gateway6) + client6.SetLocalIP(localIP6) + if _, err := client6.Announce(ctx); err != nil { + log.Debugf("PCP IPv6 announce failed: %v", err) + return + } + + addr, ok := netip.AddrFromSlice(localIP6) + if !ok { + log.Debugf("invalid IPv6 local IP: %v", localIP6) + return + } + result.mu.Lock() + result.client6 = client6 + result.localIP6 = addr + result.mu.Unlock() + log.Debugf("PCP IPv6 gateway discovered: %s (local: %s)", gateway6, localIP6) +} + +// getDefaultGateway returns the default IPv4 gateway and local IP using the system routing table. +func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) { + router, err := netroute.New() + if err != nil { + return nil, nil, err + } + + _, gateway, localIP, err = router.Route(net.IPv4zero) + if err != nil { + return nil, nil, err + } + + if gateway == nil { + return nil, nil, nat.ErrNoNATFound + } + + return gateway, localIP, nil +} + +// getDefaultGateway6 returns the default IPv6 gateway IP address using the system routing table. +func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) { + router, err := netroute.New() + if err != nil { + return nil, nil, err + } + + _, gateway, localIP, err = router.Route(net.IPv6zero) + if err != nil { + return nil, nil, err + } + + if gateway == nil { + return nil, nil, nat.ErrNoNATFound + } + + return gateway, localIP, nil +} diff --git a/client/internal/portforward/pcp/protocol.go b/client/internal/portforward/pcp/protocol.go new file mode 100644 index 000000000..d81c50c8c --- /dev/null +++ b/client/internal/portforward/pcp/protocol.go @@ -0,0 +1,225 @@ +// Package pcp implements the Port Control Protocol (RFC 6887). +// +// # Implemented Features +// +// - ANNOUNCE opcode: Discovers PCP server support +// - MAP opcode: Creates/deletes port mappings (IPv4 NAT) and firewall pinholes (IPv6) +// - Dual-stack: Simultaneous IPv4 and IPv6 support via separate clients +// - Nonce validation: Prevents response spoofing +// - Epoch tracking: Detects server restarts per Section 8.5 +// - RFC-compliant retry timing: 3s initial, exponential backoff to 1024s max (Section 8.1.1) +// +// # Not Implemented +// +// - PEER opcode: For outbound peer connections (not needed for inbound NAT traversal) +// - THIRD_PARTY option: For managing mappings on behalf of other devices +// - PREFER_FAILURE option: Requires exact external port or fail (IPv4 NAT only, not needed for IPv6 pinholing) +// - FILTER option: To restrict remote peer addresses +// +// These optional features are omitted because the primary use case is simple +// port forwarding for WireGuard, which only requires MAP with default behavior. +package pcp + +import ( + "encoding/binary" + "fmt" + "net/netip" +) + +const ( + // Version is the PCP protocol version (RFC 6887). + Version = 2 + + // Port is the standard PCP server port. + Port = 5351 + + // DefaultLifetime is the default requested mapping lifetime in seconds. + DefaultLifetime = 7200 // 2 hours + + // Header sizes + headerSize = 24 + mapPayloadSize = 36 + mapRequestSize = headerSize + mapPayloadSize // 60 bytes +) + +// Opcodes +const ( + OpAnnounce = 0 + OpMap = 1 + OpPeer = 2 + OpReply = 0x80 // OR'd with opcode in responses +) + +// Protocol numbers for MAP requests +const ( + ProtoUDP = 17 + ProtoTCP = 6 +) + +// Result codes (RFC 6887 Section 7.4) +const ( + ResultSuccess = 0 + ResultUnsuppVersion = 1 + ResultNotAuthorized = 2 + ResultMalformedRequest = 3 + ResultUnsuppOpcode = 4 + ResultUnsuppOption = 5 + ResultMalformedOption = 6 + ResultNetworkFailure = 7 + ResultNoResources = 8 + ResultUnsuppProtocol = 9 + ResultUserExQuota = 10 + ResultCannotProvideExt = 11 + ResultAddressMismatch = 12 + ResultExcessiveRemotePeers = 13 +) + +// ResultCodeString returns a human-readable string for a result code. +func ResultCodeString(code uint8) string { + switch code { + case ResultSuccess: + return "SUCCESS" + case ResultUnsuppVersion: + return "UNSUPP_VERSION" + case ResultNotAuthorized: + return "NOT_AUTHORIZED" + case ResultMalformedRequest: + return "MALFORMED_REQUEST" + case ResultUnsuppOpcode: + return "UNSUPP_OPCODE" + case ResultUnsuppOption: + return "UNSUPP_OPTION" + case ResultMalformedOption: + return "MALFORMED_OPTION" + case ResultNetworkFailure: + return "NETWORK_FAILURE" + case ResultNoResources: + return "NO_RESOURCES" + case ResultUnsuppProtocol: + return "UNSUPP_PROTOCOL" + case ResultUserExQuota: + return "USER_EX_QUOTA" + case ResultCannotProvideExt: + return "CANNOT_PROVIDE_EXTERNAL" + case ResultAddressMismatch: + return "ADDRESS_MISMATCH" + case ResultExcessiveRemotePeers: + return "EXCESSIVE_REMOTE_PEERS" + default: + return fmt.Sprintf("UNKNOWN(%d)", code) + } +} + +// Response represents a parsed PCP response header. +type Response struct { + Version uint8 + Opcode uint8 + ResultCode uint8 + Lifetime uint32 + Epoch uint32 +} + +// MapResponse contains the full response to a MAP request. +type MapResponse struct { + Response + Nonce [12]byte + Protocol uint8 + InternalPort uint16 + ExternalPort uint16 + ExternalIP netip.Addr +} + +// addrTo16 converts an address to its 16-byte IPv4-mapped IPv6 representation. +func addrTo16(addr netip.Addr) [16]byte { + if addr.Is4() { + return netip.AddrFrom4(addr.As4()).As16() + } + return addr.As16() +} + +// addrFrom16 extracts an address from a 16-byte representation, unmapping IPv4. +func addrFrom16(b [16]byte) netip.Addr { + return netip.AddrFrom16(b).Unmap() +} + +// buildAnnounceRequest creates a PCP ANNOUNCE request packet. +func buildAnnounceRequest(clientIP netip.Addr) []byte { + req := make([]byte, headerSize) + req[0] = Version + req[1] = OpAnnounce + mapped := addrTo16(clientIP) + copy(req[8:24], mapped[:]) + return req +} + +// buildMapRequest creates a PCP MAP request packet. +func buildMapRequest(clientIP netip.Addr, nonce [12]byte, protocol uint8, internalPort, suggestedExtPort uint16, suggestedExtIP netip.Addr, lifetime uint32) []byte { + req := make([]byte, mapRequestSize) + + // Header + req[0] = Version + req[1] = OpMap + binary.BigEndian.PutUint32(req[4:8], lifetime) + mapped := addrTo16(clientIP) + copy(req[8:24], mapped[:]) + + // MAP payload + copy(req[24:36], nonce[:]) + req[36] = protocol + binary.BigEndian.PutUint16(req[40:42], internalPort) + binary.BigEndian.PutUint16(req[42:44], suggestedExtPort) + if suggestedExtIP.IsValid() { + extMapped := addrTo16(suggestedExtIP) + copy(req[44:60], extMapped[:]) + } + + return req +} + +// parseResponse parses the common PCP response header. +func parseResponse(data []byte) (*Response, error) { + if len(data) < headerSize { + return nil, fmt.Errorf("response too short: %d bytes", len(data)) + } + + resp := &Response{ + Version: data[0], + Opcode: data[1], + ResultCode: data[3], // Byte 2 is reserved, byte 3 is result code (RFC 6887 §7.2) + Lifetime: binary.BigEndian.Uint32(data[4:8]), + Epoch: binary.BigEndian.Uint32(data[8:12]), + } + + if resp.Version != Version { + return nil, fmt.Errorf("unsupported PCP version: %d", resp.Version) + } + + if resp.Opcode&OpReply == 0 { + return nil, fmt.Errorf("response missing reply bit: opcode=0x%02x", resp.Opcode) + } + + return resp, nil +} + +// parseMapResponse parses a complete MAP response. +func parseMapResponse(data []byte) (*MapResponse, error) { + if len(data) < mapRequestSize { + return nil, fmt.Errorf("MAP response too short: %d bytes", len(data)) + } + + resp, err := parseResponse(data) + if err != nil { + return nil, fmt.Errorf("parse header: %w", err) + } + + mapResp := &MapResponse{ + Response: *resp, + Protocol: data[36], + InternalPort: binary.BigEndian.Uint16(data[40:42]), + ExternalPort: binary.BigEndian.Uint16(data[42:44]), + ExternalIP: addrFrom16([16]byte(data[44:60])), + } + copy(mapResp.Nonce[:], data[24:36]) + + return mapResp, nil +} diff --git a/client/internal/portforward/state.go b/client/internal/portforward/state.go new file mode 100644 index 000000000..b1315cdc0 --- /dev/null +++ b/client/internal/portforward/state.go @@ -0,0 +1,63 @@ +//go:build !js + +package portforward + +import ( + "context" + "fmt" + + "github.com/libp2p/go-nat" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/portforward/pcp" +) + +// discoverGateway is the function used for NAT gateway discovery. +// It can be replaced in tests to avoid real network operations. +// Tries PCP first, then falls back to NAT-PMP/UPnP. +var discoverGateway = defaultDiscoverGateway + +func defaultDiscoverGateway(ctx context.Context) (nat.NAT, error) { + pcpGateway, err := pcp.DiscoverPCP(ctx) + if err == nil { + return pcpGateway, nil + } + log.Debugf("PCP discovery failed: %v, trying NAT-PMP/UPnP", err) + + return nat.DiscoverGateway(ctx) +} + +// State is persisted only for crash recovery cleanup +type State struct { + InternalPort uint16 `json:"internal_port,omitempty"` + Protocol string `json:"protocol,omitempty"` +} + +func (s *State) Name() string { + return "port_forward_state" +} + +// Cleanup implements statemanager.CleanableState for crash recovery +func (s *State) Cleanup() error { + if s.InternalPort == 0 { + return nil + } + + log.Infof("cleaning up stale port mapping for port %d", s.InternalPort) + + ctx, cancel := context.WithTimeout(context.Background(), discoveryTimeout) + defer cancel() + + gateway, err := discoverGateway(ctx) + if err != nil { + // Discovery failure is not an error - gateway may not exist + log.Debugf("cleanup: no gateway found: %v", err) + return nil + } + + if err := gateway.DeletePortMapping(ctx, s.Protocol, int(s.InternalPort)); err != nil { + return fmt.Errorf("delete port mapping: %w", err) + } + + return nil +} diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go index b27f1932f..20c615d57 100644 --- a/client/internal/profilemanager/config.go +++ b/client/internal/profilemanager/config.go @@ -39,6 +39,18 @@ const ( DefaultAdminURL = "https://app.netbird.io:443" ) +// mgmProber is the subset of management client needed for URL migration probes. +type mgmProber interface { + HealthCheck() error + Close() error +} + +// newMgmProber creates a management client for probing URL reachability. +// Overridden in tests to avoid real network calls. +var newMgmProber = func(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (mgmProber, error) { + return mgm.NewClient(ctx, addr, key, tlsEnabled) +} + var DefaultInterfaceBlacklist = []string{ iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts", "Tailscale", "tailscale", "docker", "veth", "br-", "lo", @@ -753,21 +765,19 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri return config, err } - client, err := mgm.NewClient(ctx, newURL.Host, key, mgmTlsEnabled) + client, err := newMgmProber(ctx, newURL.Host, key, mgmTlsEnabled) if err != nil { log.Infof("couldn't switch to the new Management %s", newURL.String()) return config, err } defer func() { - err = client.Close() - if err != nil { + if err := client.Close(); err != nil { log.Warnf("failed to close the Management service client %v", err) } }() // gRPC check - _, err = client.GetServerPublicKey() - if err != nil { + if err = client.HealthCheck(); err != nil { log.Infof("couldn't switch to the new Management %s", newURL.String()) return nil, err } diff --git a/client/internal/profilemanager/config_test.go b/client/internal/profilemanager/config_test.go index ab13cf389..5216f2423 100644 --- a/client/internal/profilemanager/config_test.go +++ b/client/internal/profilemanager/config_test.go @@ -10,12 +10,21 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/util" ) +type mockMgmProber struct{} + +func (m *mockMgmProber) HealthCheck() error { + return nil +} + +func (m *mockMgmProber) Close() error { return nil } + func TestGetConfig(t *testing.T) { // case 1: new default config has to be generated config, err := UpdateOrCreateConfig(ConfigInput{ @@ -234,6 +243,12 @@ func TestWireguardPortDefaultVsExplicit(t *testing.T) { } func TestUpdateOldManagementURL(t *testing.T) { + origProber := newMgmProber + newMgmProber = func(_ context.Context, _ string, _ wgtypes.Key, _ bool) (mgmProber, error) { + return &mockMgmProber{}, nil + } + t.Cleanup(func() { newMgmProber = origProber }) + tests := []struct { name string previousManagementURL string @@ -273,18 +288,17 @@ func TestUpdateOldManagementURL(t *testing.T) { ConfigPath: configPath, }) require.NoError(t, err, "failed to create testing config") - previousStats, err := os.Stat(configPath) - require.NoError(t, err, "failed to create testing config stats") + previousContent, err := os.ReadFile(configPath) + require.NoError(t, err, "failed to read initial config") resultConfig, err := UpdateOldManagementURL(context.TODO(), config, configPath) require.NoError(t, err, "got error when updating old management url") require.Equal(t, tt.expectedManagementURL, resultConfig.ManagementURL.String()) - newStats, err := os.Stat(configPath) - require.NoError(t, err, "failed to create testing config stats") - switch tt.fileShouldNotChange { - case true: - require.Equal(t, previousStats.ModTime(), newStats.ModTime(), "file should not change") - case false: - require.NotEqual(t, previousStats.ModTime(), newStats.ModTime(), "file should have changed") + newContent, err := os.ReadFile(configPath) + require.NoError(t, err, "failed to read updated config") + if tt.fileShouldNotChange { + require.Equal(t, string(previousContent), string(newContent), "file should not change") + } else { + require.NotEqual(t, string(previousContent), string(newContent), "file should have changed") } }) } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 9afe2049d..3923e153b 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -52,6 +52,7 @@ type Manager interface { TriggerSelection(route.HAMap) GetRouteSelector() *routeselector.RouteSelector GetClientRoutes() route.HAMap + GetSelectedClientRoutes() route.HAMap GetClientRoutesWithNetID() map[route.NetID][]*route.Route SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string @@ -167,6 +168,7 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) { NetworkType: route.IPv4Network, } cr = append(cr, fakeIPRoute) + m.notifier.SetFakeIPRoute(fakeIPRoute) } m.notifier.SetInitialClientRoutes(cr, routesForComparison) @@ -465,6 +467,16 @@ func (m *DefaultManager) GetClientRoutes() route.HAMap { return maps.Clone(m.clientRoutes) } +// GetSelectedClientRoutes returns only the currently selected/active client routes, +// filtering out deselected exit nodes. Use this instead of GetClientRoutes when checking +// if traffic should be routed through the tunnel. +func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap { + m.mux.Lock() + defer m.mux.Unlock() + + return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes)) +} + // GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { m.mux.Lock() diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 6b06144b2..66b5e30dd 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -18,6 +18,7 @@ type MockManager struct { TriggerSelectionFunc func(haMap route.HAMap) GetRouteSelectorFunc func() *routeselector.RouteSelector GetClientRoutesFunc func() route.HAMap + GetSelectedClientRoutesFunc func() route.HAMap GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route StopFunc func(manager *statemanager.Manager) } @@ -61,7 +62,7 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector { return nil } -// GetClientRoutes mock implementation of GetClientRoutes from Manager interface +// GetClientRoutes mock implementation of GetClientRoutes from the Manager interface func (m *MockManager) GetClientRoutes() route.HAMap { if m.GetClientRoutesFunc != nil { return m.GetClientRoutesFunc() @@ -69,6 +70,14 @@ func (m *MockManager) GetClientRoutes() route.HAMap { return nil } +// GetSelectedClientRoutes mock implementation of GetSelectedClientRoutes from the Manager interface +func (m *MockManager) GetSelectedClientRoutes() route.HAMap { + if m.GetSelectedClientRoutesFunc != nil { + return m.GetSelectedClientRoutesFunc() + } + return nil +} + // GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { if m.GetClientRoutesWithNetIDFunc != nil { diff --git a/client/internal/routemanager/notifier/notifier_android.go b/client/internal/routemanager/notifier/notifier_android.go index dec0af87c..55e0b7421 100644 --- a/client/internal/routemanager/notifier/notifier_android.go +++ b/client/internal/routemanager/notifier/notifier_android.go @@ -16,6 +16,7 @@ import ( type Notifier struct { initialRoutes []*route.Route currentRoutes []*route.Route + fakeIPRoute *route.Route listener listener.NetworkChangeListener listenerMux sync.Mutex @@ -31,26 +32,15 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) { n.listener = listener } +// SetInitialClientRoutes stores the initial route sets for TUN configuration. func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) { - // initialRoutes contains fake IP block for interface configuration - filteredInitial := make([]*route.Route, 0) - for _, r := range initialRoutes { - if r.IsDynamic() { - continue - } - filteredInitial = append(filteredInitial, r) - } - n.initialRoutes = filteredInitial + n.initialRoutes = filterStatic(initialRoutes) + n.currentRoutes = filterStatic(routesForComparison) +} - // routesForComparison excludes fake IP block for comparison with new routes - filteredComparison := make([]*route.Route, 0) - for _, r := range routesForComparison { - if r.IsDynamic() { - continue - } - filteredComparison = append(filteredComparison, r) - } - n.currentRoutes = filteredComparison +// SetFakeIPRoute stores the fake IP route to be included in every TUN rebuild. +func (n *Notifier) SetFakeIPRoute(r *route.Route) { + n.fakeIPRoute = r } func (n *Notifier) OnNewRoutes(idMap route.HAMap) { @@ -83,13 +73,28 @@ func (n *Notifier) notify() { return } - routeStrings := n.routesToStrings(n.currentRoutes) + allRoutes := slices.Clone(n.currentRoutes) + if n.fakeIPRoute != nil { + allRoutes = append(allRoutes, n.fakeIPRoute) + } + + routeStrings := n.routesToStrings(allRoutes) sort.Strings(routeStrings) go func(l listener.NetworkChangeListener) { - l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, n.currentRoutes), ",")) + l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, allRoutes), ",")) }(n.listener) } +func filterStatic(routes []*route.Route) []*route.Route { + out := make([]*route.Route, 0, len(routes)) + for _, r := range routes { + if !r.IsDynamic() { + out = append(out, r) + } + } + return out +} + func (n *Notifier) routesToStrings(routes []*route.Route) []string { nets := make([]string, 0, len(routes)) for _, r := range routes { diff --git a/client/internal/routemanager/notifier/notifier_ios.go b/client/internal/routemanager/notifier/notifier_ios.go index bb125cfa4..68c85067a 100644 --- a/client/internal/routemanager/notifier/notifier_ios.go +++ b/client/internal/routemanager/notifier/notifier_ios.go @@ -34,6 +34,10 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) { // iOS doesn't care about initial routes } +func (n *Notifier) SetFakeIPRoute(*route.Route) { + // Not used on iOS +} + func (n *Notifier) OnNewRoutes(route.HAMap) { // Not used on iOS } @@ -53,7 +57,6 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) { n.currentPrefixes = newNets n.notify() } - func (n *Notifier) notify() { n.listenerMux.Lock() defer n.listenerMux.Unlock() diff --git a/client/internal/routemanager/notifier/notifier_other.go b/client/internal/routemanager/notifier/notifier_other.go index 0521e3dc2..97c815cf0 100644 --- a/client/internal/routemanager/notifier/notifier_other.go +++ b/client/internal/routemanager/notifier/notifier_other.go @@ -23,6 +23,10 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) { // Not used on non-mobile platforms } +func (n *Notifier) SetFakeIPRoute(*route.Route) { + // Not used on non-mobile platforms +} + func (n *Notifier) OnNewRoutes(idMap route.HAMap) { // Not used on non-mobile platforms } diff --git a/client/internal/routemanager/systemops/systemops_bsd_other.go b/client/internal/routemanager/systemops/systemops_bsd_other.go new file mode 100644 index 000000000..3f09219aa --- /dev/null +++ b/client/internal/routemanager/systemops/systemops_bsd_other.go @@ -0,0 +1,10 @@ +//go:build (dragonfly || freebsd || netbsd || openbsd) && !darwin + +package systemops + +// Non-darwin BSDs don't support the IP_BOUND_IF + scoped default model. They +// always fall through to the ref-counter exclusion-route path; these stubs +// exist only so systemops_unix.go compiles. +func (r *SysOps) setupAdvancedRouting() error { return nil } +func (r *SysOps) cleanupAdvancedRouting() error { return nil } +func (r *SysOps) flushPlatformExtras() error { return nil } diff --git a/client/internal/routemanager/systemops/systemops_darwin.go b/client/internal/routemanager/systemops/systemops_darwin.go new file mode 100644 index 000000000..3fcac4c6a --- /dev/null +++ b/client/internal/routemanager/systemops/systemops_darwin.go @@ -0,0 +1,253 @@ +//go:build darwin && !ios + +package systemops + +import ( + "errors" + "fmt" + "net/netip" + "os" + "time" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + "golang.org/x/net/route" + "golang.org/x/sys/unix" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/routemanager/vars" + nbnet "github.com/netbirdio/netbird/client/net" +) + +// scopedRouteBudget bounds retries for the scoped default route. Installing or +// deleting it matters enough that we're willing to spend longer waiting for the +// kernel reply than for per-prefix exclusion routes. +const scopedRouteBudget = 5 * time.Second + +// setupAdvancedRouting installs an RTF_IFSCOPE default route per address family +// pinned to the current physical egress, so IP_BOUND_IF scoped lookups can +// resolve gateway'd destinations while the VPN's split default owns the +// unscoped table. +// +// Timing note: this runs during routeManager.Init, which happens before the +// VPN interface is created and before any peer routes propagate. The initial +// mgmt / signal / relay TCP dials always fire before this runs, so those +// sockets miss the IP_BOUND_IF binding and rely on the kernel's normal route +// lookup, which at that point correctly picks the physical default. Those +// already-established TCP flows keep their originally-selected interface for +// their lifetime on Darwin because the kernel caches the egress route +// per-socket at connect time; adding the VPN's 0/1 + 128/1 split default +// afterwards does not migrate them since the original en0 default stays in +// the table. Any subsequent reconnect via nbnet.NewDialer picks up the +// populated bound-iface cache and gets IP_BOUND_IF set cleanly. +func (r *SysOps) setupAdvancedRouting() error { + // Drop any previously-cached egress interface before reinstalling. On a + // refresh, a family that no longer resolves would otherwise keep the stale + // binding, causing new sockets to scope to an interface without a matching + // scoped default. + nbnet.ClearBoundInterfaces() + + if err := r.flushScopedDefaults(); err != nil { + log.Warnf("flush residual scoped defaults: %v", err) + } + + var merr *multierror.Error + installed := 0 + + for _, unspec := range []netip.Addr{netip.IPv4Unspecified(), netip.IPv6Unspecified()} { + ok, err := r.installScopedDefaultFor(unspec) + if err != nil { + merr = multierror.Append(merr, err) + continue + } + if ok { + installed++ + } + } + + if installed == 0 && merr != nil { + return nberrors.FormatErrorOrNil(merr) + } + if merr != nil { + log.Warnf("advanced routing setup partially succeeded: %v", nberrors.FormatErrorOrNil(merr)) + } + return nil +} + +// installScopedDefaultFor resolves the physical default nexthop for the given +// address family, installs a scoped default via it, and caches the iface for +// subsequent IP_BOUND_IF / IPV6_BOUND_IF socket binds. +func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) { + nexthop, err := GetNextHop(unspec) + if err != nil { + if errors.Is(err, vars.ErrRouteNotFound) { + return false, nil + } + return false, fmt.Errorf("get default nexthop for %s: %w", unspec, err) + } + if nexthop.Intf == nil { + return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec) + } + + reused := false + if err := r.addScopedDefault(unspec, nexthop); err != nil { + if !errors.Is(err, unix.EEXIST) { + return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err) + } + // macOS installs its own RTF_IFSCOPE defaults for primary service + // selection on multi-NIC setups, so a route on this ifindex can + // already exist before we try. Binding to it via IP[V6]_BOUND_IF + // still produces the scoped lookup we need. + reused = true + } + + af := unix.AF_INET + if unspec.Is6() { + af = unix.AF_INET6 + } + nbnet.SetBoundInterface(af, nexthop.Intf) + via := "point-to-point" + if nexthop.IP.IsValid() { + via = nexthop.IP.String() + } + verb := "installed" + if reused { + verb = "reused existing" + } + log.Infof("%s scoped default route via %s on %s for %s", verb, via, nexthop.Intf.Name, afOf(unspec)) + return true, nil +} + +func (r *SysOps) cleanupAdvancedRouting() error { + nbnet.ClearBoundInterfaces() + return r.flushScopedDefaults() +} + +// flushPlatformExtras runs darwin-specific residual cleanup hooked into the +// generic FlushMarkedRoutes path, so a crashed daemon's scoped defaults get +// removed on the next boot regardless of whether a profile is brought up. +func (r *SysOps) flushPlatformExtras() error { + return r.flushScopedDefaults() +} + +// flushScopedDefaults removes any scoped default routes tagged with routeProtoFlag. +// Safe to call at startup to clear residual entries from a prior session. +func (r *SysOps) flushScopedDefaults() error { + rib, err := retryFetchRIB() + if err != nil { + return fmt.Errorf("fetch routing table: %w", err) + } + + msgs, err := route.ParseRIB(route.RIBTypeRoute, rib) + if err != nil { + return fmt.Errorf("parse routing table: %w", err) + } + + var merr *multierror.Error + removed := 0 + + for _, msg := range msgs { + rtMsg, ok := msg.(*route.RouteMessage) + if !ok { + continue + } + if rtMsg.Flags&routeProtoFlag == 0 { + continue + } + if rtMsg.Flags&unix.RTF_IFSCOPE == 0 { + continue + } + + info, err := MsgToRoute(rtMsg) + if err != nil { + log.Debugf("skip scoped flush: %v", err) + continue + } + if !info.Dst.IsValid() || info.Dst.Bits() != 0 { + continue + } + + if err := r.deleteScopedRoute(rtMsg); err != nil { + merr = multierror.Append(merr, fmt.Errorf("delete scoped default %s on index %d: %w", + info.Dst, rtMsg.Index, err)) + continue + } + removed++ + log.Debugf("flushed residual scoped default %s on index %d", info.Dst, rtMsg.Index) + } + + if removed > 0 { + log.Infof("flushed %d residual scoped default route(s)", removed) + } + return nberrors.FormatErrorOrNil(merr) +} + +func (r *SysOps) addScopedDefault(unspec netip.Addr, nexthop Nexthop) error { + return r.scopedRouteSocket(unix.RTM_ADD, unspec, nexthop) +} + +func (r *SysOps) deleteScopedRoute(rtMsg *route.RouteMessage) error { + // Preserve identifying flags from the stored route (including RTF_GATEWAY + // only if present); kernel-set bits like RTF_DONE don't belong on RTM_DELETE. + keep := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_GATEWAY | unix.RTF_IFSCOPE | routeProtoFlag + del := &route.RouteMessage{ + Type: unix.RTM_DELETE, + Flags: rtMsg.Flags & keep, + Version: unix.RTM_VERSION, + Seq: r.getSeq(), + Index: rtMsg.Index, + Addrs: rtMsg.Addrs, + } + return r.writeRouteMessage(del, scopedRouteBudget) +} + +func (r *SysOps) scopedRouteSocket(action int, unspec netip.Addr, nexthop Nexthop) error { + flags := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_IFSCOPE | routeProtoFlag + + msg := &route.RouteMessage{ + Type: action, + Flags: flags, + Version: unix.RTM_VERSION, + ID: uintptr(os.Getpid()), + Seq: r.getSeq(), + Index: nexthop.Intf.Index, + } + + const numAddrs = unix.RTAX_NETMASK + 1 + addrs := make([]route.Addr, numAddrs) + + dst, err := addrToRouteAddr(unspec) + if err != nil { + return fmt.Errorf("build destination: %w", err) + } + mask, err := prefixToRouteNetmask(netip.PrefixFrom(unspec, 0)) + if err != nil { + return fmt.Errorf("build netmask: %w", err) + } + addrs[unix.RTAX_DST] = dst + addrs[unix.RTAX_NETMASK] = mask + + if nexthop.IP.IsValid() { + msg.Flags |= unix.RTF_GATEWAY + gw, err := addrToRouteAddr(nexthop.IP.Unmap()) + if err != nil { + return fmt.Errorf("build gateway: %w", err) + } + addrs[unix.RTAX_GATEWAY] = gw + } else { + addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{ + Index: nexthop.Intf.Index, + Name: nexthop.Intf.Name, + } + } + msg.Addrs = addrs + + return r.writeRouteMessage(msg, scopedRouteBudget) +} + +func afOf(a netip.Addr) string { + if a.Is4() { + return "IPv4" + } + return "IPv6" +} diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index ec219c7fe..4211eb057 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -21,6 +21,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/statemanager" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/client/net/hooks" ) @@ -31,8 +32,6 @@ var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) -var ErrRoutingIsSeparate = errors.New("routing is separate") - func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error { stateManager.RegisterState(&ShutdownState{}) @@ -397,12 +396,16 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) { } // IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix. +// When advanced routing is active the WG socket is bound to the physical interface (fwmark on linux, +// IP_UNICAST_IF on windows, IP_BOUND_IF on darwin) and bypasses the main routing table, so the check is skipped. func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) { - localRoutes, err := hasSeparateRouting() + if nbnet.AdvancedRouting() { + return false, netip.Prefix{} + } + + localRoutes, err := GetRoutesFromTable() if err != nil { - if !errors.Is(err, ErrRoutingIsSeparate) { - log.Errorf("Failed to get routes: %v", err) - } + log.Errorf("Failed to get routes: %v", err) return false, netip.Prefix{} } diff --git a/client/internal/routemanager/systemops/systemops_js.go b/client/internal/routemanager/systemops/systemops_js.go index 808507fc9..242571b3d 100644 --- a/client/internal/routemanager/systemops/systemops_js.go +++ b/client/internal/routemanager/systemops/systemops_js.go @@ -22,10 +22,6 @@ func GetRoutesFromTable() ([]netip.Prefix, error) { return []netip.Prefix{}, nil } -func hasSeparateRouting() ([]netip.Prefix, error) { - return []netip.Prefix{}, nil -} - // GetDetailedRoutesFromTable returns empty routes for WASM. func GetDetailedRoutesFromTable() ([]DetailedRoute, error) { return []DetailedRoute{}, nil diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index bd10f131f..39a9fd978 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -894,13 +894,6 @@ func getAddressFamily(prefix netip.Prefix) int { return netlink.FAMILY_V6 } -func hasSeparateRouting() ([]netip.Prefix, error) { - if !nbnet.AdvancedRouting() { - return GetRoutesFromTable() - } - return nil, ErrRoutingIsSeparate -} - func isOpErr(err error) bool { // EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) { diff --git a/client/internal/routemanager/systemops/systemops_nonlinux.go b/client/internal/routemanager/systemops/systemops_nonlinux.go index 905a7bc12..016a62ebd 100644 --- a/client/internal/routemanager/systemops/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops/systemops_nonlinux.go @@ -48,10 +48,6 @@ func EnableIPForwarding() error { return nil } -func hasSeparateRouting() ([]netip.Prefix, error) { - return GetRoutesFromTable() -} - // GetIPRules returns IP rules for debugging (not supported on non-Linux platforms) func GetIPRules() ([]IPRule, error) { log.Infof("IP rules collection is not supported on %s", runtime.GOOS) diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index 7089178fb..2d3f9b69a 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -25,6 +25,9 @@ import ( const ( envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG" + + // routeBudget bounds retries for per-prefix exclusion route programming. + routeBudget = 1 * time.Second ) var routeProtoFlag int @@ -41,26 +44,42 @@ func init() { } func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { + if advancedRouting { + return r.setupAdvancedRouting() + } + + log.Infof("Using legacy routing setup with ref counters") return r.setupRefCounter(initAddresses, stateManager) } func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { + if advancedRouting { + return r.cleanupAdvancedRouting() + } + return r.cleanupRefCounter(stateManager) } // FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag. +// On darwin it also flushes residual RTF_IFSCOPE scoped default routes so a +// crashed prior session can't leave crud in the table. func (r *SysOps) FlushMarkedRoutes() error { + var merr *multierror.Error + + if err := r.flushPlatformExtras(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("flush platform extras: %w", err)) + } + rib, err := retryFetchRIB() if err != nil { - return fmt.Errorf("fetch routing table: %w", err) + return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("fetch routing table: %w", err))) } msgs, err := route.ParseRIB(route.RIBTypeRoute, rib) if err != nil { - return fmt.Errorf("parse routing table: %w", err) + return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("parse routing table: %w", err))) } - var merr *multierror.Error flushedCount := 0 for _, msg := range msgs { @@ -117,12 +136,12 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e return fmt.Errorf("invalid prefix: %s", prefix) } - expBackOff := backoff.NewExponentialBackOff() - expBackOff.InitialInterval = 50 * time.Millisecond - expBackOff.MaxInterval = 500 * time.Millisecond - expBackOff.MaxElapsedTime = 1 * time.Second + msg, err := r.buildRouteMessage(action, prefix, nexthop) + if err != nil { + return fmt.Errorf("build route message: %w", err) + } - if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil { + if err := r.writeRouteMessage(msg, routeBudget); err != nil { a := "add" if action == unix.RTM_DELETE { a = "remove" @@ -132,50 +151,91 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e return nil } -func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error { - operation := func() error { - fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) - if err != nil { - return fmt.Errorf("open routing socket: %w", err) +// writeRouteMessage sends a route message over AF_ROUTE and waits for the +// kernel's matching reply, retrying transient failures until budget elapses. +// Callers do not need to manage sockets or seq numbers themselves. +func (r *SysOps) writeRouteMessage(msg *route.RouteMessage, budget time.Duration) error { + expBackOff := backoff.NewExponentialBackOff() + expBackOff.InitialInterval = 50 * time.Millisecond + expBackOff.MaxInterval = 500 * time.Millisecond + expBackOff.MaxElapsedTime = budget + + return backoff.Retry(func() error { return routeMessageRoundtrip(msg) }, expBackOff) +} + +func routeMessageRoundtrip(msg *route.RouteMessage) error { + fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) + if err != nil { + return fmt.Errorf("open routing socket: %w", err) + } + defer func() { + if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) { + log.Warnf("close routing socket: %v", err) } - defer func() { - if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) { - log.Warnf("failed to close routing socket: %v", err) + }() + + tv := unix.Timeval{Sec: 1} + if err := unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil { + return backoff.Permanent(fmt.Errorf("set recv timeout: %w", err)) + } + + // AF_ROUTE is a broadcast channel: every route socket on the host sees + // every RTM_* event. With concurrent route programming the default + // per-socket queue overflows and our own reply gets dropped. + if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1<<20); err != nil { + log.Debugf("set SO_RCVBUF on route socket: %v", err) + } + + bytes, err := msg.Marshal() + if err != nil { + return backoff.Permanent(fmt.Errorf("marshal: %w", err)) + } + + if _, err = unix.Write(fd, bytes); err != nil { + if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) { + return fmt.Errorf("write: %w", err) + } + return backoff.Permanent(fmt.Errorf("write: %w", err)) + } + return readRouteResponse(fd, msg.Type, msg.Seq) +} + +// readRouteResponse reads from the AF_ROUTE socket until it sees a reply +// matching our write (same type, seq, and pid). AF_ROUTE SOCK_RAW is a +// broadcast channel: interface up/down, third-party route changes and neighbor +// discovery events can all land between our write and read, so we must filter. +func readRouteResponse(fd, wantType, wantSeq int) error { + pid := int32(os.Getpid()) + resp := make([]byte, 2048) + deadline := time.Now().Add(time.Second) + for { + if time.Now().After(deadline) { + // Transient: under concurrent pressure the kernel can drop our reply + // from the socket buffer. Let backoff.Retry re-send with a fresh seq. + return fmt.Errorf("read: timeout waiting for route reply type=%d seq=%d", wantType, wantSeq) + } + n, err := unix.Read(fd, resp) + if err != nil { + if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) { + // SO_RCVTIMEO fired while waiting; loop to re-check the absolute deadline. + continue } - }() - - msg, err := r.buildRouteMessage(action, prefix, nexthop) - if err != nil { - return backoff.Permanent(fmt.Errorf("build route message: %w", err)) + return backoff.Permanent(fmt.Errorf("read: %w", err)) } - - msgBytes, err := msg.Marshal() - if err != nil { - return backoff.Permanent(fmt.Errorf("marshal route message: %w", err)) + if n < int(unsafe.Sizeof(unix.RtMsghdr{})) { + continue } - - if _, err = unix.Write(fd, msgBytes); err != nil { - if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) { - return fmt.Errorf("write: %w", err) - } - return backoff.Permanent(fmt.Errorf("write: %w", err)) + hdr := (*unix.RtMsghdr)(unsafe.Pointer(&resp[0])) + // Darwin reflects the sender's pid on replies; matching (Type, Seq, Pid) + // uniquely identifies our own reply among broadcast traffic. + if int(hdr.Type) != wantType || int(hdr.Seq) != wantSeq || hdr.Pid != pid { + continue } - - respBuf := make([]byte, 2048) - n, err := unix.Read(fd, respBuf) - if err != nil { - return backoff.Permanent(fmt.Errorf("read route response: %w", err)) + if hdr.Errno != 0 { + return backoff.Permanent(fmt.Errorf("kernel: %w", syscall.Errno(hdr.Errno))) } - - if n > 0 { - if err := r.parseRouteResponse(respBuf[:n]); err != nil { - return backoff.Permanent(err) - } - } - return nil } - return operation } func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) { @@ -183,6 +243,7 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next Type: action, Flags: unix.RTF_UP | routeProtoFlag, Version: unix.RTM_VERSION, + ID: uintptr(os.Getpid()), Seq: r.getSeq(), } @@ -221,19 +282,6 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next return msg, nil } -func (r *SysOps) parseRouteResponse(buf []byte) error { - if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) { - return nil - } - - rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0])) - if rtMsg.Errno != 0 { - return fmt.Errorf("parse: %d", rtMsg.Errno) - } - - return nil -} - // addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr). func addrToRouteAddr(addr netip.Addr) (route.Addr, error) { if addr.Is4() { diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index 61c8bbc79..30afc013b 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -7,7 +7,6 @@ import ( "sync" "github.com/hashicorp/go-multierror" - "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/route" @@ -44,8 +43,8 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al if rs.selectedRoutes == nil { rs.selectedRoutes = map[route.NetID]struct{}{} } - maps.Clear(rs.deselectedRoutes) - maps.Clear(rs.selectedRoutes) + clear(rs.deselectedRoutes) + clear(rs.selectedRoutes) for _, r := range allRoutes { rs.deselectedRoutes[r] = struct{}{} } @@ -78,8 +77,8 @@ func (rs *RouteSelector) SelectAllRoutes() { if rs.selectedRoutes == nil { rs.selectedRoutes = map[route.NetID]struct{}{} } - maps.Clear(rs.deselectedRoutes) - maps.Clear(rs.selectedRoutes) + clear(rs.deselectedRoutes) + clear(rs.selectedRoutes) } // DeselectRoutes removes specific routes from the selection. @@ -116,8 +115,8 @@ func (rs *RouteSelector) DeselectAllRoutes() { if rs.selectedRoutes == nil { rs.selectedRoutes = map[route.NetID]struct{}{} } - maps.Clear(rs.deselectedRoutes) - maps.Clear(rs.selectedRoutes) + clear(rs.deselectedRoutes) + clear(rs.selectedRoutes) } // IsSelected checks if a specific route is selected. diff --git a/client/internal/sleep/detector_darwin.go b/client/internal/sleep/detector_darwin.go index 3d6747ed1..ef495bded 100644 --- a/client/internal/sleep/detector_darwin.go +++ b/client/internal/sleep/detector_darwin.go @@ -2,217 +2,358 @@ package sleep -/* -#cgo LDFLAGS: -framework IOKit -framework CoreFoundation -#include -#include -#include - -extern void sleepCallbackBridge(); -extern void poweredOnCallbackBridge(); -extern void suspendedCallbackBridge(); -extern void resumedCallbackBridge(); - - -// C global variables for IOKit state -static IONotificationPortRef g_notifyPortRef = NULL; -static io_object_t g_notifierObject = 0; -static io_object_t g_generalInterestNotifier = 0; -static io_connect_t g_rootPort = 0; -static CFRunLoopRef g_runLoop = NULL; - -static void sleepCallback(void* refCon, io_service_t service, natural_t messageType, void* messageArgument) { - switch (messageType) { - case kIOMessageSystemWillSleep: - sleepCallbackBridge(); - IOAllowPowerChange(g_rootPort, (long)messageArgument); - break; - case kIOMessageSystemHasPoweredOn: - poweredOnCallbackBridge(); - break; - case kIOMessageServiceIsSuspended: - suspendedCallbackBridge(); - break; - case kIOMessageServiceIsResumed: - resumedCallbackBridge(); - break; - default: - break; - } -} - -static void registerNotifications() { - g_rootPort = IORegisterForSystemPower( - NULL, - &g_notifyPortRef, - (IOServiceInterestCallback)sleepCallback, - &g_notifierObject - ); - - if (g_rootPort == 0) { - return; - } - - CFRunLoopAddSource(CFRunLoopGetCurrent(), - IONotificationPortGetRunLoopSource(g_notifyPortRef), - kCFRunLoopCommonModes); - - g_runLoop = CFRunLoopGetCurrent(); - CFRunLoopRun(); -} - -static void unregisterNotifications() { - CFRunLoopRemoveSource(g_runLoop, - IONotificationPortGetRunLoopSource(g_notifyPortRef), - kCFRunLoopCommonModes); - - IODeregisterForSystemPower(&g_notifierObject); - IOServiceClose(g_rootPort); - IONotificationPortDestroy(g_notifyPortRef); - CFRunLoopStop(g_runLoop); - - g_notifyPortRef = NULL; - g_notifierObject = 0; - g_rootPort = 0; - g_runLoop = NULL; -} - -*/ -import "C" - import ( - "context" "fmt" "runtime" "sync" "time" + "unsafe" + "github.com/ebitengine/purego" log "github.com/sirupsen/logrus" ) -var ( - serviceRegistry = make(map[*Detector]struct{}) - serviceRegistryMu sync.Mutex +// IOKit message types from IOKit/IOMessage.h. +const ( + kIOMessageCanSystemSleep uintptr = 0xe0000270 + kIOMessageSystemWillSleep uintptr = 0xe0000280 + kIOMessageSystemHasPoweredOn uintptr = 0xe0000300 ) -//export sleepCallbackBridge -func sleepCallbackBridge() { - log.Info("sleepCallbackBridge event triggered") +var ( + ioKit iokitFuncs + cf cfFuncs + cfCommonModes uintptr - serviceRegistryMu.Lock() - defer serviceRegistryMu.Unlock() + libInitOnce sync.Once + libInitErr error - for svc := range serviceRegistry { - svc.triggerCallback(EventTypeSleep) - } + // callbackThunk is the single C-callable trampoline registered with IOKit. + callbackThunk uintptr + + serviceRegistry = make(map[*Detector]struct{}) + serviceRegistryMu sync.Mutex + session *runLoopSession + + // lifecycleMu serializes Register/Deregister so a new registration can't + // start a second runloop while a previous teardown is still pending. + lifecycleMu sync.Mutex +) + +// iokitFuncs holds IOKit symbols resolved once at init. +type iokitFuncs struct { + IORegisterForSystemPower func(refcon uintptr, portRef *uintptr, callback uintptr, notifier *uintptr) uintptr + IODeregisterForSystemPower func(notifier *uintptr) int32 + IOAllowPowerChange func(kernelPort uintptr, notificationID uintptr) int32 + IOServiceClose func(connect uintptr) int32 + IONotificationPortGetRunLoopSource func(port uintptr) uintptr + IONotificationPortDestroy func(port uintptr) } -//export resumedCallbackBridge -func resumedCallbackBridge() { - log.Info("resumedCallbackBridge event triggered") +// cfFuncs holds CoreFoundation symbols resolved once at init. +type cfFuncs struct { + CFRunLoopGetCurrent func() uintptr + CFRunLoopRun func() + CFRunLoopStop func(rl uintptr) + CFRunLoopAddSource func(rl, source, mode uintptr) + CFRunLoopRemoveSource func(rl, source, mode uintptr) } -//export suspendedCallbackBridge -func suspendedCallbackBridge() { - log.Info("suspendedCallbackBridge event triggered") +// runLoopSession bundles the handles owned by one CFRunLoop lifetime. A nil +// session means no runloop is active and the next Register must start one. +type runLoopSession struct { + rl uintptr + port uintptr + notifier uintptr + rp uintptr } -//export poweredOnCallbackBridge -func poweredOnCallbackBridge() { - log.Info("poweredOnCallbackBridge event triggered") - serviceRegistryMu.Lock() - defer serviceRegistryMu.Unlock() - - for svc := range serviceRegistry { - svc.triggerCallback(EventTypeWakeUp) - } +// detectorSnapshot pins a detector's callback and done channel so dispatch +// runs with values valid at snapshot time, even if a concurrent +// Deregister/Register rewrites the detector's fields. +type detectorSnapshot struct { + detector *Detector + callback func(event EventType) + done <-chan struct{} } +// Detector delivers sleep and wake events to a registered callback. type Detector struct { callback func(event EventType) - ctx context.Context - cancel context.CancelFunc -} - -func NewDetector() (*Detector, error) { - return &Detector{}, nil + done chan struct{} } +// Register installs callback for power events. The first registration starts +// the CFRunLoop on a dedicated OS-locked thread and blocks until IOKit +// registration succeeds or fails; subsequent registrations just add to the +// dispatch set. func (d *Detector) Register(callback func(event EventType)) error { - serviceRegistryMu.Lock() - defer serviceRegistryMu.Unlock() + lifecycleMu.Lock() + defer lifecycleMu.Unlock() + serviceRegistryMu.Lock() if _, exists := serviceRegistry[d]; exists { + serviceRegistryMu.Unlock() return fmt.Errorf("detector service already registered") } - d.callback = callback + d.done = make(chan struct{}) + serviceRegistry[d] = struct{}{} + needSetup := session == nil + serviceRegistryMu.Unlock() - d.ctx, d.cancel = context.WithCancel(context.Background()) - - if len(serviceRegistry) > 0 { - serviceRegistry[d] = struct{}{} + if !needSetup { return nil } - serviceRegistry[d] = struct{}{} - - // CFRunLoop must run on a single fixed OS thread - go func() { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - C.registerNotifications() - }() + errCh := make(chan error, 1) + go runRunLoop(errCh) + if err := <-errCh; err != nil { + serviceRegistryMu.Lock() + delete(serviceRegistry, d) + close(d.done) + d.done = nil + serviceRegistryMu.Unlock() + return err + } log.Info("sleep detection service started on macOS") return nil } -// Deregister removes the detector. When the last detector is removed, IOKit registration is torn down -// and the runloop is stopped and cleaned up. +// Deregister removes the detector. When the last detector leaves, IOKit +// notifications are torn down and the runloop is stopped. func (d *Detector) Deregister() error { + lifecycleMu.Lock() + defer lifecycleMu.Unlock() + serviceRegistryMu.Lock() - defer serviceRegistryMu.Unlock() - _, exists := serviceRegistry[d] - if !exists { + if _, exists := serviceRegistry[d]; !exists { + serviceRegistryMu.Unlock() return nil } - - // cancel and remove this detector - d.cancel() + close(d.done) delete(serviceRegistry, d) - // If other Detectors still exist, leave IOKit running if len(serviceRegistry) > 0 { + serviceRegistryMu.Unlock() return nil } + sess := session + serviceRegistryMu.Unlock() log.Info("sleep detection service stopping (deregister)") - // Deregister IOKit notifications, stop runloop, and free resources - C.unregisterNotifications() + if sess == nil { + return nil + } + + if sess.rl != 0 && sess.port != 0 { + source := ioKit.IONotificationPortGetRunLoopSource(sess.port) + cf.CFRunLoopRemoveSource(sess.rl, source, cfCommonModes) + } + if sess.notifier != 0 { + n := sess.notifier + ioKit.IODeregisterForSystemPower(&n) + } + + // Clear session only after IODeregisterForSystemPower returns so any + // in-flight powerCallback can still look up session.rp to ack sleep. + serviceRegistryMu.Lock() + session = nil + serviceRegistryMu.Unlock() + + if sess.rp != 0 { + ioKit.IOServiceClose(sess.rp) + } + if sess.port != 0 { + ioKit.IONotificationPortDestroy(sess.port) + } + if sess.rl != 0 { + cf.CFRunLoopStop(sess.rl) + } return nil } -func (d *Detector) triggerCallback(event EventType) { - doneChan := make(chan struct{}) +func (d *Detector) triggerCallback(event EventType, cb func(event EventType), done <-chan struct{}) { + if cb == nil || done == nil { + return + } + select { + case <-done: + return + default: + } + + doneChan := make(chan struct{}) timeout := time.NewTimer(500 * time.Millisecond) defer timeout.Stop() - cb := d.callback - go func(callback func(event EventType)) { + go func() { + defer close(doneChan) + defer func() { + if r := recover(); r != nil { + log.Errorf("panic in sleep callback: %v", r) + } + }() log.Info("sleep detection event fired") - callback(event) - close(doneChan) - }(cb) + cb(event) + }() select { case <-doneChan: - case <-d.ctx.Done(): + case <-done: case <-timeout.C: - log.Warnf("sleep callback timed out") + log.Warn("sleep callback timed out") } } + +// NewDetector initializes IOKit/CoreFoundation bindings and returns a Detector. +func NewDetector() (*Detector, error) { + if err := initLibs(); err != nil { + return nil, err + } + return &Detector{}, nil +} + +func initLibs() error { + libInitOnce.Do(func() { + iokit, err := purego.Dlopen("/System/Library/Frameworks/IOKit.framework/IOKit", purego.RTLD_NOW|purego.RTLD_GLOBAL) + if err != nil { + libInitErr = fmt.Errorf("dlopen IOKit: %w", err) + return + } + cfLib, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL) + if err != nil { + libInitErr = fmt.Errorf("dlopen CoreFoundation: %w", err) + return + } + + purego.RegisterLibFunc(&ioKit.IORegisterForSystemPower, iokit, "IORegisterForSystemPower") + purego.RegisterLibFunc(&ioKit.IODeregisterForSystemPower, iokit, "IODeregisterForSystemPower") + purego.RegisterLibFunc(&ioKit.IOAllowPowerChange, iokit, "IOAllowPowerChange") + purego.RegisterLibFunc(&ioKit.IOServiceClose, iokit, "IOServiceClose") + purego.RegisterLibFunc(&ioKit.IONotificationPortGetRunLoopSource, iokit, "IONotificationPortGetRunLoopSource") + purego.RegisterLibFunc(&ioKit.IONotificationPortDestroy, iokit, "IONotificationPortDestroy") + + purego.RegisterLibFunc(&cf.CFRunLoopGetCurrent, cfLib, "CFRunLoopGetCurrent") + purego.RegisterLibFunc(&cf.CFRunLoopRun, cfLib, "CFRunLoopRun") + purego.RegisterLibFunc(&cf.CFRunLoopStop, cfLib, "CFRunLoopStop") + purego.RegisterLibFunc(&cf.CFRunLoopAddSource, cfLib, "CFRunLoopAddSource") + purego.RegisterLibFunc(&cf.CFRunLoopRemoveSource, cfLib, "CFRunLoopRemoveSource") + + modeAddr, err := purego.Dlsym(cfLib, "kCFRunLoopCommonModes") + if err != nil { + libInitErr = fmt.Errorf("dlsym kCFRunLoopCommonModes: %w", err) + return + } + // Launder the uintptr-to-pointer conversion through a Go variable so + // go vet's unsafeptr analyzer doesn't flag a system-library global. + cfCommonModes = **(**uintptr)(unsafe.Pointer(&modeAddr)) + + // NewCallback slots are a finite, non-reclaimable resource, so register + // a single thunk that dispatches to the current Detector set. + callbackThunk = purego.NewCallback(powerCallback) + }) + return libInitErr +} + +// powerCallback is the IOServiceInterestCallback trampoline, invoked on the +// runloop thread. A Go panic crossing the purego boundary has undefined +// behavior, so contain it here. +func powerCallback(refcon, service, messageType, messageArgument uintptr) uintptr { + defer func() { + if r := recover(); r != nil { + log.Errorf("panic in sleep powerCallback: %v", r) + } + }() + switch messageType { + case kIOMessageCanSystemSleep: + // Not acknowledging forces a 30s IOKit timeout before idle sleep. + allowPowerChange(messageArgument) + case kIOMessageSystemWillSleep: + dispatchEvent(EventTypeSleep) + allowPowerChange(messageArgument) + case kIOMessageSystemHasPoweredOn: + dispatchEvent(EventTypeWakeUp) + } + return 0 +} + +func allowPowerChange(messageArgument uintptr) { + serviceRegistryMu.Lock() + var port uintptr + if session != nil { + port = session.rp + } + serviceRegistryMu.Unlock() + if port != 0 { + ioKit.IOAllowPowerChange(port, messageArgument) + } +} + +func dispatchEvent(event EventType) { + serviceRegistryMu.Lock() + snaps := make([]detectorSnapshot, 0, len(serviceRegistry)) + for d := range serviceRegistry { + snaps = append(snaps, detectorSnapshot{ + detector: d, + callback: d.callback, + done: d.done, + }) + } + serviceRegistryMu.Unlock() + + for _, s := range snaps { + s.detector.triggerCallback(event, s.callback, s.done) + } +} + +// runRunLoop owns the OS-locked thread that CFRunLoop is pinned to. Setup +// result is reported on errCh so Register can surface failures synchronously. +func runRunLoop(errCh chan<- error) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + sess, err := setupSession() + if err == nil { + serviceRegistryMu.Lock() + session = sess + serviceRegistryMu.Unlock() + } + errCh <- err + if err != nil { + return + } + + defer func() { + if r := recover(); r != nil { + log.Errorf("panic in sleep runloop: %v", r) + } + }() + cf.CFRunLoopRun() +} + +// setupSession performs the IOKit registration on the current thread. Panics +// are converted to errors so runRunLoop never leaves errCh unsent. +func setupSession() (s *runLoopSession, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic during runloop setup: %v", r) + } + }() + + var portRef, notifier uintptr + rp := ioKit.IORegisterForSystemPower(0, &portRef, callbackThunk, ¬ifier) + if rp == 0 { + return nil, fmt.Errorf("IORegisterForSystemPower returned zero") + } + + rl := cf.CFRunLoopGetCurrent() + source := ioKit.IONotificationPortGetRunLoopSource(portRef) + cf.CFRunLoopAddSource(rl, source, cfCommonModes) + + return &runLoopSession{rl: rl, port: portRef, notifier: notifier, rp: rp}, nil +} diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 3e2da7f4e..043673904 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -161,7 +161,11 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error { cfg.WgIface = interfaceName c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile) + hostDNS := []netip.AddrPort{ + netip.MustParseAddrPort("9.9.9.9:53"), + netip.MustParseAddrPort("149.112.112.112:53"), + } + return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile) } // Stop the internal client and free the resources diff --git a/client/net/dialer_init_darwin.go b/client/net/dialer_init_darwin.go new file mode 100644 index 000000000..e18909ff7 --- /dev/null +++ b/client/net/dialer_init_darwin.go @@ -0,0 +1,5 @@ +package net + +func (d *Dialer) init() { + d.Dialer.Control = applyBoundIfToSocket +} diff --git a/client/net/dialer_init_generic.go b/client/net/dialer_init_generic.go index 18ebc6ad1..78973b47d 100644 --- a/client/net/dialer_init_generic.go +++ b/client/net/dialer_init_generic.go @@ -1,4 +1,4 @@ -//go:build !linux && !windows +//go:build !linux && !windows && !darwin package net diff --git a/client/net/env_android.go b/client/net/env_android.go deleted file mode 100644 index 9d89951a1..000000000 --- a/client/net/env_android.go +++ /dev/null @@ -1,24 +0,0 @@ -//go:build android - -package net - -// Init initializes the network environment for Android -func Init() { - // No initialization needed on Android -} - -// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes. -// Always returns true on Android since we cannot handle routes dynamically. -func AdvancedRouting() bool { - return true -} - -// SetVPNInterfaceName is a no-op on Android -func SetVPNInterfaceName(name string) { - // No-op on Android - not needed for Android VPN service -} - -// GetVPNInterfaceName returns empty string on Android -func GetVPNInterfaceName() string { - return "" -} diff --git a/client/net/env_windows.go b/client/net/env_bound_iface.go similarity index 71% rename from client/net/env_windows.go rename to client/net/env_bound_iface.go index 7e8868ba5..593988c2c 100644 --- a/client/net/env_windows.go +++ b/client/net/env_bound_iface.go @@ -1,4 +1,4 @@ -//go:build windows +//go:build (darwin && !ios) || windows package net @@ -24,17 +24,22 @@ func Init() { } func checkAdvancedRoutingSupport() bool { - var err error - var legacyRouting bool + legacyRouting := false if val := os.Getenv(envUseLegacyRouting); val != "" { - legacyRouting, err = strconv.ParseBool(val) + parsed, err := strconv.ParseBool(val) if err != nil { - log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err) + log.Warnf("ignoring unparsable %s=%q: %v", envUseLegacyRouting, val, err) + } else { + legacyRouting = parsed } } - if legacyRouting || netstack.IsEnabled() { - log.Info("advanced routing has been requested to be disabled") + if legacyRouting { + log.Infof("advanced routing disabled: legacy routing requested via %s", envUseLegacyRouting) + return false + } + if netstack.IsEnabled() { + log.Info("advanced routing disabled: netstack mode is enabled") return false } diff --git a/client/net/env_generic.go b/client/net/env_generic.go index f467930c3..18c10bb78 100644 --- a/client/net/env_generic.go +++ b/client/net/env_generic.go @@ -1,4 +1,4 @@ -//go:build !linux && !windows && !android +//go:build !linux && !windows && !darwin package net diff --git a/client/net/env_mobile.go b/client/net/env_mobile.go new file mode 100644 index 000000000..80b0fad8d --- /dev/null +++ b/client/net/env_mobile.go @@ -0,0 +1,25 @@ +//go:build ios || android + +package net + +// Init initializes the network environment for mobile platforms. +func Init() { + // no-op on mobile: routing scope is owned by the VPN extension. +} + +// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes. +// Always returns true on mobile since routes cannot be handled dynamically and the VPN extension +// owns the routing scope. +func AdvancedRouting() bool { + return true +} + +// SetVPNInterfaceName is a no-op on mobile. +func SetVPNInterfaceName(string) { + // no-op on mobile: the VPN extension manages the interface. +} + +// GetVPNInterfaceName returns an empty string on mobile. +func GetVPNInterfaceName() string { + return "" +} diff --git a/client/net/listener_init_darwin.go b/client/net/listener_init_darwin.go new file mode 100644 index 000000000..f2fcc80ed --- /dev/null +++ b/client/net/listener_init_darwin.go @@ -0,0 +1,5 @@ +package net + +func (l *ListenerConfig) init() { + l.ListenConfig.Control = applyBoundIfToSocket +} diff --git a/client/net/listener_init_generic.go b/client/net/listener_init_generic.go index 4f8f17ab2..65a785222 100644 --- a/client/net/listener_init_generic.go +++ b/client/net/listener_init_generic.go @@ -1,4 +1,4 @@ -//go:build !linux && !windows +//go:build !linux && !windows && !darwin package net diff --git a/client/net/net_darwin.go b/client/net/net_darwin.go new file mode 100644 index 000000000..00d858a6a --- /dev/null +++ b/client/net/net_darwin.go @@ -0,0 +1,160 @@ +package net + +import ( + "fmt" + "net" + "net/netip" + "strconv" + "strings" + "sync" + "syscall" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" +) + +// On darwin IPV6_BOUND_IF also scopes v4-mapped egress from dual-stack +// (IPV6_V6ONLY=0) AF_INET6 sockets, so a single setsockopt on "udp6"/"tcp6" +// covers both families. Setting IP_BOUND_IF on an AF_INET6 socket returns +// EINVAL regardless of V6ONLY because the IPPROTO_IP ctloutput path is +// dispatched by socket domain (AF_INET only) not by inp_vflag. + +// boundIface holds the physical interface chosen at routing setup time. Sockets +// created via nbnet.NewDialer / nbnet.NewListener bind to it via IP_BOUND_IF +// (IPv4) or IPV6_BOUND_IF (IPv6 / dual-stack) so their scoped route lookup +// hits the RTF_IFSCOPE default installed by the routemanager, rather than +// following the VPN's split default. +var ( + boundIfaceMu sync.RWMutex + boundIface4 *net.Interface + boundIface6 *net.Interface +) + +// SetBoundInterface records the egress interface for an address family. Called +// by the routemanager after a scoped default route has been installed. +// af must be unix.AF_INET or unix.AF_INET6; other values are ignored. +// nil iface is rejected — use ClearBoundInterfaces to clear all slots. +func SetBoundInterface(af int, iface *net.Interface) { + if iface == nil { + log.Warnf("SetBoundInterface: nil iface for AF %d, ignored", af) + return + } + boundIfaceMu.Lock() + defer boundIfaceMu.Unlock() + switch af { + case unix.AF_INET: + boundIface4 = iface + case unix.AF_INET6: + boundIface6 = iface + default: + log.Warnf("SetBoundInterface: unsupported address family %d", af) + } +} + +// ClearBoundInterfaces resets the cached egress interfaces. Called by the +// routemanager during cleanup. +func ClearBoundInterfaces() { + boundIfaceMu.Lock() + defer boundIfaceMu.Unlock() + boundIface4 = nil + boundIface6 = nil +} + +// boundInterfaceFor returns the cached egress interface for a socket's address +// family, falling back to the other family if the preferred slot is empty. +// The kernel stores both IP_BOUND_IF and IPV6_BOUND_IF in inp_boundifp, so +// either setsockopt scopes the socket; preferring same-family still matters +// when v4 and v6 defaults egress different NICs. +func boundInterfaceFor(network, address string) *net.Interface { + if iface := zoneInterface(address); iface != nil { + return iface + } + + boundIfaceMu.RLock() + defer boundIfaceMu.RUnlock() + + primary, secondary := boundIface4, boundIface6 + if isV6Network(network) { + primary, secondary = boundIface6, boundIface4 + } + if primary != nil { + return primary + } + return secondary +} + +func isV6Network(network string) bool { + return strings.HasSuffix(network, "6") +} + +// zoneInterface extracts an explicit interface from an IPv6 link-local zone (e.g. fe80::1%en0). +func zoneInterface(address string) *net.Interface { + if address == "" { + return nil + } + addr, err := netip.ParseAddrPort(address) + if err != nil { + a, err := netip.ParseAddr(address) + if err != nil { + return nil + } + addr = netip.AddrPortFrom(a, 0) + } + zone := addr.Addr().Zone() + if zone == "" { + return nil + } + if iface, err := net.InterfaceByName(zone); err == nil { + return iface + } + if idx, err := strconv.Atoi(zone); err == nil { + if iface, err := net.InterfaceByIndex(idx); err == nil { + return iface + } + } + return nil +} + +func setIPv4BoundIf(fd uintptr, iface *net.Interface) error { + if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, iface.Index); err != nil { + return fmt.Errorf("set IP_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index) + } + return nil +} + +func setIPv6BoundIf(fd uintptr, iface *net.Interface) error { + if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, iface.Index); err != nil { + return fmt.Errorf("set IPV6_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index) + } + return nil +} + +// applyBoundIfToSocket binds the socket to the cached physical egress interface +// so scoped route lookup avoids the VPN utun and egresses the underlay directly. +func applyBoundIfToSocket(network, address string, c syscall.RawConn) error { + if !AdvancedRouting() { + return nil + } + + iface := boundInterfaceFor(network, address) + if iface == nil { + log.Debugf("no bound iface cached for %s to %s, skipping BOUND_IF", network, address) + return nil + } + + isV6 := isV6Network(network) + var controlErr error + if err := c.Control(func(fd uintptr) { + if isV6 { + controlErr = setIPv6BoundIf(fd, iface) + } else { + controlErr = setIPv4BoundIf(fd, iface) + } + if controlErr == nil { + log.Debugf("set BOUND_IF=%d on %s for %s to %s", iface.Index, iface.Name, network, address) + } + }); err != nil { + return fmt.Errorf("control: %w", err) + } + return controlErr +} diff --git a/client/netbird.wxs b/client/netbird.wxs index 03221dd91..6f18b63b5 100644 --- a/client/netbird.wxs +++ b/client/netbird.wxs @@ -13,15 +13,25 @@ + + + - - + + + + + + + + + @@ -46,8 +56,30 @@ + + + + + + + + + + + + + + + + + + + diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index fa0b2f93b..11e7877f2 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -143,56 +143,6 @@ func (ExposeProtocol) EnumDescriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{1} } -// avoid collision with loglevel enum -type OSLifecycleRequest_CycleType int32 - -const ( - OSLifecycleRequest_UNKNOWN OSLifecycleRequest_CycleType = 0 - OSLifecycleRequest_SLEEP OSLifecycleRequest_CycleType = 1 - OSLifecycleRequest_WAKEUP OSLifecycleRequest_CycleType = 2 -) - -// Enum value maps for OSLifecycleRequest_CycleType. -var ( - OSLifecycleRequest_CycleType_name = map[int32]string{ - 0: "UNKNOWN", - 1: "SLEEP", - 2: "WAKEUP", - } - OSLifecycleRequest_CycleType_value = map[string]int32{ - "UNKNOWN": 0, - "SLEEP": 1, - "WAKEUP": 2, - } -) - -func (x OSLifecycleRequest_CycleType) Enum() *OSLifecycleRequest_CycleType { - p := new(OSLifecycleRequest_CycleType) - *p = x - return p -} - -func (x OSLifecycleRequest_CycleType) String() string { - return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) -} - -func (OSLifecycleRequest_CycleType) Descriptor() protoreflect.EnumDescriptor { - return file_daemon_proto_enumTypes[2].Descriptor() -} - -func (OSLifecycleRequest_CycleType) Type() protoreflect.EnumType { - return &file_daemon_proto_enumTypes[2] -} - -func (x OSLifecycleRequest_CycleType) Number() protoreflect.EnumNumber { - return protoreflect.EnumNumber(x) -} - -// Deprecated: Use OSLifecycleRequest_CycleType.Descriptor instead. -func (OSLifecycleRequest_CycleType) EnumDescriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{1, 0} -} - type SystemEvent_Severity int32 const ( @@ -229,11 +179,11 @@ func (x SystemEvent_Severity) String() string { } func (SystemEvent_Severity) Descriptor() protoreflect.EnumDescriptor { - return file_daemon_proto_enumTypes[3].Descriptor() + return file_daemon_proto_enumTypes[2].Descriptor() } func (SystemEvent_Severity) Type() protoreflect.EnumType { - return &file_daemon_proto_enumTypes[3] + return &file_daemon_proto_enumTypes[2] } func (x SystemEvent_Severity) Number() protoreflect.EnumNumber { @@ -242,7 +192,7 @@ func (x SystemEvent_Severity) Number() protoreflect.EnumNumber { // Deprecated: Use SystemEvent_Severity.Descriptor instead. func (SystemEvent_Severity) EnumDescriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{53, 0} + return file_daemon_proto_rawDescGZIP(), []int{51, 0} } type SystemEvent_Category int32 @@ -284,11 +234,11 @@ func (x SystemEvent_Category) String() string { } func (SystemEvent_Category) Descriptor() protoreflect.EnumDescriptor { - return file_daemon_proto_enumTypes[4].Descriptor() + return file_daemon_proto_enumTypes[3].Descriptor() } func (SystemEvent_Category) Type() protoreflect.EnumType { - return &file_daemon_proto_enumTypes[4] + return &file_daemon_proto_enumTypes[3] } func (x SystemEvent_Category) Number() protoreflect.EnumNumber { @@ -297,7 +247,7 @@ func (x SystemEvent_Category) Number() protoreflect.EnumNumber { // Deprecated: Use SystemEvent_Category.Descriptor instead. func (SystemEvent_Category) EnumDescriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{53, 1} + return file_daemon_proto_rawDescGZIP(), []int{51, 1} } type EmptyRequest struct { @@ -336,86 +286,6 @@ func (*EmptyRequest) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{0} } -type OSLifecycleRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - Type OSLifecycleRequest_CycleType `protobuf:"varint,1,opt,name=type,proto3,enum=daemon.OSLifecycleRequest_CycleType" json:"type,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *OSLifecycleRequest) Reset() { - *x = OSLifecycleRequest{} - mi := &file_daemon_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *OSLifecycleRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*OSLifecycleRequest) ProtoMessage() {} - -func (x *OSLifecycleRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[1] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use OSLifecycleRequest.ProtoReflect.Descriptor instead. -func (*OSLifecycleRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{1} -} - -func (x *OSLifecycleRequest) GetType() OSLifecycleRequest_CycleType { - if x != nil { - return x.Type - } - return OSLifecycleRequest_UNKNOWN -} - -type OSLifecycleResponse struct { - state protoimpl.MessageState `protogen:"open.v1"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *OSLifecycleResponse) Reset() { - *x = OSLifecycleResponse{} - mi := &file_daemon_proto_msgTypes[2] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *OSLifecycleResponse) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*OSLifecycleResponse) ProtoMessage() {} - -func (x *OSLifecycleResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[2] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use OSLifecycleResponse.ProtoReflect.Descriptor instead. -func (*OSLifecycleResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{2} -} - type LoginRequest struct { state protoimpl.MessageState `protogen:"open.v1"` // setupKey netbird setup key. @@ -478,7 +348,7 @@ type LoginRequest struct { func (x *LoginRequest) Reset() { *x = LoginRequest{} - mi := &file_daemon_proto_msgTypes[3] + mi := &file_daemon_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -490,7 +360,7 @@ func (x *LoginRequest) String() string { func (*LoginRequest) ProtoMessage() {} func (x *LoginRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[3] + mi := &file_daemon_proto_msgTypes[1] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -503,7 +373,7 @@ func (x *LoginRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use LoginRequest.ProtoReflect.Descriptor instead. func (*LoginRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{3} + return file_daemon_proto_rawDescGZIP(), []int{1} } func (x *LoginRequest) GetSetupKey() string { @@ -792,7 +662,7 @@ type LoginResponse struct { func (x *LoginResponse) Reset() { *x = LoginResponse{} - mi := &file_daemon_proto_msgTypes[4] + mi := &file_daemon_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -804,7 +674,7 @@ func (x *LoginResponse) String() string { func (*LoginResponse) ProtoMessage() {} func (x *LoginResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[4] + mi := &file_daemon_proto_msgTypes[2] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -817,7 +687,7 @@ func (x *LoginResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use LoginResponse.ProtoReflect.Descriptor instead. func (*LoginResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{4} + return file_daemon_proto_rawDescGZIP(), []int{2} } func (x *LoginResponse) GetNeedsSSOLogin() bool { @@ -858,7 +728,7 @@ type WaitSSOLoginRequest struct { func (x *WaitSSOLoginRequest) Reset() { *x = WaitSSOLoginRequest{} - mi := &file_daemon_proto_msgTypes[5] + mi := &file_daemon_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -870,7 +740,7 @@ func (x *WaitSSOLoginRequest) String() string { func (*WaitSSOLoginRequest) ProtoMessage() {} func (x *WaitSSOLoginRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[5] + mi := &file_daemon_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -883,7 +753,7 @@ func (x *WaitSSOLoginRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitSSOLoginRequest.ProtoReflect.Descriptor instead. func (*WaitSSOLoginRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{5} + return file_daemon_proto_rawDescGZIP(), []int{3} } func (x *WaitSSOLoginRequest) GetUserCode() string { @@ -909,7 +779,7 @@ type WaitSSOLoginResponse struct { func (x *WaitSSOLoginResponse) Reset() { *x = WaitSSOLoginResponse{} - mi := &file_daemon_proto_msgTypes[6] + mi := &file_daemon_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -921,7 +791,7 @@ func (x *WaitSSOLoginResponse) String() string { func (*WaitSSOLoginResponse) ProtoMessage() {} func (x *WaitSSOLoginResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[6] + mi := &file_daemon_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -934,7 +804,7 @@ func (x *WaitSSOLoginResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitSSOLoginResponse.ProtoReflect.Descriptor instead. func (*WaitSSOLoginResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{6} + return file_daemon_proto_rawDescGZIP(), []int{4} } func (x *WaitSSOLoginResponse) GetEmail() string { @@ -954,7 +824,7 @@ type UpRequest struct { func (x *UpRequest) Reset() { *x = UpRequest{} - mi := &file_daemon_proto_msgTypes[7] + mi := &file_daemon_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -966,7 +836,7 @@ func (x *UpRequest) String() string { func (*UpRequest) ProtoMessage() {} func (x *UpRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[7] + mi := &file_daemon_proto_msgTypes[5] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -979,7 +849,7 @@ func (x *UpRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use UpRequest.ProtoReflect.Descriptor instead. func (*UpRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{7} + return file_daemon_proto_rawDescGZIP(), []int{5} } func (x *UpRequest) GetProfileName() string { @@ -1004,7 +874,7 @@ type UpResponse struct { func (x *UpResponse) Reset() { *x = UpResponse{} - mi := &file_daemon_proto_msgTypes[8] + mi := &file_daemon_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1016,7 +886,7 @@ func (x *UpResponse) String() string { func (*UpResponse) ProtoMessage() {} func (x *UpResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[8] + mi := &file_daemon_proto_msgTypes[6] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1029,7 +899,7 @@ func (x *UpResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use UpResponse.ProtoReflect.Descriptor instead. func (*UpResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{8} + return file_daemon_proto_rawDescGZIP(), []int{6} } type StatusRequest struct { @@ -1044,7 +914,7 @@ type StatusRequest struct { func (x *StatusRequest) Reset() { *x = StatusRequest{} - mi := &file_daemon_proto_msgTypes[9] + mi := &file_daemon_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1056,7 +926,7 @@ func (x *StatusRequest) String() string { func (*StatusRequest) ProtoMessage() {} func (x *StatusRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[9] + mi := &file_daemon_proto_msgTypes[7] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1069,7 +939,7 @@ func (x *StatusRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use StatusRequest.ProtoReflect.Descriptor instead. func (*StatusRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{9} + return file_daemon_proto_rawDescGZIP(), []int{7} } func (x *StatusRequest) GetGetFullPeerStatus() bool { @@ -1106,7 +976,7 @@ type StatusResponse struct { func (x *StatusResponse) Reset() { *x = StatusResponse{} - mi := &file_daemon_proto_msgTypes[10] + mi := &file_daemon_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1118,7 +988,7 @@ func (x *StatusResponse) String() string { func (*StatusResponse) ProtoMessage() {} func (x *StatusResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[10] + mi := &file_daemon_proto_msgTypes[8] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1131,7 +1001,7 @@ func (x *StatusResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use StatusResponse.ProtoReflect.Descriptor instead. func (*StatusResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{10} + return file_daemon_proto_rawDescGZIP(), []int{8} } func (x *StatusResponse) GetStatus() string { @@ -1163,7 +1033,7 @@ type DownRequest struct { func (x *DownRequest) Reset() { *x = DownRequest{} - mi := &file_daemon_proto_msgTypes[11] + mi := &file_daemon_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1175,7 +1045,7 @@ func (x *DownRequest) String() string { func (*DownRequest) ProtoMessage() {} func (x *DownRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[11] + mi := &file_daemon_proto_msgTypes[9] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1188,7 +1058,7 @@ func (x *DownRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DownRequest.ProtoReflect.Descriptor instead. func (*DownRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{11} + return file_daemon_proto_rawDescGZIP(), []int{9} } type DownResponse struct { @@ -1199,7 +1069,7 @@ type DownResponse struct { func (x *DownResponse) Reset() { *x = DownResponse{} - mi := &file_daemon_proto_msgTypes[12] + mi := &file_daemon_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1211,7 +1081,7 @@ func (x *DownResponse) String() string { func (*DownResponse) ProtoMessage() {} func (x *DownResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[12] + mi := &file_daemon_proto_msgTypes[10] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1224,7 +1094,7 @@ func (x *DownResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DownResponse.ProtoReflect.Descriptor instead. func (*DownResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{12} + return file_daemon_proto_rawDescGZIP(), []int{10} } type GetConfigRequest struct { @@ -1237,7 +1107,7 @@ type GetConfigRequest struct { func (x *GetConfigRequest) Reset() { *x = GetConfigRequest{} - mi := &file_daemon_proto_msgTypes[13] + mi := &file_daemon_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1249,7 +1119,7 @@ func (x *GetConfigRequest) String() string { func (*GetConfigRequest) ProtoMessage() {} func (x *GetConfigRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[13] + mi := &file_daemon_proto_msgTypes[11] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1262,7 +1132,7 @@ func (x *GetConfigRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetConfigRequest.ProtoReflect.Descriptor instead. func (*GetConfigRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{13} + return file_daemon_proto_rawDescGZIP(), []int{11} } func (x *GetConfigRequest) GetProfileName() string { @@ -1318,7 +1188,7 @@ type GetConfigResponse struct { func (x *GetConfigResponse) Reset() { *x = GetConfigResponse{} - mi := &file_daemon_proto_msgTypes[14] + mi := &file_daemon_proto_msgTypes[12] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1330,7 +1200,7 @@ func (x *GetConfigResponse) String() string { func (*GetConfigResponse) ProtoMessage() {} func (x *GetConfigResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[14] + mi := &file_daemon_proto_msgTypes[12] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1343,7 +1213,7 @@ func (x *GetConfigResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetConfigResponse.ProtoReflect.Descriptor instead. func (*GetConfigResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{14} + return file_daemon_proto_rawDescGZIP(), []int{12} } func (x *GetConfigResponse) GetManagementUrl() string { @@ -1555,7 +1425,7 @@ type PeerState struct { func (x *PeerState) Reset() { *x = PeerState{} - mi := &file_daemon_proto_msgTypes[15] + mi := &file_daemon_proto_msgTypes[13] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1567,7 +1437,7 @@ func (x *PeerState) String() string { func (*PeerState) ProtoMessage() {} func (x *PeerState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[15] + mi := &file_daemon_proto_msgTypes[13] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1580,7 +1450,7 @@ func (x *PeerState) ProtoReflect() protoreflect.Message { // Deprecated: Use PeerState.ProtoReflect.Descriptor instead. func (*PeerState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{15} + return file_daemon_proto_rawDescGZIP(), []int{13} } func (x *PeerState) GetIP() string { @@ -1725,7 +1595,7 @@ type LocalPeerState struct { func (x *LocalPeerState) Reset() { *x = LocalPeerState{} - mi := &file_daemon_proto_msgTypes[16] + mi := &file_daemon_proto_msgTypes[14] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1737,7 +1607,7 @@ func (x *LocalPeerState) String() string { func (*LocalPeerState) ProtoMessage() {} func (x *LocalPeerState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[16] + mi := &file_daemon_proto_msgTypes[14] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1750,7 +1620,7 @@ func (x *LocalPeerState) ProtoReflect() protoreflect.Message { // Deprecated: Use LocalPeerState.ProtoReflect.Descriptor instead. func (*LocalPeerState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{16} + return file_daemon_proto_rawDescGZIP(), []int{14} } func (x *LocalPeerState) GetIP() string { @@ -1814,7 +1684,7 @@ type SignalState struct { func (x *SignalState) Reset() { *x = SignalState{} - mi := &file_daemon_proto_msgTypes[17] + mi := &file_daemon_proto_msgTypes[15] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1826,7 +1696,7 @@ func (x *SignalState) String() string { func (*SignalState) ProtoMessage() {} func (x *SignalState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[17] + mi := &file_daemon_proto_msgTypes[15] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1839,7 +1709,7 @@ func (x *SignalState) ProtoReflect() protoreflect.Message { // Deprecated: Use SignalState.ProtoReflect.Descriptor instead. func (*SignalState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{17} + return file_daemon_proto_rawDescGZIP(), []int{15} } func (x *SignalState) GetURL() string { @@ -1875,7 +1745,7 @@ type ManagementState struct { func (x *ManagementState) Reset() { *x = ManagementState{} - mi := &file_daemon_proto_msgTypes[18] + mi := &file_daemon_proto_msgTypes[16] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1887,7 +1757,7 @@ func (x *ManagementState) String() string { func (*ManagementState) ProtoMessage() {} func (x *ManagementState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[18] + mi := &file_daemon_proto_msgTypes[16] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1900,7 +1770,7 @@ func (x *ManagementState) ProtoReflect() protoreflect.Message { // Deprecated: Use ManagementState.ProtoReflect.Descriptor instead. func (*ManagementState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{18} + return file_daemon_proto_rawDescGZIP(), []int{16} } func (x *ManagementState) GetURL() string { @@ -1936,7 +1806,7 @@ type RelayState struct { func (x *RelayState) Reset() { *x = RelayState{} - mi := &file_daemon_proto_msgTypes[19] + mi := &file_daemon_proto_msgTypes[17] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1948,7 +1818,7 @@ func (x *RelayState) String() string { func (*RelayState) ProtoMessage() {} func (x *RelayState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[19] + mi := &file_daemon_proto_msgTypes[17] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1961,7 +1831,7 @@ func (x *RelayState) ProtoReflect() protoreflect.Message { // Deprecated: Use RelayState.ProtoReflect.Descriptor instead. func (*RelayState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{19} + return file_daemon_proto_rawDescGZIP(), []int{17} } func (x *RelayState) GetURI() string { @@ -1997,7 +1867,7 @@ type NSGroupState struct { func (x *NSGroupState) Reset() { *x = NSGroupState{} - mi := &file_daemon_proto_msgTypes[20] + mi := &file_daemon_proto_msgTypes[18] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2009,7 +1879,7 @@ func (x *NSGroupState) String() string { func (*NSGroupState) ProtoMessage() {} func (x *NSGroupState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[20] + mi := &file_daemon_proto_msgTypes[18] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2022,7 +1892,7 @@ func (x *NSGroupState) ProtoReflect() protoreflect.Message { // Deprecated: Use NSGroupState.ProtoReflect.Descriptor instead. func (*NSGroupState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{20} + return file_daemon_proto_rawDescGZIP(), []int{18} } func (x *NSGroupState) GetServers() []string { @@ -2067,7 +1937,7 @@ type SSHSessionInfo struct { func (x *SSHSessionInfo) Reset() { *x = SSHSessionInfo{} - mi := &file_daemon_proto_msgTypes[21] + mi := &file_daemon_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2079,7 +1949,7 @@ func (x *SSHSessionInfo) String() string { func (*SSHSessionInfo) ProtoMessage() {} func (x *SSHSessionInfo) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[21] + mi := &file_daemon_proto_msgTypes[19] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2092,7 +1962,7 @@ func (x *SSHSessionInfo) ProtoReflect() protoreflect.Message { // Deprecated: Use SSHSessionInfo.ProtoReflect.Descriptor instead. func (*SSHSessionInfo) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{21} + return file_daemon_proto_rawDescGZIP(), []int{19} } func (x *SSHSessionInfo) GetUsername() string { @@ -2141,7 +2011,7 @@ type SSHServerState struct { func (x *SSHServerState) Reset() { *x = SSHServerState{} - mi := &file_daemon_proto_msgTypes[22] + mi := &file_daemon_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2153,7 +2023,7 @@ func (x *SSHServerState) String() string { func (*SSHServerState) ProtoMessage() {} func (x *SSHServerState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[22] + mi := &file_daemon_proto_msgTypes[20] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2166,7 +2036,7 @@ func (x *SSHServerState) ProtoReflect() protoreflect.Message { // Deprecated: Use SSHServerState.ProtoReflect.Descriptor instead. func (*SSHServerState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{22} + return file_daemon_proto_rawDescGZIP(), []int{20} } func (x *SSHServerState) GetEnabled() bool { @@ -2202,7 +2072,7 @@ type FullStatus struct { func (x *FullStatus) Reset() { *x = FullStatus{} - mi := &file_daemon_proto_msgTypes[23] + mi := &file_daemon_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2214,7 +2084,7 @@ func (x *FullStatus) String() string { func (*FullStatus) ProtoMessage() {} func (x *FullStatus) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[23] + mi := &file_daemon_proto_msgTypes[21] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2227,7 +2097,7 @@ func (x *FullStatus) ProtoReflect() protoreflect.Message { // Deprecated: Use FullStatus.ProtoReflect.Descriptor instead. func (*FullStatus) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{23} + return file_daemon_proto_rawDescGZIP(), []int{21} } func (x *FullStatus) GetManagementState() *ManagementState { @@ -2309,7 +2179,7 @@ type ListNetworksRequest struct { func (x *ListNetworksRequest) Reset() { *x = ListNetworksRequest{} - mi := &file_daemon_proto_msgTypes[24] + mi := &file_daemon_proto_msgTypes[22] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2321,7 +2191,7 @@ func (x *ListNetworksRequest) String() string { func (*ListNetworksRequest) ProtoMessage() {} func (x *ListNetworksRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[24] + mi := &file_daemon_proto_msgTypes[22] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2334,7 +2204,7 @@ func (x *ListNetworksRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListNetworksRequest.ProtoReflect.Descriptor instead. func (*ListNetworksRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{24} + return file_daemon_proto_rawDescGZIP(), []int{22} } type ListNetworksResponse struct { @@ -2346,7 +2216,7 @@ type ListNetworksResponse struct { func (x *ListNetworksResponse) Reset() { *x = ListNetworksResponse{} - mi := &file_daemon_proto_msgTypes[25] + mi := &file_daemon_proto_msgTypes[23] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2358,7 +2228,7 @@ func (x *ListNetworksResponse) String() string { func (*ListNetworksResponse) ProtoMessage() {} func (x *ListNetworksResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[25] + mi := &file_daemon_proto_msgTypes[23] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2371,7 +2241,7 @@ func (x *ListNetworksResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListNetworksResponse.ProtoReflect.Descriptor instead. func (*ListNetworksResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{25} + return file_daemon_proto_rawDescGZIP(), []int{23} } func (x *ListNetworksResponse) GetRoutes() []*Network { @@ -2392,7 +2262,7 @@ type SelectNetworksRequest struct { func (x *SelectNetworksRequest) Reset() { *x = SelectNetworksRequest{} - mi := &file_daemon_proto_msgTypes[26] + mi := &file_daemon_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2404,7 +2274,7 @@ func (x *SelectNetworksRequest) String() string { func (*SelectNetworksRequest) ProtoMessage() {} func (x *SelectNetworksRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[26] + mi := &file_daemon_proto_msgTypes[24] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2417,7 +2287,7 @@ func (x *SelectNetworksRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SelectNetworksRequest.ProtoReflect.Descriptor instead. func (*SelectNetworksRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{26} + return file_daemon_proto_rawDescGZIP(), []int{24} } func (x *SelectNetworksRequest) GetNetworkIDs() []string { @@ -2449,7 +2319,7 @@ type SelectNetworksResponse struct { func (x *SelectNetworksResponse) Reset() { *x = SelectNetworksResponse{} - mi := &file_daemon_proto_msgTypes[27] + mi := &file_daemon_proto_msgTypes[25] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2461,7 +2331,7 @@ func (x *SelectNetworksResponse) String() string { func (*SelectNetworksResponse) ProtoMessage() {} func (x *SelectNetworksResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[27] + mi := &file_daemon_proto_msgTypes[25] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2474,7 +2344,7 @@ func (x *SelectNetworksResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SelectNetworksResponse.ProtoReflect.Descriptor instead. func (*SelectNetworksResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{27} + return file_daemon_proto_rawDescGZIP(), []int{25} } type IPList struct { @@ -2486,7 +2356,7 @@ type IPList struct { func (x *IPList) Reset() { *x = IPList{} - mi := &file_daemon_proto_msgTypes[28] + mi := &file_daemon_proto_msgTypes[26] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2498,7 +2368,7 @@ func (x *IPList) String() string { func (*IPList) ProtoMessage() {} func (x *IPList) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[28] + mi := &file_daemon_proto_msgTypes[26] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2511,7 +2381,7 @@ func (x *IPList) ProtoReflect() protoreflect.Message { // Deprecated: Use IPList.ProtoReflect.Descriptor instead. func (*IPList) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{28} + return file_daemon_proto_rawDescGZIP(), []int{26} } func (x *IPList) GetIps() []string { @@ -2534,7 +2404,7 @@ type Network struct { func (x *Network) Reset() { *x = Network{} - mi := &file_daemon_proto_msgTypes[29] + mi := &file_daemon_proto_msgTypes[27] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2546,7 +2416,7 @@ func (x *Network) String() string { func (*Network) ProtoMessage() {} func (x *Network) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[29] + mi := &file_daemon_proto_msgTypes[27] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2559,7 +2429,7 @@ func (x *Network) ProtoReflect() protoreflect.Message { // Deprecated: Use Network.ProtoReflect.Descriptor instead. func (*Network) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{29} + return file_daemon_proto_rawDescGZIP(), []int{27} } func (x *Network) GetID() string { @@ -2611,7 +2481,7 @@ type PortInfo struct { func (x *PortInfo) Reset() { *x = PortInfo{} - mi := &file_daemon_proto_msgTypes[30] + mi := &file_daemon_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2623,7 +2493,7 @@ func (x *PortInfo) String() string { func (*PortInfo) ProtoMessage() {} func (x *PortInfo) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[30] + mi := &file_daemon_proto_msgTypes[28] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2636,7 +2506,7 @@ func (x *PortInfo) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo.ProtoReflect.Descriptor instead. func (*PortInfo) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{30} + return file_daemon_proto_rawDescGZIP(), []int{28} } func (x *PortInfo) GetPortSelection() isPortInfo_PortSelection { @@ -2693,7 +2563,7 @@ type ForwardingRule struct { func (x *ForwardingRule) Reset() { *x = ForwardingRule{} - mi := &file_daemon_proto_msgTypes[31] + mi := &file_daemon_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2705,7 +2575,7 @@ func (x *ForwardingRule) String() string { func (*ForwardingRule) ProtoMessage() {} func (x *ForwardingRule) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[31] + mi := &file_daemon_proto_msgTypes[29] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2718,7 +2588,7 @@ func (x *ForwardingRule) ProtoReflect() protoreflect.Message { // Deprecated: Use ForwardingRule.ProtoReflect.Descriptor instead. func (*ForwardingRule) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{31} + return file_daemon_proto_rawDescGZIP(), []int{29} } func (x *ForwardingRule) GetProtocol() string { @@ -2765,7 +2635,7 @@ type ForwardingRulesResponse struct { func (x *ForwardingRulesResponse) Reset() { *x = ForwardingRulesResponse{} - mi := &file_daemon_proto_msgTypes[32] + mi := &file_daemon_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2777,7 +2647,7 @@ func (x *ForwardingRulesResponse) String() string { func (*ForwardingRulesResponse) ProtoMessage() {} func (x *ForwardingRulesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[32] + mi := &file_daemon_proto_msgTypes[30] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2790,7 +2660,7 @@ func (x *ForwardingRulesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ForwardingRulesResponse.ProtoReflect.Descriptor instead. func (*ForwardingRulesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{32} + return file_daemon_proto_rawDescGZIP(), []int{30} } func (x *ForwardingRulesResponse) GetRules() []*ForwardingRule { @@ -2813,7 +2683,7 @@ type DebugBundleRequest struct { func (x *DebugBundleRequest) Reset() { *x = DebugBundleRequest{} - mi := &file_daemon_proto_msgTypes[33] + mi := &file_daemon_proto_msgTypes[31] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2825,7 +2695,7 @@ func (x *DebugBundleRequest) String() string { func (*DebugBundleRequest) ProtoMessage() {} func (x *DebugBundleRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[33] + mi := &file_daemon_proto_msgTypes[31] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2838,7 +2708,7 @@ func (x *DebugBundleRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DebugBundleRequest.ProtoReflect.Descriptor instead. func (*DebugBundleRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{33} + return file_daemon_proto_rawDescGZIP(), []int{31} } func (x *DebugBundleRequest) GetAnonymize() bool { @@ -2880,7 +2750,7 @@ type DebugBundleResponse struct { func (x *DebugBundleResponse) Reset() { *x = DebugBundleResponse{} - mi := &file_daemon_proto_msgTypes[34] + mi := &file_daemon_proto_msgTypes[32] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2892,7 +2762,7 @@ func (x *DebugBundleResponse) String() string { func (*DebugBundleResponse) ProtoMessage() {} func (x *DebugBundleResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[34] + mi := &file_daemon_proto_msgTypes[32] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2905,7 +2775,7 @@ func (x *DebugBundleResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DebugBundleResponse.ProtoReflect.Descriptor instead. func (*DebugBundleResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{34} + return file_daemon_proto_rawDescGZIP(), []int{32} } func (x *DebugBundleResponse) GetPath() string { @@ -2937,7 +2807,7 @@ type GetLogLevelRequest struct { func (x *GetLogLevelRequest) Reset() { *x = GetLogLevelRequest{} - mi := &file_daemon_proto_msgTypes[35] + mi := &file_daemon_proto_msgTypes[33] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2949,7 +2819,7 @@ func (x *GetLogLevelRequest) String() string { func (*GetLogLevelRequest) ProtoMessage() {} func (x *GetLogLevelRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[35] + mi := &file_daemon_proto_msgTypes[33] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2962,7 +2832,7 @@ func (x *GetLogLevelRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetLogLevelRequest.ProtoReflect.Descriptor instead. func (*GetLogLevelRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{35} + return file_daemon_proto_rawDescGZIP(), []int{33} } type GetLogLevelResponse struct { @@ -2974,7 +2844,7 @@ type GetLogLevelResponse struct { func (x *GetLogLevelResponse) Reset() { *x = GetLogLevelResponse{} - mi := &file_daemon_proto_msgTypes[36] + mi := &file_daemon_proto_msgTypes[34] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2986,7 +2856,7 @@ func (x *GetLogLevelResponse) String() string { func (*GetLogLevelResponse) ProtoMessage() {} func (x *GetLogLevelResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[36] + mi := &file_daemon_proto_msgTypes[34] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2999,7 +2869,7 @@ func (x *GetLogLevelResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetLogLevelResponse.ProtoReflect.Descriptor instead. func (*GetLogLevelResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{36} + return file_daemon_proto_rawDescGZIP(), []int{34} } func (x *GetLogLevelResponse) GetLevel() LogLevel { @@ -3018,7 +2888,7 @@ type SetLogLevelRequest struct { func (x *SetLogLevelRequest) Reset() { *x = SetLogLevelRequest{} - mi := &file_daemon_proto_msgTypes[37] + mi := &file_daemon_proto_msgTypes[35] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3030,7 +2900,7 @@ func (x *SetLogLevelRequest) String() string { func (*SetLogLevelRequest) ProtoMessage() {} func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[37] + mi := &file_daemon_proto_msgTypes[35] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3043,7 +2913,7 @@ func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SetLogLevelRequest.ProtoReflect.Descriptor instead. func (*SetLogLevelRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{37} + return file_daemon_proto_rawDescGZIP(), []int{35} } func (x *SetLogLevelRequest) GetLevel() LogLevel { @@ -3061,7 +2931,7 @@ type SetLogLevelResponse struct { func (x *SetLogLevelResponse) Reset() { *x = SetLogLevelResponse{} - mi := &file_daemon_proto_msgTypes[38] + mi := &file_daemon_proto_msgTypes[36] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3073,7 +2943,7 @@ func (x *SetLogLevelResponse) String() string { func (*SetLogLevelResponse) ProtoMessage() {} func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[38] + mi := &file_daemon_proto_msgTypes[36] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3086,7 +2956,7 @@ func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SetLogLevelResponse.ProtoReflect.Descriptor instead. func (*SetLogLevelResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{38} + return file_daemon_proto_rawDescGZIP(), []int{36} } // State represents a daemon state entry @@ -3099,7 +2969,7 @@ type State struct { func (x *State) Reset() { *x = State{} - mi := &file_daemon_proto_msgTypes[39] + mi := &file_daemon_proto_msgTypes[37] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3111,7 +2981,7 @@ func (x *State) String() string { func (*State) ProtoMessage() {} func (x *State) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[39] + mi := &file_daemon_proto_msgTypes[37] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3124,7 +2994,7 @@ func (x *State) ProtoReflect() protoreflect.Message { // Deprecated: Use State.ProtoReflect.Descriptor instead. func (*State) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{39} + return file_daemon_proto_rawDescGZIP(), []int{37} } func (x *State) GetName() string { @@ -3143,7 +3013,7 @@ type ListStatesRequest struct { func (x *ListStatesRequest) Reset() { *x = ListStatesRequest{} - mi := &file_daemon_proto_msgTypes[40] + mi := &file_daemon_proto_msgTypes[38] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3155,7 +3025,7 @@ func (x *ListStatesRequest) String() string { func (*ListStatesRequest) ProtoMessage() {} func (x *ListStatesRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[40] + mi := &file_daemon_proto_msgTypes[38] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3168,7 +3038,7 @@ func (x *ListStatesRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListStatesRequest.ProtoReflect.Descriptor instead. func (*ListStatesRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{40} + return file_daemon_proto_rawDescGZIP(), []int{38} } // ListStatesResponse contains a list of states @@ -3181,7 +3051,7 @@ type ListStatesResponse struct { func (x *ListStatesResponse) Reset() { *x = ListStatesResponse{} - mi := &file_daemon_proto_msgTypes[41] + mi := &file_daemon_proto_msgTypes[39] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3193,7 +3063,7 @@ func (x *ListStatesResponse) String() string { func (*ListStatesResponse) ProtoMessage() {} func (x *ListStatesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[41] + mi := &file_daemon_proto_msgTypes[39] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3206,7 +3076,7 @@ func (x *ListStatesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListStatesResponse.ProtoReflect.Descriptor instead. func (*ListStatesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{41} + return file_daemon_proto_rawDescGZIP(), []int{39} } func (x *ListStatesResponse) GetStates() []*State { @@ -3227,7 +3097,7 @@ type CleanStateRequest struct { func (x *CleanStateRequest) Reset() { *x = CleanStateRequest{} - mi := &file_daemon_proto_msgTypes[42] + mi := &file_daemon_proto_msgTypes[40] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3239,7 +3109,7 @@ func (x *CleanStateRequest) String() string { func (*CleanStateRequest) ProtoMessage() {} func (x *CleanStateRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[42] + mi := &file_daemon_proto_msgTypes[40] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3252,7 +3122,7 @@ func (x *CleanStateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use CleanStateRequest.ProtoReflect.Descriptor instead. func (*CleanStateRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{42} + return file_daemon_proto_rawDescGZIP(), []int{40} } func (x *CleanStateRequest) GetStateName() string { @@ -3279,7 +3149,7 @@ type CleanStateResponse struct { func (x *CleanStateResponse) Reset() { *x = CleanStateResponse{} - mi := &file_daemon_proto_msgTypes[43] + mi := &file_daemon_proto_msgTypes[41] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3291,7 +3161,7 @@ func (x *CleanStateResponse) String() string { func (*CleanStateResponse) ProtoMessage() {} func (x *CleanStateResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[43] + mi := &file_daemon_proto_msgTypes[41] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3304,7 +3174,7 @@ func (x *CleanStateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use CleanStateResponse.ProtoReflect.Descriptor instead. func (*CleanStateResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{43} + return file_daemon_proto_rawDescGZIP(), []int{41} } func (x *CleanStateResponse) GetCleanedStates() int32 { @@ -3325,7 +3195,7 @@ type DeleteStateRequest struct { func (x *DeleteStateRequest) Reset() { *x = DeleteStateRequest{} - mi := &file_daemon_proto_msgTypes[44] + mi := &file_daemon_proto_msgTypes[42] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3337,7 +3207,7 @@ func (x *DeleteStateRequest) String() string { func (*DeleteStateRequest) ProtoMessage() {} func (x *DeleteStateRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[44] + mi := &file_daemon_proto_msgTypes[42] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3350,7 +3220,7 @@ func (x *DeleteStateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DeleteStateRequest.ProtoReflect.Descriptor instead. func (*DeleteStateRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{44} + return file_daemon_proto_rawDescGZIP(), []int{42} } func (x *DeleteStateRequest) GetStateName() string { @@ -3377,7 +3247,7 @@ type DeleteStateResponse struct { func (x *DeleteStateResponse) Reset() { *x = DeleteStateResponse{} - mi := &file_daemon_proto_msgTypes[45] + mi := &file_daemon_proto_msgTypes[43] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3389,7 +3259,7 @@ func (x *DeleteStateResponse) String() string { func (*DeleteStateResponse) ProtoMessage() {} func (x *DeleteStateResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[45] + mi := &file_daemon_proto_msgTypes[43] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3402,7 +3272,7 @@ func (x *DeleteStateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DeleteStateResponse.ProtoReflect.Descriptor instead. func (*DeleteStateResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{45} + return file_daemon_proto_rawDescGZIP(), []int{43} } func (x *DeleteStateResponse) GetDeletedStates() int32 { @@ -3421,7 +3291,7 @@ type SetSyncResponsePersistenceRequest struct { func (x *SetSyncResponsePersistenceRequest) Reset() { *x = SetSyncResponsePersistenceRequest{} - mi := &file_daemon_proto_msgTypes[46] + mi := &file_daemon_proto_msgTypes[44] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3433,7 +3303,7 @@ func (x *SetSyncResponsePersistenceRequest) String() string { func (*SetSyncResponsePersistenceRequest) ProtoMessage() {} func (x *SetSyncResponsePersistenceRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[46] + mi := &file_daemon_proto_msgTypes[44] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3446,7 +3316,7 @@ func (x *SetSyncResponsePersistenceRequest) ProtoReflect() protoreflect.Message // Deprecated: Use SetSyncResponsePersistenceRequest.ProtoReflect.Descriptor instead. func (*SetSyncResponsePersistenceRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{46} + return file_daemon_proto_rawDescGZIP(), []int{44} } func (x *SetSyncResponsePersistenceRequest) GetEnabled() bool { @@ -3464,7 +3334,7 @@ type SetSyncResponsePersistenceResponse struct { func (x *SetSyncResponsePersistenceResponse) Reset() { *x = SetSyncResponsePersistenceResponse{} - mi := &file_daemon_proto_msgTypes[47] + mi := &file_daemon_proto_msgTypes[45] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3476,7 +3346,7 @@ func (x *SetSyncResponsePersistenceResponse) String() string { func (*SetSyncResponsePersistenceResponse) ProtoMessage() {} func (x *SetSyncResponsePersistenceResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[47] + mi := &file_daemon_proto_msgTypes[45] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3489,7 +3359,7 @@ func (x *SetSyncResponsePersistenceResponse) ProtoReflect() protoreflect.Message // Deprecated: Use SetSyncResponsePersistenceResponse.ProtoReflect.Descriptor instead. func (*SetSyncResponsePersistenceResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{47} + return file_daemon_proto_rawDescGZIP(), []int{45} } type TCPFlags struct { @@ -3506,7 +3376,7 @@ type TCPFlags struct { func (x *TCPFlags) Reset() { *x = TCPFlags{} - mi := &file_daemon_proto_msgTypes[48] + mi := &file_daemon_proto_msgTypes[46] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3518,7 +3388,7 @@ func (x *TCPFlags) String() string { func (*TCPFlags) ProtoMessage() {} func (x *TCPFlags) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[48] + mi := &file_daemon_proto_msgTypes[46] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3531,7 +3401,7 @@ func (x *TCPFlags) ProtoReflect() protoreflect.Message { // Deprecated: Use TCPFlags.ProtoReflect.Descriptor instead. func (*TCPFlags) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{48} + return file_daemon_proto_rawDescGZIP(), []int{46} } func (x *TCPFlags) GetSyn() bool { @@ -3593,7 +3463,7 @@ type TracePacketRequest struct { func (x *TracePacketRequest) Reset() { *x = TracePacketRequest{} - mi := &file_daemon_proto_msgTypes[49] + mi := &file_daemon_proto_msgTypes[47] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3605,7 +3475,7 @@ func (x *TracePacketRequest) String() string { func (*TracePacketRequest) ProtoMessage() {} func (x *TracePacketRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[49] + mi := &file_daemon_proto_msgTypes[47] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3618,7 +3488,7 @@ func (x *TracePacketRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use TracePacketRequest.ProtoReflect.Descriptor instead. func (*TracePacketRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{49} + return file_daemon_proto_rawDescGZIP(), []int{47} } func (x *TracePacketRequest) GetSourceIp() string { @@ -3696,7 +3566,7 @@ type TraceStage struct { func (x *TraceStage) Reset() { *x = TraceStage{} - mi := &file_daemon_proto_msgTypes[50] + mi := &file_daemon_proto_msgTypes[48] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3708,7 +3578,7 @@ func (x *TraceStage) String() string { func (*TraceStage) ProtoMessage() {} func (x *TraceStage) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[50] + mi := &file_daemon_proto_msgTypes[48] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3721,7 +3591,7 @@ func (x *TraceStage) ProtoReflect() protoreflect.Message { // Deprecated: Use TraceStage.ProtoReflect.Descriptor instead. func (*TraceStage) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{50} + return file_daemon_proto_rawDescGZIP(), []int{48} } func (x *TraceStage) GetName() string { @@ -3762,7 +3632,7 @@ type TracePacketResponse struct { func (x *TracePacketResponse) Reset() { *x = TracePacketResponse{} - mi := &file_daemon_proto_msgTypes[51] + mi := &file_daemon_proto_msgTypes[49] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3774,7 +3644,7 @@ func (x *TracePacketResponse) String() string { func (*TracePacketResponse) ProtoMessage() {} func (x *TracePacketResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[51] + mi := &file_daemon_proto_msgTypes[49] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3787,7 +3657,7 @@ func (x *TracePacketResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use TracePacketResponse.ProtoReflect.Descriptor instead. func (*TracePacketResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{51} + return file_daemon_proto_rawDescGZIP(), []int{49} } func (x *TracePacketResponse) GetStages() []*TraceStage { @@ -3812,7 +3682,7 @@ type SubscribeRequest struct { func (x *SubscribeRequest) Reset() { *x = SubscribeRequest{} - mi := &file_daemon_proto_msgTypes[52] + mi := &file_daemon_proto_msgTypes[50] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3824,7 +3694,7 @@ func (x *SubscribeRequest) String() string { func (*SubscribeRequest) ProtoMessage() {} func (x *SubscribeRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[52] + mi := &file_daemon_proto_msgTypes[50] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3837,7 +3707,7 @@ func (x *SubscribeRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SubscribeRequest.ProtoReflect.Descriptor instead. func (*SubscribeRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{52} + return file_daemon_proto_rawDescGZIP(), []int{50} } type SystemEvent struct { @@ -3855,7 +3725,7 @@ type SystemEvent struct { func (x *SystemEvent) Reset() { *x = SystemEvent{} - mi := &file_daemon_proto_msgTypes[53] + mi := &file_daemon_proto_msgTypes[51] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3867,7 +3737,7 @@ func (x *SystemEvent) String() string { func (*SystemEvent) ProtoMessage() {} func (x *SystemEvent) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[53] + mi := &file_daemon_proto_msgTypes[51] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3880,7 +3750,7 @@ func (x *SystemEvent) ProtoReflect() protoreflect.Message { // Deprecated: Use SystemEvent.ProtoReflect.Descriptor instead. func (*SystemEvent) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{53} + return file_daemon_proto_rawDescGZIP(), []int{51} } func (x *SystemEvent) GetId() string { @@ -3940,7 +3810,7 @@ type GetEventsRequest struct { func (x *GetEventsRequest) Reset() { *x = GetEventsRequest{} - mi := &file_daemon_proto_msgTypes[54] + mi := &file_daemon_proto_msgTypes[52] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3952,7 +3822,7 @@ func (x *GetEventsRequest) String() string { func (*GetEventsRequest) ProtoMessage() {} func (x *GetEventsRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[54] + mi := &file_daemon_proto_msgTypes[52] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3965,7 +3835,7 @@ func (x *GetEventsRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetEventsRequest.ProtoReflect.Descriptor instead. func (*GetEventsRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{54} + return file_daemon_proto_rawDescGZIP(), []int{52} } type GetEventsResponse struct { @@ -3977,7 +3847,7 @@ type GetEventsResponse struct { func (x *GetEventsResponse) Reset() { *x = GetEventsResponse{} - mi := &file_daemon_proto_msgTypes[55] + mi := &file_daemon_proto_msgTypes[53] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3989,7 +3859,7 @@ func (x *GetEventsResponse) String() string { func (*GetEventsResponse) ProtoMessage() {} func (x *GetEventsResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[55] + mi := &file_daemon_proto_msgTypes[53] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4002,7 +3872,7 @@ func (x *GetEventsResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetEventsResponse.ProtoReflect.Descriptor instead. func (*GetEventsResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{55} + return file_daemon_proto_rawDescGZIP(), []int{53} } func (x *GetEventsResponse) GetEvents() []*SystemEvent { @@ -4022,7 +3892,7 @@ type SwitchProfileRequest struct { func (x *SwitchProfileRequest) Reset() { *x = SwitchProfileRequest{} - mi := &file_daemon_proto_msgTypes[56] + mi := &file_daemon_proto_msgTypes[54] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4034,7 +3904,7 @@ func (x *SwitchProfileRequest) String() string { func (*SwitchProfileRequest) ProtoMessage() {} func (x *SwitchProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[56] + mi := &file_daemon_proto_msgTypes[54] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4047,7 +3917,7 @@ func (x *SwitchProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SwitchProfileRequest.ProtoReflect.Descriptor instead. func (*SwitchProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{56} + return file_daemon_proto_rawDescGZIP(), []int{54} } func (x *SwitchProfileRequest) GetProfileName() string { @@ -4072,7 +3942,7 @@ type SwitchProfileResponse struct { func (x *SwitchProfileResponse) Reset() { *x = SwitchProfileResponse{} - mi := &file_daemon_proto_msgTypes[57] + mi := &file_daemon_proto_msgTypes[55] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4084,7 +3954,7 @@ func (x *SwitchProfileResponse) String() string { func (*SwitchProfileResponse) ProtoMessage() {} func (x *SwitchProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[57] + mi := &file_daemon_proto_msgTypes[55] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4097,7 +3967,7 @@ func (x *SwitchProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SwitchProfileResponse.ProtoReflect.Descriptor instead. func (*SwitchProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{57} + return file_daemon_proto_rawDescGZIP(), []int{55} } type SetConfigRequest struct { @@ -4145,7 +4015,7 @@ type SetConfigRequest struct { func (x *SetConfigRequest) Reset() { *x = SetConfigRequest{} - mi := &file_daemon_proto_msgTypes[58] + mi := &file_daemon_proto_msgTypes[56] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4157,7 +4027,7 @@ func (x *SetConfigRequest) String() string { func (*SetConfigRequest) ProtoMessage() {} func (x *SetConfigRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[58] + mi := &file_daemon_proto_msgTypes[56] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4170,7 +4040,7 @@ func (x *SetConfigRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SetConfigRequest.ProtoReflect.Descriptor instead. func (*SetConfigRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{58} + return file_daemon_proto_rawDescGZIP(), []int{56} } func (x *SetConfigRequest) GetUsername() string { @@ -4419,7 +4289,7 @@ type SetConfigResponse struct { func (x *SetConfigResponse) Reset() { *x = SetConfigResponse{} - mi := &file_daemon_proto_msgTypes[59] + mi := &file_daemon_proto_msgTypes[57] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4431,7 +4301,7 @@ func (x *SetConfigResponse) String() string { func (*SetConfigResponse) ProtoMessage() {} func (x *SetConfigResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[59] + mi := &file_daemon_proto_msgTypes[57] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4444,7 +4314,7 @@ func (x *SetConfigResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SetConfigResponse.ProtoReflect.Descriptor instead. func (*SetConfigResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{59} + return file_daemon_proto_rawDescGZIP(), []int{57} } type AddProfileRequest struct { @@ -4457,7 +4327,7 @@ type AddProfileRequest struct { func (x *AddProfileRequest) Reset() { *x = AddProfileRequest{} - mi := &file_daemon_proto_msgTypes[60] + mi := &file_daemon_proto_msgTypes[58] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4469,7 +4339,7 @@ func (x *AddProfileRequest) String() string { func (*AddProfileRequest) ProtoMessage() {} func (x *AddProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[60] + mi := &file_daemon_proto_msgTypes[58] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4482,7 +4352,7 @@ func (x *AddProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use AddProfileRequest.ProtoReflect.Descriptor instead. func (*AddProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{60} + return file_daemon_proto_rawDescGZIP(), []int{58} } func (x *AddProfileRequest) GetUsername() string { @@ -4507,7 +4377,7 @@ type AddProfileResponse struct { func (x *AddProfileResponse) Reset() { *x = AddProfileResponse{} - mi := &file_daemon_proto_msgTypes[61] + mi := &file_daemon_proto_msgTypes[59] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4519,7 +4389,7 @@ func (x *AddProfileResponse) String() string { func (*AddProfileResponse) ProtoMessage() {} func (x *AddProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[61] + mi := &file_daemon_proto_msgTypes[59] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4532,7 +4402,7 @@ func (x *AddProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use AddProfileResponse.ProtoReflect.Descriptor instead. func (*AddProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{61} + return file_daemon_proto_rawDescGZIP(), []int{59} } type RemoveProfileRequest struct { @@ -4545,7 +4415,7 @@ type RemoveProfileRequest struct { func (x *RemoveProfileRequest) Reset() { *x = RemoveProfileRequest{} - mi := &file_daemon_proto_msgTypes[62] + mi := &file_daemon_proto_msgTypes[60] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4557,7 +4427,7 @@ func (x *RemoveProfileRequest) String() string { func (*RemoveProfileRequest) ProtoMessage() {} func (x *RemoveProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[62] + mi := &file_daemon_proto_msgTypes[60] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4570,7 +4440,7 @@ func (x *RemoveProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use RemoveProfileRequest.ProtoReflect.Descriptor instead. func (*RemoveProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{62} + return file_daemon_proto_rawDescGZIP(), []int{60} } func (x *RemoveProfileRequest) GetUsername() string { @@ -4595,7 +4465,7 @@ type RemoveProfileResponse struct { func (x *RemoveProfileResponse) Reset() { *x = RemoveProfileResponse{} - mi := &file_daemon_proto_msgTypes[63] + mi := &file_daemon_proto_msgTypes[61] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4607,7 +4477,7 @@ func (x *RemoveProfileResponse) String() string { func (*RemoveProfileResponse) ProtoMessage() {} func (x *RemoveProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[63] + mi := &file_daemon_proto_msgTypes[61] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4620,7 +4490,7 @@ func (x *RemoveProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RemoveProfileResponse.ProtoReflect.Descriptor instead. func (*RemoveProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{63} + return file_daemon_proto_rawDescGZIP(), []int{61} } type ListProfilesRequest struct { @@ -4632,7 +4502,7 @@ type ListProfilesRequest struct { func (x *ListProfilesRequest) Reset() { *x = ListProfilesRequest{} - mi := &file_daemon_proto_msgTypes[64] + mi := &file_daemon_proto_msgTypes[62] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4644,7 +4514,7 @@ func (x *ListProfilesRequest) String() string { func (*ListProfilesRequest) ProtoMessage() {} func (x *ListProfilesRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[64] + mi := &file_daemon_proto_msgTypes[62] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4657,7 +4527,7 @@ func (x *ListProfilesRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListProfilesRequest.ProtoReflect.Descriptor instead. func (*ListProfilesRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{64} + return file_daemon_proto_rawDescGZIP(), []int{62} } func (x *ListProfilesRequest) GetUsername() string { @@ -4676,7 +4546,7 @@ type ListProfilesResponse struct { func (x *ListProfilesResponse) Reset() { *x = ListProfilesResponse{} - mi := &file_daemon_proto_msgTypes[65] + mi := &file_daemon_proto_msgTypes[63] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4688,7 +4558,7 @@ func (x *ListProfilesResponse) String() string { func (*ListProfilesResponse) ProtoMessage() {} func (x *ListProfilesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[65] + mi := &file_daemon_proto_msgTypes[63] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4701,7 +4571,7 @@ func (x *ListProfilesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListProfilesResponse.ProtoReflect.Descriptor instead. func (*ListProfilesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{65} + return file_daemon_proto_rawDescGZIP(), []int{63} } func (x *ListProfilesResponse) GetProfiles() []*Profile { @@ -4721,7 +4591,7 @@ type Profile struct { func (x *Profile) Reset() { *x = Profile{} - mi := &file_daemon_proto_msgTypes[66] + mi := &file_daemon_proto_msgTypes[64] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4733,7 +4603,7 @@ func (x *Profile) String() string { func (*Profile) ProtoMessage() {} func (x *Profile) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[66] + mi := &file_daemon_proto_msgTypes[64] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4746,7 +4616,7 @@ func (x *Profile) ProtoReflect() protoreflect.Message { // Deprecated: Use Profile.ProtoReflect.Descriptor instead. func (*Profile) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{66} + return file_daemon_proto_rawDescGZIP(), []int{64} } func (x *Profile) GetName() string { @@ -4771,7 +4641,7 @@ type GetActiveProfileRequest struct { func (x *GetActiveProfileRequest) Reset() { *x = GetActiveProfileRequest{} - mi := &file_daemon_proto_msgTypes[67] + mi := &file_daemon_proto_msgTypes[65] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4783,7 +4653,7 @@ func (x *GetActiveProfileRequest) String() string { func (*GetActiveProfileRequest) ProtoMessage() {} func (x *GetActiveProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[67] + mi := &file_daemon_proto_msgTypes[65] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4796,7 +4666,7 @@ func (x *GetActiveProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetActiveProfileRequest.ProtoReflect.Descriptor instead. func (*GetActiveProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{67} + return file_daemon_proto_rawDescGZIP(), []int{65} } type GetActiveProfileResponse struct { @@ -4809,7 +4679,7 @@ type GetActiveProfileResponse struct { func (x *GetActiveProfileResponse) Reset() { *x = GetActiveProfileResponse{} - mi := &file_daemon_proto_msgTypes[68] + mi := &file_daemon_proto_msgTypes[66] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4821,7 +4691,7 @@ func (x *GetActiveProfileResponse) String() string { func (*GetActiveProfileResponse) ProtoMessage() {} func (x *GetActiveProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[68] + mi := &file_daemon_proto_msgTypes[66] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4834,7 +4704,7 @@ func (x *GetActiveProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetActiveProfileResponse.ProtoReflect.Descriptor instead. func (*GetActiveProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{68} + return file_daemon_proto_rawDescGZIP(), []int{66} } func (x *GetActiveProfileResponse) GetProfileName() string { @@ -4861,7 +4731,7 @@ type LogoutRequest struct { func (x *LogoutRequest) Reset() { *x = LogoutRequest{} - mi := &file_daemon_proto_msgTypes[69] + mi := &file_daemon_proto_msgTypes[67] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4873,7 +4743,7 @@ func (x *LogoutRequest) String() string { func (*LogoutRequest) ProtoMessage() {} func (x *LogoutRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[69] + mi := &file_daemon_proto_msgTypes[67] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4886,7 +4756,7 @@ func (x *LogoutRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use LogoutRequest.ProtoReflect.Descriptor instead. func (*LogoutRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{69} + return file_daemon_proto_rawDescGZIP(), []int{67} } func (x *LogoutRequest) GetProfileName() string { @@ -4911,7 +4781,7 @@ type LogoutResponse struct { func (x *LogoutResponse) Reset() { *x = LogoutResponse{} - mi := &file_daemon_proto_msgTypes[70] + mi := &file_daemon_proto_msgTypes[68] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4923,7 +4793,7 @@ func (x *LogoutResponse) String() string { func (*LogoutResponse) ProtoMessage() {} func (x *LogoutResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[70] + mi := &file_daemon_proto_msgTypes[68] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4936,7 +4806,7 @@ func (x *LogoutResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use LogoutResponse.ProtoReflect.Descriptor instead. func (*LogoutResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{70} + return file_daemon_proto_rawDescGZIP(), []int{68} } type GetFeaturesRequest struct { @@ -4947,7 +4817,7 @@ type GetFeaturesRequest struct { func (x *GetFeaturesRequest) Reset() { *x = GetFeaturesRequest{} - mi := &file_daemon_proto_msgTypes[71] + mi := &file_daemon_proto_msgTypes[69] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4959,7 +4829,7 @@ func (x *GetFeaturesRequest) String() string { func (*GetFeaturesRequest) ProtoMessage() {} func (x *GetFeaturesRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[71] + mi := &file_daemon_proto_msgTypes[69] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4972,20 +4842,21 @@ func (x *GetFeaturesRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetFeaturesRequest.ProtoReflect.Descriptor instead. func (*GetFeaturesRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{71} + return file_daemon_proto_rawDescGZIP(), []int{69} } type GetFeaturesResponse struct { state protoimpl.MessageState `protogen:"open.v1"` DisableProfiles bool `protobuf:"varint,1,opt,name=disable_profiles,json=disableProfiles,proto3" json:"disable_profiles,omitempty"` DisableUpdateSettings bool `protobuf:"varint,2,opt,name=disable_update_settings,json=disableUpdateSettings,proto3" json:"disable_update_settings,omitempty"` + DisableNetworks bool `protobuf:"varint,3,opt,name=disable_networks,json=disableNetworks,proto3" json:"disable_networks,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetFeaturesResponse) Reset() { *x = GetFeaturesResponse{} - mi := &file_daemon_proto_msgTypes[72] + mi := &file_daemon_proto_msgTypes[70] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4997,7 +4868,7 @@ func (x *GetFeaturesResponse) String() string { func (*GetFeaturesResponse) ProtoMessage() {} func (x *GetFeaturesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[72] + mi := &file_daemon_proto_msgTypes[70] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5010,7 +4881,7 @@ func (x *GetFeaturesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetFeaturesResponse.ProtoReflect.Descriptor instead. func (*GetFeaturesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{72} + return file_daemon_proto_rawDescGZIP(), []int{70} } func (x *GetFeaturesResponse) GetDisableProfiles() bool { @@ -5027,6 +4898,13 @@ func (x *GetFeaturesResponse) GetDisableUpdateSettings() bool { return false } +func (x *GetFeaturesResponse) GetDisableNetworks() bool { + if x != nil { + return x.DisableNetworks + } + return false +} + type TriggerUpdateRequest struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -5035,7 +4913,7 @@ type TriggerUpdateRequest struct { func (x *TriggerUpdateRequest) Reset() { *x = TriggerUpdateRequest{} - mi := &file_daemon_proto_msgTypes[73] + mi := &file_daemon_proto_msgTypes[71] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5047,7 +4925,7 @@ func (x *TriggerUpdateRequest) String() string { func (*TriggerUpdateRequest) ProtoMessage() {} func (x *TriggerUpdateRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[73] + mi := &file_daemon_proto_msgTypes[71] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5060,7 +4938,7 @@ func (x *TriggerUpdateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use TriggerUpdateRequest.ProtoReflect.Descriptor instead. func (*TriggerUpdateRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{73} + return file_daemon_proto_rawDescGZIP(), []int{71} } type TriggerUpdateResponse struct { @@ -5073,7 +4951,7 @@ type TriggerUpdateResponse struct { func (x *TriggerUpdateResponse) Reset() { *x = TriggerUpdateResponse{} - mi := &file_daemon_proto_msgTypes[74] + mi := &file_daemon_proto_msgTypes[72] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5085,7 +4963,7 @@ func (x *TriggerUpdateResponse) String() string { func (*TriggerUpdateResponse) ProtoMessage() {} func (x *TriggerUpdateResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[74] + mi := &file_daemon_proto_msgTypes[72] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5098,7 +4976,7 @@ func (x *TriggerUpdateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use TriggerUpdateResponse.ProtoReflect.Descriptor instead. func (*TriggerUpdateResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{74} + return file_daemon_proto_rawDescGZIP(), []int{72} } func (x *TriggerUpdateResponse) GetSuccess() bool { @@ -5126,7 +5004,7 @@ type GetPeerSSHHostKeyRequest struct { func (x *GetPeerSSHHostKeyRequest) Reset() { *x = GetPeerSSHHostKeyRequest{} - mi := &file_daemon_proto_msgTypes[75] + mi := &file_daemon_proto_msgTypes[73] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5138,7 +5016,7 @@ func (x *GetPeerSSHHostKeyRequest) String() string { func (*GetPeerSSHHostKeyRequest) ProtoMessage() {} func (x *GetPeerSSHHostKeyRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[75] + mi := &file_daemon_proto_msgTypes[73] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5151,7 +5029,7 @@ func (x *GetPeerSSHHostKeyRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetPeerSSHHostKeyRequest.ProtoReflect.Descriptor instead. func (*GetPeerSSHHostKeyRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{75} + return file_daemon_proto_rawDescGZIP(), []int{73} } func (x *GetPeerSSHHostKeyRequest) GetPeerAddress() string { @@ -5178,7 +5056,7 @@ type GetPeerSSHHostKeyResponse struct { func (x *GetPeerSSHHostKeyResponse) Reset() { *x = GetPeerSSHHostKeyResponse{} - mi := &file_daemon_proto_msgTypes[76] + mi := &file_daemon_proto_msgTypes[74] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5190,7 +5068,7 @@ func (x *GetPeerSSHHostKeyResponse) String() string { func (*GetPeerSSHHostKeyResponse) ProtoMessage() {} func (x *GetPeerSSHHostKeyResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[76] + mi := &file_daemon_proto_msgTypes[74] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5203,7 +5081,7 @@ func (x *GetPeerSSHHostKeyResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetPeerSSHHostKeyResponse.ProtoReflect.Descriptor instead. func (*GetPeerSSHHostKeyResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{76} + return file_daemon_proto_rawDescGZIP(), []int{74} } func (x *GetPeerSSHHostKeyResponse) GetSshHostKey() []byte { @@ -5245,7 +5123,7 @@ type RequestJWTAuthRequest struct { func (x *RequestJWTAuthRequest) Reset() { *x = RequestJWTAuthRequest{} - mi := &file_daemon_proto_msgTypes[77] + mi := &file_daemon_proto_msgTypes[75] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5257,7 +5135,7 @@ func (x *RequestJWTAuthRequest) String() string { func (*RequestJWTAuthRequest) ProtoMessage() {} func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[77] + mi := &file_daemon_proto_msgTypes[75] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5270,7 +5148,7 @@ func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use RequestJWTAuthRequest.ProtoReflect.Descriptor instead. func (*RequestJWTAuthRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{77} + return file_daemon_proto_rawDescGZIP(), []int{75} } func (x *RequestJWTAuthRequest) GetHint() string { @@ -5303,7 +5181,7 @@ type RequestJWTAuthResponse struct { func (x *RequestJWTAuthResponse) Reset() { *x = RequestJWTAuthResponse{} - mi := &file_daemon_proto_msgTypes[78] + mi := &file_daemon_proto_msgTypes[76] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5315,7 +5193,7 @@ func (x *RequestJWTAuthResponse) String() string { func (*RequestJWTAuthResponse) ProtoMessage() {} func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[78] + mi := &file_daemon_proto_msgTypes[76] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5328,7 +5206,7 @@ func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RequestJWTAuthResponse.ProtoReflect.Descriptor instead. func (*RequestJWTAuthResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{78} + return file_daemon_proto_rawDescGZIP(), []int{76} } func (x *RequestJWTAuthResponse) GetVerificationURI() string { @@ -5393,7 +5271,7 @@ type WaitJWTTokenRequest struct { func (x *WaitJWTTokenRequest) Reset() { *x = WaitJWTTokenRequest{} - mi := &file_daemon_proto_msgTypes[79] + mi := &file_daemon_proto_msgTypes[77] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5405,7 +5283,7 @@ func (x *WaitJWTTokenRequest) String() string { func (*WaitJWTTokenRequest) ProtoMessage() {} func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[79] + mi := &file_daemon_proto_msgTypes[77] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5418,7 +5296,7 @@ func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitJWTTokenRequest.ProtoReflect.Descriptor instead. func (*WaitJWTTokenRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{79} + return file_daemon_proto_rawDescGZIP(), []int{77} } func (x *WaitJWTTokenRequest) GetDeviceCode() string { @@ -5450,7 +5328,7 @@ type WaitJWTTokenResponse struct { func (x *WaitJWTTokenResponse) Reset() { *x = WaitJWTTokenResponse{} - mi := &file_daemon_proto_msgTypes[80] + mi := &file_daemon_proto_msgTypes[78] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5462,7 +5340,7 @@ func (x *WaitJWTTokenResponse) String() string { func (*WaitJWTTokenResponse) ProtoMessage() {} func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[80] + mi := &file_daemon_proto_msgTypes[78] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5475,7 +5353,7 @@ func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitJWTTokenResponse.ProtoReflect.Descriptor instead. func (*WaitJWTTokenResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{80} + return file_daemon_proto_rawDescGZIP(), []int{78} } func (x *WaitJWTTokenResponse) GetToken() string { @@ -5508,7 +5386,7 @@ type StartCPUProfileRequest struct { func (x *StartCPUProfileRequest) Reset() { *x = StartCPUProfileRequest{} - mi := &file_daemon_proto_msgTypes[81] + mi := &file_daemon_proto_msgTypes[79] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5520,7 +5398,7 @@ func (x *StartCPUProfileRequest) String() string { func (*StartCPUProfileRequest) ProtoMessage() {} func (x *StartCPUProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[81] + mi := &file_daemon_proto_msgTypes[79] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5533,7 +5411,7 @@ func (x *StartCPUProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use StartCPUProfileRequest.ProtoReflect.Descriptor instead. func (*StartCPUProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{81} + return file_daemon_proto_rawDescGZIP(), []int{79} } // StartCPUProfileResponse confirms CPU profiling has started @@ -5545,7 +5423,7 @@ type StartCPUProfileResponse struct { func (x *StartCPUProfileResponse) Reset() { *x = StartCPUProfileResponse{} - mi := &file_daemon_proto_msgTypes[82] + mi := &file_daemon_proto_msgTypes[80] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5557,7 +5435,7 @@ func (x *StartCPUProfileResponse) String() string { func (*StartCPUProfileResponse) ProtoMessage() {} func (x *StartCPUProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[82] + mi := &file_daemon_proto_msgTypes[80] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5570,7 +5448,7 @@ func (x *StartCPUProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use StartCPUProfileResponse.ProtoReflect.Descriptor instead. func (*StartCPUProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{82} + return file_daemon_proto_rawDescGZIP(), []int{80} } // StopCPUProfileRequest for stopping CPU profiling @@ -5582,7 +5460,7 @@ type StopCPUProfileRequest struct { func (x *StopCPUProfileRequest) Reset() { *x = StopCPUProfileRequest{} - mi := &file_daemon_proto_msgTypes[83] + mi := &file_daemon_proto_msgTypes[81] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5594,7 +5472,7 @@ func (x *StopCPUProfileRequest) String() string { func (*StopCPUProfileRequest) ProtoMessage() {} func (x *StopCPUProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[83] + mi := &file_daemon_proto_msgTypes[81] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5607,7 +5485,7 @@ func (x *StopCPUProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use StopCPUProfileRequest.ProtoReflect.Descriptor instead. func (*StopCPUProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{83} + return file_daemon_proto_rawDescGZIP(), []int{81} } // StopCPUProfileResponse confirms CPU profiling has stopped @@ -5619,7 +5497,7 @@ type StopCPUProfileResponse struct { func (x *StopCPUProfileResponse) Reset() { *x = StopCPUProfileResponse{} - mi := &file_daemon_proto_msgTypes[84] + mi := &file_daemon_proto_msgTypes[82] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5631,7 +5509,7 @@ func (x *StopCPUProfileResponse) String() string { func (*StopCPUProfileResponse) ProtoMessage() {} func (x *StopCPUProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[84] + mi := &file_daemon_proto_msgTypes[82] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5644,7 +5522,7 @@ func (x *StopCPUProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use StopCPUProfileResponse.ProtoReflect.Descriptor instead. func (*StopCPUProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{84} + return file_daemon_proto_rawDescGZIP(), []int{82} } type InstallerResultRequest struct { @@ -5655,7 +5533,7 @@ type InstallerResultRequest struct { func (x *InstallerResultRequest) Reset() { *x = InstallerResultRequest{} - mi := &file_daemon_proto_msgTypes[85] + mi := &file_daemon_proto_msgTypes[83] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5667,7 +5545,7 @@ func (x *InstallerResultRequest) String() string { func (*InstallerResultRequest) ProtoMessage() {} func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[85] + mi := &file_daemon_proto_msgTypes[83] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5680,7 +5558,7 @@ func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead. func (*InstallerResultRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{85} + return file_daemon_proto_rawDescGZIP(), []int{83} } type InstallerResultResponse struct { @@ -5693,7 +5571,7 @@ type InstallerResultResponse struct { func (x *InstallerResultResponse) Reset() { *x = InstallerResultResponse{} - mi := &file_daemon_proto_msgTypes[86] + mi := &file_daemon_proto_msgTypes[84] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5705,7 +5583,7 @@ func (x *InstallerResultResponse) String() string { func (*InstallerResultResponse) ProtoMessage() {} func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[86] + mi := &file_daemon_proto_msgTypes[84] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5718,7 +5596,7 @@ func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead. func (*InstallerResultResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{86} + return file_daemon_proto_rawDescGZIP(), []int{84} } func (x *InstallerResultResponse) GetSuccess() bool { @@ -5751,7 +5629,7 @@ type ExposeServiceRequest struct { func (x *ExposeServiceRequest) Reset() { *x = ExposeServiceRequest{} - mi := &file_daemon_proto_msgTypes[87] + mi := &file_daemon_proto_msgTypes[85] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5763,7 +5641,7 @@ func (x *ExposeServiceRequest) String() string { func (*ExposeServiceRequest) ProtoMessage() {} func (x *ExposeServiceRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[87] + mi := &file_daemon_proto_msgTypes[85] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5776,7 +5654,7 @@ func (x *ExposeServiceRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ExposeServiceRequest.ProtoReflect.Descriptor instead. func (*ExposeServiceRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{87} + return file_daemon_proto_rawDescGZIP(), []int{85} } func (x *ExposeServiceRequest) GetPort() uint32 { @@ -5847,7 +5725,7 @@ type ExposeServiceEvent struct { func (x *ExposeServiceEvent) Reset() { *x = ExposeServiceEvent{} - mi := &file_daemon_proto_msgTypes[88] + mi := &file_daemon_proto_msgTypes[86] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5859,7 +5737,7 @@ func (x *ExposeServiceEvent) String() string { func (*ExposeServiceEvent) ProtoMessage() {} func (x *ExposeServiceEvent) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[88] + mi := &file_daemon_proto_msgTypes[86] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5872,7 +5750,7 @@ func (x *ExposeServiceEvent) ProtoReflect() protoreflect.Message { // Deprecated: Use ExposeServiceEvent.ProtoReflect.Descriptor instead. func (*ExposeServiceEvent) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{88} + return file_daemon_proto_rawDescGZIP(), []int{86} } func (x *ExposeServiceEvent) GetEvent() isExposeServiceEvent_Event { @@ -5913,7 +5791,7 @@ type ExposeServiceReady struct { func (x *ExposeServiceReady) Reset() { *x = ExposeServiceReady{} - mi := &file_daemon_proto_msgTypes[89] + mi := &file_daemon_proto_msgTypes[87] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5925,7 +5803,7 @@ func (x *ExposeServiceReady) String() string { func (*ExposeServiceReady) ProtoMessage() {} func (x *ExposeServiceReady) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[89] + mi := &file_daemon_proto_msgTypes[87] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5938,7 +5816,7 @@ func (x *ExposeServiceReady) ProtoReflect() protoreflect.Message { // Deprecated: Use ExposeServiceReady.ProtoReflect.Descriptor instead. func (*ExposeServiceReady) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{89} + return file_daemon_proto_rawDescGZIP(), []int{87} } func (x *ExposeServiceReady) GetServiceName() string { @@ -5969,6 +5847,288 @@ func (x *ExposeServiceReady) GetPortAutoAssigned() bool { return false } +type StartCaptureRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + TextOutput bool `protobuf:"varint,1,opt,name=text_output,json=textOutput,proto3" json:"text_output,omitempty"` + SnapLen uint32 `protobuf:"varint,2,opt,name=snap_len,json=snapLen,proto3" json:"snap_len,omitempty"` + Duration *durationpb.Duration `protobuf:"bytes,3,opt,name=duration,proto3" json:"duration,omitempty"` + FilterExpr string `protobuf:"bytes,4,opt,name=filter_expr,json=filterExpr,proto3" json:"filter_expr,omitempty"` + Verbose bool `protobuf:"varint,5,opt,name=verbose,proto3" json:"verbose,omitempty"` + Ascii bool `protobuf:"varint,6,opt,name=ascii,proto3" json:"ascii,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StartCaptureRequest) Reset() { + *x = StartCaptureRequest{} + mi := &file_daemon_proto_msgTypes[88] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StartCaptureRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StartCaptureRequest) ProtoMessage() {} + +func (x *StartCaptureRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[88] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StartCaptureRequest.ProtoReflect.Descriptor instead. +func (*StartCaptureRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{88} +} + +func (x *StartCaptureRequest) GetTextOutput() bool { + if x != nil { + return x.TextOutput + } + return false +} + +func (x *StartCaptureRequest) GetSnapLen() uint32 { + if x != nil { + return x.SnapLen + } + return 0 +} + +func (x *StartCaptureRequest) GetDuration() *durationpb.Duration { + if x != nil { + return x.Duration + } + return nil +} + +func (x *StartCaptureRequest) GetFilterExpr() string { + if x != nil { + return x.FilterExpr + } + return "" +} + +func (x *StartCaptureRequest) GetVerbose() bool { + if x != nil { + return x.Verbose + } + return false +} + +func (x *StartCaptureRequest) GetAscii() bool { + if x != nil { + return x.Ascii + } + return false +} + +type CapturePacket struct { + state protoimpl.MessageState `protogen:"open.v1"` + Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *CapturePacket) Reset() { + *x = CapturePacket{} + mi := &file_daemon_proto_msgTypes[89] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *CapturePacket) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CapturePacket) ProtoMessage() {} + +func (x *CapturePacket) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[89] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CapturePacket.ProtoReflect.Descriptor instead. +func (*CapturePacket) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{89} +} + +func (x *CapturePacket) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + +type StartBundleCaptureRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // timeout auto-stops the capture after this duration. + // Clamped to a server-side maximum (10 minutes). Zero or unset defaults to the maximum. + Timeout *durationpb.Duration `protobuf:"bytes,1,opt,name=timeout,proto3" json:"timeout,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StartBundleCaptureRequest) Reset() { + *x = StartBundleCaptureRequest{} + mi := &file_daemon_proto_msgTypes[90] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StartBundleCaptureRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StartBundleCaptureRequest) ProtoMessage() {} + +func (x *StartBundleCaptureRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[90] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StartBundleCaptureRequest.ProtoReflect.Descriptor instead. +func (*StartBundleCaptureRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{90} +} + +func (x *StartBundleCaptureRequest) GetTimeout() *durationpb.Duration { + if x != nil { + return x.Timeout + } + return nil +} + +type StartBundleCaptureResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StartBundleCaptureResponse) Reset() { + *x = StartBundleCaptureResponse{} + mi := &file_daemon_proto_msgTypes[91] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StartBundleCaptureResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StartBundleCaptureResponse) ProtoMessage() {} + +func (x *StartBundleCaptureResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[91] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StartBundleCaptureResponse.ProtoReflect.Descriptor instead. +func (*StartBundleCaptureResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{91} +} + +type StopBundleCaptureRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StopBundleCaptureRequest) Reset() { + *x = StopBundleCaptureRequest{} + mi := &file_daemon_proto_msgTypes[92] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StopBundleCaptureRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StopBundleCaptureRequest) ProtoMessage() {} + +func (x *StopBundleCaptureRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[92] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StopBundleCaptureRequest.ProtoReflect.Descriptor instead. +func (*StopBundleCaptureRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{92} +} + +type StopBundleCaptureResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StopBundleCaptureResponse) Reset() { + *x = StopBundleCaptureResponse{} + mi := &file_daemon_proto_msgTypes[93] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StopBundleCaptureResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StopBundleCaptureResponse) ProtoMessage() {} + +func (x *StopBundleCaptureResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[93] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StopBundleCaptureResponse.ProtoReflect.Descriptor instead. +func (*StopBundleCaptureResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{93} +} + type PortInfo_Range struct { state protoimpl.MessageState `protogen:"open.v1"` Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"` @@ -5979,7 +6139,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} - mi := &file_daemon_proto_msgTypes[91] + mi := &file_daemon_proto_msgTypes[95] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5991,7 +6151,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[91] + mi := &file_daemon_proto_msgTypes[95] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -6004,7 +6164,7 @@ func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo_Range.ProtoReflect.Descriptor instead. func (*PortInfo_Range) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{30, 0} + return file_daemon_proto_rawDescGZIP(), []int{28, 0} } func (x *PortInfo_Range) GetStart() uint32 { @@ -6026,15 +6186,7 @@ var File_daemon_proto protoreflect.FileDescriptor const file_daemon_proto_rawDesc = "" + "\n" + "\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" + - "\fEmptyRequest\"\x7f\n" + - "\x12OSLifecycleRequest\x128\n" + - "\x04type\x18\x01 \x01(\x0e2$.daemon.OSLifecycleRequest.CycleTypeR\x04type\"/\n" + - "\tCycleType\x12\v\n" + - "\aUNKNOWN\x10\x00\x12\t\n" + - "\x05SLEEP\x10\x01\x12\n" + - "\n" + - "\x06WAKEUP\x10\x02\"\x15\n" + - "\x13OSLifecycleResponse\"\xb6\x12\n" + + "\fEmptyRequest\"\xb6\x12\n" + "\fLoginRequest\x12\x1a\n" + "\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" + "\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" + @@ -6472,10 +6624,11 @@ const file_daemon_proto_rawDesc = "" + "\f_profileNameB\v\n" + "\t_username\"\x10\n" + "\x0eLogoutResponse\"\x14\n" + - "\x12GetFeaturesRequest\"x\n" + + "\x12GetFeaturesRequest\"\xa3\x01\n" + "\x13GetFeaturesResponse\x12)\n" + "\x10disable_profiles\x18\x01 \x01(\bR\x0fdisableProfiles\x126\n" + - "\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\"\x16\n" + + "\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\x12)\n" + + "\x10disable_networks\x18\x03 \x01(\bR\x0fdisableNetworks\"\x16\n" + "\x14TriggerUpdateRequest\"M\n" + "\x15TriggerUpdateResponse\x12\x18\n" + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" + @@ -6539,7 +6692,23 @@ const file_daemon_proto_rawDesc = "" + "\vservice_url\x18\x02 \x01(\tR\n" + "serviceUrl\x12\x16\n" + "\x06domain\x18\x03 \x01(\tR\x06domain\x12,\n" + - "\x12port_auto_assigned\x18\x04 \x01(\bR\x10portAutoAssigned*b\n" + + "\x12port_auto_assigned\x18\x04 \x01(\bR\x10portAutoAssigned\"\xd9\x01\n" + + "\x13StartCaptureRequest\x12\x1f\n" + + "\vtext_output\x18\x01 \x01(\bR\n" + + "textOutput\x12\x19\n" + + "\bsnap_len\x18\x02 \x01(\rR\asnapLen\x125\n" + + "\bduration\x18\x03 \x01(\v2\x19.google.protobuf.DurationR\bduration\x12\x1f\n" + + "\vfilter_expr\x18\x04 \x01(\tR\n" + + "filterExpr\x12\x18\n" + + "\averbose\x18\x05 \x01(\bR\averbose\x12\x14\n" + + "\x05ascii\x18\x06 \x01(\bR\x05ascii\"#\n" + + "\rCapturePacket\x12\x12\n" + + "\x04data\x18\x01 \x01(\fR\x04data\"P\n" + + "\x19StartBundleCaptureRequest\x123\n" + + "\atimeout\x18\x01 \x01(\v2\x19.google.protobuf.DurationR\atimeout\"\x1c\n" + + "\x1aStartBundleCaptureResponse\"\x1a\n" + + "\x18StopBundleCaptureRequest\"\x1b\n" + + "\x19StopBundleCaptureResponse*b\n" + "\bLogLevel\x12\v\n" + "\aUNKNOWN\x10\x00\x12\t\n" + "\x05PANIC\x10\x01\x12\t\n" + @@ -6557,7 +6726,7 @@ const file_daemon_proto_rawDesc = "" + "\n" + "EXPOSE_UDP\x10\x03\x12\x0e\n" + "\n" + - "EXPOSE_TLS\x10\x042\xfc\x15\n" + + "EXPOSE_TLS\x10\x042\xaf\x17\n" + "\rDaemonService\x126\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + @@ -6578,7 +6747,10 @@ const file_daemon_proto_rawDesc = "" + "CleanState\x12\x19.daemon.CleanStateRequest\x1a\x1a.daemon.CleanStateResponse\"\x00\x12H\n" + "\vDeleteState\x12\x1a.daemon.DeleteStateRequest\x1a\x1b.daemon.DeleteStateResponse\"\x00\x12u\n" + "\x1aSetSyncResponsePersistence\x12).daemon.SetSyncResponsePersistenceRequest\x1a*.daemon.SetSyncResponsePersistenceResponse\"\x00\x12H\n" + - "\vTracePacket\x12\x1a.daemon.TracePacketRequest\x1a\x1b.daemon.TracePacketResponse\"\x00\x12D\n" + + "\vTracePacket\x12\x1a.daemon.TracePacketRequest\x1a\x1b.daemon.TracePacketResponse\"\x00\x12F\n" + + "\fStartCapture\x12\x1b.daemon.StartCaptureRequest\x1a\x15.daemon.CapturePacket\"\x000\x01\x12]\n" + + "\x12StartBundleCapture\x12!.daemon.StartBundleCaptureRequest\x1a\".daemon.StartBundleCaptureResponse\"\x00\x12Z\n" + + "\x11StopBundleCapture\x12 .daemon.StopBundleCaptureRequest\x1a!.daemon.StopBundleCaptureResponse\"\x00\x12D\n" + "\x0fSubscribeEvents\x12\x18.daemon.SubscribeRequest\x1a\x13.daemon.SystemEvent\"\x000\x01\x12B\n" + "\tGetEvents\x12\x18.daemon.GetEventsRequest\x1a\x19.daemon.GetEventsResponse\"\x00\x12N\n" + "\rSwitchProfile\x12\x1c.daemon.SwitchProfileRequest\x1a\x1d.daemon.SwitchProfileResponse\"\x00\x12B\n" + @@ -6595,8 +6767,7 @@ const file_daemon_proto_rawDesc = "" + "\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" + "\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12T\n" + "\x0fStartCPUProfile\x12\x1e.daemon.StartCPUProfileRequest\x1a\x1f.daemon.StartCPUProfileResponse\"\x00\x12Q\n" + - "\x0eStopCPUProfile\x12\x1d.daemon.StopCPUProfileRequest\x1a\x1e.daemon.StopCPUProfileResponse\"\x00\x12N\n" + - "\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" + + "\x0eStopCPUProfile\x12\x1d.daemon.StopCPUProfileRequest\x1a\x1e.daemon.StopCPUProfileResponse\"\x00\x12W\n" + "\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00\x12M\n" + "\rExposeService\x12\x1c.daemon.ExposeServiceRequest\x1a\x1a.daemon.ExposeServiceEvent\"\x000\x01B\bZ\x06/protob\x06proto3" @@ -6612,226 +6783,234 @@ func file_daemon_proto_rawDescGZIP() []byte { return file_daemon_proto_rawDescData } -var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 5) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 93) +var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 97) var file_daemon_proto_goTypes = []any{ (LogLevel)(0), // 0: daemon.LogLevel (ExposeProtocol)(0), // 1: daemon.ExposeProtocol - (OSLifecycleRequest_CycleType)(0), // 2: daemon.OSLifecycleRequest.CycleType - (SystemEvent_Severity)(0), // 3: daemon.SystemEvent.Severity - (SystemEvent_Category)(0), // 4: daemon.SystemEvent.Category - (*EmptyRequest)(nil), // 5: daemon.EmptyRequest - (*OSLifecycleRequest)(nil), // 6: daemon.OSLifecycleRequest - (*OSLifecycleResponse)(nil), // 7: daemon.OSLifecycleResponse - (*LoginRequest)(nil), // 8: daemon.LoginRequest - (*LoginResponse)(nil), // 9: daemon.LoginResponse - (*WaitSSOLoginRequest)(nil), // 10: daemon.WaitSSOLoginRequest - (*WaitSSOLoginResponse)(nil), // 11: daemon.WaitSSOLoginResponse - (*UpRequest)(nil), // 12: daemon.UpRequest - (*UpResponse)(nil), // 13: daemon.UpResponse - (*StatusRequest)(nil), // 14: daemon.StatusRequest - (*StatusResponse)(nil), // 15: daemon.StatusResponse - (*DownRequest)(nil), // 16: daemon.DownRequest - (*DownResponse)(nil), // 17: daemon.DownResponse - (*GetConfigRequest)(nil), // 18: daemon.GetConfigRequest - (*GetConfigResponse)(nil), // 19: daemon.GetConfigResponse - (*PeerState)(nil), // 20: daemon.PeerState - (*LocalPeerState)(nil), // 21: daemon.LocalPeerState - (*SignalState)(nil), // 22: daemon.SignalState - (*ManagementState)(nil), // 23: daemon.ManagementState - (*RelayState)(nil), // 24: daemon.RelayState - (*NSGroupState)(nil), // 25: daemon.NSGroupState - (*SSHSessionInfo)(nil), // 26: daemon.SSHSessionInfo - (*SSHServerState)(nil), // 27: daemon.SSHServerState - (*FullStatus)(nil), // 28: daemon.FullStatus - (*ListNetworksRequest)(nil), // 29: daemon.ListNetworksRequest - (*ListNetworksResponse)(nil), // 30: daemon.ListNetworksResponse - (*SelectNetworksRequest)(nil), // 31: daemon.SelectNetworksRequest - (*SelectNetworksResponse)(nil), // 32: daemon.SelectNetworksResponse - (*IPList)(nil), // 33: daemon.IPList - (*Network)(nil), // 34: daemon.Network - (*PortInfo)(nil), // 35: daemon.PortInfo - (*ForwardingRule)(nil), // 36: daemon.ForwardingRule - (*ForwardingRulesResponse)(nil), // 37: daemon.ForwardingRulesResponse - (*DebugBundleRequest)(nil), // 38: daemon.DebugBundleRequest - (*DebugBundleResponse)(nil), // 39: daemon.DebugBundleResponse - (*GetLogLevelRequest)(nil), // 40: daemon.GetLogLevelRequest - (*GetLogLevelResponse)(nil), // 41: daemon.GetLogLevelResponse - (*SetLogLevelRequest)(nil), // 42: daemon.SetLogLevelRequest - (*SetLogLevelResponse)(nil), // 43: daemon.SetLogLevelResponse - (*State)(nil), // 44: daemon.State - (*ListStatesRequest)(nil), // 45: daemon.ListStatesRequest - (*ListStatesResponse)(nil), // 46: daemon.ListStatesResponse - (*CleanStateRequest)(nil), // 47: daemon.CleanStateRequest - (*CleanStateResponse)(nil), // 48: daemon.CleanStateResponse - (*DeleteStateRequest)(nil), // 49: daemon.DeleteStateRequest - (*DeleteStateResponse)(nil), // 50: daemon.DeleteStateResponse - (*SetSyncResponsePersistenceRequest)(nil), // 51: daemon.SetSyncResponsePersistenceRequest - (*SetSyncResponsePersistenceResponse)(nil), // 52: daemon.SetSyncResponsePersistenceResponse - (*TCPFlags)(nil), // 53: daemon.TCPFlags - (*TracePacketRequest)(nil), // 54: daemon.TracePacketRequest - (*TraceStage)(nil), // 55: daemon.TraceStage - (*TracePacketResponse)(nil), // 56: daemon.TracePacketResponse - (*SubscribeRequest)(nil), // 57: daemon.SubscribeRequest - (*SystemEvent)(nil), // 58: daemon.SystemEvent - (*GetEventsRequest)(nil), // 59: daemon.GetEventsRequest - (*GetEventsResponse)(nil), // 60: daemon.GetEventsResponse - (*SwitchProfileRequest)(nil), // 61: daemon.SwitchProfileRequest - (*SwitchProfileResponse)(nil), // 62: daemon.SwitchProfileResponse - (*SetConfigRequest)(nil), // 63: daemon.SetConfigRequest - (*SetConfigResponse)(nil), // 64: daemon.SetConfigResponse - (*AddProfileRequest)(nil), // 65: daemon.AddProfileRequest - (*AddProfileResponse)(nil), // 66: daemon.AddProfileResponse - (*RemoveProfileRequest)(nil), // 67: daemon.RemoveProfileRequest - (*RemoveProfileResponse)(nil), // 68: daemon.RemoveProfileResponse - (*ListProfilesRequest)(nil), // 69: daemon.ListProfilesRequest - (*ListProfilesResponse)(nil), // 70: daemon.ListProfilesResponse - (*Profile)(nil), // 71: daemon.Profile - (*GetActiveProfileRequest)(nil), // 72: daemon.GetActiveProfileRequest - (*GetActiveProfileResponse)(nil), // 73: daemon.GetActiveProfileResponse - (*LogoutRequest)(nil), // 74: daemon.LogoutRequest - (*LogoutResponse)(nil), // 75: daemon.LogoutResponse - (*GetFeaturesRequest)(nil), // 76: daemon.GetFeaturesRequest - (*GetFeaturesResponse)(nil), // 77: daemon.GetFeaturesResponse - (*TriggerUpdateRequest)(nil), // 78: daemon.TriggerUpdateRequest - (*TriggerUpdateResponse)(nil), // 79: daemon.TriggerUpdateResponse - (*GetPeerSSHHostKeyRequest)(nil), // 80: daemon.GetPeerSSHHostKeyRequest - (*GetPeerSSHHostKeyResponse)(nil), // 81: daemon.GetPeerSSHHostKeyResponse - (*RequestJWTAuthRequest)(nil), // 82: daemon.RequestJWTAuthRequest - (*RequestJWTAuthResponse)(nil), // 83: daemon.RequestJWTAuthResponse - (*WaitJWTTokenRequest)(nil), // 84: daemon.WaitJWTTokenRequest - (*WaitJWTTokenResponse)(nil), // 85: daemon.WaitJWTTokenResponse - (*StartCPUProfileRequest)(nil), // 86: daemon.StartCPUProfileRequest - (*StartCPUProfileResponse)(nil), // 87: daemon.StartCPUProfileResponse - (*StopCPUProfileRequest)(nil), // 88: daemon.StopCPUProfileRequest - (*StopCPUProfileResponse)(nil), // 89: daemon.StopCPUProfileResponse - (*InstallerResultRequest)(nil), // 90: daemon.InstallerResultRequest - (*InstallerResultResponse)(nil), // 91: daemon.InstallerResultResponse - (*ExposeServiceRequest)(nil), // 92: daemon.ExposeServiceRequest - (*ExposeServiceEvent)(nil), // 93: daemon.ExposeServiceEvent - (*ExposeServiceReady)(nil), // 94: daemon.ExposeServiceReady - nil, // 95: daemon.Network.ResolvedIPsEntry - (*PortInfo_Range)(nil), // 96: daemon.PortInfo.Range - nil, // 97: daemon.SystemEvent.MetadataEntry - (*durationpb.Duration)(nil), // 98: google.protobuf.Duration - (*timestamppb.Timestamp)(nil), // 99: google.protobuf.Timestamp + (SystemEvent_Severity)(0), // 2: daemon.SystemEvent.Severity + (SystemEvent_Category)(0), // 3: daemon.SystemEvent.Category + (*EmptyRequest)(nil), // 4: daemon.EmptyRequest + (*LoginRequest)(nil), // 5: daemon.LoginRequest + (*LoginResponse)(nil), // 6: daemon.LoginResponse + (*WaitSSOLoginRequest)(nil), // 7: daemon.WaitSSOLoginRequest + (*WaitSSOLoginResponse)(nil), // 8: daemon.WaitSSOLoginResponse + (*UpRequest)(nil), // 9: daemon.UpRequest + (*UpResponse)(nil), // 10: daemon.UpResponse + (*StatusRequest)(nil), // 11: daemon.StatusRequest + (*StatusResponse)(nil), // 12: daemon.StatusResponse + (*DownRequest)(nil), // 13: daemon.DownRequest + (*DownResponse)(nil), // 14: daemon.DownResponse + (*GetConfigRequest)(nil), // 15: daemon.GetConfigRequest + (*GetConfigResponse)(nil), // 16: daemon.GetConfigResponse + (*PeerState)(nil), // 17: daemon.PeerState + (*LocalPeerState)(nil), // 18: daemon.LocalPeerState + (*SignalState)(nil), // 19: daemon.SignalState + (*ManagementState)(nil), // 20: daemon.ManagementState + (*RelayState)(nil), // 21: daemon.RelayState + (*NSGroupState)(nil), // 22: daemon.NSGroupState + (*SSHSessionInfo)(nil), // 23: daemon.SSHSessionInfo + (*SSHServerState)(nil), // 24: daemon.SSHServerState + (*FullStatus)(nil), // 25: daemon.FullStatus + (*ListNetworksRequest)(nil), // 26: daemon.ListNetworksRequest + (*ListNetworksResponse)(nil), // 27: daemon.ListNetworksResponse + (*SelectNetworksRequest)(nil), // 28: daemon.SelectNetworksRequest + (*SelectNetworksResponse)(nil), // 29: daemon.SelectNetworksResponse + (*IPList)(nil), // 30: daemon.IPList + (*Network)(nil), // 31: daemon.Network + (*PortInfo)(nil), // 32: daemon.PortInfo + (*ForwardingRule)(nil), // 33: daemon.ForwardingRule + (*ForwardingRulesResponse)(nil), // 34: daemon.ForwardingRulesResponse + (*DebugBundleRequest)(nil), // 35: daemon.DebugBundleRequest + (*DebugBundleResponse)(nil), // 36: daemon.DebugBundleResponse + (*GetLogLevelRequest)(nil), // 37: daemon.GetLogLevelRequest + (*GetLogLevelResponse)(nil), // 38: daemon.GetLogLevelResponse + (*SetLogLevelRequest)(nil), // 39: daemon.SetLogLevelRequest + (*SetLogLevelResponse)(nil), // 40: daemon.SetLogLevelResponse + (*State)(nil), // 41: daemon.State + (*ListStatesRequest)(nil), // 42: daemon.ListStatesRequest + (*ListStatesResponse)(nil), // 43: daemon.ListStatesResponse + (*CleanStateRequest)(nil), // 44: daemon.CleanStateRequest + (*CleanStateResponse)(nil), // 45: daemon.CleanStateResponse + (*DeleteStateRequest)(nil), // 46: daemon.DeleteStateRequest + (*DeleteStateResponse)(nil), // 47: daemon.DeleteStateResponse + (*SetSyncResponsePersistenceRequest)(nil), // 48: daemon.SetSyncResponsePersistenceRequest + (*SetSyncResponsePersistenceResponse)(nil), // 49: daemon.SetSyncResponsePersistenceResponse + (*TCPFlags)(nil), // 50: daemon.TCPFlags + (*TracePacketRequest)(nil), // 51: daemon.TracePacketRequest + (*TraceStage)(nil), // 52: daemon.TraceStage + (*TracePacketResponse)(nil), // 53: daemon.TracePacketResponse + (*SubscribeRequest)(nil), // 54: daemon.SubscribeRequest + (*SystemEvent)(nil), // 55: daemon.SystemEvent + (*GetEventsRequest)(nil), // 56: daemon.GetEventsRequest + (*GetEventsResponse)(nil), // 57: daemon.GetEventsResponse + (*SwitchProfileRequest)(nil), // 58: daemon.SwitchProfileRequest + (*SwitchProfileResponse)(nil), // 59: daemon.SwitchProfileResponse + (*SetConfigRequest)(nil), // 60: daemon.SetConfigRequest + (*SetConfigResponse)(nil), // 61: daemon.SetConfigResponse + (*AddProfileRequest)(nil), // 62: daemon.AddProfileRequest + (*AddProfileResponse)(nil), // 63: daemon.AddProfileResponse + (*RemoveProfileRequest)(nil), // 64: daemon.RemoveProfileRequest + (*RemoveProfileResponse)(nil), // 65: daemon.RemoveProfileResponse + (*ListProfilesRequest)(nil), // 66: daemon.ListProfilesRequest + (*ListProfilesResponse)(nil), // 67: daemon.ListProfilesResponse + (*Profile)(nil), // 68: daemon.Profile + (*GetActiveProfileRequest)(nil), // 69: daemon.GetActiveProfileRequest + (*GetActiveProfileResponse)(nil), // 70: daemon.GetActiveProfileResponse + (*LogoutRequest)(nil), // 71: daemon.LogoutRequest + (*LogoutResponse)(nil), // 72: daemon.LogoutResponse + (*GetFeaturesRequest)(nil), // 73: daemon.GetFeaturesRequest + (*GetFeaturesResponse)(nil), // 74: daemon.GetFeaturesResponse + (*TriggerUpdateRequest)(nil), // 75: daemon.TriggerUpdateRequest + (*TriggerUpdateResponse)(nil), // 76: daemon.TriggerUpdateResponse + (*GetPeerSSHHostKeyRequest)(nil), // 77: daemon.GetPeerSSHHostKeyRequest + (*GetPeerSSHHostKeyResponse)(nil), // 78: daemon.GetPeerSSHHostKeyResponse + (*RequestJWTAuthRequest)(nil), // 79: daemon.RequestJWTAuthRequest + (*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse + (*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest + (*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse + (*StartCPUProfileRequest)(nil), // 83: daemon.StartCPUProfileRequest + (*StartCPUProfileResponse)(nil), // 84: daemon.StartCPUProfileResponse + (*StopCPUProfileRequest)(nil), // 85: daemon.StopCPUProfileRequest + (*StopCPUProfileResponse)(nil), // 86: daemon.StopCPUProfileResponse + (*InstallerResultRequest)(nil), // 87: daemon.InstallerResultRequest + (*InstallerResultResponse)(nil), // 88: daemon.InstallerResultResponse + (*ExposeServiceRequest)(nil), // 89: daemon.ExposeServiceRequest + (*ExposeServiceEvent)(nil), // 90: daemon.ExposeServiceEvent + (*ExposeServiceReady)(nil), // 91: daemon.ExposeServiceReady + (*StartCaptureRequest)(nil), // 92: daemon.StartCaptureRequest + (*CapturePacket)(nil), // 93: daemon.CapturePacket + (*StartBundleCaptureRequest)(nil), // 94: daemon.StartBundleCaptureRequest + (*StartBundleCaptureResponse)(nil), // 95: daemon.StartBundleCaptureResponse + (*StopBundleCaptureRequest)(nil), // 96: daemon.StopBundleCaptureRequest + (*StopBundleCaptureResponse)(nil), // 97: daemon.StopBundleCaptureResponse + nil, // 98: daemon.Network.ResolvedIPsEntry + (*PortInfo_Range)(nil), // 99: daemon.PortInfo.Range + nil, // 100: daemon.SystemEvent.MetadataEntry + (*durationpb.Duration)(nil), // 101: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 102: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ - 2, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType - 98, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration - 28, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 99, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 99, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 98, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration - 26, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo - 23, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState - 22, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState - 21, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState - 20, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState - 24, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState - 25, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState - 58, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent - 27, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState - 34, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network - 95, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry - 96, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range - 35, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo - 35, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo - 36, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule - 0, // 21: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel - 0, // 22: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel - 44, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State - 53, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags - 55, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage - 3, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity - 4, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category - 99, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp - 97, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry - 58, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent - 98, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration - 71, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile - 1, // 33: daemon.ExposeServiceRequest.protocol:type_name -> daemon.ExposeProtocol - 94, // 34: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady - 33, // 35: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList - 8, // 36: daemon.DaemonService.Login:input_type -> daemon.LoginRequest - 10, // 37: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest - 12, // 38: daemon.DaemonService.Up:input_type -> daemon.UpRequest - 14, // 39: daemon.DaemonService.Status:input_type -> daemon.StatusRequest - 16, // 40: daemon.DaemonService.Down:input_type -> daemon.DownRequest - 18, // 41: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest - 29, // 42: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest - 31, // 43: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest - 31, // 44: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest - 5, // 45: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest - 38, // 46: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest - 40, // 47: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest - 42, // 48: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest - 45, // 49: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest - 47, // 50: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest - 49, // 51: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest - 51, // 52: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest - 54, // 53: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest - 57, // 54: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest - 59, // 55: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest - 61, // 56: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest - 63, // 57: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest - 65, // 58: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest - 67, // 59: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest - 69, // 60: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest - 72, // 61: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest - 74, // 62: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest - 76, // 63: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest - 78, // 64: daemon.DaemonService.TriggerUpdate:input_type -> daemon.TriggerUpdateRequest - 80, // 65: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest - 82, // 66: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest - 84, // 67: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest - 86, // 68: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest - 88, // 69: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest - 6, // 70: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest - 90, // 71: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest - 92, // 72: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest - 9, // 73: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 11, // 74: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 13, // 75: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 15, // 76: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 17, // 77: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 19, // 78: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 30, // 79: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse - 32, // 80: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse - 32, // 81: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse - 37, // 82: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse - 39, // 83: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 41, // 84: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 43, // 85: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 46, // 86: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse - 48, // 87: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse - 50, // 88: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse - 52, // 89: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse - 56, // 90: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse - 58, // 91: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent - 60, // 92: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse - 62, // 93: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse - 64, // 94: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse - 66, // 95: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse - 68, // 96: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse - 70, // 97: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse - 73, // 98: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse - 75, // 99: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse - 77, // 100: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse - 79, // 101: daemon.DaemonService.TriggerUpdate:output_type -> daemon.TriggerUpdateResponse - 81, // 102: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse - 83, // 103: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse - 85, // 104: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse - 87, // 105: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse - 89, // 106: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse - 7, // 107: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse - 91, // 108: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse - 93, // 109: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent - 73, // [73:110] is the sub-list for method output_type - 36, // [36:73] is the sub-list for method input_type - 36, // [36:36] is the sub-list for extension type_name - 36, // [36:36] is the sub-list for extension extendee - 0, // [0:36] is the sub-list for field type_name + 101, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 25, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus + 102, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 102, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 101, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 23, // 5: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo + 20, // 6: daemon.FullStatus.managementState:type_name -> daemon.ManagementState + 19, // 7: daemon.FullStatus.signalState:type_name -> daemon.SignalState + 18, // 8: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState + 17, // 9: daemon.FullStatus.peers:type_name -> daemon.PeerState + 21, // 10: daemon.FullStatus.relays:type_name -> daemon.RelayState + 22, // 11: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState + 55, // 12: daemon.FullStatus.events:type_name -> daemon.SystemEvent + 24, // 13: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState + 31, // 14: daemon.ListNetworksResponse.routes:type_name -> daemon.Network + 98, // 15: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry + 99, // 16: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range + 32, // 17: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo + 32, // 18: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo + 33, // 19: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule + 0, // 20: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel + 0, // 21: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel + 41, // 22: daemon.ListStatesResponse.states:type_name -> daemon.State + 50, // 23: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags + 52, // 24: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage + 2, // 25: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity + 3, // 26: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category + 102, // 27: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp + 100, // 28: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry + 55, // 29: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent + 101, // 30: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 68, // 31: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile + 1, // 32: daemon.ExposeServiceRequest.protocol:type_name -> daemon.ExposeProtocol + 91, // 33: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady + 101, // 34: daemon.StartCaptureRequest.duration:type_name -> google.protobuf.Duration + 101, // 35: daemon.StartBundleCaptureRequest.timeout:type_name -> google.protobuf.Duration + 30, // 36: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList + 5, // 37: daemon.DaemonService.Login:input_type -> daemon.LoginRequest + 7, // 38: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest + 9, // 39: daemon.DaemonService.Up:input_type -> daemon.UpRequest + 11, // 40: daemon.DaemonService.Status:input_type -> daemon.StatusRequest + 13, // 41: daemon.DaemonService.Down:input_type -> daemon.DownRequest + 15, // 42: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest + 26, // 43: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest + 28, // 44: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest + 28, // 45: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest + 4, // 46: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest + 35, // 47: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest + 37, // 48: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest + 39, // 49: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest + 42, // 50: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest + 44, // 51: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest + 46, // 52: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest + 48, // 53: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest + 51, // 54: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest + 92, // 55: daemon.DaemonService.StartCapture:input_type -> daemon.StartCaptureRequest + 94, // 56: daemon.DaemonService.StartBundleCapture:input_type -> daemon.StartBundleCaptureRequest + 96, // 57: daemon.DaemonService.StopBundleCapture:input_type -> daemon.StopBundleCaptureRequest + 54, // 58: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest + 56, // 59: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest + 58, // 60: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest + 60, // 61: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest + 62, // 62: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest + 64, // 63: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest + 66, // 64: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest + 69, // 65: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest + 71, // 66: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest + 73, // 67: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest + 75, // 68: daemon.DaemonService.TriggerUpdate:input_type -> daemon.TriggerUpdateRequest + 77, // 69: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest + 79, // 70: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest + 81, // 71: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest + 83, // 72: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest + 85, // 73: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest + 87, // 74: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest + 89, // 75: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest + 6, // 76: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 8, // 77: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 10, // 78: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 12, // 79: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 14, // 80: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 16, // 81: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 27, // 82: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 29, // 83: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 29, // 84: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse + 34, // 85: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse + 36, // 86: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 38, // 87: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 40, // 88: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 43, // 89: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse + 45, // 90: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse + 47, // 91: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse + 49, // 92: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse + 53, // 93: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse + 93, // 94: daemon.DaemonService.StartCapture:output_type -> daemon.CapturePacket + 95, // 95: daemon.DaemonService.StartBundleCapture:output_type -> daemon.StartBundleCaptureResponse + 97, // 96: daemon.DaemonService.StopBundleCapture:output_type -> daemon.StopBundleCaptureResponse + 55, // 97: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent + 57, // 98: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse + 59, // 99: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse + 61, // 100: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse + 63, // 101: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse + 65, // 102: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse + 67, // 103: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse + 70, // 104: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse + 72, // 105: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse + 74, // 106: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse + 76, // 107: daemon.DaemonService.TriggerUpdate:output_type -> daemon.TriggerUpdateResponse + 78, // 108: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse + 80, // 109: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse + 82, // 110: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse + 84, // 111: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse + 86, // 112: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse + 88, // 113: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse + 90, // 114: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent + 76, // [76:115] is the sub-list for method output_type + 37, // [37:76] is the sub-list for method input_type + 37, // [37:37] is the sub-list for extension type_name + 37, // [37:37] is the sub-list for extension extendee + 0, // [0:37] is the sub-list for field type_name } func init() { file_daemon_proto_init() } @@ -6839,20 +7018,20 @@ func file_daemon_proto_init() { if File_daemon_proto != nil { return } - file_daemon_proto_msgTypes[3].OneofWrappers = []any{} + file_daemon_proto_msgTypes[1].OneofWrappers = []any{} + file_daemon_proto_msgTypes[5].OneofWrappers = []any{} file_daemon_proto_msgTypes[7].OneofWrappers = []any{} - file_daemon_proto_msgTypes[9].OneofWrappers = []any{} - file_daemon_proto_msgTypes[30].OneofWrappers = []any{ + file_daemon_proto_msgTypes[28].OneofWrappers = []any{ (*PortInfo_Port)(nil), (*PortInfo_Range_)(nil), } - file_daemon_proto_msgTypes[49].OneofWrappers = []any{} - file_daemon_proto_msgTypes[50].OneofWrappers = []any{} + file_daemon_proto_msgTypes[47].OneofWrappers = []any{} + file_daemon_proto_msgTypes[48].OneofWrappers = []any{} + file_daemon_proto_msgTypes[54].OneofWrappers = []any{} file_daemon_proto_msgTypes[56].OneofWrappers = []any{} - file_daemon_proto_msgTypes[58].OneofWrappers = []any{} - file_daemon_proto_msgTypes[69].OneofWrappers = []any{} - file_daemon_proto_msgTypes[77].OneofWrappers = []any{} - file_daemon_proto_msgTypes[88].OneofWrappers = []any{ + file_daemon_proto_msgTypes[67].OneofWrappers = []any{} + file_daemon_proto_msgTypes[75].OneofWrappers = []any{} + file_daemon_proto_msgTypes[86].OneofWrappers = []any{ (*ExposeServiceEvent_Ready)(nil), } type x struct{} @@ -6860,8 +7039,8 @@ func file_daemon_proto_init() { File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)), - NumEnums: 5, - NumMessages: 93, + NumEnums: 4, + NumMessages: 97, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 89302c8c3..3fee9eca8 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -64,6 +64,17 @@ service DaemonService { rpc TracePacket(TracePacketRequest) returns (TracePacketResponse) {} + // StartCapture begins streaming packet capture on the WireGuard interface. + // Requires --enable-capture set at service install/reconfigure time. + rpc StartCapture(StartCaptureRequest) returns (stream CapturePacket) {} + + // StartBundleCapture begins capturing packets to a server-side temp file + // for inclusion in the next debug bundle. Auto-stops after the given timeout. + rpc StartBundleCapture(StartBundleCaptureRequest) returns (StartBundleCaptureResponse) {} + + // StopBundleCapture stops the running bundle capture. Idempotent. + rpc StopBundleCapture(StopBundleCaptureRequest) returns (StopBundleCaptureResponse) {} + rpc SubscribeEvents(SubscribeRequest) returns (stream SystemEvent) {} rpc GetEvents(GetEventsRequest) returns (GetEventsResponse) {} @@ -104,8 +115,6 @@ service DaemonService { // StopCPUProfile stops CPU profiling in the daemon rpc StopCPUProfile(StopCPUProfileRequest) returns (StopCPUProfileResponse) {} - rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {} - rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {} // ExposeService exposes a local port via the NetBird reverse proxy @@ -114,20 +123,6 @@ service DaemonService { -message OSLifecycleRequest { - // avoid collision with loglevel enum - enum CycleType { - UNKNOWN = 0; - SLEEP = 1; - WAKEUP = 2; - } - - CycleType type = 1; -} - -message OSLifecycleResponse {} - - message LoginRequest { // setupKey netbird setup key. string setupKey = 1; @@ -727,6 +722,7 @@ message GetFeaturesRequest{} message GetFeaturesResponse{ bool disable_profiles = 1; bool disable_update_settings = 2; + bool disable_networks = 3; } message TriggerUpdateRequest {} @@ -847,3 +843,26 @@ message ExposeServiceReady { string domain = 3; bool port_auto_assigned = 4; } + +message StartCaptureRequest { + bool text_output = 1; + uint32 snap_len = 2; + google.protobuf.Duration duration = 3; + string filter_expr = 4; + bool verbose = 5; + bool ascii = 6; +} + +message CapturePacket { + bytes data = 1; +} + +message StartBundleCaptureRequest { + // timeout auto-stops the capture after this duration. + // Clamped to a server-side maximum (10 minutes). Zero or unset defaults to the maximum. + google.protobuf.Duration timeout = 1; +} + +message StartBundleCaptureResponse {} +message StopBundleCaptureRequest {} +message StopBundleCaptureResponse {} diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index e5bd89597..66a8efcc3 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.6.1 +// - protoc v6.33.1 +// source: daemon.proto package proto @@ -11,8 +15,50 @@ import ( // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.32.0 or later. -const _ = grpc.SupportPackageIsVersion7 +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + DaemonService_Login_FullMethodName = "/daemon.DaemonService/Login" + DaemonService_WaitSSOLogin_FullMethodName = "/daemon.DaemonService/WaitSSOLogin" + DaemonService_Up_FullMethodName = "/daemon.DaemonService/Up" + DaemonService_Status_FullMethodName = "/daemon.DaemonService/Status" + DaemonService_Down_FullMethodName = "/daemon.DaemonService/Down" + DaemonService_GetConfig_FullMethodName = "/daemon.DaemonService/GetConfig" + DaemonService_ListNetworks_FullMethodName = "/daemon.DaemonService/ListNetworks" + DaemonService_SelectNetworks_FullMethodName = "/daemon.DaemonService/SelectNetworks" + DaemonService_DeselectNetworks_FullMethodName = "/daemon.DaemonService/DeselectNetworks" + DaemonService_ForwardingRules_FullMethodName = "/daemon.DaemonService/ForwardingRules" + DaemonService_DebugBundle_FullMethodName = "/daemon.DaemonService/DebugBundle" + DaemonService_GetLogLevel_FullMethodName = "/daemon.DaemonService/GetLogLevel" + DaemonService_SetLogLevel_FullMethodName = "/daemon.DaemonService/SetLogLevel" + DaemonService_ListStates_FullMethodName = "/daemon.DaemonService/ListStates" + DaemonService_CleanState_FullMethodName = "/daemon.DaemonService/CleanState" + DaemonService_DeleteState_FullMethodName = "/daemon.DaemonService/DeleteState" + DaemonService_SetSyncResponsePersistence_FullMethodName = "/daemon.DaemonService/SetSyncResponsePersistence" + DaemonService_TracePacket_FullMethodName = "/daemon.DaemonService/TracePacket" + DaemonService_StartCapture_FullMethodName = "/daemon.DaemonService/StartCapture" + DaemonService_StartBundleCapture_FullMethodName = "/daemon.DaemonService/StartBundleCapture" + DaemonService_StopBundleCapture_FullMethodName = "/daemon.DaemonService/StopBundleCapture" + DaemonService_SubscribeEvents_FullMethodName = "/daemon.DaemonService/SubscribeEvents" + DaemonService_GetEvents_FullMethodName = "/daemon.DaemonService/GetEvents" + DaemonService_SwitchProfile_FullMethodName = "/daemon.DaemonService/SwitchProfile" + DaemonService_SetConfig_FullMethodName = "/daemon.DaemonService/SetConfig" + DaemonService_AddProfile_FullMethodName = "/daemon.DaemonService/AddProfile" + DaemonService_RemoveProfile_FullMethodName = "/daemon.DaemonService/RemoveProfile" + DaemonService_ListProfiles_FullMethodName = "/daemon.DaemonService/ListProfiles" + DaemonService_GetActiveProfile_FullMethodName = "/daemon.DaemonService/GetActiveProfile" + DaemonService_Logout_FullMethodName = "/daemon.DaemonService/Logout" + DaemonService_GetFeatures_FullMethodName = "/daemon.DaemonService/GetFeatures" + DaemonService_TriggerUpdate_FullMethodName = "/daemon.DaemonService/TriggerUpdate" + DaemonService_GetPeerSSHHostKey_FullMethodName = "/daemon.DaemonService/GetPeerSSHHostKey" + DaemonService_RequestJWTAuth_FullMethodName = "/daemon.DaemonService/RequestJWTAuth" + DaemonService_WaitJWTToken_FullMethodName = "/daemon.DaemonService/WaitJWTToken" + DaemonService_StartCPUProfile_FullMethodName = "/daemon.DaemonService/StartCPUProfile" + DaemonService_StopCPUProfile_FullMethodName = "/daemon.DaemonService/StopCPUProfile" + DaemonService_GetInstallerResult_FullMethodName = "/daemon.DaemonService/GetInstallerResult" + DaemonService_ExposeService_FullMethodName = "/daemon.DaemonService/ExposeService" +) // DaemonServiceClient is the client API for DaemonService service. // @@ -53,7 +99,15 @@ type DaemonServiceClient interface { // SetSyncResponsePersistence enables or disables sync response persistence SetSyncResponsePersistence(ctx context.Context, in *SetSyncResponsePersistenceRequest, opts ...grpc.CallOption) (*SetSyncResponsePersistenceResponse, error) TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error) - SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error) + // StartCapture begins streaming packet capture on the WireGuard interface. + // Requires --enable-capture set at service install/reconfigure time. + StartCapture(ctx context.Context, in *StartCaptureRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[CapturePacket], error) + // StartBundleCapture begins capturing packets to a server-side temp file + // for inclusion in the next debug bundle. Auto-stops after the given timeout. + StartBundleCapture(ctx context.Context, in *StartBundleCaptureRequest, opts ...grpc.CallOption) (*StartBundleCaptureResponse, error) + // StopBundleCapture stops the running bundle capture. Idempotent. + StopBundleCapture(ctx context.Context, in *StopBundleCaptureRequest, opts ...grpc.CallOption) (*StopBundleCaptureResponse, error) + SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[SystemEvent], error) GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error) SwitchProfile(ctx context.Context, in *SwitchProfileRequest, opts ...grpc.CallOption) (*SwitchProfileResponse, error) SetConfig(ctx context.Context, in *SetConfigRequest, opts ...grpc.CallOption) (*SetConfigResponse, error) @@ -77,10 +131,9 @@ type DaemonServiceClient interface { StartCPUProfile(ctx context.Context, in *StartCPUProfileRequest, opts ...grpc.CallOption) (*StartCPUProfileResponse, error) // StopCPUProfile stops CPU profiling in the daemon StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error) - NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error) // ExposeService exposes a local port via the NetBird reverse proxy - ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (DaemonService_ExposeServiceClient, error) + ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ExposeServiceEvent], error) } type daemonServiceClient struct { @@ -92,8 +145,9 @@ func NewDaemonServiceClient(cc grpc.ClientConnInterface) DaemonServiceClient { } func (c *daemonServiceClient) Login(ctx context.Context, in *LoginRequest, opts ...grpc.CallOption) (*LoginResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(LoginResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/Login", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_Login_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -101,8 +155,9 @@ func (c *daemonServiceClient) Login(ctx context.Context, in *LoginRequest, opts } func (c *daemonServiceClient) WaitSSOLogin(ctx context.Context, in *WaitSSOLoginRequest, opts ...grpc.CallOption) (*WaitSSOLoginResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(WaitSSOLoginResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/WaitSSOLogin", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_WaitSSOLogin_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -110,8 +165,9 @@ func (c *daemonServiceClient) WaitSSOLogin(ctx context.Context, in *WaitSSOLogin } func (c *daemonServiceClient) Up(ctx context.Context, in *UpRequest, opts ...grpc.CallOption) (*UpResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(UpResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/Up", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_Up_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -119,8 +175,9 @@ func (c *daemonServiceClient) Up(ctx context.Context, in *UpRequest, opts ...grp } func (c *daemonServiceClient) Status(ctx context.Context, in *StatusRequest, opts ...grpc.CallOption) (*StatusResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(StatusResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/Status", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_Status_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -128,8 +185,9 @@ func (c *daemonServiceClient) Status(ctx context.Context, in *StatusRequest, opt } func (c *daemonServiceClient) Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(DownResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/Down", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_Down_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -137,8 +195,9 @@ func (c *daemonServiceClient) Down(ctx context.Context, in *DownRequest, opts .. } func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetConfigResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetConfig", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_GetConfig_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -146,8 +205,9 @@ func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigReques } func (c *daemonServiceClient) ListNetworks(ctx context.Context, in *ListNetworksRequest, opts ...grpc.CallOption) (*ListNetworksResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ListNetworksResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListNetworks", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_ListNetworks_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -155,8 +215,9 @@ func (c *daemonServiceClient) ListNetworks(ctx context.Context, in *ListNetworks } func (c *daemonServiceClient) SelectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(SelectNetworksResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/SelectNetworks", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_SelectNetworks_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -164,8 +225,9 @@ func (c *daemonServiceClient) SelectNetworks(ctx context.Context, in *SelectNetw } func (c *daemonServiceClient) DeselectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(SelectNetworksResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeselectNetworks", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_DeselectNetworks_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -173,8 +235,9 @@ func (c *daemonServiceClient) DeselectNetworks(ctx context.Context, in *SelectNe } func (c *daemonServiceClient) ForwardingRules(ctx context.Context, in *EmptyRequest, opts ...grpc.CallOption) (*ForwardingRulesResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ForwardingRulesResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/ForwardingRules", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_ForwardingRules_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -182,8 +245,9 @@ func (c *daemonServiceClient) ForwardingRules(ctx context.Context, in *EmptyRequ } func (c *daemonServiceClient) DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(DebugBundleResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/DebugBundle", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_DebugBundle_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -191,8 +255,9 @@ func (c *daemonServiceClient) DebugBundle(ctx context.Context, in *DebugBundleRe } func (c *daemonServiceClient) GetLogLevel(ctx context.Context, in *GetLogLevelRequest, opts ...grpc.CallOption) (*GetLogLevelResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetLogLevelResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetLogLevel", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_GetLogLevel_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -200,8 +265,9 @@ func (c *daemonServiceClient) GetLogLevel(ctx context.Context, in *GetLogLevelRe } func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(SetLogLevelResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetLogLevel", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_SetLogLevel_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -209,8 +275,9 @@ func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRe } func (c *daemonServiceClient) ListStates(ctx context.Context, in *ListStatesRequest, opts ...grpc.CallOption) (*ListStatesResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ListStatesResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListStates", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_ListStates_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -218,8 +285,9 @@ func (c *daemonServiceClient) ListStates(ctx context.Context, in *ListStatesRequ } func (c *daemonServiceClient) CleanState(ctx context.Context, in *CleanStateRequest, opts ...grpc.CallOption) (*CleanStateResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(CleanStateResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/CleanState", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_CleanState_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -227,8 +295,9 @@ func (c *daemonServiceClient) CleanState(ctx context.Context, in *CleanStateRequ } func (c *daemonServiceClient) DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(DeleteStateResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeleteState", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_DeleteState_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -236,8 +305,9 @@ func (c *daemonServiceClient) DeleteState(ctx context.Context, in *DeleteStateRe } func (c *daemonServiceClient) SetSyncResponsePersistence(ctx context.Context, in *SetSyncResponsePersistenceRequest, opts ...grpc.CallOption) (*SetSyncResponsePersistenceResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(SetSyncResponsePersistenceResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetSyncResponsePersistence", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_SetSyncResponsePersistence_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -245,20 +315,22 @@ func (c *daemonServiceClient) SetSyncResponsePersistence(ctx context.Context, in } func (c *daemonServiceClient) TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(TracePacketResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/TracePacket", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_TracePacket_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } -func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error) { - stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[0], "/daemon.DaemonService/SubscribeEvents", opts...) +func (c *daemonServiceClient) StartCapture(ctx context.Context, in *StartCaptureRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[CapturePacket], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[0], DaemonService_StartCapture_FullMethodName, cOpts...) if err != nil { return nil, err } - x := &daemonServiceSubscribeEventsClient{stream} + x := &grpc.GenericClientStream[StartCaptureRequest, CapturePacket]{ClientStream: stream} if err := x.ClientStream.SendMsg(in); err != nil { return nil, err } @@ -268,26 +340,52 @@ func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *Subscribe return x, nil } -type DaemonService_SubscribeEventsClient interface { - Recv() (*SystemEvent, error) - grpc.ClientStream -} +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type DaemonService_StartCaptureClient = grpc.ServerStreamingClient[CapturePacket] -type daemonServiceSubscribeEventsClient struct { - grpc.ClientStream -} - -func (x *daemonServiceSubscribeEventsClient) Recv() (*SystemEvent, error) { - m := new(SystemEvent) - if err := x.ClientStream.RecvMsg(m); err != nil { +func (c *daemonServiceClient) StartBundleCapture(ctx context.Context, in *StartBundleCaptureRequest, opts ...grpc.CallOption) (*StartBundleCaptureResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(StartBundleCaptureResponse) + err := c.cc.Invoke(ctx, DaemonService_StartBundleCapture_FullMethodName, in, out, cOpts...) + if err != nil { return nil, err } - return m, nil + return out, nil } +func (c *daemonServiceClient) StopBundleCapture(ctx context.Context, in *StopBundleCaptureRequest, opts ...grpc.CallOption) (*StopBundleCaptureResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(StopBundleCaptureResponse) + err := c.cc.Invoke(ctx, DaemonService_StopBundleCapture_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[SystemEvent], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[1], DaemonService_SubscribeEvents_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[SubscribeRequest, SystemEvent]{ClientStream: stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type DaemonService_SubscribeEventsClient = grpc.ServerStreamingClient[SystemEvent] + func (c *daemonServiceClient) GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetEventsResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetEvents", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_GetEvents_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -295,8 +393,9 @@ func (c *daemonServiceClient) GetEvents(ctx context.Context, in *GetEventsReques } func (c *daemonServiceClient) SwitchProfile(ctx context.Context, in *SwitchProfileRequest, opts ...grpc.CallOption) (*SwitchProfileResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(SwitchProfileResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/SwitchProfile", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_SwitchProfile_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -304,8 +403,9 @@ func (c *daemonServiceClient) SwitchProfile(ctx context.Context, in *SwitchProfi } func (c *daemonServiceClient) SetConfig(ctx context.Context, in *SetConfigRequest, opts ...grpc.CallOption) (*SetConfigResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(SetConfigResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetConfig", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_SetConfig_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -313,8 +413,9 @@ func (c *daemonServiceClient) SetConfig(ctx context.Context, in *SetConfigReques } func (c *daemonServiceClient) AddProfile(ctx context.Context, in *AddProfileRequest, opts ...grpc.CallOption) (*AddProfileResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(AddProfileResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/AddProfile", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_AddProfile_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -322,8 +423,9 @@ func (c *daemonServiceClient) AddProfile(ctx context.Context, in *AddProfileRequ } func (c *daemonServiceClient) RemoveProfile(ctx context.Context, in *RemoveProfileRequest, opts ...grpc.CallOption) (*RemoveProfileResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(RemoveProfileResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/RemoveProfile", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_RemoveProfile_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -331,8 +433,9 @@ func (c *daemonServiceClient) RemoveProfile(ctx context.Context, in *RemoveProfi } func (c *daemonServiceClient) ListProfiles(ctx context.Context, in *ListProfilesRequest, opts ...grpc.CallOption) (*ListProfilesResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ListProfilesResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListProfiles", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_ListProfiles_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -340,8 +443,9 @@ func (c *daemonServiceClient) ListProfiles(ctx context.Context, in *ListProfiles } func (c *daemonServiceClient) GetActiveProfile(ctx context.Context, in *GetActiveProfileRequest, opts ...grpc.CallOption) (*GetActiveProfileResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetActiveProfileResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetActiveProfile", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_GetActiveProfile_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -349,8 +453,9 @@ func (c *daemonServiceClient) GetActiveProfile(ctx context.Context, in *GetActiv } func (c *daemonServiceClient) Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(LogoutResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/Logout", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_Logout_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -358,8 +463,9 @@ func (c *daemonServiceClient) Logout(ctx context.Context, in *LogoutRequest, opt } func (c *daemonServiceClient) GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetFeaturesResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetFeatures", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_GetFeatures_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -367,8 +473,9 @@ func (c *daemonServiceClient) GetFeatures(ctx context.Context, in *GetFeaturesRe } func (c *daemonServiceClient) TriggerUpdate(ctx context.Context, in *TriggerUpdateRequest, opts ...grpc.CallOption) (*TriggerUpdateResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(TriggerUpdateResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/TriggerUpdate", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_TriggerUpdate_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -376,8 +483,9 @@ func (c *daemonServiceClient) TriggerUpdate(ctx context.Context, in *TriggerUpda } func (c *daemonServiceClient) GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetPeerSSHHostKeyResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetPeerSSHHostKey", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_GetPeerSSHHostKey_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -385,8 +493,9 @@ func (c *daemonServiceClient) GetPeerSSHHostKey(ctx context.Context, in *GetPeer } func (c *daemonServiceClient) RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(RequestJWTAuthResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/RequestJWTAuth", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_RequestJWTAuth_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -394,8 +503,9 @@ func (c *daemonServiceClient) RequestJWTAuth(ctx context.Context, in *RequestJWT } func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(WaitJWTTokenResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/WaitJWTToken", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_WaitJWTToken_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -403,8 +513,9 @@ func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTToken } func (c *daemonServiceClient) StartCPUProfile(ctx context.Context, in *StartCPUProfileRequest, opts ...grpc.CallOption) (*StartCPUProfileResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(StartCPUProfileResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/StartCPUProfile", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_StartCPUProfile_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -412,17 +523,9 @@ func (c *daemonServiceClient) StartCPUProfile(ctx context.Context, in *StartCPUP } func (c *daemonServiceClient) StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(StopCPUProfileResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/StopCPUProfile", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) { - out := new(OSLifecycleResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/NotifyOSLifecycle", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_StopCPUProfile_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -430,20 +533,22 @@ func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifec } func (c *daemonServiceClient) GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(InstallerResultResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetInstallerResult", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_GetInstallerResult_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } -func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (DaemonService_ExposeServiceClient, error) { - stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[1], "/daemon.DaemonService/ExposeService", opts...) +func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ExposeServiceEvent], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[2], DaemonService_ExposeService_FullMethodName, cOpts...) if err != nil { return nil, err } - x := &daemonServiceExposeServiceClient{stream} + x := &grpc.GenericClientStream[ExposeServiceRequest, ExposeServiceEvent]{ClientStream: stream} if err := x.ClientStream.SendMsg(in); err != nil { return nil, err } @@ -453,26 +558,12 @@ func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServi return x, nil } -type DaemonService_ExposeServiceClient interface { - Recv() (*ExposeServiceEvent, error) - grpc.ClientStream -} - -type daemonServiceExposeServiceClient struct { - grpc.ClientStream -} - -func (x *daemonServiceExposeServiceClient) Recv() (*ExposeServiceEvent, error) { - m := new(ExposeServiceEvent) - if err := x.ClientStream.RecvMsg(m); err != nil { - return nil, err - } - return m, nil -} +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type DaemonService_ExposeServiceClient = grpc.ServerStreamingClient[ExposeServiceEvent] // DaemonServiceServer is the server API for DaemonService service. // All implementations must embed UnimplementedDaemonServiceServer -// for forward compatibility +// for forward compatibility. type DaemonServiceServer interface { // Login uses setup key to prepare configuration for the daemon. Login(context.Context, *LoginRequest) (*LoginResponse, error) @@ -509,7 +600,15 @@ type DaemonServiceServer interface { // SetSyncResponsePersistence enables or disables sync response persistence SetSyncResponsePersistence(context.Context, *SetSyncResponsePersistenceRequest) (*SetSyncResponsePersistenceResponse, error) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) - SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error + // StartCapture begins streaming packet capture on the WireGuard interface. + // Requires --enable-capture set at service install/reconfigure time. + StartCapture(*StartCaptureRequest, grpc.ServerStreamingServer[CapturePacket]) error + // StartBundleCapture begins capturing packets to a server-side temp file + // for inclusion in the next debug bundle. Auto-stops after the given timeout. + StartBundleCapture(context.Context, *StartBundleCaptureRequest) (*StartBundleCaptureResponse, error) + // StopBundleCapture stops the running bundle capture. Idempotent. + StopBundleCapture(context.Context, *StopBundleCaptureRequest) (*StopBundleCaptureResponse, error) + SubscribeEvents(*SubscribeRequest, grpc.ServerStreamingServer[SystemEvent]) error GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error) SwitchProfile(context.Context, *SwitchProfileRequest) (*SwitchProfileResponse, error) SetConfig(context.Context, *SetConfigRequest) (*SetConfigResponse, error) @@ -533,129 +632,138 @@ type DaemonServiceServer interface { StartCPUProfile(context.Context, *StartCPUProfileRequest) (*StartCPUProfileResponse, error) // StopCPUProfile stops CPU profiling in the daemon StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error) - NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) // ExposeService exposes a local port via the NetBird reverse proxy - ExposeService(*ExposeServiceRequest, DaemonService_ExposeServiceServer) error + ExposeService(*ExposeServiceRequest, grpc.ServerStreamingServer[ExposeServiceEvent]) error mustEmbedUnimplementedDaemonServiceServer() } -// UnimplementedDaemonServiceServer must be embedded to have forward compatible implementations. -type UnimplementedDaemonServiceServer struct { -} +// UnimplementedDaemonServiceServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedDaemonServiceServer struct{} func (UnimplementedDaemonServiceServer) Login(context.Context, *LoginRequest) (*LoginResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method Login not implemented") + return nil, status.Error(codes.Unimplemented, "method Login not implemented") } func (UnimplementedDaemonServiceServer) WaitSSOLogin(context.Context, *WaitSSOLoginRequest) (*WaitSSOLoginResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method WaitSSOLogin not implemented") + return nil, status.Error(codes.Unimplemented, "method WaitSSOLogin not implemented") } func (UnimplementedDaemonServiceServer) Up(context.Context, *UpRequest) (*UpResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method Up not implemented") + return nil, status.Error(codes.Unimplemented, "method Up not implemented") } func (UnimplementedDaemonServiceServer) Status(context.Context, *StatusRequest) (*StatusResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method Status not implemented") + return nil, status.Error(codes.Unimplemented, "method Status not implemented") } func (UnimplementedDaemonServiceServer) Down(context.Context, *DownRequest) (*DownResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method Down not implemented") + return nil, status.Error(codes.Unimplemented, "method Down not implemented") } func (UnimplementedDaemonServiceServer) GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetConfig not implemented") + return nil, status.Error(codes.Unimplemented, "method GetConfig not implemented") } func (UnimplementedDaemonServiceServer) ListNetworks(context.Context, *ListNetworksRequest) (*ListNetworksResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method ListNetworks not implemented") + return nil, status.Error(codes.Unimplemented, "method ListNetworks not implemented") } func (UnimplementedDaemonServiceServer) SelectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method SelectNetworks not implemented") + return nil, status.Error(codes.Unimplemented, "method SelectNetworks not implemented") } func (UnimplementedDaemonServiceServer) DeselectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method DeselectNetworks not implemented") + return nil, status.Error(codes.Unimplemented, "method DeselectNetworks not implemented") } func (UnimplementedDaemonServiceServer) ForwardingRules(context.Context, *EmptyRequest) (*ForwardingRulesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method ForwardingRules not implemented") + return nil, status.Error(codes.Unimplemented, "method ForwardingRules not implemented") } func (UnimplementedDaemonServiceServer) DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method DebugBundle not implemented") + return nil, status.Error(codes.Unimplemented, "method DebugBundle not implemented") } func (UnimplementedDaemonServiceServer) GetLogLevel(context.Context, *GetLogLevelRequest) (*GetLogLevelResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetLogLevel not implemented") + return nil, status.Error(codes.Unimplemented, "method GetLogLevel not implemented") } func (UnimplementedDaemonServiceServer) SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method SetLogLevel not implemented") + return nil, status.Error(codes.Unimplemented, "method SetLogLevel not implemented") } func (UnimplementedDaemonServiceServer) ListStates(context.Context, *ListStatesRequest) (*ListStatesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method ListStates not implemented") + return nil, status.Error(codes.Unimplemented, "method ListStates not implemented") } func (UnimplementedDaemonServiceServer) CleanState(context.Context, *CleanStateRequest) (*CleanStateResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method CleanState not implemented") + return nil, status.Error(codes.Unimplemented, "method CleanState not implemented") } func (UnimplementedDaemonServiceServer) DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method DeleteState not implemented") + return nil, status.Error(codes.Unimplemented, "method DeleteState not implemented") } func (UnimplementedDaemonServiceServer) SetSyncResponsePersistence(context.Context, *SetSyncResponsePersistenceRequest) (*SetSyncResponsePersistenceResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method SetSyncResponsePersistence not implemented") + return nil, status.Error(codes.Unimplemented, "method SetSyncResponsePersistence not implemented") } func (UnimplementedDaemonServiceServer) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method TracePacket not implemented") + return nil, status.Error(codes.Unimplemented, "method TracePacket not implemented") } -func (UnimplementedDaemonServiceServer) SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error { - return status.Errorf(codes.Unimplemented, "method SubscribeEvents not implemented") +func (UnimplementedDaemonServiceServer) StartCapture(*StartCaptureRequest, grpc.ServerStreamingServer[CapturePacket]) error { + return status.Error(codes.Unimplemented, "method StartCapture not implemented") +} +func (UnimplementedDaemonServiceServer) StartBundleCapture(context.Context, *StartBundleCaptureRequest) (*StartBundleCaptureResponse, error) { + return nil, status.Error(codes.Unimplemented, "method StartBundleCapture not implemented") +} +func (UnimplementedDaemonServiceServer) StopBundleCapture(context.Context, *StopBundleCaptureRequest) (*StopBundleCaptureResponse, error) { + return nil, status.Error(codes.Unimplemented, "method StopBundleCapture not implemented") +} +func (UnimplementedDaemonServiceServer) SubscribeEvents(*SubscribeRequest, grpc.ServerStreamingServer[SystemEvent]) error { + return status.Error(codes.Unimplemented, "method SubscribeEvents not implemented") } func (UnimplementedDaemonServiceServer) GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetEvents not implemented") + return nil, status.Error(codes.Unimplemented, "method GetEvents not implemented") } func (UnimplementedDaemonServiceServer) SwitchProfile(context.Context, *SwitchProfileRequest) (*SwitchProfileResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method SwitchProfile not implemented") + return nil, status.Error(codes.Unimplemented, "method SwitchProfile not implemented") } func (UnimplementedDaemonServiceServer) SetConfig(context.Context, *SetConfigRequest) (*SetConfigResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method SetConfig not implemented") + return nil, status.Error(codes.Unimplemented, "method SetConfig not implemented") } func (UnimplementedDaemonServiceServer) AddProfile(context.Context, *AddProfileRequest) (*AddProfileResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method AddProfile not implemented") + return nil, status.Error(codes.Unimplemented, "method AddProfile not implemented") } func (UnimplementedDaemonServiceServer) RemoveProfile(context.Context, *RemoveProfileRequest) (*RemoveProfileResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method RemoveProfile not implemented") + return nil, status.Error(codes.Unimplemented, "method RemoveProfile not implemented") } func (UnimplementedDaemonServiceServer) ListProfiles(context.Context, *ListProfilesRequest) (*ListProfilesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method ListProfiles not implemented") + return nil, status.Error(codes.Unimplemented, "method ListProfiles not implemented") } func (UnimplementedDaemonServiceServer) GetActiveProfile(context.Context, *GetActiveProfileRequest) (*GetActiveProfileResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetActiveProfile not implemented") + return nil, status.Error(codes.Unimplemented, "method GetActiveProfile not implemented") } func (UnimplementedDaemonServiceServer) Logout(context.Context, *LogoutRequest) (*LogoutResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method Logout not implemented") + return nil, status.Error(codes.Unimplemented, "method Logout not implemented") } func (UnimplementedDaemonServiceServer) GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetFeatures not implemented") + return nil, status.Error(codes.Unimplemented, "method GetFeatures not implemented") } func (UnimplementedDaemonServiceServer) TriggerUpdate(context.Context, *TriggerUpdateRequest) (*TriggerUpdateResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method TriggerUpdate not implemented") + return nil, status.Error(codes.Unimplemented, "method TriggerUpdate not implemented") } func (UnimplementedDaemonServiceServer) GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetPeerSSHHostKey not implemented") + return nil, status.Error(codes.Unimplemented, "method GetPeerSSHHostKey not implemented") } func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method RequestJWTAuth not implemented") + return nil, status.Error(codes.Unimplemented, "method RequestJWTAuth not implemented") } func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented") + return nil, status.Error(codes.Unimplemented, "method WaitJWTToken not implemented") } func (UnimplementedDaemonServiceServer) StartCPUProfile(context.Context, *StartCPUProfileRequest) (*StartCPUProfileResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method StartCPUProfile not implemented") + return nil, status.Error(codes.Unimplemented, "method StartCPUProfile not implemented") } func (UnimplementedDaemonServiceServer) StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method StopCPUProfile not implemented") -} -func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented") + return nil, status.Error(codes.Unimplemented, "method StopCPUProfile not implemented") } func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetInstallerResult not implemented") + return nil, status.Error(codes.Unimplemented, "method GetInstallerResult not implemented") } -func (UnimplementedDaemonServiceServer) ExposeService(*ExposeServiceRequest, DaemonService_ExposeServiceServer) error { - return status.Errorf(codes.Unimplemented, "method ExposeService not implemented") +func (UnimplementedDaemonServiceServer) ExposeService(*ExposeServiceRequest, grpc.ServerStreamingServer[ExposeServiceEvent]) error { + return status.Error(codes.Unimplemented, "method ExposeService not implemented") } func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} +func (UnimplementedDaemonServiceServer) testEmbeddedByValue() {} // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. // Use of this interface is not recommended, as added methods to DaemonServiceServer will @@ -665,6 +773,13 @@ type UnsafeDaemonServiceServer interface { } func RegisterDaemonServiceServer(s grpc.ServiceRegistrar, srv DaemonServiceServer) { + // If the following call panics, it indicates UnimplementedDaemonServiceServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } s.RegisterService(&DaemonService_ServiceDesc, srv) } @@ -678,7 +793,7 @@ func _DaemonService_Login_Handler(srv interface{}, ctx context.Context, dec func } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/Login", + FullMethod: DaemonService_Login_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).Login(ctx, req.(*LoginRequest)) @@ -696,7 +811,7 @@ func _DaemonService_WaitSSOLogin_Handler(srv interface{}, ctx context.Context, d } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/WaitSSOLogin", + FullMethod: DaemonService_WaitSSOLogin_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).WaitSSOLogin(ctx, req.(*WaitSSOLoginRequest)) @@ -714,7 +829,7 @@ func _DaemonService_Up_Handler(srv interface{}, ctx context.Context, dec func(in } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/Up", + FullMethod: DaemonService_Up_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).Up(ctx, req.(*UpRequest)) @@ -732,7 +847,7 @@ func _DaemonService_Status_Handler(srv interface{}, ctx context.Context, dec fun } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/Status", + FullMethod: DaemonService_Status_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).Status(ctx, req.(*StatusRequest)) @@ -750,7 +865,7 @@ func _DaemonService_Down_Handler(srv interface{}, ctx context.Context, dec func( } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/Down", + FullMethod: DaemonService_Down_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).Down(ctx, req.(*DownRequest)) @@ -768,7 +883,7 @@ func _DaemonService_GetConfig_Handler(srv interface{}, ctx context.Context, dec } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/GetConfig", + FullMethod: DaemonService_GetConfig_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).GetConfig(ctx, req.(*GetConfigRequest)) @@ -786,7 +901,7 @@ func _DaemonService_ListNetworks_Handler(srv interface{}, ctx context.Context, d } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/ListNetworks", + FullMethod: DaemonService_ListNetworks_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).ListNetworks(ctx, req.(*ListNetworksRequest)) @@ -804,7 +919,7 @@ func _DaemonService_SelectNetworks_Handler(srv interface{}, ctx context.Context, } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/SelectNetworks", + FullMethod: DaemonService_SelectNetworks_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).SelectNetworks(ctx, req.(*SelectNetworksRequest)) @@ -822,7 +937,7 @@ func _DaemonService_DeselectNetworks_Handler(srv interface{}, ctx context.Contex } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/DeselectNetworks", + FullMethod: DaemonService_DeselectNetworks_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).DeselectNetworks(ctx, req.(*SelectNetworksRequest)) @@ -840,7 +955,7 @@ func _DaemonService_ForwardingRules_Handler(srv interface{}, ctx context.Context } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/ForwardingRules", + FullMethod: DaemonService_ForwardingRules_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).ForwardingRules(ctx, req.(*EmptyRequest)) @@ -858,7 +973,7 @@ func _DaemonService_DebugBundle_Handler(srv interface{}, ctx context.Context, de } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/DebugBundle", + FullMethod: DaemonService_DebugBundle_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).DebugBundle(ctx, req.(*DebugBundleRequest)) @@ -876,7 +991,7 @@ func _DaemonService_GetLogLevel_Handler(srv interface{}, ctx context.Context, de } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/GetLogLevel", + FullMethod: DaemonService_GetLogLevel_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).GetLogLevel(ctx, req.(*GetLogLevelRequest)) @@ -894,7 +1009,7 @@ func _DaemonService_SetLogLevel_Handler(srv interface{}, ctx context.Context, de } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/SetLogLevel", + FullMethod: DaemonService_SetLogLevel_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).SetLogLevel(ctx, req.(*SetLogLevelRequest)) @@ -912,7 +1027,7 @@ func _DaemonService_ListStates_Handler(srv interface{}, ctx context.Context, dec } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/ListStates", + FullMethod: DaemonService_ListStates_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).ListStates(ctx, req.(*ListStatesRequest)) @@ -930,7 +1045,7 @@ func _DaemonService_CleanState_Handler(srv interface{}, ctx context.Context, dec } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/CleanState", + FullMethod: DaemonService_CleanState_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).CleanState(ctx, req.(*CleanStateRequest)) @@ -948,7 +1063,7 @@ func _DaemonService_DeleteState_Handler(srv interface{}, ctx context.Context, de } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/DeleteState", + FullMethod: DaemonService_DeleteState_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).DeleteState(ctx, req.(*DeleteStateRequest)) @@ -966,7 +1081,7 @@ func _DaemonService_SetSyncResponsePersistence_Handler(srv interface{}, ctx cont } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/SetSyncResponsePersistence", + FullMethod: DaemonService_SetSyncResponsePersistence_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).SetSyncResponsePersistence(ctx, req.(*SetSyncResponsePersistenceRequest)) @@ -984,7 +1099,7 @@ func _DaemonService_TracePacket_Handler(srv interface{}, ctx context.Context, de } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/TracePacket", + FullMethod: DaemonService_TracePacket_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).TracePacket(ctx, req.(*TracePacketRequest)) @@ -992,26 +1107,63 @@ func _DaemonService_TracePacket_Handler(srv interface{}, ctx context.Context, de return interceptor(ctx, in, info, handler) } +func _DaemonService_StartCapture_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(StartCaptureRequest) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(DaemonServiceServer).StartCapture(m, &grpc.GenericServerStream[StartCaptureRequest, CapturePacket]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type DaemonService_StartCaptureServer = grpc.ServerStreamingServer[CapturePacket] + +func _DaemonService_StartBundleCapture_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(StartBundleCaptureRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).StartBundleCapture(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: DaemonService_StartBundleCapture_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).StartBundleCapture(ctx, req.(*StartBundleCaptureRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_StopBundleCapture_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(StopBundleCaptureRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).StopBundleCapture(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: DaemonService_StopBundleCapture_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).StopBundleCapture(ctx, req.(*StopBundleCaptureRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _DaemonService_SubscribeEvents_Handler(srv interface{}, stream grpc.ServerStream) error { m := new(SubscribeRequest) if err := stream.RecvMsg(m); err != nil { return err } - return srv.(DaemonServiceServer).SubscribeEvents(m, &daemonServiceSubscribeEventsServer{stream}) + return srv.(DaemonServiceServer).SubscribeEvents(m, &grpc.GenericServerStream[SubscribeRequest, SystemEvent]{ServerStream: stream}) } -type DaemonService_SubscribeEventsServer interface { - Send(*SystemEvent) error - grpc.ServerStream -} - -type daemonServiceSubscribeEventsServer struct { - grpc.ServerStream -} - -func (x *daemonServiceSubscribeEventsServer) Send(m *SystemEvent) error { - return x.ServerStream.SendMsg(m) -} +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type DaemonService_SubscribeEventsServer = grpc.ServerStreamingServer[SystemEvent] func _DaemonService_GetEvents_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(GetEventsRequest) @@ -1023,7 +1175,7 @@ func _DaemonService_GetEvents_Handler(srv interface{}, ctx context.Context, dec } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/GetEvents", + FullMethod: DaemonService_GetEvents_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).GetEvents(ctx, req.(*GetEventsRequest)) @@ -1041,7 +1193,7 @@ func _DaemonService_SwitchProfile_Handler(srv interface{}, ctx context.Context, } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/SwitchProfile", + FullMethod: DaemonService_SwitchProfile_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).SwitchProfile(ctx, req.(*SwitchProfileRequest)) @@ -1059,7 +1211,7 @@ func _DaemonService_SetConfig_Handler(srv interface{}, ctx context.Context, dec } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/SetConfig", + FullMethod: DaemonService_SetConfig_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).SetConfig(ctx, req.(*SetConfigRequest)) @@ -1077,7 +1229,7 @@ func _DaemonService_AddProfile_Handler(srv interface{}, ctx context.Context, dec } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/AddProfile", + FullMethod: DaemonService_AddProfile_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).AddProfile(ctx, req.(*AddProfileRequest)) @@ -1095,7 +1247,7 @@ func _DaemonService_RemoveProfile_Handler(srv interface{}, ctx context.Context, } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/RemoveProfile", + FullMethod: DaemonService_RemoveProfile_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).RemoveProfile(ctx, req.(*RemoveProfileRequest)) @@ -1113,7 +1265,7 @@ func _DaemonService_ListProfiles_Handler(srv interface{}, ctx context.Context, d } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/ListProfiles", + FullMethod: DaemonService_ListProfiles_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).ListProfiles(ctx, req.(*ListProfilesRequest)) @@ -1131,7 +1283,7 @@ func _DaemonService_GetActiveProfile_Handler(srv interface{}, ctx context.Contex } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/GetActiveProfile", + FullMethod: DaemonService_GetActiveProfile_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).GetActiveProfile(ctx, req.(*GetActiveProfileRequest)) @@ -1149,7 +1301,7 @@ func _DaemonService_Logout_Handler(srv interface{}, ctx context.Context, dec fun } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/Logout", + FullMethod: DaemonService_Logout_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).Logout(ctx, req.(*LogoutRequest)) @@ -1167,7 +1319,7 @@ func _DaemonService_GetFeatures_Handler(srv interface{}, ctx context.Context, de } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/GetFeatures", + FullMethod: DaemonService_GetFeatures_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).GetFeatures(ctx, req.(*GetFeaturesRequest)) @@ -1185,7 +1337,7 @@ func _DaemonService_TriggerUpdate_Handler(srv interface{}, ctx context.Context, } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/TriggerUpdate", + FullMethod: DaemonService_TriggerUpdate_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).TriggerUpdate(ctx, req.(*TriggerUpdateRequest)) @@ -1203,7 +1355,7 @@ func _DaemonService_GetPeerSSHHostKey_Handler(srv interface{}, ctx context.Conte } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/GetPeerSSHHostKey", + FullMethod: DaemonService_GetPeerSSHHostKey_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).GetPeerSSHHostKey(ctx, req.(*GetPeerSSHHostKeyRequest)) @@ -1221,7 +1373,7 @@ func _DaemonService_RequestJWTAuth_Handler(srv interface{}, ctx context.Context, } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/RequestJWTAuth", + FullMethod: DaemonService_RequestJWTAuth_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).RequestJWTAuth(ctx, req.(*RequestJWTAuthRequest)) @@ -1239,7 +1391,7 @@ func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, d } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/WaitJWTToken", + FullMethod: DaemonService_WaitJWTToken_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).WaitJWTToken(ctx, req.(*WaitJWTTokenRequest)) @@ -1257,7 +1409,7 @@ func _DaemonService_StartCPUProfile_Handler(srv interface{}, ctx context.Context } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/StartCPUProfile", + FullMethod: DaemonService_StartCPUProfile_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).StartCPUProfile(ctx, req.(*StartCPUProfileRequest)) @@ -1275,7 +1427,7 @@ func _DaemonService_StopCPUProfile_Handler(srv interface{}, ctx context.Context, } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/StopCPUProfile", + FullMethod: DaemonService_StopCPUProfile_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).StopCPUProfile(ctx, req.(*StopCPUProfileRequest)) @@ -1283,24 +1435,6 @@ func _DaemonService_StopCPUProfile_Handler(srv interface{}, ctx context.Context, return interceptor(ctx, in, info, handler) } -func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(OSLifecycleRequest) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(DaemonServiceServer).NotifyOSLifecycle(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/daemon.DaemonService/NotifyOSLifecycle", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(DaemonServiceServer).NotifyOSLifecycle(ctx, req.(*OSLifecycleRequest)) - } - return interceptor(ctx, in, info, handler) -} - func _DaemonService_GetInstallerResult_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(InstallerResultRequest) if err := dec(in); err != nil { @@ -1311,7 +1445,7 @@ func _DaemonService_GetInstallerResult_Handler(srv interface{}, ctx context.Cont } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/GetInstallerResult", + FullMethod: DaemonService_GetInstallerResult_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).GetInstallerResult(ctx, req.(*InstallerResultRequest)) @@ -1324,21 +1458,11 @@ func _DaemonService_ExposeService_Handler(srv interface{}, stream grpc.ServerStr if err := stream.RecvMsg(m); err != nil { return err } - return srv.(DaemonServiceServer).ExposeService(m, &daemonServiceExposeServiceServer{stream}) + return srv.(DaemonServiceServer).ExposeService(m, &grpc.GenericServerStream[ExposeServiceRequest, ExposeServiceEvent]{ServerStream: stream}) } -type DaemonService_ExposeServiceServer interface { - Send(*ExposeServiceEvent) error - grpc.ServerStream -} - -type daemonServiceExposeServiceServer struct { - grpc.ServerStream -} - -func (x *daemonServiceExposeServiceServer) Send(m *ExposeServiceEvent) error { - return x.ServerStream.SendMsg(m) -} +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type DaemonService_ExposeServiceServer = grpc.ServerStreamingServer[ExposeServiceEvent] // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // It's only intended for direct use with grpc.RegisterService, @@ -1419,6 +1543,14 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "TracePacket", Handler: _DaemonService_TracePacket_Handler, }, + { + MethodName: "StartBundleCapture", + Handler: _DaemonService_StartBundleCapture_Handler, + }, + { + MethodName: "StopBundleCapture", + Handler: _DaemonService_StopBundleCapture_Handler, + }, { MethodName: "GetEvents", Handler: _DaemonService_GetEvents_Handler, @@ -1479,16 +1611,17 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "StopCPUProfile", Handler: _DaemonService_StopCPUProfile_Handler, }, - { - MethodName: "NotifyOSLifecycle", - Handler: _DaemonService_NotifyOSLifecycle_Handler, - }, { MethodName: "GetInstallerResult", Handler: _DaemonService_GetInstallerResult_Handler, }, }, Streams: []grpc.StreamDesc{ + { + StreamName: "StartCapture", + Handler: _DaemonService_StartCapture_Handler, + ServerStreams: true, + }, { StreamName: "SubscribeEvents", Handler: _DaemonService_SubscribeEvents_Handler, diff --git a/client/server/capture.go b/client/server/capture.go new file mode 100644 index 000000000..308c00338 --- /dev/null +++ b/client/server/capture.go @@ -0,0 +1,365 @@ +package server + +import ( + "context" + "io" + "os" + "sync" + "time" + + log "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/util/capture" +) + +const maxBundleCaptureDuration = 10 * time.Minute + +// bundleCapture holds the state of an in-progress capture destined for the +// debug bundle. The lifecycle is: +// +// StartBundleCapture → capture running, writing to temp file +// StopBundleCapture → capture stopped, temp file available +// DebugBundle → temp file included in zip, then cleaned up +type bundleCapture struct { + mu sync.Mutex + sess *capture.Session + file *os.File + engine *internal.Engine + cancel context.CancelFunc + stopped bool +} + +// stop halts the capture session and closes the pcap writer. Idempotent. +func (bc *bundleCapture) stop() { + bc.mu.Lock() + defer bc.mu.Unlock() + + if bc.stopped { + return + } + bc.stopped = true + + if bc.cancel != nil { + bc.cancel() + } + if bc.sess != nil { + bc.sess.Stop() + } +} + +// path returns the temp file path, or "" if no file exists. +func (bc *bundleCapture) path() string { + if bc.file == nil { + return "" + } + return bc.file.Name() +} + +// cleanup removes the temp file. +func (bc *bundleCapture) cleanup() { + if bc.file == nil { + return + } + name := bc.file.Name() + if err := bc.file.Close(); err != nil { + log.Debugf("close bundle capture file: %v", err) + } + if err := os.Remove(name); err != nil && !os.IsNotExist(err) { + log.Debugf("remove bundle capture file: %v", err) + } + bc.file = nil +} + +// StartCapture streams a pcap or text packet capture over gRPC. +// Gated by the --enable-capture service flag. +func (s *Server) StartCapture(req *proto.StartCaptureRequest, stream proto.DaemonService_StartCaptureServer) error { + if !s.captureEnabled { + return status.Error(codes.PermissionDenied, + "packet capture is disabled; reinstall or reconfigure the service with --enable-capture") + } + + if d := req.GetDuration(); d != nil && d.AsDuration() < 0 { + return status.Error(codes.InvalidArgument, "duration must not be negative") + } + + matcher, err := parseCaptureFilter(req) + if err != nil { + return status.Errorf(codes.InvalidArgument, "invalid filter: %v", err) + } + + pr, pw := io.Pipe() + + opts := capture.Options{ + Matcher: matcher, + SnapLen: req.GetSnapLen(), + Verbose: req.GetVerbose(), + ASCII: req.GetAscii(), + } + if req.GetTextOutput() { + opts.TextOutput = pw + } else { + opts.Output = pw + } + + sess, err := capture.NewSession(opts) + if err != nil { + pw.Close() + return status.Errorf(codes.Internal, "create capture session: %v", err) + } + + engine, err := s.claimCapture(sess) + if err != nil { + sess.Stop() + pw.Close() + return err + } + + if err := engine.SetCapture(sess); err != nil { + s.releaseCapture(sess) + sess.Stop() + pw.Close() + return status.Errorf(codes.Internal, "set capture: %v", err) + } + + // Send an empty initial message to signal that the capture was accepted. + // The client waits for this before printing the banner, so it must arrive + // before any packet data. + if err := stream.Send(&proto.CapturePacket{}); err != nil { + s.clearCaptureIfOwner(sess, engine) + sess.Stop() + pw.Close() + return status.Errorf(codes.Internal, "send initial message: %v", err) + } + + ctx := stream.Context() + if d := req.GetDuration(); d != nil { + if dur := d.AsDuration(); dur > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, dur) + defer cancel() + } + } + + go func() { + <-ctx.Done() + s.clearCaptureIfOwner(sess, engine) + sess.Stop() + pw.Close() + }() + defer pr.Close() + + log.Infof("packet capture started (text=%v, expr=%q)", req.GetTextOutput(), req.GetFilterExpr()) + defer func() { + stats := sess.Stats() + log.Infof("packet capture stopped: %d packets, %d bytes, %d dropped", + stats.Packets, stats.Bytes, stats.Dropped) + }() + + return streamToGRPC(pr, stream) +} + +func streamToGRPC(r io.Reader, stream proto.DaemonService_StartCaptureServer) error { + buf := make([]byte, 32*1024) + for { + n, readErr := r.Read(buf) + if n > 0 { + if err := stream.Send(&proto.CapturePacket{Data: buf[:n]}); err != nil { + log.Debugf("capture stream send: %v", err) + return nil //nolint:nilerr // client disconnected + } + } + if readErr != nil { + return nil //nolint:nilerr // pipe closed, capture stopped normally + } + } +} + +// StartBundleCapture begins capturing packets to a server-side temp file for +// inclusion in the next debug bundle. Not gated by --enable-capture since the +// output stays on the server (same trust level as CPU profiling). +// +// A timeout auto-stops the capture as a safety net if StopBundleCapture is +// never called (e.g. CLI crash). +func (s *Server) StartBundleCapture(_ context.Context, req *proto.StartBundleCaptureRequest) (*proto.StartBundleCaptureResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + s.stopBundleCaptureLocked() + s.cleanupBundleCapture() + + if s.activeCapture != nil { + return nil, status.Error(codes.FailedPrecondition, "another capture is already running") + } + + engine, err := s.getCaptureEngineLocked() + if err != nil { + // Not fatal: kernel mode or not connected. Log and return success + // so the debug bundle still generates without capture data. + log.Warnf("packet capture unavailable, skipping: %v", err) + return &proto.StartBundleCaptureResponse{}, nil + } + + timeout := req.GetTimeout().AsDuration() + if timeout <= 0 || timeout > maxBundleCaptureDuration { + timeout = maxBundleCaptureDuration + } + + f, err := os.CreateTemp("", "netbird.capture.*.pcap") + if err != nil { + return nil, status.Errorf(codes.Internal, "create temp file: %v", err) + } + + sess, err := capture.NewSession(capture.Options{Output: f}) + if err != nil { + f.Close() + os.Remove(f.Name()) + return nil, status.Errorf(codes.Internal, "create capture session: %v", err) + } + + if err := engine.SetCapture(sess); err != nil { + sess.Stop() + f.Close() + os.Remove(f.Name()) + log.Warnf("packet capture unavailable (no filtered device), skipping: %v", err) + return &proto.StartBundleCaptureResponse{}, nil + } + s.activeCapture = sess + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + bc := &bundleCapture{ + sess: sess, + file: f, + engine: engine, + cancel: cancel, + } + + s.bundleCapture = bc + + go func() { + <-ctx.Done() + s.mutex.Lock() + if s.bundleCapture == bc { + s.stopBundleCaptureLocked() + } else { + bc.stop() + } + s.mutex.Unlock() + log.Infof("bundle capture auto-stopped after timeout") + }() + log.Infof("bundle capture started (timeout=%s, file=%s)", timeout, f.Name()) + + return &proto.StartBundleCaptureResponse{}, nil +} + +// StopBundleCapture stops the running bundle capture. Idempotent. +func (s *Server) StopBundleCapture(_ context.Context, _ *proto.StopBundleCaptureRequest) (*proto.StopBundleCaptureResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + s.stopBundleCaptureLocked() + return &proto.StopBundleCaptureResponse{}, nil +} + +// stopBundleCaptureLocked stops the bundle capture if running. Must hold s.mutex. +func (s *Server) stopBundleCaptureLocked() { + if s.bundleCapture == nil { + return + } + bc := s.bundleCapture + if bc.engine != nil && s.activeCapture == bc.sess { + if err := bc.engine.SetCapture(nil); err != nil { + log.Debugf("clear bundle capture: %v", err) + } + s.activeCapture = nil + } + bc.stop() + + stats := bc.sess.Stats() + log.Infof("bundle capture stopped: %d packets, %d bytes, %d dropped", + stats.Packets, stats.Bytes, stats.Dropped) +} + +// bundleCapturePath returns the temp file path if a capture has been taken, +// stops any running capture, and returns "". Called from DebugBundle. +// Must hold s.mutex. +func (s *Server) bundleCapturePath() string { + if s.bundleCapture == nil { + return "" + } + + s.bundleCapture.stop() + return s.bundleCapture.path() +} + +// cleanupBundleCapture removes the temp file and clears state. Must hold s.mutex. +func (s *Server) cleanupBundleCapture() { + if s.bundleCapture == nil { + return + } + s.bundleCapture.cleanup() + s.bundleCapture = nil +} + +// claimCapture reserves the engine's capture slot for sess. Returns +// FailedPrecondition if another capture is already active. +func (s *Server) claimCapture(sess *capture.Session) (*internal.Engine, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.activeCapture != nil { + return nil, status.Error(codes.FailedPrecondition, "another capture is already running") + } + engine, err := s.getCaptureEngineLocked() + if err != nil { + return nil, err + } + s.activeCapture = sess + return engine, nil +} + +// releaseCapture clears the active-capture owner if it still matches sess. +func (s *Server) releaseCapture(sess *capture.Session) { + s.mutex.Lock() + defer s.mutex.Unlock() + if s.activeCapture == sess { + s.activeCapture = nil + } +} + +// clearCaptureIfOwner clears engine's capture slot only if sess still owns it. +func (s *Server) clearCaptureIfOwner(sess *capture.Session, engine *internal.Engine) { + s.mutex.Lock() + defer s.mutex.Unlock() + if s.activeCapture != sess { + return + } + if err := engine.SetCapture(nil); err != nil { + log.Debugf("clear capture: %v", err) + } + s.activeCapture = nil +} + +func (s *Server) getCaptureEngineLocked() (*internal.Engine, error) { + if s.connectClient == nil { + return nil, status.Error(codes.FailedPrecondition, "client not connected") + } + engine := s.connectClient.Engine() + if engine == nil { + return nil, status.Error(codes.FailedPrecondition, "engine not initialized") + } + return engine, nil +} + +// parseCaptureFilter returns a Matcher from the request. +// Returns nil (match all) when no filter expression is set. +func parseCaptureFilter(req *proto.StartCaptureRequest) (capture.Matcher, error) { + expr := req.GetFilterExpr() + if expr == "" { + return nil, nil //nolint:nilnil // nil Matcher means "match all" + } + return capture.ParseFilter(expr) +} diff --git a/client/server/debug.go b/client/server/debug.go index 81708e576..33247db5f 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -43,7 +43,9 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( }() } - // Prepare refresh callback for health probes + capturePath := s.bundleCapturePath() + defer s.cleanupBundleCapture() + var refreshStatus func() if s.connectClient != nil { engine := s.connectClient.Engine() @@ -62,6 +64,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( SyncResponse: syncResponse, LogPath: s.logFile, CPUProfile: cpuProfileData, + CapturePath: capturePath, RefreshStatus: refreshStatus, ClientMetrics: clientMetrics, }, diff --git a/client/server/network.go b/client/server/network.go index bb1cce56c..76c5af40e 100644 --- a/client/server/network.go +++ b/client/server/network.go @@ -9,6 +9,8 @@ import ( "strings" "golang.org/x/exp/maps" + "google.golang.org/grpc/codes" + gstatus "google.golang.org/grpc/status" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/route" @@ -27,6 +29,10 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro s.mutex.Lock() defer s.mutex.Unlock() + if s.networksDisabled { + return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled) + } + if s.connectClient == nil { return nil, fmt.Errorf("not connected") } @@ -118,6 +124,10 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ s.mutex.Lock() defer s.mutex.Unlock() + if s.networksDisabled { + return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled) + } + if s.connectClient == nil { return nil, fmt.Errorf("not connected") } @@ -164,6 +174,10 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe s.mutex.Lock() defer s.mutex.Unlock() + if s.networksDisabled { + return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled) + } + if s.connectClient == nil { return nil, fmt.Errorf("not connected") } diff --git a/client/server/server.go b/client/server/server.go index 7c1e70692..648ffa8ce 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -33,6 +33,7 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/updater" "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/util/capture" "github.com/netbirdio/netbird/version" ) @@ -53,6 +54,7 @@ const ( errRestoreResidualState = "failed to restore residual state: %v" errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled" errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled" + errNetworksDisabled = "network selection is disabled by the administrator" ) var ErrServiceNotUp = errors.New("service is not up") @@ -88,6 +90,11 @@ type Server struct { profileManager *profilemanager.ServiceManager profilesDisabled bool updateSettingsDisabled bool + captureEnabled bool + bundleCapture *bundleCapture + // activeCapture is the session currently installed on the engine; guarded by s.mutex. + activeCapture *capture.Session + networksDisabled bool sleepHandler *sleephandler.SleepHandler @@ -104,7 +111,7 @@ type oauthAuthFlow struct { } // New server instance constructor. -func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool) *Server { +func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool, captureEnabled bool, networksDisabled bool) *Server { s := &Server{ rootCtx: ctx, logFile: logFile, @@ -113,10 +120,13 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable profileManager: profilemanager.NewServiceManager(configFile), profilesDisabled: profilesDisabled, updateSettingsDisabled: updateSettingsDisabled, + captureEnabled: captureEnabled, + networksDisabled: networksDisabled, jwtCache: newJWTCache(), } agent := &serverAgent{s} s.sleepHandler = sleephandler.New(agent) + s.startSleepDetector() return s } @@ -1359,6 +1369,10 @@ func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.Daemon return gstatus.Errorf(codes.FailedPrecondition, "engine not initialized") } + if engine.IsBlockInbound() { + return gstatus.Errorf(codes.FailedPrecondition, "expose requires inbound connections but 'block inbound' is enabled, disable it first") + } + mgr := engine.GetExposeManager() if mgr == nil { return gstatus.Errorf(codes.Internal, "expose manager not available") @@ -1624,6 +1638,7 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) features := &proto.GetFeaturesResponse{ DisableProfiles: s.checkProfilesDisabled(), DisableUpdateSettings: s.checkUpdateSettingsDisabled(), + DisableNetworks: s.networksDisabled, } return features, nil diff --git a/client/server/server_test.go b/client/server/server_test.go index 6de23d501..641cd85fe 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -36,6 +36,7 @@ import ( daemonProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" @@ -103,7 +104,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "debug", "", false, false) + s := New(ctx, "debug", "", false, false, false, false) s.config = config @@ -164,7 +165,7 @@ func TestServer_Up(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "console", "", false, false) + s := New(ctx, "console", "", false, false, false, false) err = s.Start() require.NoError(t, err) @@ -234,7 +235,7 @@ func TestServer_SubcribeEvents(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "console", "", false, false) + s := New(ctx, "console", "", false, false, false, false) err = s.Start() require.NoError(t, err) @@ -309,7 +310,12 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve jobManager := job.NewJobManager(nil, store, peersManager) - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore) + cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, "", err + } + + ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) @@ -320,7 +326,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) peersUpdateManager := update_channel.NewPeersUpdateManager(metrics) networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config) - accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore) if err != nil { return nil, "", err } @@ -329,7 +335,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve if err != nil { return nil, "", err } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil) if err != nil { return nil, "", err } diff --git a/client/server/setconfig_test.go b/client/server/setconfig_test.go index 8e360175d..b90b5653d 100644 --- a/client/server/setconfig_test.go +++ b/client/server/setconfig_test.go @@ -53,7 +53,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) { require.NoError(t, err) ctx := context.Background() - s := New(ctx, "console", "", false, false) + s := New(ctx, "console", "", false, false, false, false) rosenpassEnabled := true rosenpassPermissive := true diff --git a/client/server/sleep.go b/client/server/sleep.go index 7a83c75a6..877ad9690 100644 --- a/client/server/sleep.go +++ b/client/server/sleep.go @@ -2,13 +2,18 @@ package server import ( "context" + "os" + "strconv" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/sleep" "github.com/netbirdio/netbird/client/proto" ) +const envDisableSleepDetector = "NB_DISABLE_SLEEP_DETECTOR" + // serverAgent adapts Server to the handler.Agent and handler.StatusChecker interfaces type serverAgent struct { s *Server @@ -28,19 +33,61 @@ func (a *serverAgent) Status() (internal.StatusType, error) { return internal.CtxGetState(a.s.rootCtx).Status() } -// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type. -func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) { - switch req.GetType() { - case proto.OSLifecycleRequest_WAKEUP: - if err := s.sleepHandler.HandleWakeUp(callerCtx); err != nil { - return &proto.OSLifecycleResponse{}, err - } - case proto.OSLifecycleRequest_SLEEP: - if err := s.sleepHandler.HandleSleep(callerCtx); err != nil { - return &proto.OSLifecycleResponse{}, err - } - default: - log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType()) +// startSleepDetector starts the OS sleep/wake detector and forwards events to +// the sleep handler. On platforms without a supported detector the attempt +// logs a warning and returns. Setting NB_DISABLE_SLEEP_DETECTOR=true skips +// registration entirely. +func (s *Server) startSleepDetector() { + if sleepDetectorDisabled() { + log.Info("sleep detection disabled via " + envDisableSleepDetector) + return } - return &proto.OSLifecycleResponse{}, nil + + svc, err := sleep.New() + if err != nil { + log.Warnf("failed to initialize sleep detection: %v", err) + return + } + + err = svc.Register(func(event sleep.EventType) { + switch event { + case sleep.EventTypeSleep: + log.Info("handling sleep event") + if err := s.sleepHandler.HandleSleep(s.rootCtx); err != nil { + log.Errorf("failed to handle sleep event: %v", err) + } + case sleep.EventTypeWakeUp: + log.Info("handling wakeup event") + if err := s.sleepHandler.HandleWakeUp(s.rootCtx); err != nil { + log.Errorf("failed to handle wakeup event: %v", err) + } + } + }) + if err != nil { + log.Errorf("failed to register sleep detector: %v", err) + return + } + + log.Info("sleep detection service initialized") + + go func() { + <-s.rootCtx.Done() + log.Info("stopping sleep event listener") + if err := svc.Deregister(); err != nil { + log.Errorf("failed to deregister sleep detector: %v", err) + } + }() +} + +func sleepDetectorDisabled() bool { + val := os.Getenv(envDisableSleepDetector) + if val == "" { + return false + } + disabled, err := strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s=%q: %v", envDisableSleepDetector, val, err) + return false + } + return disabled } diff --git a/client/server/state.go b/client/server/state.go index 8dca6bde1..f2d823465 100644 --- a/client/server/state.go +++ b/client/server/state.go @@ -12,7 +12,6 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/client/proto" ) @@ -138,10 +137,8 @@ func restoreResidualState(ctx context.Context, statePath string) error { } // clean up any remaining routes independently of the state file - if !nbnet.AdvancedRouting() { - if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil { - merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err)) - } + if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err)) } return nberrors.FormatErrorOrNil(merr) diff --git a/client/server/state_generic.go b/client/server/state_generic.go index 980ba0cda..86475ca42 100644 --- a/client/server/state_generic.go +++ b/client/server/state_generic.go @@ -9,6 +9,7 @@ import ( "github.com/netbirdio/netbird/client/ssh/config" ) +// registerStates registers all states that need crash recovery cleanup. func registerStates(mgr *statemanager.Manager) { mgr.RegisterState(&dns.ShutdownState{}) mgr.RegisterState(&systemops.ShutdownState{}) diff --git a/client/server/state_linux.go b/client/server/state_linux.go index 019477d8e..b193d4dfa 100644 --- a/client/server/state_linux.go +++ b/client/server/state_linux.go @@ -11,6 +11,7 @@ import ( "github.com/netbirdio/netbird/client/ssh/config" ) +// registerStates registers all states that need crash recovery cleanup. func registerStates(mgr *statemanager.Manager) { mgr.RegisterState(&dns.ShutdownState{}) mgr.RegisterState(&systemops.ShutdownState{}) diff --git a/client/ssh/config/manager.go b/client/ssh/config/manager.go index cc47fd2d2..6e584b2c3 100644 --- a/client/ssh/config/manager.go +++ b/client/ssh/config/manager.go @@ -187,24 +187,23 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) { return "", fmt.Errorf("get NetBird executable path: %w", err) } - hostLine := strings.Join(deduplicatedPatterns, " ") - config := fmt.Sprintf("Host %s\n", hostLine) - config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath) - config += " PreferredAuthentications password,publickey,keyboard-interactive\n" - config += " PasswordAuthentication yes\n" - config += " PubkeyAuthentication yes\n" - config += " BatchMode no\n" - config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath) - config += " StrictHostKeyChecking no\n" + hostList := strings.Join(deduplicatedPatterns, ",") + config := fmt.Sprintf("Match host \"%s\" exec \"%s ssh detect %%h %%p\"\n", hostList, execPath) + config += " PreferredAuthentications password,publickey,keyboard-interactive\n" + config += " PasswordAuthentication yes\n" + config += " PubkeyAuthentication yes\n" + config += " BatchMode no\n" + config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath) + config += " StrictHostKeyChecking no\n" if runtime.GOOS == "windows" { - config += " UserKnownHostsFile NUL\n" + config += " UserKnownHostsFile NUL\n" } else { - config += " UserKnownHostsFile /dev/null\n" + config += " UserKnownHostsFile /dev/null\n" } - config += " CheckHostIP no\n" - config += " LogLevel ERROR\n\n" + config += " CheckHostIP no\n" + config += " LogLevel ERROR\n\n" return config, nil } diff --git a/client/ssh/config/manager_test.go b/client/ssh/config/manager_test.go index dc3ad95b3..e7380c7f2 100644 --- a/client/ssh/config/manager_test.go +++ b/client/ssh/config/manager_test.go @@ -116,6 +116,37 @@ func TestManager_PeerLimit(t *testing.T) { assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers") } +func TestManager_MatchHostFormat(t *testing.T) { + tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test") + require.NoError(t, err) + defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }() + + manager := &Manager{ + sshConfigDir: filepath.Join(tempDir, "ssh_config.d"), + sshConfigFile: "99-netbird.conf", + } + + peers := []PeerSSHInfo{ + {Hostname: "peer1", IP: "100.125.1.1", FQDN: "peer1.nb.internal"}, + {Hostname: "peer2", IP: "100.125.1.2", FQDN: "peer2.nb.internal"}, + } + + err = manager.SetupSSHClientConfig(peers) + require.NoError(t, err) + + configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile) + content, err := os.ReadFile(configPath) + require.NoError(t, err) + configStr := string(content) + + // Must use "Match host" with comma-separated patterns, not a bare "Host" directive. + // A bare "Host" followed by "Match exec" is incorrect per ssh_config(5): the Host block + // ends at the next Match keyword, making it a no-op and leaving the Match exec unscoped. + assert.NotContains(t, configStr, "\nHost ", "should not use bare Host directive") + assert.Contains(t, configStr, "Match host \"100.125.1.1,peer1.nb.internal,peer1,100.125.1.2,peer2.nb.internal,peer2\"", + "should use Match host with comma-separated patterns") +} + func TestManager_ForcedSSHConfig(t *testing.T) { // Set force environment variable t.Setenv(EnvForceSSHConfig, "true") diff --git a/client/ssh/proxy/proxy.go b/client/ssh/proxy/proxy.go index 8897b9c7e..59007f75c 100644 --- a/client/ssh/proxy/proxy.go +++ b/client/ssh/proxy/proxy.go @@ -141,7 +141,7 @@ func (p *SSHProxy) runProxySSHServer(jwtToken string) error { func (p *SSHProxy) handleSSHSession(session ssh.Session) { ptyReq, winCh, isPty := session.Pty() - hasCommand := len(session.Command()) > 0 + hasCommand := session.RawCommand() != "" sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User()) if err != nil { @@ -180,7 +180,7 @@ func (p *SSHProxy) handleSSHSession(session ssh.Session) { } if hasCommand { - if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil { + if err := serverSession.Run(session.RawCommand()); err != nil { log.Debugf("run command: %v", err) p.handleProxyExitCode(session, err) } diff --git a/client/ssh/proxy/proxy_test.go b/client/ssh/proxy/proxy_test.go index dba2e88da..b33d5f8f4 100644 --- a/client/ssh/proxy/proxy_test.go +++ b/client/ssh/proxy/proxy_test.go @@ -1,6 +1,7 @@ package proxy import ( + "bytes" "context" "crypto/rand" "crypto/rsa" @@ -245,6 +246,191 @@ func TestSSHProxy_Connect(t *testing.T) { cancel() } +// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting +// when forwarding commands to the backend. This is critical for tools like +// Ansible that send commands such as: +// +// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0' +// +// The single quotes must be preserved so the backend shell receives the +// subshell expression as a single argument to -c. +func TestSSHProxy_CommandQuoting(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + sshClient, cleanup := setupProxySSHClient(t) + defer cleanup() + + // These commands simulate what the SSH protocol delivers as exec payloads. + // When a user types: ssh host '/bin/sh -c "( echo hello )"' + // the local shell strips the outer single quotes, and the SSH exec request + // contains the raw string: /bin/sh -c "( echo hello )" + // + // The proxy must forward this string verbatim. Using session.Command() + // (shlex.Split + strings.Join) strips the inner double quotes, breaking + // the command on the backend. + tests := []struct { + name string + command string + expect string + }{ + { + name: "subshell_in_double_quotes", + command: `/bin/sh -c "( echo from-subshell ) && echo outer"`, + expect: "from-subshell\nouter\n", + }, + { + name: "printf_with_special_chars", + command: `/bin/sh -c "printf '%s\n' 'hello world'"`, + expect: "hello world\n", + }, + { + name: "nested_command_substitution", + command: `/bin/sh -c "echo $(echo nested)"`, + expect: "nested\n", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + session, err := sshClient.NewSession() + require.NoError(t, err) + defer func() { _ = session.Close() }() + + var stderrBuf bytes.Buffer + session.Stderr = &stderrBuf + + outputCh := make(chan []byte, 1) + errCh := make(chan error, 1) + go func() { + output, err := session.Output(tc.command) + outputCh <- output + errCh <- err + }() + + select { + case output := <-outputCh: + err := <-errCh + if stderrBuf.Len() > 0 { + t.Logf("stderr: %s", stderrBuf.String()) + } + require.NoError(t, err, "command should succeed: %s", tc.command) + assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command) + case <-time.After(5 * time.Second): + t.Fatalf("command timed out: %s", tc.command) + } + }) + } +} + +// setupProxySSHClient creates a full proxy test environment and returns +// an SSH client connected through the proxy to a backend NetBird SSH server. +func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) { + t.Helper() + + const ( + issuer = "https://test-issuer.example.com" + audience = "test-audience" + ) + + jwksServer, privateKey, jwksURL := setupJWKSServer(t) + + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + hostPubKey, err := nbssh.GeneratePublicKey(hostKey) + require.NoError(t, err) + + serverConfig := &server.Config{ + HostKeyPEM: hostKey, + JWT: &server.JWTConfig{ + Issuer: issuer, + Audiences: []string{audience}, + KeysLocation: jwksURL, + }, + } + sshServer := server.New(serverConfig) + sshServer.SetAllowRootLogin(true) + + testUsername := testutil.GetTestUsername(t) + testJWTUser := "test-username" + testUserHash, err := sshuserhash.HashUserID(testJWTUser) + require.NoError(t, err) + + authConfig := &sshauth.Config{ + UserIDClaim: sshauth.DefaultUserIDClaim, + AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash}, + MachineUsers: map[string][]uint32{ + testUsername: {0}, + }, + } + sshServer.UpdateSSHAuth(authConfig) + + sshServerAddr := server.StartTestServer(t, sshServer) + + mockDaemon := startMockDaemon(t) + + host, portStr, err := net.SplitHostPort(sshServerAddr) + require.NoError(t, err) + port, err := strconv.Atoi(portStr) + require.NoError(t, err) + + mockDaemon.setHostKey(host, hostPubKey) + + validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser) + mockDaemon.setJWTToken(validToken) + + proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil) + require.NoError(t, err) + + origStdin := os.Stdin + origStdout := os.Stdout + + stdinReader, stdinWriter, err := os.Pipe() + require.NoError(t, err) + stdoutReader, stdoutWriter, err := os.Pipe() + require.NoError(t, err) + + os.Stdin = stdinReader + os.Stdout = stdoutWriter + + clientConn, proxyConn := net.Pipe() + + go func() { _, _ = io.Copy(stdinWriter, proxyConn) }() + go func() { _, _ = io.Copy(proxyConn, stdoutReader) }() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + + go func() { + _ = proxyInstance.Connect(ctx) + }() + + sshConfig := &cryptossh.ClientConfig{ + User: testutil.GetTestUsername(t), + Auth: []cryptossh.AuthMethod{}, + HostKeyCallback: cryptossh.InsecureIgnoreHostKey(), + Timeout: 5 * time.Second, + } + + sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig) + require.NoError(t, err) + + client := cryptossh.NewClient(sshClientConn, chans, reqs) + + cleanupFn := func() { + _ = client.Close() + _ = clientConn.Close() + cancel() + os.Stdin = origStdin + os.Stdout = origStdout + _ = sshServer.Stop() + mockDaemon.stop() + jwksServer.Close() + } + + return client, cleanupFn +} + type mockDaemonServer struct { proto.UnimplementedDaemonServiceServer hostKeys map[string][]byte diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go index 4431ae423..82d3b700f 100644 --- a/client/ssh/server/server.go +++ b/client/ssh/server/server.go @@ -284,19 +284,21 @@ func (s *Server) closeListener(ln net.Listener) { // Stop closes the SSH server func (s *Server) Stop() error { s.mu.Lock() - defer s.mu.Unlock() - - if s.sshServer == nil { + sshServer := s.sshServer + if sshServer == nil { + s.mu.Unlock() return nil } + s.sshServer = nil + s.listener = nil + s.mu.Unlock() - if err := s.sshServer.Close(); err != nil { + // Close outside the lock: session handlers need s.mu for unregisterSession. + if err := sshServer.Close(); err != nil { log.Debugf("close SSH server: %v", err) } - s.sshServer = nil - s.listener = nil - + s.mu.Lock() maps.Clear(s.sessions) maps.Clear(s.pendingAuthJWT) maps.Clear(s.connections) @@ -307,6 +309,7 @@ func (s *Server) Stop() error { } } maps.Clear(s.remoteForwardListeners) + s.mu.Unlock() return nil } diff --git a/client/ssh/server/session_handlers.go b/client/ssh/server/session_handlers.go index f12a75961..0e531bb96 100644 --- a/client/ssh/server/session_handlers.go +++ b/client/ssh/server/session_handlers.go @@ -60,7 +60,7 @@ func (s *Server) sessionHandler(session ssh.Session) { } ptyReq, winCh, isPty := session.Pty() - hasCommand := len(session.Command()) > 0 + hasCommand := session.RawCommand() != "" if isPty && !hasCommand { // ssh - PTY interactive session (login) diff --git a/client/system/info.go b/client/system/info.go index 01176e765..175d1f07f 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -2,7 +2,6 @@ package system import ( "context" - "net" "net/netip" "strings" @@ -145,56 +144,6 @@ func extractDeviceName(ctx context.Context, defaultName string) string { return v } -func networkAddresses() ([]NetworkAddress, error) { - interfaces, err := net.Interfaces() - if err != nil { - return nil, err - } - - var netAddresses []NetworkAddress - for _, iface := range interfaces { - if iface.HardwareAddr.String() == "" { - continue - } - addrs, err := iface.Addrs() - if err != nil { - continue - } - - for _, address := range addrs { - ipNet, ok := address.(*net.IPNet) - if !ok { - continue - } - - if ipNet.IP.IsLoopback() { - continue - } - - netAddr := NetworkAddress{ - NetIP: netip.MustParsePrefix(ipNet.String()), - Mac: iface.HardwareAddr.String(), - } - - if isDuplicated(netAddresses, netAddr) { - continue - } - - netAddresses = append(netAddresses, netAddr) - } - } - return netAddresses, nil -} - -func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool { - for _, duplicated := range addresses { - if duplicated.NetIP == addr.NetIP { - return true - } - } - return false -} - // GetInfoWithChecks retrieves and parses the system information with applied checks. func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) { log.Debugf("gathering system information with checks: %d", len(checks)) diff --git a/client/system/info_freebsd.go b/client/system/info_freebsd.go index 8e1353151..755172842 100644 --- a/client/system/info_freebsd.go +++ b/client/system/info_freebsd.go @@ -43,18 +43,24 @@ func GetInfo(ctx context.Context) *Info { systemHostname, _ := os.Hostname() + addrs, err := networkAddresses() + if err != nil { + log.Warnf("failed to discover network addresses: %s", err) + } + return &Info{ - GoOS: runtime.GOOS, - Kernel: osInfo[0], - Platform: runtime.GOARCH, - OS: osName, - OSVersion: osVersion, - Hostname: extractDeviceName(ctx, systemHostname), - CPUs: runtime.NumCPU(), - NetbirdVersion: version.NetbirdVersion(), - UIVersion: extractUserAgent(ctx), - KernelVersion: osInfo[1], - Environment: env, + GoOS: runtime.GOOS, + Kernel: osInfo[0], + Platform: runtime.GOARCH, + OS: osName, + OSVersion: osVersion, + Hostname: extractDeviceName(ctx, systemHostname), + CPUs: runtime.NumCPU(), + NetbirdVersion: version.NetbirdVersion(), + UIVersion: extractUserAgent(ctx), + KernelVersion: osInfo[1], + NetworkAddresses: addrs, + Environment: env, } } diff --git a/client/system/info_ios.go b/client/system/info_ios.go index 322609db4..ad42b1edf 100644 --- a/client/system/info_ios.go +++ b/client/system/info_ios.go @@ -2,12 +2,16 @@ package system import ( "context" + "net" + "net/netip" "runtime" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/version" ) -// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update +// UpdateStaticInfoAsync is a no-op on iOS as there is no static info to update func UpdateStaticInfoAsync() { // do nothing } @@ -15,11 +19,24 @@ func UpdateStaticInfoAsync() { // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { - // Convert fixed-size byte arrays to Go strings sysName := extractOsName(ctx, "sysName") swVersion := extractOsVersion(ctx, "swVersion") - gio := &Info{Kernel: sysName, OSVersion: swVersion, Platform: "unknown", OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: swVersion} + addrs, err := networkAddresses() + if err != nil { + log.Warnf("failed to discover network addresses: %s", err) + } + + gio := &Info{ + Kernel: sysName, + OSVersion: swVersion, + Platform: "unknown", + OS: sysName, + GoOS: runtime.GOOS, + CPUs: runtime.NumCPU(), + KernelVersion: swVersion, + NetworkAddresses: addrs, + } gio.Hostname = extractDeviceName(ctx, "hostname") gio.NetbirdVersion = version.NetbirdVersion() gio.UIVersion = extractUserAgent(ctx) @@ -27,6 +44,66 @@ func GetInfo(ctx context.Context) *Info { return gio } +// networkAddresses returns the list of network addresses on iOS. +// On iOS, hardware (MAC) addresses are not available due to Apple's privacy +// restrictions (iOS returns a fixed 02:00:00:00:00:00 placeholder), so we +// leave Mac empty to match Android's behavior. We also skip the HardwareAddr +// check that other platforms use and filter out link-local addresses as they +// are not useful for posture checks. +func networkAddresses() ([]NetworkAddress, error) { + interfaces, err := net.Interfaces() + if err != nil { + return nil, err + } + + var netAddresses []NetworkAddress + for _, iface := range interfaces { + if iface.Flags&net.FlagUp == 0 { + continue + } + addrs, err := iface.Addrs() + if err != nil { + continue + } + + for _, address := range addrs { + netAddr, ok := toNetworkAddress(address) + if !ok { + continue + } + if isDuplicated(netAddresses, netAddr) { + continue + } + netAddresses = append(netAddresses, netAddr) + } + } + return netAddresses, nil +} + +func toNetworkAddress(address net.Addr) (NetworkAddress, bool) { + ipNet, ok := address.(*net.IPNet) + if !ok { + return NetworkAddress{}, false + } + if ipNet.IP.IsLoopback() || ipNet.IP.IsLinkLocalUnicast() || ipNet.IP.IsMulticast() { + return NetworkAddress{}, false + } + prefix, err := netip.ParsePrefix(ipNet.String()) + if err != nil { + return NetworkAddress{}, false + } + return NetworkAddress{NetIP: prefix, Mac: ""}, true +} + +func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool { + for _, duplicated := range addresses { + if duplicated.NetIP == addr.NetIP { + return true + } + } + return false +} + // checkFileAndProcess checks if the file path exists and if a process is running at that path. func checkFileAndProcess(paths []string) ([]File, error) { return []File{}, nil diff --git a/client/system/network_addr.go b/client/system/network_addr.go new file mode 100644 index 000000000..5423cf8ad --- /dev/null +++ b/client/system/network_addr.go @@ -0,0 +1,66 @@ +//go:build !ios + +package system + +import ( + "net" + "net/netip" +) + +func networkAddresses() ([]NetworkAddress, error) { + interfaces, err := net.Interfaces() + if err != nil { + return nil, err + } + + var netAddresses []NetworkAddress + for _, iface := range interfaces { + if iface.Flags&net.FlagUp == 0 { + continue + } + if iface.HardwareAddr.String() == "" { + continue + } + addrs, err := iface.Addrs() + if err != nil { + continue + } + + mac := iface.HardwareAddr.String() + for _, address := range addrs { + netAddr, ok := toNetworkAddress(address, mac) + if !ok { + continue + } + if isDuplicated(netAddresses, netAddr) { + continue + } + netAddresses = append(netAddresses, netAddr) + } + } + return netAddresses, nil +} + +func toNetworkAddress(address net.Addr, mac string) (NetworkAddress, bool) { + ipNet, ok := address.(*net.IPNet) + if !ok { + return NetworkAddress{}, false + } + if ipNet.IP.IsLoopback() { + return NetworkAddress{}, false + } + prefix, err := netip.ParsePrefix(ipNet.String()) + if err != nil { + return NetworkAddress{}, false + } + return NetworkAddress{NetIP: prefix, Mac: mac}, true +} + +func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool { + for _, duplicated := range addresses { + if duplicated.NetIP == addr.NetIP { + return true + } + } + return false +} diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 0574e53d0..28f98ae59 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -38,10 +38,10 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/profilemanager" - "github.com/netbirdio/netbird/client/internal/sleep" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/ui/desktop" "github.com/netbirdio/netbird/client/ui/event" + "github.com/netbirdio/netbird/client/ui/notifier" "github.com/netbirdio/netbird/client/ui/process" "github.com/netbirdio/netbird/util" @@ -260,6 +260,7 @@ type serviceClient struct { // application with main windows. app fyne.App + notifier notifier.Notifier wSettings fyne.Window showAdvancedSettings bool sendNotification bool @@ -314,6 +315,7 @@ type serviceClient struct { lastNotifiedVersion string settingsEnabled bool profilesEnabled bool + networksEnabled bool showNetworks bool wNetworks fyne.Window wProfiles fyne.Window @@ -324,6 +326,7 @@ type serviceClient struct { exitNodeMu sync.Mutex mExitNodeItems []menuHandler exitNodeRetryCancel context.CancelFunc + mExitNodeSeparator *systray.MenuItem mExitNodeDeselectAll *systray.MenuItem logFile string wLoginURL fyne.Window @@ -362,11 +365,13 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient { cancel: cancel, addr: args.addr, app: args.app, + notifier: notifier.New(args.app), logFile: args.logFile, sendNotification: false, showAdvancedSettings: args.showSettings, showNetworks: args.showNetworks, + networksEnabled: true, } s.eventHandler = newEventHandler(s) @@ -889,7 +894,7 @@ func (s *serviceClient) updateStatus() error { if err != nil { log.Errorf("get service status: %v", err) if s.connected { - s.app.SendNotification(fyne.NewNotification("Error", "Connection to service lost")) + s.notifier.Send("Error", "Connection to service lost") } s.setDisconnectedStatus() return err @@ -919,8 +924,10 @@ func (s *serviceClient) updateStatus() error { s.mStatus.SetIcon(s.icConnectedDot) s.mUp.Disable() s.mDown.Enable() - s.mNetworks.Enable() - s.mExitNode.Enable() + if s.networksEnabled { + s.mNetworks.Enable() + s.mExitNode.Enable() + } s.startExitNodeRefresh() systrayIconState = true case status.Status == string(internal.StatusConnecting): @@ -1092,19 +1099,19 @@ func (s *serviceClient) onTrayReady() { s.getSrvConfig() time.Sleep(100 * time.Millisecond) // To prevent race condition caused by systray not being fully initialized and ignoring setIcon for { + // Check features before status so menus respect disable flags before being enabled + s.checkAndUpdateFeatures() + err := s.updateStatus() if err != nil { log.Errorf("error while updating status: %v", err) } - // Check features periodically to handle daemon restarts - s.checkAndUpdateFeatures() - time.Sleep(2 * time.Second) } }() - s.eventManager = event.NewManager(s.app, s.addr) + s.eventManager = event.NewManager(s.notifier, s.addr) s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked()) s.eventManager.AddHandler(func(event *proto.SystemEvent) { if event.Category == proto.SystemEvent_SYSTEM { @@ -1141,9 +1148,6 @@ func (s *serviceClient) onTrayReady() { go s.eventManager.Start(s.ctx) go s.eventHandler.listen(s.ctx) - - // Start sleep detection listener - go s.startSleepListener() } func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File { @@ -1204,62 +1208,6 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService return s.conn, nil } -// startSleepListener initializes the sleep detection service and listens for sleep events -func (s *serviceClient) startSleepListener() { - sleepService, err := sleep.New() - if err != nil { - log.Warnf("%v", err) - return - } - - if err := sleepService.Register(s.handleSleepEvents); err != nil { - log.Errorf("failed to start sleep detection: %v", err) - return - } - - log.Info("sleep detection service initialized") - - // Cleanup on context cancellation - go func() { - <-s.ctx.Done() - log.Info("stopping sleep event listener") - if err := sleepService.Deregister(); err != nil { - log.Errorf("failed to deregister sleep detection: %v", err) - } - }() -} - -// handleSleepEvents sends a sleep notification to the daemon via gRPC -func (s *serviceClient) handleSleepEvents(event sleep.EventType) { - conn, err := s.getSrvClient(0) - if err != nil { - log.Errorf("failed to get daemon client for sleep notification: %v", err) - return - } - - req := &proto.OSLifecycleRequest{} - - switch event { - case sleep.EventTypeWakeUp: - log.Infof("handle wakeup event: %v", event) - req.Type = proto.OSLifecycleRequest_WAKEUP - case sleep.EventTypeSleep: - log.Infof("handle sleep event: %v", event) - req.Type = proto.OSLifecycleRequest_SLEEP - default: - log.Infof("unknown event: %v", event) - return - } - - _, err = conn.NotifyOSLifecycle(s.ctx, req) - if err != nil { - log.Errorf("failed to notify daemon about os lifecycle notification: %v", err) - return - } - - log.Info("successfully notified daemon about os lifecycle") -} - // setSettingsEnabled enables or disables the settings menu based on the provided state func (s *serviceClient) setSettingsEnabled(enabled bool) { if s.mSettings != nil { @@ -1298,6 +1246,16 @@ func (s *serviceClient) checkAndUpdateFeatures() { s.mProfile.setEnabled(profilesEnabled) } } + + // Update networks and exit node menus based on current features + s.networksEnabled = features == nil || !features.DisableNetworks + if s.networksEnabled && s.connected { + s.mNetworks.Enable() + s.mExitNode.Enable() + } else { + s.mNetworks.Disable() + s.mExitNode.Disable() + } } // getFeatures from the daemon to determine which features are enabled/disabled. @@ -1533,7 +1491,7 @@ func (s *serviceClient) onUpdateAvailable(newVersion string, enforced bool) { if enforced && s.lastNotifiedVersion != newVersion { s.lastNotifiedVersion = newVersion - s.app.SendNotification(fyne.NewNotification("Update available", "A new version "+newVersion+" is ready to install")) + s.notifier.Send("Update available", "A new version "+newVersion+" is ready to install") } } diff --git a/client/ui/debug.go b/client/ui/debug.go index 29f73a66a..cf5ac1a75 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -16,6 +16,7 @@ import ( "fyne.io/fyne/v2/widget" log "github.com/sirupsen/logrus" "github.com/skratchdot/open-golang/open" + "google.golang.org/protobuf/types/known/durationpb" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/proto" @@ -24,9 +25,10 @@ import ( // Initial state for the debug collection type debugInitialState struct { - wasDown bool - logLevel proto.LogLevel - isLevelTrace bool + wasDown bool + needsRestoreUp bool + logLevel proto.LogLevel + isLevelTrace bool } // Debug collection parameters @@ -37,6 +39,7 @@ type debugCollectionParams struct { upload bool uploadURL string enablePersistence bool + capture bool } // UI components for progress tracking @@ -50,25 +53,58 @@ type progressUI struct { func (s *serviceClient) showDebugUI() { w := s.app.NewWindow("NetBird Debug") w.SetOnClosed(s.cancel) - w.Resize(fyne.NewSize(600, 500)) w.SetFixedSize(true) anonymizeCheck := widget.NewCheck("Anonymize sensitive information (public IPs, domains, ...)", nil) systemInfoCheck := widget.NewCheck("Include system information (routes, interfaces, ...)", nil) systemInfoCheck.SetChecked(true) + captureCheck := widget.NewCheck("Include packet capture", nil) uploadCheck := widget.NewCheck("Upload bundle automatically after creation", nil) uploadCheck.SetChecked(true) - uploadURLLabel := widget.NewLabel("Debug upload URL:") + uploadURLContainer, uploadURL := s.buildUploadSection(uploadCheck) + + debugModeContainer, runForDurationCheck, durationInput, noteLabel := s.buildDurationSection() + + statusLabel := widget.NewLabel("") + statusLabel.Hide() + progressBar := widget.NewProgressBar() + progressBar.Hide() + createButton := widget.NewButton("Create Debug Bundle", nil) + + uiControls := []fyne.Disableable{ + anonymizeCheck, systemInfoCheck, captureCheck, + uploadCheck, uploadURL, runForDurationCheck, durationInput, createButton, + } + + createButton.OnTapped = s.getCreateHandler( + statusLabel, progressBar, uploadCheck, uploadURL, + anonymizeCheck, systemInfoCheck, captureCheck, + runForDurationCheck, durationInput, uiControls, w, + ) + + content := container.NewVBox( + widget.NewLabel("Create a debug bundle to help troubleshoot issues with NetBird"), + widget.NewLabel(""), + anonymizeCheck, systemInfoCheck, captureCheck, + uploadCheck, uploadURLContainer, + widget.NewLabel(""), + debugModeContainer, noteLabel, + widget.NewLabel(""), + statusLabel, progressBar, createButton, + ) + + w.SetContent(container.NewPadded(content)) + w.Show() +} + +func (s *serviceClient) buildUploadSection(uploadCheck *widget.Check) (*fyne.Container, *widget.Entry) { uploadURL := widget.NewEntry() uploadURL.SetText(uptypes.DefaultBundleURL) uploadURL.SetPlaceHolder("Enter upload URL") - uploadURLContainer := container.NewVBox( - uploadURLLabel, - uploadURL, - ) + uploadURLContainer := container.NewVBox(widget.NewLabel("Debug upload URL:"), uploadURL) uploadCheck.OnChanged = func(checked bool) { if checked { @@ -77,13 +113,14 @@ func (s *serviceClient) showDebugUI() { uploadURLContainer.Hide() } } + return uploadURLContainer, uploadURL +} - debugModeContainer := container.NewHBox() +func (s *serviceClient) buildDurationSection() (*fyne.Container, *widget.Check, *widget.Entry, *widget.Label) { runForDurationCheck := widget.NewCheck("Run with trace logs before creating bundle", nil) runForDurationCheck.SetChecked(true) forLabel := widget.NewLabel("for") - durationInput := widget.NewEntry() durationInput.SetText("1") minutesLabel := widget.NewLabel("minute") @@ -107,63 +144,8 @@ func (s *serviceClient) showDebugUI() { } } - debugModeContainer.Add(runForDurationCheck) - debugModeContainer.Add(forLabel) - debugModeContainer.Add(durationInput) - debugModeContainer.Add(minutesLabel) - - statusLabel := widget.NewLabel("") - statusLabel.Hide() - - progressBar := widget.NewProgressBar() - progressBar.Hide() - - createButton := widget.NewButton("Create Debug Bundle", nil) - - // UI controls that should be disabled during debug collection - uiControls := []fyne.Disableable{ - anonymizeCheck, - systemInfoCheck, - uploadCheck, - uploadURL, - runForDurationCheck, - durationInput, - createButton, - } - - createButton.OnTapped = s.getCreateHandler( - statusLabel, - progressBar, - uploadCheck, - uploadURL, - anonymizeCheck, - systemInfoCheck, - runForDurationCheck, - durationInput, - uiControls, - w, - ) - - content := container.NewVBox( - widget.NewLabel("Create a debug bundle to help troubleshoot issues with NetBird"), - widget.NewLabel(""), - anonymizeCheck, - systemInfoCheck, - uploadCheck, - uploadURLContainer, - widget.NewLabel(""), - debugModeContainer, - noteLabel, - widget.NewLabel(""), - statusLabel, - progressBar, - createButton, - ) - - paddedContent := container.NewPadded(content) - w.SetContent(paddedContent) - - w.Show() + modeContainer := container.NewHBox(runForDurationCheck, forLabel, durationInput, minutesLabel) + return modeContainer, runForDurationCheck, durationInput, noteLabel } func validateMinute(s string, minutesLabel *widget.Label) error { @@ -199,6 +181,7 @@ func (s *serviceClient) getCreateHandler( uploadURL *widget.Entry, anonymizeCheck *widget.Check, systemInfoCheck *widget.Check, + captureCheck *widget.Check, runForDurationCheck *widget.Check, duration *widget.Entry, uiControls []fyne.Disableable, @@ -221,6 +204,7 @@ func (s *serviceClient) getCreateHandler( params := &debugCollectionParams{ anonymize: anonymizeCheck.Checked, systemInfo: systemInfoCheck.Checked, + capture: captureCheck.Checked, upload: uploadCheck.Checked, uploadURL: url, enablePersistence: true, @@ -252,10 +236,7 @@ func (s *serviceClient) getCreateHandler( statusLabel.SetText("Creating debug bundle...") go s.handleDebugCreation( - anonymizeCheck.Checked, - systemInfoCheck.Checked, - uploadCheck.Checked, - url, + params, statusLabel, uiControls, w, @@ -370,47 +351,72 @@ func startProgressTracker(ctx context.Context, wg *sync.WaitGroup, duration time func (s *serviceClient) configureServiceForDebug( conn proto.DaemonServiceClient, state *debugInitialState, - enablePersistence bool, -) error { + params *debugCollectionParams, +) { if state.wasDown { if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil { - return fmt.Errorf("bring service up: %v", err) + log.Warnf("failed to bring service up: %v", err) + } else { + log.Info("Service brought up for debug") + time.Sleep(time.Second * 10) } - log.Info("Service brought up for debug") - time.Sleep(time.Second * 10) } if !state.isLevelTrace { if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: proto.LogLevel_TRACE}); err != nil { - return fmt.Errorf("set log level to TRACE: %v", err) + log.Warnf("failed to set log level to TRACE: %v", err) + } else { + log.Info("Log level set to TRACE for debug") } - log.Info("Log level set to TRACE for debug") } if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil { - return fmt.Errorf("bring service down: %v", err) + log.Warnf("failed to bring service down: %v", err) + } else { + state.needsRestoreUp = !state.wasDown + time.Sleep(time.Second) } - time.Sleep(time.Second) - if enablePersistence { + if params.enablePersistence { if _, err := conn.SetSyncResponsePersistence(s.ctx, &proto.SetSyncResponsePersistenceRequest{ Enabled: true, }); err != nil { - return fmt.Errorf("enable sync response persistence: %v", err) + log.Warnf("failed to enable sync response persistence: %v", err) + } else { + log.Info("Sync response persistence enabled for debug") } - log.Info("Sync response persistence enabled for debug") } if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil { - return fmt.Errorf("bring service back up: %v", err) + log.Warnf("failed to bring service back up: %v", err) + } else { + state.needsRestoreUp = false + time.Sleep(time.Second * 3) } - time.Sleep(time.Second * 3) if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil { log.Warnf("failed to start CPU profiling: %v", err) } - return nil + s.startBundleCaptureIfEnabled(conn, params) +} + +func (s *serviceClient) startBundleCaptureIfEnabled(conn proto.DaemonServiceClient, params *debugCollectionParams) { + if !params.capture { + return + } + + const maxCapture = 10 * time.Minute + timeout := params.duration + 30*time.Second + if timeout > maxCapture { + timeout = maxCapture + log.Warnf("packet capture clamped to %s (server maximum)", maxCapture) + } + if _, err := conn.StartBundleCapture(s.ctx, &proto.StartBundleCaptureRequest{ + Timeout: durationpb.New(timeout), + }); err != nil { + log.Warnf("failed to start bundle capture: %v", err) + } } func (s *serviceClient) collectDebugData( @@ -424,9 +430,7 @@ func (s *serviceClient) collectDebugData( var wg sync.WaitGroup startProgressTracker(ctx, &wg, params.duration, progress) - if err := s.configureServiceForDebug(conn, state, params.enablePersistence); err != nil { - return err - } + s.configureServiceForDebug(conn, state, params) wg.Wait() progress.progressBar.Hide() @@ -436,6 +440,14 @@ func (s *serviceClient) collectDebugData( log.Warnf("failed to stop CPU profiling: %v", err) } + if params.capture { + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := conn.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil { + log.Warnf("failed to stop bundle capture: %v", err) + } + } + return nil } @@ -482,9 +494,17 @@ func (s *serviceClient) createDebugBundleFromCollection( // Restore service to original state func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, state *debugInitialState) { + if state.needsRestoreUp { + if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil { + log.Warnf("failed to restore up state: %v", err) + } else { + log.Info("Service state restored to up") + } + } + if state.wasDown { if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil { - log.Errorf("Failed to restore down state: %v", err) + log.Warnf("failed to restore down state: %v", err) } else { log.Info("Service state restored to down") } @@ -492,7 +512,7 @@ func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, stat if !state.isLevelTrace { if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: state.logLevel}); err != nil { - log.Errorf("Failed to restore log level: %v", err) + log.Warnf("failed to restore log level: %v", err) } else { log.Info("Log level restored to original setting") } @@ -508,18 +528,37 @@ func handleError(progress *progressUI, errMsg string) { } func (s *serviceClient) handleDebugCreation( - anonymize bool, - systemInfo bool, - upload bool, - uploadURL string, + params *debugCollectionParams, statusLabel *widget.Label, uiControls []fyne.Disableable, w fyne.Window, ) { - log.Infof("Creating debug bundle (Anonymized: %v, System Info: %v, Upload Attempt: %v)...", - anonymize, systemInfo, upload) + conn, err := s.getSrvClient(failFastTimeout) + if err != nil { + log.Errorf("Failed to get client for debug: %v", err) + statusLabel.SetText(fmt.Sprintf("Error: %v", err)) + enableUIControls(uiControls) + return + } - resp, err := s.createDebugBundle(anonymize, systemInfo, uploadURL) + if params.capture { + if _, err := conn.StartBundleCapture(s.ctx, &proto.StartBundleCaptureRequest{ + Timeout: durationpb.New(30 * time.Second), + }); err != nil { + log.Warnf("failed to start bundle capture: %v", err) + } else { + defer func() { + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := conn.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil { + log.Warnf("failed to stop bundle capture: %v", err) + } + }() + time.Sleep(2 * time.Second) + } + } + + resp, err := s.createDebugBundle(params.anonymize, params.systemInfo, params.uploadURL) if err != nil { log.Errorf("Failed to create debug bundle: %v", err) statusLabel.SetText(fmt.Sprintf("Error creating bundle: %v", err)) @@ -531,7 +570,7 @@ func (s *serviceClient) handleDebugCreation( uploadFailureReason := resp.GetUploadFailureReason() uploadedKey := resp.GetUploadedKey() - if upload { + if params.upload { if uploadFailureReason != "" { showUploadFailedDialog(w, localPath, uploadFailureReason) } else { diff --git a/client/ui/event/event.go b/client/ui/event/event.go index b8ed09a5c..ea968f60a 100644 --- a/client/ui/event/event.go +++ b/client/ui/event/event.go @@ -8,7 +8,6 @@ import ( "sync" "time" - "fyne.io/fyne/v2" "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "google.golang.org/grpc" @@ -18,11 +17,17 @@ import ( "github.com/netbirdio/netbird/client/ui/desktop" ) +// Notifier sends desktop notifications. Defined here so the event package +// does not depend on fyne or the platform-specific notifier implementation. +type Notifier interface { + Send(title, body string) +} + type Handler func(*proto.SystemEvent) type Manager struct { - app fyne.App - addr string + notifier Notifier + addr string mu sync.Mutex ctx context.Context @@ -31,10 +36,10 @@ type Manager struct { handlers []Handler } -func NewManager(app fyne.App, addr string) *Manager { +func NewManager(notifier Notifier, addr string) *Manager { return &Manager{ - app: app, - addr: addr, + notifier: notifier, + addr: addr, } } @@ -114,7 +119,7 @@ func (e *Manager) handleEvent(event *proto.SystemEvent) { if id != "" { body += fmt.Sprintf(" ID: %s", id) } - e.app.SendNotification(fyne.NewNotification(title, body)) + e.notifier.Send(title, body) } for _, handler := range handlers { diff --git a/client/ui/event_handler.go b/client/ui/event_handler.go index 60a580dae..876fcef5f 100644 --- a/client/ui/event_handler.go +++ b/client/ui/event_handler.go @@ -9,7 +9,6 @@ import ( "os" "os/exec" - "fyne.io/fyne/v2" "fyne.io/systray" log "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" @@ -87,7 +86,7 @@ func (h *eventHandler) handleConnectClick() { if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) { log.Debugf("connect operation cancelled by user") } else { - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect")) + h.client.notifier.Send("Error", "Failed to connect") log.Errorf("connect failed: %v", err) } } @@ -112,7 +111,7 @@ func (h *eventHandler) handleDisconnectClick() { if err := h.client.menuDownClick(); err != nil { st, ok := status.FromError(err) if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) { - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to disconnect")) + h.client.notifier.Send("Error", "Failed to disconnect") log.Errorf("disconnect failed: %v", err) } else { log.Debugf("disconnect cancelled or already disconnecting") @@ -130,7 +129,7 @@ func (h *eventHandler) handleAllowSSHClick() { if err := h.updateConfigWithErr(); err != nil { h.toggleCheckbox(h.client.mAllowSSH) // revert checkbox state on error log.Errorf("failed to update config: %v", err) - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update SSH settings")) + h.client.notifier.Send("Error", "Failed to update SSH settings") } } @@ -140,7 +139,7 @@ func (h *eventHandler) handleAutoConnectClick() { if err := h.updateConfigWithErr(); err != nil { h.toggleCheckbox(h.client.mAutoConnect) // revert checkbox state on error log.Errorf("failed to update config: %v", err) - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update auto-connect settings")) + h.client.notifier.Send("Error", "Failed to update auto-connect settings") } } @@ -149,7 +148,7 @@ func (h *eventHandler) handleRosenpassClick() { if err := h.updateConfigWithErr(); err != nil { h.toggleCheckbox(h.client.mEnableRosenpass) // revert checkbox state on error log.Errorf("failed to update config: %v", err) - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update Rosenpass settings")) + h.client.notifier.Send("Error", "Failed to update Rosenpass settings") } } @@ -158,7 +157,7 @@ func (h *eventHandler) handleLazyConnectionClick() { if err := h.updateConfigWithErr(); err != nil { h.toggleCheckbox(h.client.mLazyConnEnabled) // revert checkbox state on error log.Errorf("failed to update config: %v", err) - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update lazy connection settings")) + h.client.notifier.Send("Error", "Failed to update lazy connection settings") } } @@ -167,7 +166,7 @@ func (h *eventHandler) handleBlockInboundClick() { if err := h.updateConfigWithErr(); err != nil { h.toggleCheckbox(h.client.mBlockInbound) // revert checkbox state on error log.Errorf("failed to update config: %v", err) - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update block inbound settings")) + h.client.notifier.Send("Error", "Failed to update block inbound settings") } } @@ -176,7 +175,7 @@ func (h *eventHandler) handleNotificationsClick() { if err := h.updateConfigWithErr(); err != nil { h.toggleCheckbox(h.client.mNotifications) // revert checkbox state on error log.Errorf("failed to update config: %v", err) - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update notifications settings")) + h.client.notifier.Send("Error", "Failed to update notifications settings") } else if h.client.eventManager != nil { h.client.eventManager.SetNotificationsEnabled(h.client.mNotifications.Checked()) } diff --git a/client/ui/network.go b/client/ui/network.go index ed03f5ada..571e871bb 100644 --- a/client/ui/network.go +++ b/client/ui/network.go @@ -421,6 +421,10 @@ func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) { node.Remove() } s.mExitNodeItems = nil + if s.mExitNodeSeparator != nil { + s.mExitNodeSeparator.Remove() + s.mExitNodeSeparator = nil + } if s.mExitNodeDeselectAll != nil { s.mExitNodeDeselectAll.Remove() s.mExitNodeDeselectAll = nil @@ -453,31 +457,37 @@ func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) { } if showDeselectAll { - s.mExitNode.AddSeparator() - deselectAllItem := s.mExitNode.AddSubMenuItem("Deselect All", "Deselect All") - s.mExitNodeDeselectAll = deselectAllItem - go func() { - for { - _, ok := <-deselectAllItem.ClickedCh - if !ok { - // channel closed: exit the goroutine - return - } - exitNodes, err := s.handleExitNodeMenuDeselectAll() - if err != nil { - log.Warnf("failed to handle deselect all exit nodes: %v", err) - } else { - s.exitNodeMu.Lock() - s.recreateExitNodeMenu(exitNodes) - s.exitNodeMu.Unlock() - } - } - - }() + s.addExitNodeDeselectAll() } } +func (s *serviceClient) addExitNodeDeselectAll() { + sep := s.mExitNode.AddSubMenuItem("───────────────", "") + sep.Disable() + s.mExitNodeSeparator = sep + + deselectAllItem := s.mExitNode.AddSubMenuItem("Deselect All", "Deselect All") + s.mExitNodeDeselectAll = deselectAllItem + + go func() { + for { + _, ok := <-deselectAllItem.ClickedCh + if !ok { + return + } + exitNodes, err := s.handleExitNodeMenuDeselectAll() + if err != nil { + log.Warnf("failed to handle deselect all exit nodes: %v", err) + } else { + s.exitNodeMu.Lock() + s.recreateExitNodeMenu(exitNodes) + s.exitNodeMu.Unlock() + } + } + }() +} + func (s *serviceClient) getExitNodes(conn proto.DaemonServiceClient) ([]*proto.Network, error) { ctx, cancel := context.WithTimeout(s.ctx, defaultFailTimeout) defer cancel() diff --git a/client/ui/notifier/notifier.go b/client/ui/notifier/notifier.go new file mode 100644 index 000000000..8d1cbe4c4 --- /dev/null +++ b/client/ui/notifier/notifier.go @@ -0,0 +1,27 @@ +// Package notifier sends desktop notifications. On Windows it uses the WinRT +// COM API directly via go-toast/v2 to avoid the PowerShell window flash that +// fyne's default implementation produces. On other platforms it delegates to +// fyne. +package notifier + +import "fyne.io/fyne/v2" + +// Notifier sends desktop notifications. +type Notifier interface { + Send(title, body string) +} + +// New returns a platform-specific Notifier. The fyne app is used as the +// fallback notifier on platforms where no native implementation is wired up, +// and on Windows when the COM path fails to initialize. +func New(app fyne.App) Notifier { + return newNotifier(app) +} + +type fyneNotifier struct { + app fyne.App +} + +func (f *fyneNotifier) Send(title, body string) { + f.app.SendNotification(fyne.NewNotification(title, body)) +} diff --git a/client/ui/notifier/notifier_other.go b/client/ui/notifier/notifier_other.go new file mode 100644 index 000000000..686d2885f --- /dev/null +++ b/client/ui/notifier/notifier_other.go @@ -0,0 +1,9 @@ +//go:build !windows + +package notifier + +import "fyne.io/fyne/v2" + +func newNotifier(app fyne.App) Notifier { + return &fyneNotifier{app: app} +} diff --git a/client/ui/notifier/notifier_windows.go b/client/ui/notifier/notifier_windows.go new file mode 100644 index 000000000..c7afb43ae --- /dev/null +++ b/client/ui/notifier/notifier_windows.go @@ -0,0 +1,88 @@ +package notifier + +import ( + "os" + "path/filepath" + "sync" + + "fyne.io/fyne/v2" + toast "git.sr.ht/~jackmordaunt/go-toast/v2" + "git.sr.ht/~jackmordaunt/go-toast/v2/wintoast" + log "github.com/sirupsen/logrus" +) + +const ( + // appID is the AppUserModelID shown in the Windows Action Center. It + // must match the System.AppUserModel.ID property set on the Start Menu + // shortcut by the MSI (see client/netbird.wxs); otherwise Windows + // groups toasts under a separate, unbranded entry. + appID = "NetBird" + + // appGUID identifies the COM activation callback class. Generated once + // for NetBird; do not change without coordinating an installer bump, + // since old registry entries pointing at the previous GUID would orphan. + appGUID = "{0E1B4DE7-E148-432B-9814-544F941826EC}" +) + +type comNotifier struct { + fallback *fyneNotifier + ready bool + iconPath string +} + +var ( + initOnce sync.Once + initErr error +) + +func newNotifier(app fyne.App) Notifier { + n := &comNotifier{ + fallback: &fyneNotifier{app: app}, + iconPath: resolveIcon(), + } + initOnce.Do(func() { + initErr = wintoast.SetAppData(wintoast.AppData{ + AppID: appID, + GUID: appGUID, + IconPath: n.iconPath, + }) + }) + if initErr != nil { + log.Warnf("toast: register app data failed, falling back to fyne notifications: %v", initErr) + return n.fallback + } + n.ready = true + return n +} + +func (n *comNotifier) Send(title, body string) { + if !n.ready { + n.fallback.Send(title, body) + return + } + notification := toast.Notification{ + AppID: appID, + Title: title, + Body: body, + Icon: n.iconPath, + } + if err := notification.Push(); err != nil { + log.Warnf("toast: push failed, using fyne fallback: %v", err) + n.fallback.Send(title, body) + } +} + +// resolveIcon returns an absolute path to the toast icon, or an empty string +// when no icon can be located. Windows requires a PNG/JPG for the +// AppUserModelId IconUri registry value; .ico is silently ignored. +func resolveIcon() string { + exe, err := os.Executable() + if err != nil { + return "" + } + candidate := filepath.Join(filepath.Dir(exe), "netbird.png") + if _, err := os.Stat(candidate); err == nil { + return candidate + } + return "" +} diff --git a/client/ui/profile.go b/client/ui/profile.go index 74189c9a0..7ee89e631 100644 --- a/client/ui/profile.go +++ b/client/ui/profile.go @@ -548,7 +548,7 @@ func (p *profileMenu) refresh() { if err != nil { log.Errorf("failed to switch profile: %v", err) // show notification dialog - p.app.SendNotification(fyne.NewNotification("Error", "Failed to switch profile")) + p.serviceClient.notifier.Send("Error", "Failed to switch profile") return } @@ -628,9 +628,9 @@ func (p *profileMenu) refresh() { } if err := p.eventHandler.logout(p.ctx); err != nil { log.Errorf("logout failed: %v", err) - p.app.SendNotification(fyne.NewNotification("Error", "Failed to deregister")) + p.serviceClient.notifier.Send("Error", "Failed to deregister") } else { - p.app.SendNotification(fyne.NewNotification("Success", "Deregistered successfully")) + p.serviceClient.notifier.Send("Success", "Deregistered successfully") } } } diff --git a/client/wasm/cmd/main.go b/client/wasm/cmd/main.go index d8e50ab6d..cb512f132 100644 --- a/client/wasm/cmd/main.go +++ b/client/wasm/cmd/main.go @@ -5,6 +5,7 @@ package main import ( "context" "fmt" + "sync" "syscall/js" "time" @@ -14,6 +15,7 @@ import ( netbird "github.com/netbirdio/netbird/client/embed" sshdetection "github.com/netbirdio/netbird/client/ssh/detection" nbstatus "github.com/netbirdio/netbird/client/status" + wasmcapture "github.com/netbirdio/netbird/client/wasm/internal/capture" "github.com/netbirdio/netbird/client/wasm/internal/http" "github.com/netbirdio/netbird/client/wasm/internal/rdp" "github.com/netbirdio/netbird/client/wasm/internal/ssh" @@ -459,6 +461,95 @@ func createSetLogLevelMethod(client *netbird.Client) js.Func { }) } +// createStartCaptureMethod creates the programmable packet capture method. +// Returns a JS interface with onpacket callback and stop() method. +// +// Usage from JavaScript: +// +// const cap = await client.startCapture({ filter: "tcp port 443", verbose: true }) +// cap.onpacket = (line) => console.log(line) +// const stats = cap.stop() +func createStartCaptureMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(_ js.Value, args []js.Value) any { + var opts js.Value + if len(args) > 0 { + opts = args[0] + } + + return createPromise(func(resolve, reject js.Value) { + iface, err := wasmcapture.Start(client, opts) + if err != nil { + reject.Invoke(js.ValueOf(fmt.Sprintf("start capture: %v", err))) + return + } + resolve.Invoke(iface) + }) + }) +} + +// captureMethods returns capture() and stopCapture() that share state for +// the console-log shortcut. capture() logs packets to the browser console +// and stopCapture() ends it, like Ctrl+C on the CLI. +// +// Usage from browser devtools console: +// +// await client.capture() // capture all packets +// await client.capture("tcp") // capture with filter +// await client.capture({filter: "host 10.0.0.1", verbose: true}) +// client.stopCapture() // stop and print stats +func captureMethods(client *netbird.Client) (startFn, stopFn js.Func) { + var mu sync.Mutex + var active *wasmcapture.Handle + + startFn = js.FuncOf(func(_ js.Value, args []js.Value) any { + var opts js.Value + if len(args) > 0 { + opts = args[0] + } + + return createPromise(func(resolve, reject js.Value) { + mu.Lock() + defer mu.Unlock() + + if active != nil { + active.Stop() + active = nil + } + + h, err := wasmcapture.StartConsole(client, opts) + if err != nil { + reject.Invoke(js.ValueOf(fmt.Sprintf("start capture: %v", err))) + return + } + active = h + + console := js.Global().Get("console") + console.Call("log", "[capture] started, call client.stopCapture() to stop") + resolve.Invoke(js.Undefined()) + }) + }) + + stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any { + mu.Lock() + defer mu.Unlock() + + if active == nil { + js.Global().Get("console").Call("log", "[capture] no active capture") + return js.Undefined() + } + + stats := active.Stop() + active = nil + + console := js.Global().Get("console") + console.Call("log", fmt.Sprintf("[capture] stopped: %d packets, %d bytes, %d dropped", + stats.Packets, stats.Bytes, stats.Dropped)) + return js.Undefined() + }) + + return startFn, stopFn +} + // createPromise is a helper to create JavaScript promises func createPromise(handler func(resolve, reject js.Value)) js.Value { return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any { @@ -521,6 +612,11 @@ func createClientObject(client *netbird.Client) js.Value { obj["statusDetail"] = createStatusDetailMethod(client) obj["getSyncResponse"] = createGetSyncResponseMethod(client) obj["setLogLevel"] = createSetLogLevelMethod(client) + obj["startCapture"] = createStartCaptureMethod(client) + + capStart, capStop := captureMethods(client) + obj["capture"] = capStart + obj["stopCapture"] = capStop return js.ValueOf(obj) } diff --git a/client/wasm/internal/capture/capture.go b/client/wasm/internal/capture/capture.go new file mode 100644 index 000000000..53e43c45e --- /dev/null +++ b/client/wasm/internal/capture/capture.go @@ -0,0 +1,176 @@ +//go:build js + +// Package capture bridges the util/capture package to JavaScript via syscall/js. +package capture + +import ( + "strings" + "sync" + "syscall/js" + + netbird "github.com/netbirdio/netbird/client/embed" +) + +// Handle holds a running capture session so it can be stopped later. +type Handle struct { + cs *netbird.CaptureSession + stopFn js.Func + stopped bool +} + +// Stop ends the capture and returns stats. +func (h *Handle) Stop() netbird.CaptureStats { + if h.stopped { + return h.cs.Stats() + } + h.stopped = true + h.stopFn.Release() + + h.cs.Stop() + return h.cs.Stats() +} + +func statsToJS(s netbird.CaptureStats) js.Value { + obj := js.Global().Get("Object").Call("create", js.Null()) + obj.Set("packets", js.ValueOf(s.Packets)) + obj.Set("bytes", js.ValueOf(s.Bytes)) + obj.Set("dropped", js.ValueOf(s.Dropped)) + return obj +} + +// parseOpts extracts filter/verbose/ascii from a JS options value. +func parseOpts(jsOpts js.Value) (filter string, verbose, ascii bool) { + if jsOpts.IsNull() || jsOpts.IsUndefined() { + return + } + if jsOpts.Type() == js.TypeString { + filter = jsOpts.String() + return + } + if jsOpts.Type() != js.TypeObject { + return + } + if f := jsOpts.Get("filter"); !f.IsUndefined() && !f.IsNull() { + filter = f.String() + } + if v := jsOpts.Get("verbose"); !v.IsUndefined() { + verbose = v.Truthy() + } + if a := jsOpts.Get("ascii"); !a.IsUndefined() { + ascii = a.Truthy() + } + return +} + +// Start creates a capture session and returns a JS interface for streaming text +// output. The returned object exposes: +// +// onpacket(callback) - set callback(string) for each text line +// stop() - stop capture and return stats { packets, bytes, dropped } +// +// Options: { filter: string, verbose: bool, ascii: bool } or just a filter string. +func Start(client *netbird.Client, jsOpts js.Value) (js.Value, error) { + filter, verbose, ascii := parseOpts(jsOpts) + + cb := &jsCallbackWriter{} + + cs, err := client.StartCapture(netbird.CaptureOptions{ + TextOutput: cb, + Filter: filter, + Verbose: verbose, + ASCII: ascii, + }) + if err != nil { + return js.Undefined(), err + } + + handle := &Handle{cs: cs} + + iface := js.Global().Get("Object").Call("create", js.Null()) + handle.stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any { + return statsToJS(handle.Stop()) + }) + iface.Set("stop", handle.stopFn) + iface.Set("onpacket", js.Undefined()) + cb.setInterface(iface) + + return iface, nil +} + +// StartConsole starts a capture that logs every packet line to console.log. +// Returns a Handle so the caller can stop it later. +func StartConsole(client *netbird.Client, jsOpts js.Value) (*Handle, error) { + filter, verbose, ascii := parseOpts(jsOpts) + + cb := &jsCallbackWriter{} + + cs, err := client.StartCapture(netbird.CaptureOptions{ + TextOutput: cb, + Filter: filter, + Verbose: verbose, + ASCII: ascii, + }) + if err != nil { + return nil, err + } + + handle := &Handle{cs: cs} + handle.stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any { + return statsToJS(handle.Stop()) + }) + + iface := js.Global().Get("Object").Call("create", js.Null()) + console := js.Global().Get("console") + iface.Set("onpacket", console.Get("log").Call("bind", console, js.ValueOf("[capture]"))) + cb.setInterface(iface) + + return handle, nil +} + +// jsCallbackWriter is an io.Writer that buffers text until a newline, then +// invokes the JS onpacket callback with each complete line. +type jsCallbackWriter struct { + mu sync.Mutex + iface js.Value + buf strings.Builder +} + +func (w *jsCallbackWriter) setInterface(iface js.Value) { + w.mu.Lock() + defer w.mu.Unlock() + w.iface = iface +} + +func (w *jsCallbackWriter) Write(p []byte) (int, error) { + w.mu.Lock() + w.buf.Write(p) + + var lines []string + for { + str := w.buf.String() + idx := strings.IndexByte(str, '\n') + if idx < 0 { + break + } + lines = append(lines, str[:idx]) + w.buf.Reset() + if idx+1 < len(str) { + w.buf.WriteString(str[idx+1:]) + } + } + + iface := w.iface + w.mu.Unlock() + + if iface.IsUndefined() { + return len(p), nil + } + cb := iface.Get("onpacket") + if cb.IsUndefined() || cb.IsNull() { + return len(p), nil + } + for _, line := range lines { + cb.Invoke(js.ValueOf(line)) + } + return len(p), nil +} diff --git a/combined/cmd/config.go b/combined/cmd/config.go index 85664d0d2..ce4df8394 100644 --- a/combined/cmd/config.go +++ b/combined/cmd/config.go @@ -179,9 +179,11 @@ type StoreConfig struct { // ReverseProxyConfig contains reverse proxy settings type ReverseProxyConfig struct { - TrustedHTTPProxies []string `yaml:"trustedHTTPProxies"` - TrustedHTTPProxiesCount uint `yaml:"trustedHTTPProxiesCount"` - TrustedPeers []string `yaml:"trustedPeers"` + TrustedHTTPProxies []string `yaml:"trustedHTTPProxies"` + TrustedHTTPProxiesCount uint `yaml:"trustedHTTPProxiesCount"` + TrustedPeers []string `yaml:"trustedPeers"` + AccessLogRetentionDays int `yaml:"accessLogRetentionDays"` + AccessLogCleanupIntervalHours int `yaml:"accessLogCleanupIntervalHours"` } // DefaultConfig returns a CombinedConfig with default values @@ -645,7 +647,9 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) { // Build reverse proxy config reverseProxy := nbconfig.ReverseProxy{ - TrustedHTTPProxiesCount: mgmt.ReverseProxy.TrustedHTTPProxiesCount, + TrustedHTTPProxiesCount: mgmt.ReverseProxy.TrustedHTTPProxiesCount, + AccessLogRetentionDays: mgmt.ReverseProxy.AccessLogRetentionDays, + AccessLogCleanupIntervalHours: mgmt.ReverseProxy.AccessLogCleanupIntervalHours, } for _, p := range mgmt.ReverseProxy.TrustedHTTPProxies { if prefix, err := netip.ParsePrefix(p); err == nil { diff --git a/combined/cmd/root.go b/combined/cmd/root.go index ea1ff908a..db986b4d4 100644 --- a/combined/cmd/root.go +++ b/combined/cmd/root.go @@ -29,6 +29,7 @@ import ( "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/relay/healthcheck" relayServer "github.com/netbirdio/netbird/relay/server" + "github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/relay/server/listener/ws" sharedMetrics "github.com/netbirdio/netbird/shared/metrics" "github.com/netbirdio/netbird/shared/relay/auth" @@ -523,7 +524,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (* func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler { wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter)) - var relayAcceptFn func(conn net.Conn) + var relayAcceptFn func(conn listener.Conn) if relaySrv != nil { relayAcceptFn = relaySrv.RelayAccept() } @@ -563,7 +564,7 @@ func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, re } // handleRelayWebSocket handles incoming WebSocket connections for the relay service -func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn net.Conn), cfg *CombinedConfig) { +func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn listener.Conn), cfg *CombinedConfig) { acceptOptions := &websocket.AcceptOptions{ OriginPatterns: []string{"*"}, } @@ -585,15 +586,9 @@ func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func( return } - lAddr, err := net.ResolveTCPAddr("tcp", cfg.Server.ListenAddress) - if err != nil { - _ = wsConn.Close(websocket.StatusInternalError, "internal error") - return - } - log.Debugf("Relay WS client connected from: %s", rAddr) - conn := ws.NewConn(wsConn, lAddr, rAddr) + conn := ws.NewConn(wsConn, rAddr) acceptFn(conn) } diff --git a/combined/config.yaml.example b/combined/config.yaml.example index dce658d89..af85b0477 100644 --- a/combined/config.yaml.example +++ b/combined/config.yaml.example @@ -119,6 +119,8 @@ server: # Reverse proxy settings (optional) # reverseProxy: - # trustedHTTPProxies: [] - # trustedHTTPProxiesCount: 0 - # trustedPeers: [] + # trustedHTTPProxies: [] # CIDRs of trusted reverse proxies (e.g. ["10.0.0.0/8"]) + # trustedHTTPProxiesCount: 0 # Number of trusted proxies in front of the server (alternative to trustedHTTPProxies) + # trustedPeers: [] # CIDRs of trusted peer networks (e.g. ["100.64.0.0/10"]) + # accessLogRetentionDays: 7 # Days to retain HTTP access logs. 0 (or unset) defaults to 7. Negative values disable cleanup (logs kept indefinitely). + # accessLogCleanupIntervalHours: 24 # How often (in hours) to run the access-log cleanup job. 0 (or unset) is treated as "not set" and defaults to 24 hours; cleanup remains enabled. To disable cleanup, set accessLogRetentionDays to a negative value. diff --git a/flow/client/client.go b/flow/client/client.go index 318fcfe1e..180a4b441 100644 --- a/flow/client/client.go +++ b/flow/client/client.go @@ -13,12 +13,9 @@ import ( "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" - "google.golang.org/grpc/status" nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/flow/proto" @@ -26,11 +23,22 @@ import ( "github.com/netbirdio/netbird/util/wsproxy" ) +var ErrClientClosed = errors.New("client is closed") + +// minHealthyDuration is the minimum time a stream must survive before a failure +// resets the backoff timer. Streams that fail faster are considered unhealthy and +// should not reset backoff, so that MaxElapsedTime can eventually stop retries. +const minHealthyDuration = 5 * time.Second + type GRPCClient struct { realClient proto.FlowServiceClient clientConn *grpc.ClientConn stream proto.FlowService_EventsClient - streamMu sync.Mutex + target string + opts []grpc.DialOption + closed bool // prevent creating conn in the middle of the Close + receiving bool // prevent concurrent Receive calls + mu sync.Mutex // protects clientConn, realClient, stream, closed, and receiving } func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCClient, error) { @@ -65,7 +73,8 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl grpc.WithDefaultServiceConfig(`{"healthCheckConfig": {"serviceName": ""}}`), ) - conn, err := grpc.NewClient(fmt.Sprintf("%s:%s", parsedURL.Hostname(), parsedURL.Port()), opts...) + target := parsedURL.Host + conn, err := grpc.NewClient(target, opts...) if err != nil { return nil, fmt.Errorf("creating new grpc client: %w", err) } @@ -73,30 +82,73 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl return &GRPCClient{ realClient: proto.NewFlowServiceClient(conn), clientConn: conn, + target: target, + opts: opts, }, nil } func (c *GRPCClient) Close() error { - c.streamMu.Lock() - defer c.streamMu.Unlock() - + c.mu.Lock() + c.closed = true c.stream = nil - if err := c.clientConn.Close(); err != nil && !errors.Is(err, context.Canceled) { + conn := c.clientConn + c.clientConn = nil + c.mu.Unlock() + + if conn == nil { + return nil + } + + if err := conn.Close(); err != nil && !errors.Is(err, context.Canceled) { return fmt.Errorf("close client connection: %w", err) } return nil } +func (c *GRPCClient) Send(event *proto.FlowEvent) error { + c.mu.Lock() + stream := c.stream + c.mu.Unlock() + + if stream == nil { + return errors.New("stream not initialized") + } + + if err := stream.Send(event); err != nil { + return fmt.Errorf("send flow event: %w", err) + } + + return nil +} + func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHandler func(msg *proto.FlowEventAck) error) error { + c.mu.Lock() + if c.receiving { + c.mu.Unlock() + return errors.New("concurrent Receive calls are not supported") + } + c.receiving = true + c.mu.Unlock() + defer func() { + c.mu.Lock() + c.receiving = false + c.mu.Unlock() + }() + backOff := defaultBackoff(ctx, interval) operation := func() error { - if err := c.establishStreamAndReceive(ctx, msgHandler); err != nil { - if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled { - return fmt.Errorf("receive: %w: %w", err, context.Canceled) - } + stream, err := c.establishStream(ctx) + if err != nil { + log.Errorf("failed to establish flow stream, retrying: %v", err) + return c.handleRetryableError(err, time.Time{}, backOff) + } + + streamStart := time.Now() + + if err := c.receive(stream, msgHandler); err != nil { log.Errorf("receive failed: %v", err) - return fmt.Errorf("receive: %w", err) + return c.handleRetryableError(err, streamStart, backOff) } return nil } @@ -108,37 +160,106 @@ func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHan return nil } -func (c *GRPCClient) establishStreamAndReceive(ctx context.Context, msgHandler func(msg *proto.FlowEventAck) error) error { - if c.clientConn.GetState() == connectivity.Shutdown { - return errors.New("connection to flow receiver has been shut down") +// handleRetryableError resets the backoff timer if the stream was healthy long +// enough and recreates the underlying ClientConn so that gRPC's internal +// subchannel backoff does not accumulate and compete with our own retry timer. +// A zero streamStart means the stream was never established. +func (c *GRPCClient) handleRetryableError(err error, streamStart time.Time, backOff backoff.BackOff) error { + if isContextDone(err) { + return backoff.Permanent(err) } - stream, err := c.realClient.Events(ctx, grpc.WaitForReady(true)) - if err != nil { - return fmt.Errorf("create event stream: %w", err) + var permErr *backoff.PermanentError + if errors.As(err, &permErr) { + return err } - err = stream.Send(&proto.FlowEvent{IsInitiator: true}) + // Reset the backoff so the next retry starts with a short delay instead of + // continuing the already-elapsed timer. Only do this if the stream was healthy + // long enough; short-lived connect/drop cycles must not defeat MaxElapsedTime. + if !streamStart.IsZero() && time.Since(streamStart) >= minHealthyDuration { + backOff.Reset() + } + + if recreateErr := c.recreateConnection(); recreateErr != nil { + log.Errorf("recreate connection: %v", recreateErr) + return recreateErr + } + + log.Infof("connection recreated, retrying stream") + return fmt.Errorf("retrying after error: %w", err) +} + +func (c *GRPCClient) recreateConnection() error { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return backoff.Permanent(ErrClientClosed) + } + + conn, err := grpc.NewClient(c.target, c.opts...) if err != nil { - log.Infof("failed to send initiator message to flow receiver but will attempt to continue. Error: %s", err) + c.mu.Unlock() + return fmt.Errorf("create new connection: %w", err) + } + + old := c.clientConn + c.clientConn = conn + c.realClient = proto.NewFlowServiceClient(conn) + c.stream = nil + c.mu.Unlock() + + _ = old.Close() + + return nil +} + +func (c *GRPCClient) establishStream(ctx context.Context) (proto.FlowService_EventsClient, error) { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil, backoff.Permanent(ErrClientClosed) + } + cl := c.realClient + c.mu.Unlock() + + // open stream outside the lock — blocking operation + stream, err := cl.Events(ctx) + if err != nil { + return nil, fmt.Errorf("create event stream: %w", err) + } + streamReady := false + defer func() { + if !streamReady { + _ = stream.CloseSend() + } + }() + + if err = stream.Send(&proto.FlowEvent{IsInitiator: true}); err != nil { + return nil, fmt.Errorf("send initiator: %w", err) } if err = checkHeader(stream); err != nil { - return fmt.Errorf("check header: %w", err) + return nil, fmt.Errorf("check header: %w", err) } - c.streamMu.Lock() + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil, backoff.Permanent(ErrClientClosed) + } c.stream = stream - c.streamMu.Unlock() + c.mu.Unlock() + streamReady = true - return c.receive(stream, msgHandler) + return stream, nil } func (c *GRPCClient) receive(stream proto.FlowService_EventsClient, msgHandler func(msg *proto.FlowEventAck) error) error { for { msg, err := stream.Recv() if err != nil { - return fmt.Errorf("receive from stream: %w", err) + return err } if msg.IsInitiator { @@ -169,7 +290,7 @@ func checkHeader(stream proto.FlowService_EventsClient) error { func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff { return backoff.WithContext(&backoff.ExponentialBackOff{ InitialInterval: 800 * time.Millisecond, - RandomizationFactor: 1, + RandomizationFactor: 0.5, Multiplier: 1.7, MaxInterval: interval / 2, MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months @@ -178,18 +299,11 @@ func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff }, ctx) } -func (c *GRPCClient) Send(event *proto.FlowEvent) error { - c.streamMu.Lock() - stream := c.stream - c.streamMu.Unlock() - - if stream == nil { - return errors.New("stream not initialized") - } - - if err := stream.Send(event); err != nil { - return fmt.Errorf("send flow event: %w", err) - } - - return nil +// isContextDone reports whether the local context has been canceled or has +// exceeded its deadline. It deliberately does not inspect gRPC status codes: +// a server- or proxy-sent codes.Canceled / codes.DeadlineExceeded must not +// short-circuit our retry loop, since retrying is the correct response when +// the local context is still alive. +func isContextDone(err error) bool { + return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) } diff --git a/flow/client/client_test.go b/flow/client/client_test.go index efe01c003..c8f5f4af4 100644 --- a/flow/client/client_test.go +++ b/flow/client/client_test.go @@ -2,8 +2,11 @@ package client_test import ( "context" + "encoding/binary" "errors" "net" + "sync" + "sync/atomic" "testing" "time" @@ -11,6 +14,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" flow "github.com/netbirdio/netbird/flow/client" "github.com/netbirdio/netbird/flow/proto" @@ -18,21 +23,89 @@ import ( type testServer struct { proto.UnimplementedFlowServiceServer - events chan *proto.FlowEvent - acks chan *proto.FlowEventAck - grpcSrv *grpc.Server - addr string + events chan *proto.FlowEvent + acks chan *proto.FlowEventAck + grpcSrv *grpc.Server + addr string + listener *connTrackListener + closeStream chan struct{} // signal server to close the stream + handlerDone chan struct{} // signaled each time Events() exits + handlerStarted chan struct{} // signaled each time Events() begins +} + +// connTrackListener wraps a net.Listener to track accepted connections +// so tests can forcefully close them to simulate PROTOCOL_ERROR/RST_STREAM. +type connTrackListener struct { + net.Listener + mu sync.Mutex + conns []net.Conn +} + +func (l *connTrackListener) Accept() (net.Conn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + l.mu.Lock() + l.conns = append(l.conns, c) + l.mu.Unlock() + return c, nil +} + +// sendRSTStream writes a raw HTTP/2 RST_STREAM frame with PROTOCOL_ERROR +// (error code 0x1) on every tracked connection. This produces the exact error: +// +// rpc error: code = Internal desc = stream terminated by RST_STREAM with error code: PROTOCOL_ERROR +// +// HTTP/2 RST_STREAM frame format (9-byte header + 4-byte payload): +// +// Length (3 bytes): 0x000004 +// Type (1 byte): 0x03 (RST_STREAM) +// Flags (1 byte): 0x00 +// Stream ID (4 bytes): target stream (must have bit 31 clear) +// Error Code (4 bytes): 0x00000001 (PROTOCOL_ERROR) +func (l *connTrackListener) connCount() int { + l.mu.Lock() + defer l.mu.Unlock() + return len(l.conns) +} + +func (l *connTrackListener) sendRSTStream(streamID uint32) { + l.mu.Lock() + defer l.mu.Unlock() + + frame := make([]byte, 13) // 9-byte header + 4-byte payload + // Length = 4 (3 bytes, big-endian) + frame[0], frame[1], frame[2] = 0, 0, 4 + // Type = RST_STREAM (0x03) + frame[3] = 0x03 + // Flags = 0 + frame[4] = 0x00 + // Stream ID (4 bytes, big-endian, bit 31 reserved = 0) + binary.BigEndian.PutUint32(frame[5:9], streamID) + // Error Code = PROTOCOL_ERROR (0x1) + binary.BigEndian.PutUint32(frame[9:13], 0x1) + + for _, c := range l.conns { + _, _ = c.Write(frame) + } } func newTestServer(t *testing.T) *testServer { - listener, err := net.Listen("tcp", "127.0.0.1:0") + rawListener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) + listener := &connTrackListener{Listener: rawListener} + s := &testServer{ - events: make(chan *proto.FlowEvent, 100), - acks: make(chan *proto.FlowEventAck, 100), - grpcSrv: grpc.NewServer(), - addr: listener.Addr().String(), + events: make(chan *proto.FlowEvent, 100), + acks: make(chan *proto.FlowEventAck, 100), + grpcSrv: grpc.NewServer(), + addr: rawListener.Addr().String(), + listener: listener, + closeStream: make(chan struct{}, 1), + handlerDone: make(chan struct{}, 10), + handlerStarted: make(chan struct{}, 10), } proto.RegisterFlowServiceServer(s.grpcSrv, s) @@ -51,11 +124,23 @@ func newTestServer(t *testing.T) *testServer { } func (s *testServer) Events(stream proto.FlowService_EventsServer) error { + defer func() { + select { + case s.handlerDone <- struct{}{}: + default: + } + }() + err := stream.Send(&proto.FlowEventAck{IsInitiator: true}) if err != nil { return err } + select { + case s.handlerStarted <- struct{}{}: + default: + } + ctx, cancel := context.WithCancel(stream.Context()) defer cancel() @@ -91,6 +176,8 @@ func (s *testServer) Events(stream proto.FlowService_EventsServer) error { if err := stream.Send(ack); err != nil { return err } + case <-s.closeStream: + return status.Errorf(codes.Internal, "server closing stream") case <-ctx.Done(): return ctx.Err() } @@ -110,16 +197,13 @@ func TestReceive(t *testing.T) { assert.NoError(t, err, "failed to close flow") }) - receivedAcks := make(map[string]bool) + var ackCount atomic.Int32 receiveDone := make(chan struct{}) go func() { err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error { if !msg.IsInitiator && len(msg.EventId) > 0 { - id := string(msg.EventId) - receivedAcks[id] = true - - if len(receivedAcks) >= 3 { + if ackCount.Add(1) >= 3 { close(receiveDone) } } @@ -130,7 +214,11 @@ func TestReceive(t *testing.T) { } }() - time.Sleep(500 * time.Millisecond) + select { + case <-server.handlerStarted: + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for stream to be established") + } for i := 0; i < 3; i++ { eventID := uuid.New().String() @@ -153,7 +241,7 @@ func TestReceive(t *testing.T) { t.Fatal("timeout waiting for acks to be processed") } - assert.Equal(t, 3, len(receivedAcks)) + assert.Equal(t, int32(3), ackCount.Load()) } func TestReceive_ContextCancellation(t *testing.T) { @@ -254,3 +342,208 @@ func TestSend(t *testing.T) { t.Fatal("timeout waiting for ack to be received by flow") } } + +func TestNewClient_PermanentClose(t *testing.T) { + server := newTestServer(t) + + client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second) + require.NoError(t, err) + + err = client.Close() + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + + done := make(chan error, 1) + go func() { + done <- client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error { + return nil + }) + }() + + select { + case err := <-done: + require.ErrorIs(t, err, flow.ErrClientClosed) + case <-time.After(2 * time.Second): + t.Fatal("Receive did not return after Close — stuck in retry loop") + } +} + +func TestNewClient_CloseVerify(t *testing.T) { + server := newTestServer(t) + + client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + + done := make(chan error, 1) + go func() { + done <- client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error { + return nil + }) + }() + + closeDone := make(chan struct{}, 1) + go func() { + _ = client.Close() + closeDone <- struct{}{} + }() + + select { + case err := <-done: + require.Error(t, err) + case <-time.After(2 * time.Second): + t.Fatal("Receive did not return after Close — stuck in retry loop") + } + + select { + case <-closeDone: + return + case <-time.After(2 * time.Second): + t.Fatal("Close did not return — blocked in retry loop") + } + +} + +func TestClose_WhileReceiving(t *testing.T) { + server := newTestServer(t) + client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second) + require.NoError(t, err) + + ctx := context.Background() // no timeout — intentional + receiveDone := make(chan struct{}) + go func() { + _ = client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error { + return nil + }) + close(receiveDone) + }() + + // Wait for the server-side handler to confirm the stream is established. + select { + case <-server.handlerStarted: + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for stream to be established") + } + + closeDone := make(chan struct{}) + go func() { + _ = client.Close() + close(closeDone) + }() + + select { + case <-closeDone: + // Close returned — good + case <-time.After(2 * time.Second): + t.Fatal("Close blocked forever — Receive stuck in retry loop") + } + + select { + case <-receiveDone: + case <-time.After(2 * time.Second): + t.Fatal("Receive did not exit after Close") + } +} + +func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) { + server := newTestServer(t) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(cancel) + + client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second) + require.NoError(t, err) + + // Cleanups run LIFO: the goroutine-drain registered here runs after Close below, + // which is when Receive has actually returned. Without this, the Receive goroutine + // can outlive the test and call t.Logf after teardown, panicking. + receiveDone := make(chan struct{}) + t.Cleanup(func() { + select { + case <-receiveDone: + case <-time.After(2 * time.Second): + t.Error("Receive goroutine did not exit after Close") + } + }) + t.Cleanup(func() { + err := client.Close() + assert.NoError(t, err, "failed to close flow") + }) + + // Track acks received before and after server-side stream close + var ackCount atomic.Int32 + receivedFirst := make(chan struct{}) + receivedAfterReconnect := make(chan struct{}) + + go func() { + defer close(receiveDone) + err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error { + if msg.IsInitiator || len(msg.EventId) == 0 { + return nil + } + n := ackCount.Add(1) + if n == 1 { + close(receivedFirst) + } + if n == 2 { + close(receivedAfterReconnect) + } + return nil + }) + if err != nil && !errors.Is(err, context.Canceled) { + t.Logf("receive error: %v", err) + } + }() + + // Wait for stream to be established, then send first ack + select { + case <-server.handlerStarted: + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for stream to be established") + } + server.acks <- &proto.FlowEventAck{EventId: []byte("before-close")} + + select { + case <-receivedFirst: + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for first ack") + } + + // Snapshot connection count before injecting the fault. + connsBefore := server.listener.connCount() + + // Send a raw HTTP/2 RST_STREAM frame with PROTOCOL_ERROR on the TCP connection. + // gRPC multiplexes streams on stream IDs 1, 3, 5, ... (odd, client-initiated). + // Stream ID 1 is the client's first stream (our Events bidi stream). + // This produces the exact error the client sees in production: + // "stream terminated by RST_STREAM with error code: PROTOCOL_ERROR" + server.listener.sendRSTStream(1) + + // Wait for the old Events() handler to fully exit so it can no longer + // drain s.acks and drop our injected ack on a broken stream. + select { + case <-server.handlerDone: + case <-time.After(5 * time.Second): + t.Fatal("old Events() handler did not exit after RST_STREAM") + } + + require.Eventually(t, func() bool { + return server.listener.connCount() > connsBefore + }, 5*time.Second, 50*time.Millisecond, "client did not open a new TCP connection after RST_STREAM") + + server.acks <- &proto.FlowEventAck{EventId: []byte("after-close")} + + select { + case <-receivedAfterReconnect: + // Client successfully reconnected and received ack after server-side stream close + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for ack after server-side stream close — client did not reconnect") + } + + assert.GreaterOrEqual(t, int(ackCount.Load()), 2, "should have received acks before and after stream close") + assert.GreaterOrEqual(t, server.listener.connCount(), 2, "client should have created at least 2 TCP connections (original + reconnect)") +} diff --git a/go.mod b/go.mod index 89bc06fea..8e6a481d2 100644 --- a/go.mod +++ b/go.mod @@ -13,28 +13,29 @@ require ( github.com/onsi/ginkgo v1.16.5 github.com/onsi/gomega v1.27.6 github.com/rs/cors v1.8.0 - github.com/sirupsen/logrus v1.9.3 + github.com/sirupsen/logrus v1.9.4 github.com/spf13/cobra v1.10.1 github.com/spf13/pflag v1.0.9 github.com/vishvananda/netlink v1.3.1 - golang.org/x/crypto v0.48.0 - golang.org/x/sys v0.41.0 + golang.org/x/crypto v0.49.0 + golang.org/x/sys v0.42.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 - google.golang.org/grpc v1.79.3 + google.golang.org/grpc v1.80.0 google.golang.org/protobuf v1.36.11 - gopkg.in/natefinch/lumberjack.v2 v2.0.0 + gopkg.in/natefinch/lumberjack.v2 v2.2.1 ) require ( fyne.io/fyne/v2 v2.7.0 fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9 + git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3 github.com/awnumar/memguard v0.23.0 - github.com/aws/aws-sdk-go-v2 v1.36.3 - github.com/aws/aws-sdk-go-v2/config v1.29.14 - github.com/aws/aws-sdk-go-v2/credentials v1.17.67 - github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2 + github.com/aws/aws-sdk-go-v2 v1.38.3 + github.com/aws/aws-sdk-go-v2/config v1.31.6 + github.com/aws/aws-sdk-go-v2/credentials v1.18.10 + github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3 github.com/c-robinson/iplib v1.0.3 github.com/caddyserver/certmagic v0.21.3 github.com/cilium/ebpf v0.15.0 @@ -42,13 +43,17 @@ require ( github.com/coreos/go-iptables v0.7.0 github.com/coreos/go-oidc/v3 v3.14.1 github.com/creack/pty v1.1.24 + github.com/crowdsecurity/crowdsec v1.7.7 + github.com/crowdsecurity/go-cs-bouncer v0.0.21 github.com/dexidp/dex v0.0.0-00010101000000-000000000000 github.com/dexidp/dex/api/v2 v2.4.0 + github.com/ebitengine/purego v0.8.4 github.com/eko/gocache/lib/v4 v4.2.0 github.com/eko/gocache/store/go_cache/v4 v4.2.2 github.com/eko/gocache/store/redis/v4 v4.2.2 github.com/fsnotify/fsnotify v1.9.0 github.com/gliderlabs/ssh v0.3.8 + github.com/go-jose/go-jose/v4 v4.1.3 github.com/godbus/dbus/v5 v5.1.0 github.com/golang-jwt/jwt/v5 v5.3.0 github.com/golang/mock v1.6.0 @@ -59,15 +64,16 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 - github.com/hashicorp/go-version v1.6.0 + github.com/hashicorp/go-version v1.7.0 github.com/jackc/pgx/v5 v5.5.5 github.com/libdns/route53 v1.5.0 + github.com/libp2p/go-nat v0.2.0 github.com/libp2p/go-netroute v0.2.1 github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 - github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25 + github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/oapi-codegen/runtime v1.1.2 github.com/okta/okta-sdk-golang/v2 v2.18.0 @@ -102,22 +108,22 @@ require ( github.com/yusufpapurcu/wmi v1.2.4 github.com/zcalusic/sysinfo v1.1.3 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 - go.opentelemetry.io/otel v1.42.0 + go.opentelemetry.io/otel v1.43.0 go.opentelemetry.io/otel/exporters/prometheus v0.64.0 - go.opentelemetry.io/otel/metric v1.42.0 - go.opentelemetry.io/otel/sdk/metric v1.42.0 + go.opentelemetry.io/otel/metric v1.43.0 + go.opentelemetry.io/otel/sdk/metric v1.43.0 go.uber.org/mock v0.5.2 go.uber.org/zap v1.27.0 goauthentik.io/api/v3 v3.2023051.3 - golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 + golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b golang.org/x/mobile v0.0.0-20251113184115-a159579294ab - golang.org/x/mod v0.32.0 - golang.org/x/net v0.51.0 - golang.org/x/oauth2 v0.34.0 - golang.org/x/sync v0.19.0 - golang.org/x/term v0.40.0 - golang.org/x/time v0.14.0 - google.golang.org/api v0.257.0 + golang.org/x/mod v0.33.0 + golang.org/x/net v0.52.0 + golang.org/x/oauth2 v0.36.0 + golang.org/x/sync v0.20.0 + golang.org/x/term v0.41.0 + golang.org/x/time v0.15.0 + google.golang.org/api v0.276.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.5.7 gorm.io/driver/postgres v1.5.7 @@ -127,11 +133,11 @@ require ( ) require ( - cloud.google.com/go/auth v0.17.0 // indirect + cloud.google.com/go/auth v0.20.0 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect cloud.google.com/go/compute/metadata v0.9.0 // indirect dario.cat/mergo v1.0.1 // indirect - filippo.io/edwards25519 v1.1.0 // indirect + filippo.io/edwards25519 v1.1.1 // indirect github.com/AppsFlyer/go-sundheit v0.6.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect @@ -142,36 +148,38 @@ require ( github.com/Microsoft/go-winio v0.6.2 // indirect github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect + github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect github.com/awnumar/memcall v0.4.0 // indirect - github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.6 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.6 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.6 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect - github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.34 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.7.0 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.15 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.6 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.8.6 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.6 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.6 // indirect github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect - github.com/aws/smithy-go v1.22.2 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.29.1 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.34.2 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.38.2 // indirect + github.com/aws/smithy-go v1.23.0 // indirect github.com/beevik/etree v1.6.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/caddyserver/zerossl v0.1.3 // indirect + github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/containerd/log v0.1.0 // indirect github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/crowdsecurity/go-cs-lib v0.0.25 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/docker v28.0.1+incompatible // indirect - github.com/docker/go-connections v0.5.0 // indirect + github.com/docker/go-connections v0.6.0 // indirect github.com/docker/go-units v0.5.0 // indirect - github.com/ebitengine/purego v0.8.2 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fredbi/uri v1.1.1 // indirect github.com/fyne-io/gl-js v0.2.0 // indirect @@ -181,43 +189,59 @@ require ( github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect - github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/go-ldap/ldap/v3 v3.4.12 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect + github.com/go-openapi/analysis v0.23.0 // indirect + github.com/go-openapi/errors v0.22.2 // indirect + github.com/go-openapi/jsonpointer v0.21.1 // indirect + github.com/go-openapi/jsonreference v0.21.0 // indirect + github.com/go-openapi/loads v0.22.0 // indirect + github.com/go-openapi/spec v0.21.0 // indirect + github.com/go-openapi/strfmt v0.23.0 // indirect + github.com/go-openapi/swag v0.23.1 // indirect + github.com/go-openapi/validate v0.24.0 // indirect github.com/go-sql-driver/mysql v1.9.3 // indirect github.com/go-text/render v0.2.0 // indirect github.com/go-text/typesetting v0.2.1 // indirect + github.com/goccy/go-yaml v1.18.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang-jwt/jwt/v4 v4.5.2 // indirect github.com/google/btree v1.1.2 // indirect + github.com/google/go-querystring v1.1.0 // indirect github.com/google/s2a-go v0.1.9 // indirect - github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect - github.com/googleapis/gax-go/v2 v2.15.0 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect + github.com/googleapis/gax-go/v2 v2.21.0 // indirect github.com/gorilla/handlers v1.5.2 // indirect github.com/hack-pad/go-indexeddb v0.3.2 // indirect github.com/hack-pad/safejs v0.1.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-uuid v1.0.3 // indirect github.com/huandu/xstrings v1.5.0 // indirect + github.com/huin/goupnp v1.2.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/jackpal/go-nat-pmp v1.0.2 // indirect github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/jonboulle/clockwork v0.5.0 // indirect + github.com/josharian/intern v1.0.0 // indirect github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 // indirect github.com/kelseyhightower/envconfig v1.4.0 // indirect github.com/klauspost/compress v1.18.0 // indirect - github.com/klauspost/cpuid/v2 v2.2.7 // indirect + github.com/klauspost/cpuid/v2 v2.2.10 // indirect + github.com/koron/go-ssdp v0.0.4 // indirect github.com/kr/fs v0.1.0 // indirect github.com/lib/pq v1.10.9 // indirect github.com/libdns/libdns v0.2.2 // indirect github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect github.com/magiconair/properties v1.8.10 // indirect + github.com/mailru/easyjson v0.9.0 // indirect github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect github.com/mattn/go-sqlite3 v1.14.32 // indirect github.com/mdelapenya/tlscert v0.2.0 // indirect @@ -225,6 +249,7 @@ require ( github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect github.com/mholt/acmez/v2 v2.0.1 // indirect github.com/mitchellh/copystructure v1.2.0 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/patternmatcher v0.6.0 // indirect @@ -236,7 +261,8 @@ require ( github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect github.com/nicksnyder/go-i18n/v2 v2.5.1 // indirect - github.com/nxadm/tail v1.4.8 // indirect + github.com/nxadm/tail v1.4.11 // indirect + github.com/oklog/ulid v1.3.1 // indirect github.com/onsi/ginkgo/v2 v2.9.5 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect @@ -246,41 +272,43 @@ require ( github.com/pion/transport/v2 v2.2.4 // indirect github.com/pion/turn/v4 v4.1.1 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.67.5 // indirect github.com/prometheus/otlptranslator v1.0.0 // indirect github.com/prometheus/procfs v0.19.2 // indirect - github.com/russellhaering/goxmldsig v1.5.0 // indirect + github.com/russellhaering/goxmldsig v1.6.0 // indirect github.com/rymdport/portal v0.4.2 // indirect - github.com/shirou/gopsutil/v4 v4.25.1 // indirect + github.com/shirou/gopsutil/v4 v4.25.8 // indirect github.com/shoenig/go-m1cpu v0.2.1 // indirect github.com/shopspring/decimal v1.4.0 // indirect github.com/spf13/cast v1.7.0 // indirect github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect github.com/stretchr/objx v0.5.2 // indirect - github.com/tklauser/go-sysconf v0.3.14 // indirect - github.com/tklauser/numcpus v0.8.0 // indirect + github.com/tklauser/go-sysconf v0.3.15 // indirect + github.com/tklauser/numcpus v0.10.0 // indirect github.com/vishvananda/netns v0.0.5 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/wlynxg/anet v0.0.5 // indirect github.com/yuin/goldmark v1.7.8 // indirect github.com/zeebo/blake3 v0.2.3 // indirect + go.mongodb.org/mongo-driver v1.17.9 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect - go.opentelemetry.io/otel/sdk v1.42.0 // indirect - go.opentelemetry.io/otel/trace v1.42.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 // indirect + go.opentelemetry.io/otel/sdk v1.43.0 // indirect + go.opentelemetry.io/otel/trace v1.43.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect golang.org/x/image v0.33.0 // indirect - golang.org/x/text v0.34.0 // indirect - golang.org/x/tools v0.41.0 // indirect + golang.org/x/text v0.35.0 // indirect + golang.org/x/tools v0.42.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect ) replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 @@ -296,3 +324,5 @@ replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184 replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0 + +replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0 diff --git a/go.sum b/go.sum index 629388ccb..2abf55142 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -cloud.google.com/go/auth v0.17.0 h1:74yCm7hCj2rUyyAocqnFzsAYXgJhrG26XCFimrc/Kz4= -cloud.google.com/go/auth v0.17.0/go.mod h1:6wv/t5/6rOPAX4fJiRjKkJCvswLwdet7G8+UGXt7nCQ= +cloud.google.com/go/auth v0.20.0 h1:kXTssoVb4azsVDoUiF8KvxAqrsQcQtB53DcSgta74CA= +cloud.google.com/go/auth v0.20.0/go.mod h1:942/yi/itH1SsmpyrbnTMDgGfdy2BUqIKyd0cyYLc5Q= cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= @@ -9,12 +9,14 @@ cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw= cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M= dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= -filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= -filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw= +filippo.io/edwards25519 v1.1.1/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= fyne.io/fyne/v2 v2.7.0 h1:GvZSpE3X0liU/fqstInVvRsaboIVpIWQ4/sfjDGIGGQ= fyne.io/fyne/v2 v2.7.0/go.mod h1:xClVlrhxl7D+LT+BWYmcrW4Nf+dJTvkhnPgji7spAwE= fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9 h1:829+77I4TaMrcg9B3wf+gHhdSgoCVEgH2czlPXPbfj4= fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs= +git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3 h1:N3IGoHHp9pb6mj1cbXbuaSXV/UMKwmbKLf53nQmtqMA= +git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3/go.mod h1:QtOLZGz8olr4qH2vWK0QH0w0O4T9fEIjMuWpKUsH7nc= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8= github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl0yUMSk= @@ -40,48 +42,50 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ= github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk= +github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= +github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/awnumar/memcall v0.4.0 h1:B7hgZYdfH6Ot1Goaz8jGne/7i8xD4taZie/PNSFZ29g= github.com/awnumar/memcall v0.4.0/go.mod h1:8xOx1YbfyuCg3Fy6TO8DK0kZUua3V42/goA5Ru47E8w= github.com/awnumar/memguard v0.23.0 h1:sJ3a1/SWlcuKIQ7MV+R9p0Pvo9CWsMbGZvcZQtmc68A= github.com/awnumar/memguard v0.23.0/go.mod h1:olVofBrsPdITtJ2HgxQKrEYEMyIBAIciVG4wNnZhW9M= -github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= -github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 h1:zAybnyUQXIZ5mok5Jqwlf58/TFE7uvd3IAsa1aF9cXs= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10/go.mod h1:qqvMj6gHLR/EXWZw4ZbqlPbQUyenf4h82UQUlKc+l14= -github.com/aws/aws-sdk-go-v2/config v1.29.14 h1:f+eEi/2cKCg9pqKBoAIwRGzVb70MRKqWX4dg1BDcSJM= -github.com/aws/aws-sdk-go-v2/config v1.29.14/go.mod h1:wVPHWcIFv3WO89w0rE10gzf17ZYy+UVS1Geq8Iei34g= -github.com/aws/aws-sdk-go-v2/credentials v1.17.67 h1:9KxtdcIA/5xPNQyZRgUSpYOE6j9Bc4+D7nZua0KGYOM= -github.com/aws/aws-sdk-go-v2/credentials v1.17.67/go.mod h1:p3C44m+cfnbv763s52gCqrjaqyPikj9Sg47kUVaNZQQ= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 h1:x793wxmUWVDhshP8WW2mlnXuFrO4cOd3HLBroh1paFw= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30/go.mod h1:Jpne2tDnYiFascUEs2AWHJL9Yp7A5ZVy3TNyxaAjD6M= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 h1:ZK5jHhnrioRkUNOc+hOgQKlUL5JeC3S6JgLxtQ+Rm0Q= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34/go.mod h1:p4VfIceZokChbA9FzMbRGz5OV+lekcVtHlPKEO0gSZY= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 h1:SZwFm17ZUNNg5Np0ioo/gq8Mn6u9w19Mri8DnJ15Jf0= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34/go.mod h1:dFZsC0BLo346mvKQLWmoJxT+Sjp+qcVR1tRVHQGOH9Q= +github.com/aws/aws-sdk-go-v2 v1.38.3 h1:B6cV4oxnMs45fql4yRH+/Po/YU+597zgWqvDpYMturk= +github.com/aws/aws-sdk-go-v2 v1.38.3/go.mod h1:sDioUELIUO9Znk23YVmIk86/9DOpkbyyVb1i/gUNFXY= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1 h1:i8p8P4diljCr60PpJp6qZXNlgX4m2yQFpYk+9ZT+J4E= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1/go.mod h1:ddqbooRZYNoJ2dsTwOty16rM+/Aqmk/GOXrK8cg7V00= +github.com/aws/aws-sdk-go-v2/config v1.31.6 h1:a1t8fXY4GT4xjyJExz4knbuoxSCacB5hT/WgtfPyLjo= +github.com/aws/aws-sdk-go-v2/config v1.31.6/go.mod h1:5ByscNi7R+ztvOGzeUaIu49vkMk2soq5NaH5PYe33MQ= +github.com/aws/aws-sdk-go-v2/credentials v1.18.10 h1:xdJnXCouCx8Y0NncgoptztUocIYLKeQxrCgN6x9sdhg= +github.com/aws/aws-sdk-go-v2/credentials v1.18.10/go.mod h1:7tQk08ntj914F/5i9jC4+2HQTAuJirq7m1vZVIhEkWs= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.6 h1:wbjnrrMnKew78/juW7I2BtKQwa1qlf6EjQgS69uYY14= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.6/go.mod h1:AtiqqNrDioJXuUgz3+3T0mBWN7Hro2n9wll2zRUc0ww= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.6 h1:uF68eJA6+S9iVr9WgX1NaRGyQ/6MdIyc4JNUo6TN1FA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.6/go.mod h1:qlPeVZCGPiobx8wb1ft0GHT5l+dc6ldnwInDFaMvC7Y= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.6 h1:pa1DEC6JoI0zduhZePp3zmhWvk/xxm4NB8Hy/Tlsgos= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.6/go.mod h1:gxEjPebnhWGJoaDdtDkA0JX46VRg1wcTHYe63OfX5pE= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.34 h1:ZNTqv4nIdE/DiBfUUfXcLZ/Spcuz+RjeziUtNJackkM= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.34/go.mod h1:zf7Vcd1ViW7cPqYWEHLHJkS50X0JS2IKz9Cgaj6ugrs= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3/go.mod h1:0yKJC/kb8sAnmlYa6Zs3QVYqaC8ug2AbnNChv5Ox3uA= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.7.0 h1:lguz0bmOoGzozP9XfRJR1QIayEYo+2vP/No3OfLF0pU= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.7.0/go.mod h1:iu6FSzgt+M2/x3Dk8zhycdIcHjEFb36IS8HVUVFoMg0= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 h1:dM9/92u2F1JbDaGooxTq18wmmFzbJRfXfVfy96/1CXM= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15/go.mod h1:SwFBy2vjtA0vZbjjaFtfN045boopadnoVPhu4Fv66vY= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.15 h1:moLQUoVq91LiqT1nbvzDukyqAlCv89ZmwaHw/ZFlFZg= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.15/go.mod h1:ZH34PJUc8ApjBIfgQCFvkWcUDBtl/WTD+uiYHjd8igA= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.6 h1:R0tNFJqfjHL3900cqhXuwQ+1K4G0xc9Yf8EDbFXCKEw= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.6/go.mod h1:y/7sDdu+aJvPtGXr4xYosdpq9a6T9Z0jkXfugmti0rI= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1 h1:oegbebPEMA/1Jny7kvwejowCaHz1FWZAQ94WXFNCyTM= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1/go.mod h1:kemo5Myr9ac0U9JfSjMo9yHLtw+pECEHsFtJ9tqCEI8= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.8.6 h1:hncKj/4gR+TPauZgTAsxOxNcvBayhUlYZ6LO/BYiQ30= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.8.6/go.mod h1:OiIh45tp6HdJDDJGnja0mw8ihQGz3VGrUflLqSL0SmM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.6 h1:LHS1YAIJXJ4K9zS+1d/xa9JAA9sL2QyXIQCQFQW/X08= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.6/go.mod h1:c9PCiTEuh0wQID5/KqA32J+HAgZxN9tOGXKCiYJjTZI= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.6 h1:nEXUSAwyUfLTgnc9cxlDWy637qsq4UWwp3sNAfl0Z3Y= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.6/go.mod h1:HGzIULx4Ge3Do2V0FaiYKcyKzOqwrhUZgCI77NisswQ= github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3 h1:MmLCRqP4U4Cw9gJ4bNrCG0mWqEtBlmAVleyelcHARMU= github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3/go.mod h1:AMPjK2YnRh0YgOID3PqhJA1BRNfXDfGOnSsKHtAe8yA= -github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2 h1:tWUG+4wZqdMl/znThEk9tcCy8tTMxq8dW0JTgamohrY= -github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2/go.mod h1:U5SNqwhXB3Xe6F47kXvWihPl/ilGaEDe8HD/50Z9wxc= -github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 h1:1Gw+9ajCV1jogloEv1RRnvfRFia2cL6c9cuKV2Ps+G8= -github.com/aws/aws-sdk-go-v2/service/sso v1.25.3/go.mod h1:qs4a9T5EMLl/Cajiw2TcbNt2UNo/Hqlyp+GiuG4CFDI= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 h1:hXmVKytPfTy5axZ+fYbR5d0cFmC3JvwLm5kM83luako= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs= -github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/XvaX32evhproijJEZY= -github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= -github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= -github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= +github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3 h1:ETkfWcXP2KNPLecaDa++5bsQhCRa5M5sLUJa5DWYIIg= +github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3/go.mod h1:+/3ZTqoYb3Ur7DObD00tarKMLMuKg8iqz5CHEanqTnw= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.1 h1:8OLZnVJPvjnrxEwHFg9hVUof/P4sibH+Ea4KKuqAGSg= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.1/go.mod h1:27M3BpVi0C02UiQh1w9nsBEit6pLhlaH3NHna6WUbDE= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.34.2 h1:gKWSTnqudpo8dAxqBqZnDoDWCiEh/40FziUjr/mo6uA= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.34.2/go.mod h1:x7+rkNmRoEN1U13A6JE2fXne9EWyJy54o3n6d4mGaXQ= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.2 h1:YZPjhyaGzhDQEvsffDEcpycq49nl7fiGcfJTIo8BszI= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.2/go.mod h1:2dIN8qhQfv37BdUYGgEC8Q3tteM3zFxTI1MLO2O3J3c= +github.com/aws/smithy-go v1.23.0 h1:8n6I3gXzWJB2DxBDnfxgBaSX6oe0d/t10qGz7OKqMCE= +github.com/aws/smithy-go v1.23.0/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= github.com/beevik/etree v1.6.0 h1:u8Kwy8pp9D9XeITj2Z0XtA5qqZEmtJtuXZRQi+j03eE= github.com/beevik/etree v1.6.0/go.mod h1:bh4zJxiIr62SOf9pRzN7UUYaEDa9HEKafK25+sLc0Gc= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -99,6 +103,8 @@ github.com/caddyserver/zerossl v0.1.3 h1:onS+pxp3M8HnHpN5MMbOMyNjmTheJyWRaZYwn+Y github.com/caddyserver/zerossl v0.1.3/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= +github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk= @@ -118,11 +124,18 @@ github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHf github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= +github.com/crowdsecurity/crowdsec v1.7.7 h1:sduZN763iXsrZodocWDrsR//7nLeffGu+RVkkIsbQkE= +github.com/crowdsecurity/crowdsec v1.7.7/go.mod h1:L1HLGPDnBYCcY+yfSFnuBbQ1G9DHEJN9c+Kevv9F+4Q= +github.com/crowdsecurity/go-cs-bouncer v0.0.21 h1:arPz0VtdVSaz+auOSfHythzkZVLyy18CzYvYab8UJDU= +github.com/crowdsecurity/go-cs-bouncer v0.0.21/go.mod h1:4JiH0XXA4KKnnWThItUpe5+heJHWzsLOSA2IWJqUDBA= +github.com/crowdsecurity/go-cs-lib v0.0.25 h1:Ov6VPW9yV+OPsbAIQk1iTkEWhwkpaG0v3lrBzeqjzj4= +github.com/crowdsecurity/go-cs-lib v0.0.25/go.mod h1:X0GMJY2CxdA1S09SpuqIKaWQsvRGxXmecUp9cP599dE= github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:/DS5cDX3FJdl+XaN2D7XAwFpuanTxnp52DBLZAaJKx0= github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dexidp/dex/api/v2 v2.4.0 h1:gNba7n6BKVp8X4Jp24cxYn5rIIGhM6kDOXcZoL6tr9A= github.com/dexidp/dex/api/v2 v2.4.0/go.mod h1:/p550ADvFFh7K95VmhUD+jgm15VdaNnab9td8DHOpyI= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= @@ -131,12 +144,12 @@ github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5Qvfr github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= github.com/docker/docker v28.0.1+incompatible h1:FCHjSRdXhNRFjlHMTv4jUNlIBbTeRjrWfeFuJp7jpo0= github.com/docker/docker v28.0.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= -github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= -github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= +github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94= +github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= -github.com/ebitengine/purego v0.8.2 h1:jPPGWs2sZ1UgOSgD2bClL0MJIqu58nOmIcBuXr62z1I= -github.com/ebitengine/purego v0.8.2/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= +github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw= +github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= github.com/eko/gocache/lib/v4 v4.2.0 h1:MNykyi5Xw+5Wu3+PUrvtOCaKSZM1nUSVftbzmeC7Yuw= github.com/eko/gocache/lib/v4 v4.2.0/go.mod h1:7ViVmbU+CzDHzRpmB4SXKyyzyuJ8A3UW3/cszpcqB4M= github.com/eko/gocache/store/go_cache/v4 v4.2.2 h1:tAI9nl6TLoJyKG1ujF0CS0n/IgTEMl+NivxtR5R3/hw= @@ -155,6 +168,7 @@ github.com/fredbi/uri v1.1.1 h1:xZHJC08GZNIUhbP5ImTHnt5Ya0T8FI2VAwI/37kh2Ko= github.com/fredbi/uri v1.1.1/go.mod h1:4+DZQ5zBjEwQCDmXW5JdIjz0PUA+yJbvtBv+u+adr5o= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/fyne-io/gl-js v0.2.0 h1:+EXMLVEa18EfkXBVKhifYB6OGs3HwKO3lUElA0LlAjs= @@ -187,6 +201,24 @@ github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= +github.com/go-openapi/analysis v0.23.0 h1:aGday7OWupfMs+LbmLZG4k0MYXIANxcuBTYUC03zFCU= +github.com/go-openapi/analysis v0.23.0/go.mod h1:9mz9ZWaSlV8TvjQHLl2mUW2PbZtemkE8yA5v22ohupo= +github.com/go-openapi/errors v0.22.2 h1:rdxhzcBUazEcGccKqbY1Y7NS8FDcMyIRr0934jrYnZg= +github.com/go-openapi/errors v0.22.2/go.mod h1:+n/5UdIqdVnLIJ6Q9Se8HNGUXYaY6CN8ImWzfi/Gzp0= +github.com/go-openapi/jsonpointer v0.21.1 h1:whnzv/pNXtK2FbX/W9yJfRmE2gsmkfahjMKB0fZvcic= +github.com/go-openapi/jsonpointer v0.21.1/go.mod h1:50I1STOfbY1ycR8jGz8DaMeLCdXiI6aDteEdRNNzpdk= +github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ= +github.com/go-openapi/jsonreference v0.21.0/go.mod h1:LmZmgsrTkVg9LG4EaHeY8cBDslNPMo06cago5JNLkm4= +github.com/go-openapi/loads v0.22.0 h1:ECPGd4jX1U6NApCGG1We+uEozOAvXvJSF4nnwHZ8Aco= +github.com/go-openapi/loads v0.22.0/go.mod h1:yLsaTCS92mnSAZX5WWoxszLj0u+Ojl+Zs5Stn1oF+rs= +github.com/go-openapi/spec v0.21.0 h1:LTVzPc3p/RzRnkQqLRndbAzjY0d0BCL72A6j3CdL9ZY= +github.com/go-openapi/spec v0.21.0/go.mod h1:78u6VdPw81XU44qEWGhtr982gJ5BWg2c0I5XwVMotYk= +github.com/go-openapi/strfmt v0.23.0 h1:nlUS6BCqcnAk0pyhi9Y+kdDVZdZMHfEKQiS4HaMgO/c= +github.com/go-openapi/strfmt v0.23.0/go.mod h1:NrtIpfKtWIygRkKVsxh7XQMDQW5HKQl6S5ik2elW+K4= +github.com/go-openapi/swag v0.23.1 h1:lpsStH0n2ittzTnbaSloVZLuB5+fvSY/+hnagBjSNZU= +github.com/go-openapi/swag v0.23.1/go.mod h1:STZs8TbRvEQQKUA+JZNAm3EWlgaOBGpyFDqQnDHMef0= +github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3BumrGD58= +github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ= github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM= github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY= github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI= @@ -203,10 +235,14 @@ github.com/go-text/typesetting v0.2.1 h1:x0jMOGyO3d1qFAPI0j4GSsh7M0Q3Ypjzr4+CEVg github.com/go-text/typesetting v0.2.1/go.mod h1:mTOxEwasOFpAMBjEQDhdWRckoLLeI/+qrQeBCTGEt6M= github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066 h1:qCuYC+94v2xrb1PoS4NIDe7DGYtLnU2wWiQe9a1B1c0= github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066/go.mod h1:DDxDdQEnB70R8owOx3LVpEFvpMK9eeH1o2r0yZhFI9o= +github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw= +github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= +github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= @@ -230,6 +266,7 @@ github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl76 github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= @@ -237,6 +274,8 @@ github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= +github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= @@ -248,10 +287,10 @@ github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/googleapis/enterprise-certificate-proxy v0.3.7 h1:zrn2Ee/nWmHulBx5sAVrGgAa0f2/R35S4DJwfFaUPFQ= -github.com/googleapis/enterprise-certificate-proxy v0.3.7/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= -github.com/googleapis/gax-go/v2 v2.15.0 h1:SyjDc1mGgZU5LncH8gimWo9lW1DtIfPibOG81vgd/bo= -github.com/googleapis/gax-go/v2 v2.15.0/go.mod h1:zVVkkxAQHa1RQpg9z2AUCMnKhi0Qld9rcmyfL1OZhoc= +github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8= +github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg= +github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI= +github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4= github.com/gopacket/gopacket v1.1.1 h1:zbx9F9d6A7sWNkFKrvMBZTfGgxFoY4NgUudFVVHMfcw= github.com/gopacket/gopacket v1.1.1/go.mod h1:HavMeONEl7W9036of9LbSWoonqhH7HA1+ZRO+rMIvFs= github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE= @@ -276,11 +315,13 @@ github.com/hashicorp/go-secure-stdlib/base62 v0.1.2/go.mod h1:EdWO6czbmthiwZ3/PU github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= -github.com/hashicorp/go-version v1.6.0 h1:feTTfFNnjP967rlCxM/I9g701jU+RN74YKx2mOkIeek= -github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= +github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY= +github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= +github.com/huin/goupnp v1.2.0 h1:uOKW26NG1hsSSbXIZ1IR7XP9Gjd1U8pnLaCMgntmkmY= +github.com/huin/goupnp v1.2.0/go.mod h1:gnGPsThkYa7bFi/KWmEysQRf48l2dvR5bxr2OFckNX8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= @@ -291,6 +332,8 @@ github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus= +github.com/jackpal/go-nat-pmp v1.0.2/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc= github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= @@ -315,6 +358,8 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGw github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/jonboulle/clockwork v0.5.0 h1:Hyh9A8u51kptdkR+cqRpT1EebBwTn1oK9YfGYbdFz6I= github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M= github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw= @@ -326,8 +371,10 @@ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+o github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= -github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= -github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE= +github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/koron/go-ssdp v0.0.4 h1:1IDwrghSKYM7yLf7XCzbByg2sJ/JcNOZRXS2jczTwz0= +github.com/koron/go-ssdp v0.0.4/go.mod h1:oDXq+E5IL5q0U8uSBcoAXzTzInwy5lEgC91HoKtbmZk= github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -346,6 +393,8 @@ github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s= github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ= github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA= github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q= +github.com/libp2p/go-nat v0.2.0 h1:Tyz+bUFAYqGyJ/ppPPymMGbIgNRH+WqC5QrT5fKrrGk= +github.com/libp2p/go-nat v0.2.0/go.mod h1:3MJr+GRpRkyT65EpVPBstXLvOlAPzUVlG6Pwg9ohLJk= github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 h1:J56rFEfUTFT9j9CiRXhi1r8lUJ4W5idG3CiaBZGojNU= github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81/go.mod h1:RD8ML/YdXctQ7qbcizZkw5mZ6l8Ogrl1dodBzVJduwI= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= @@ -376,6 +425,8 @@ github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa1 github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4= github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= @@ -398,12 +449,14 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U= github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU= +github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus= +github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25 h1:iwAq/Ncaq0etl4uAlVsbNBzC1yY52o0AmY7uCm2AMTs= -github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY= +github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 h1:F3zS5fT9xzD1OFLfcdAE+3FfyiwjGukF1hvj0jErgs8= +github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42/go.mod h1:n47r67ZSPgwSmT/Z1o48JjZQW9YJ6m/6Bd/uAXkL3Pg= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= @@ -415,10 +468,13 @@ github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk= github.com/nicksnyder/go-i18n/v2 v2.5.1/go.mod h1:DrhgsSDZxoAfvVrBVLXoxZn/pN5TXqaDbq7ju94viiQ= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= -github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= +github.com/nxadm/tail v1.4.11 h1:8feyoE3OzPrcshW5/MJ4sGESc5cqmGkGCWlco4l0bqY= +github.com/nxadm/tail v1.4.11/go.mod h1:OTaG3NK980DZzxbRq6lEuzgU+mug70nY11sMd4JXXHc= github.com/oapi-codegen/runtime v1.1.2 h1:P2+CubHq8fO4Q6fV1tqDBZHCwpVpvPg7oKiYzQgXIyI= github.com/oapi-codegen/runtime v1.1.2/go.mod h1:SK9X900oXmPWilYR5/WKPzt3Kqxn/uS/+lbpREv+eCg= +github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= +github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/okta/okta-sdk-golang/v2 v2.18.0 h1:cfDasMb7CShbZvOrF6n+DnLevWwiHgedWMGJ8M8xKDc= github.com/okta/okta-sdk-golang/v2 v2.18.0/go.mod h1:dz30v3ctAiMb7jpsCngGfQUAEGm1/NsWT92uTbNDQIs= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= @@ -439,8 +495,8 @@ github.com/oschwald/maxminddb-golang v1.12.0 h1:9FnTOD0YOhP7DGxGsq4glzpGy5+w7pq5 github.com/oschwald/maxminddb-golang v1.12.0/go.mod h1:q0Nob5lTCqyQ8WT6FYgS1L7PXKVVbgiymefNwIjPzgY= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= -github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0= -github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203 h1:E7Kmf11E4K7B5hDti2K2NqPb1nlYlGYsu02S1JNd/Bs= github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= @@ -478,8 +534,9 @@ github.com/pkg/profile v1.7.0 h1:hnbDkaNWPCLMO9wGLdBFTIZvzDrDfBM2072E1S9gJkA= github.com/pkg/profile v1.7.0/go.mod h1:8Uer0jas47ZQMJ7VD+OHknK4YDY07LPUC6dEvqDjvNo= github.com/pkg/sftp v1.13.9 h1:4NGkvGudBL7GteO3m6qnaQ4pC0Kvf0onSVc9gR3EWBw= github.com/pkg/sftp v1.13.9/go.mod h1:OBN7bVXdstkFFN/gdnHPUb5TE8eb8G1Rp9wCItqjkkA= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= @@ -503,15 +560,15 @@ github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so= github.com/rs/cors v1.8.0/go.mod h1:EBwu+T5AvHOcXwvZIkQFjUN6s8Czyqw12GL/Y0tUyRM= github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4= github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= -github.com/russellhaering/goxmldsig v1.5.0 h1:AU2UkkYIUOTyZRbe08XMThaOCelArgvNfYapcmSjBNw= -github.com/russellhaering/goxmldsig v1.5.0/go.mod h1:x98CjQNFJcWfMxeOrMnMKg70lvDP6tE0nTaeUnjXDmk= +github.com/russellhaering/goxmldsig v1.6.0 h1:8fdWXEPh2k/NZNQBPFNoVfS3JmzS4ZprY/sAOpKQLks= +github.com/russellhaering/goxmldsig v1.6.0/go.mod h1:TrnaquDcYxWXfJrOjeMBTX4mLBeYAqaHEyUeWPxZlBM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/rymdport/portal v0.4.2 h1:7jKRSemwlTyVHHrTGgQg7gmNPJs88xkbKcIL3NlcmSU= github.com/rymdport/portal v0.4.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4= github.com/shirou/gopsutil/v3 v3.24.4 h1:dEHgzZXt4LMNm+oYELpzl9YCqV65Yr/6SfrvgRBtXeU= github.com/shirou/gopsutil/v3 v3.24.4/go.mod h1:lTd2mdiOspcqLgAnr9/nGi71NkeMpWKdmhuxm9GusH8= -github.com/shirou/gopsutil/v4 v4.25.1 h1:QSWkTc+fu9LTAWfkZwZ6j8MSUk4A2LV7rbH0ZqmLjXs= -github.com/shirou/gopsutil/v4 v4.25.1/go.mod h1:RoUCUpndaJFtT+2zsZzzmhvbfGoDCJ7nFXKJf8GqJbI= +github.com/shirou/gopsutil/v4 v4.25.8 h1:NnAsw9lN7587WHxjJA9ryDnqhJpFH6A+wagYWTOH970= +github.com/shirou/gopsutil/v4 v4.25.8/go.mod h1:q9QdMmfAOVIw7a+eF86P7ISEU6ka+NLgkUxlopV4RwI= github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= github.com/shoenig/go-m1cpu v0.2.1 h1:yqRB4fvOge2+FyRXFkXqsyMoqPazv14Yyy+iyccT2E4= github.com/shoenig/go-m1cpu v0.2.1/go.mod h1:KkDOw6m3ZJQAPHbrzkZki4hnx+pDRR1Lo+ldA56wD5w= @@ -520,8 +577,8 @@ github.com/shoenig/test v1.7.0 h1:eWcHtTXa6QLnBvm0jgEabMRN/uJ4DMV3M8xUGgRkZmk= github.com/shoenig/test v1.7.0/go.mod h1:UxJ6u/x2v/TNs/LoLxBNJRV9DiwBBKYxXSyczsBHFoI= github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= -github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= -github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w= +github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g= github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA= github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= @@ -570,11 +627,11 @@ github.com/ti-mo/conntrack v0.5.1/go.mod h1:T6NCbkMdVU4qEIgwL0njA6lw/iCAbzchlnwm github.com/ti-mo/netfilter v0.5.2 h1:CTjOwFuNNeZ9QPdRXt1MZFLFUf84cKtiQutNauHWd40= github.com/ti-mo/netfilter v0.5.2/go.mod h1:Btx3AtFiOVdHReTDmP9AE+hlkOcvIy403u7BXXbWZKo= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= -github.com/tklauser/go-sysconf v0.3.14 h1:g5vzr9iPFFz24v2KZXs/pvpvh8/V9Fw6vQK5ZZb78yU= -github.com/tklauser/go-sysconf v0.3.14/go.mod h1:1ym4lWMLUOhuBOPGtRcJm7tEGX4SCYNEEEtghGG/8uY= +github.com/tklauser/go-sysconf v0.3.15 h1:VE89k0criAymJ/Os65CSn1IXaol+1wrsFHEB8Ol49K4= +github.com/tklauser/go-sysconf v0.3.15/go.mod h1:Dmjwr6tYFIseJw7a3dRLJfsHAMXZ3nEnL/aZY+0IuI4= github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= -github.com/tklauser/numcpus v0.8.0 h1:Mx4Wwe/FjZLeQsK/6kt2EOepwwSl7SmJrK5bV/dXYgY= -github.com/tklauser/numcpus v0.8.0/go.mod h1:ZJZlAY+dmR4eut8epnzf0u/VwodKmryxR8txiloSqBE= +github.com/tklauser/numcpus v0.10.0 h1:18njr6LDBk1zuna922MgdjQuJFjrdppsZG60sHGfjso= +github.com/tklauser/numcpus v0.10.0/go.mod h1:BiTKazU708GQTYF4mB+cmlpT2Is1gLk7XVuEeem8LsQ= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= @@ -603,28 +660,30 @@ github.com/zeebo/blake3 v0.2.3 h1:TFoLXsjeXqRNFxSbk35Dk4YtszE/MQQGK10BH4ptoTg= github.com/zeebo/blake3 v0.2.3/go.mod h1:mjJjZpnsyIVtVgTOSpJ9vmRE4wgDeyt2HU3qXvvKCaQ= github.com/zeebo/pcg v1.0.1 h1:lyqfGeWiv4ahac6ttHs+I5hwtH/+1mrhlCtVNQM2kHo= github.com/zeebo/pcg v1.0.1/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4= +go.mongodb.org/mongo-driver v1.17.9 h1:IexDdCuuNJ3BHrELgBlyaH9p60JXAvdzWR128q+U5tU= +go.mongodb.org/mongo-driver v1.17.9/go.mod h1:LlOhpH5NUEfhxcAwG0UEkMqwYcc4JU18gtCdGudk/tQ= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 h1:F7Jx+6hwnZ41NSFTO5q4LYDtJRXBf2PD0rNBkeB/lus= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0/go.mod h1:UHB22Z8QsdRDrnAtX4PntOl36ajSxcdUMt1sF7Y6E7Q= -go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho= -go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 h1:OyrsyzuttWTSur2qN/Lm0m2a8yqyIjUVBZcxFPuXq2o= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0/go.mod h1:C2NGBr+kAB4bk3xtMXfZ94gqFDtg/GkI7e9zqGh5Beg= +go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= +go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU= go.opentelemetry.io/otel/exporters/prometheus v0.64.0 h1:g0LRDXMX/G1SEZtK8zl8Chm4K6GBwRkjPKE36LxiTYs= go.opentelemetry.io/otel/exporters/prometheus v0.64.0/go.mod h1:UrgcjnarfdlBDP3GjDIJWe6HTprwSazNjwsI+Ru6hro= -go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4= -go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI= -go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo= -go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts= -go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA= -go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc= -go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY= -go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc= +go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM= +go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY= +go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg= +go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg= +go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw= +go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A= +go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= +go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= @@ -650,10 +709,10 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= -golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= -golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= -golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= +golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ= golang.org/x/image v0.33.0/go.mod h1:DD3OsTYT9chzuzTQt+zMcOlBHgfoKQb1gry8p76Y1sc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= @@ -668,8 +727,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c= -golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= @@ -688,11 +747,11 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= -golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= -golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= -golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -704,8 +763,8 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -725,8 +784,8 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -740,8 +799,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= -golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -754,8 +813,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= -golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= -golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= +golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= +golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -767,10 +826,10 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= -golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= -golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= -golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= -golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= @@ -782,8 +841,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= -golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= -golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -794,19 +853,19 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvY golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= -gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= -gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= -google.golang.org/api v0.257.0 h1:8Y0lzvHlZps53PEaw+G29SsQIkuKrumGWs9puiexNAA= -google.golang.org/api v0.257.0/go.mod h1:4eJrr+vbVaZSqs7vovFd1Jb/A6ml6iw2e6FBYf3GAO4= +gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= +gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= +google.golang.org/api v0.276.0 h1:nVArUtfLEihtW+b0DdcqRGK1xoEm2+ltAihyztq7MKY= +google.golang.org/api v0.276.0/go.mod h1:Fnag/EWUPIcJXuIkP1pjoTgS5vdxlk3eeemL7Do6bvw= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4= -google.golang.org/genproto v0.0.0-20250603155806-513f23925822/go.mod h1:HubltRL7rMh0LfnQPkMH4NPDFEWp0jw3vixw7jEM53s= -google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 h1:fCvbg86sFXwdrl5LgVcTEvNC+2txB5mgROGmRL5mrls= -google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 h1:ggcbiqK8WWh6l1dnltU4BgWGIGo+EVYxCaAPih/zQXQ= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= -google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE= -google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= +google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0= +google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I= +google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7 h1:41r6JMbpzBMen0R/4TZeeAmGXSJC7DftGINUodzTkPI= +google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= +google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= +google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -828,8 +887,8 @@ gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8 gopkg.in/go-playground/validator.v9 v9.29.1/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= -gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8= -gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI= gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= diff --git a/idp/dex/config.go b/idp/dex/config.go index 3db04a4cb..e686233ad 100644 --- a/idp/dex/config.go +++ b/idp/dex/config.go @@ -170,20 +170,68 @@ type Connector struct { } // ToStorageConnector converts a Connector to storage.Connector type. +// It maps custom connector types (e.g., "zitadel", "entra") to Dex-native types +// and augments the config with OIDC defaults when needed. func (c *Connector) ToStorageConnector() (storage.Connector, error) { - data, err := json.Marshal(c.Config) + dexType, augmentedConfig := mapConnectorToDex(c.Type, c.Config) + + data, err := json.Marshal(augmentedConfig) if err != nil { return storage.Connector{}, fmt.Errorf("failed to marshal connector config: %v", err) } return storage.Connector{ ID: c.ID, - Type: c.Type, + Type: dexType, Name: c.Name, Config: data, }, nil } +// mapConnectorToDex maps custom connector types to Dex-native types and applies +// OIDC defaults. This ensures static connectors from config files or env vars +// are stored with types that Dex can open. +func mapConnectorToDex(connType string, config map[string]interface{}) (string, map[string]interface{}) { + switch connType { + case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs": + return "oidc", applyOIDCDefaults(connType, config) + default: + return connType, config + } +} + +// applyOIDCDefaults clones the config map, sets common OIDC defaults, +// and applies provider-specific overrides. +func applyOIDCDefaults(connType string, config map[string]interface{}) map[string]interface{} { + augmented := make(map[string]interface{}, len(config)+4) + for k, v := range config { + augmented[k] = v + } + setDefault(augmented, "scopes", []string{"openid", "profile", "email"}) + setDefault(augmented, "insecureEnableGroups", true) + setDefault(augmented, "insecureSkipEmailVerified", true) + + switch connType { + case "zitadel": + setDefault(augmented, "getUserInfo", true) + case "entra": + setDefault(augmented, "claimMapping", map[string]string{"email": "preferred_username"}) + case "okta", "pocketid": + augmented["scopes"] = []string{"openid", "profile", "email", "groups"} + case "adfs": + augmented["scopes"] = []string{"openid", "profile", "email", "allatclaims"} + } + + return augmented +} + +// setDefault sets a key in the map only if it doesn't already exist. +func setDefault(m map[string]interface{}, key string, value interface{}) { + if _, ok := m[key]; !ok { + m[key] = value + } +} + // StorageConfig is a configuration that can create a storage. type StorageConfig interface { Open(logger *slog.Logger) (storage.Storage, error) diff --git a/idp/dex/connector.go b/idp/dex/connector.go index ba2bb1f00..8aba92999 100644 --- a/idp/dex/connector.go +++ b/idp/dex/connector.go @@ -168,7 +168,7 @@ func (p *Provider) buildStorageConnector(cfg *ConnectorConfig) (storage.Connecto var err error switch cfg.Type { - case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak": + case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs": dexType = "oidc" configData, err = buildOIDCConnectorConfig(cfg, redirectURI) case "google": @@ -220,6 +220,8 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"} case "pocketid": oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"} + case "adfs": + oidcConfig["scopes"] = []string{"openid", "profile", "email", "allatclaims"} } return encodeConnectorConfig(oidcConfig) } @@ -283,7 +285,7 @@ func inferIdentityProviderType(dexType, connectorID string, _ map[string]interfa // inferOIDCProviderType infers the specific OIDC provider from connector ID func inferOIDCProviderType(connectorID string) string { connectorIDLower := strings.ToLower(connectorID) - for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak"} { + for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak", "adfs"} { if strings.Contains(connectorIDLower, provider) { return provider } diff --git a/idp/dex/provider.go b/idp/dex/provider.go index 68fe48486..24aed1b99 100644 --- a/idp/dex/provider.go +++ b/idp/dex/provider.go @@ -4,6 +4,7 @@ package dex import ( "context" "encoding/base64" + "encoding/json" "errors" "fmt" "log/slog" @@ -19,10 +20,13 @@ import ( "github.com/dexidp/dex/server" "github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage/sql" + jose "github.com/go-jose/go-jose/v4" "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "golang.org/x/crypto/bcrypt" "google.golang.org/grpc" + + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" ) // Config matches what management/internals/server/server.go expects @@ -666,3 +670,46 @@ func (p *Provider) GetAuthorizationEndpoint() string { } return issuer + "/auth" } + +// GetJWKS reads signing keys directly from Dex storage and returns them as Jwks. +// This avoids HTTP round-trips when the embedded IDP is co-located with the management server. +// The key retrieval mirrors Dex's own handlePublicKeys/ValidationKeys logic: +// SigningKeyPub first, then all VerificationKeys, serialized via go-jose. +func (p *Provider) GetJWKS(ctx context.Context) (*nbjwt.Jwks, error) { + keys, err := p.storage.GetKeys(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get keys from storage: %w", err) + } + + if keys.SigningKeyPub == nil { + return nil, fmt.Errorf("no public keys found in storage") + } + + // Build the key set exactly as Dex's localSigner.ValidationKeys does: + // signing key first, then all verification (rotated) keys. + joseKeys := make([]jose.JSONWebKey, 0, len(keys.VerificationKeys)+1) + joseKeys = append(joseKeys, *keys.SigningKeyPub) + for _, vk := range keys.VerificationKeys { + if vk.PublicKey != nil { + joseKeys = append(joseKeys, *vk.PublicKey) + } + } + + // Serialize through go-jose (same as Dex's handlePublicKeys handler) + // then deserialize into our Jwks type, so the JSON field mapping is identical + // to what the /keys HTTP endpoint would return. + joseSet := jose.JSONWebKeySet{Keys: joseKeys} + data, err := json.Marshal(joseSet) + if err != nil { + return nil, fmt.Errorf("failed to marshal JWKS: %w", err) + } + + jwks := &nbjwt.Jwks{} + if err := json.Unmarshal(data, jwks); err != nil { + return nil, fmt.Errorf("failed to unmarshal JWKS: %w", err) + } + + jwks.ExpiresInTime = keys.NextRotation + + return jwks, nil +} diff --git a/idp/dex/provider_test.go b/idp/dex/provider_test.go index bd2f676fb..4ed89fd2e 100644 --- a/idp/dex/provider_test.go +++ b/idp/dex/provider_test.go @@ -2,11 +2,14 @@ package dex import ( "context" + "encoding/json" "log/slog" "os" "path/filepath" "testing" + "github.com/dexidp/dex/storage" + sqllib "github.com/dexidp/dex/storage/sql" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -197,6 +200,295 @@ enablePasswordDB: true t.Logf("User lookup successful: rawID=%s, connectorID=%s", rawID, connID) } +// openTestStorage creates a SQLite storage in the given directory for testing. +func openTestStorage(t *testing.T, tmpDir string) storage.Storage { + t.Helper() + logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) + stor, err := (&sqllib.SQLite3{File: filepath.Join(tmpDir, "dex.db")}).Open(logger) + require.NoError(t, err) + return stor +} + +func TestStaticConnectors_CreatedFromYAML(t *testing.T) { + ctx := context.Background() + + tmpDir, err := os.MkdirTemp("", "dex-static-conn-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + yamlContent := ` +issuer: http://localhost:5556/dex +storage: + type: sqlite3 + config: + file: ` + filepath.Join(tmpDir, "dex.db") + ` +web: + http: 127.0.0.1:5556 +enablePasswordDB: true +connectors: +- type: oidc + id: my-oidc + name: My OIDC Provider + config: + issuer: https://accounts.example.com + clientID: test-client-id + clientSecret: test-client-secret + redirectURI: http://localhost:5556/dex/callback +` + configPath := filepath.Join(tmpDir, "config.yaml") + err = os.WriteFile(configPath, []byte(yamlContent), 0644) + require.NoError(t, err) + + yamlConfig, err := LoadConfig(configPath) + require.NoError(t, err) + + // Open storage and run initializeStorage directly (avoids Dex server + // trying to dial the OIDC issuer) + stor := openTestStorage(t, tmpDir) + defer stor.Close() + + err = initializeStorage(ctx, stor, yamlConfig) + require.NoError(t, err) + + // Verify connector was created in storage + conn, err := stor.GetConnector(ctx, "my-oidc") + require.NoError(t, err) + assert.Equal(t, "my-oidc", conn.ID) + assert.Equal(t, "My OIDC Provider", conn.Name) + assert.Equal(t, "oidc", conn.Type) + + // Verify config fields were serialized correctly + var configMap map[string]interface{} + err = json.Unmarshal(conn.Config, &configMap) + require.NoError(t, err) + assert.Equal(t, "https://accounts.example.com", configMap["issuer"]) + assert.Equal(t, "test-client-id", configMap["clientID"]) +} + +func TestStaticConnectors_UpdatedOnRestart(t *testing.T) { + ctx := context.Background() + + tmpDir, err := os.MkdirTemp("", "dex-static-conn-update-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + dbFile := filepath.Join(tmpDir, "dex.db") + + // First: load config with initial connector + yamlContent1 := ` +issuer: http://localhost:5556/dex +storage: + type: sqlite3 + config: + file: ` + dbFile + ` +web: + http: 127.0.0.1:5556 +enablePasswordDB: true +connectors: +- type: oidc + id: my-oidc + name: Original Name + config: + issuer: https://accounts.example.com + clientID: original-client-id + clientSecret: original-secret +` + configPath := filepath.Join(tmpDir, "config.yaml") + err = os.WriteFile(configPath, []byte(yamlContent1), 0644) + require.NoError(t, err) + + yamlConfig1, err := LoadConfig(configPath) + require.NoError(t, err) + + stor := openTestStorage(t, tmpDir) + err = initializeStorage(ctx, stor, yamlConfig1) + require.NoError(t, err) + + // Verify initial state + conn, err := stor.GetConnector(ctx, "my-oidc") + require.NoError(t, err) + assert.Equal(t, "Original Name", conn.Name) + + var configMap1 map[string]interface{} + err = json.Unmarshal(conn.Config, &configMap1) + require.NoError(t, err) + assert.Equal(t, "original-client-id", configMap1["clientID"]) + + // Close storage to simulate restart + stor.Close() + + // Second: load updated config against the same DB + yamlContent2 := ` +issuer: http://localhost:5556/dex +storage: + type: sqlite3 + config: + file: ` + dbFile + ` +web: + http: 127.0.0.1:5556 +enablePasswordDB: true +connectors: +- type: oidc + id: my-oidc + name: Updated Name + config: + issuer: https://accounts.example.com + clientID: updated-client-id + clientSecret: updated-secret +` + err = os.WriteFile(configPath, []byte(yamlContent2), 0644) + require.NoError(t, err) + + yamlConfig2, err := LoadConfig(configPath) + require.NoError(t, err) + + stor2 := openTestStorage(t, tmpDir) + defer stor2.Close() + + err = initializeStorage(ctx, stor2, yamlConfig2) + require.NoError(t, err) + + // Verify connector was updated, not duplicated + allConnectors, err := stor2.ListConnectors(ctx) + require.NoError(t, err) + + nonLocalCount := 0 + for _, c := range allConnectors { + if c.ID != "local" { + nonLocalCount++ + } + } + assert.Equal(t, 1, nonLocalCount, "connector should be updated, not duplicated") + + conn2, err := stor2.GetConnector(ctx, "my-oidc") + require.NoError(t, err) + assert.Equal(t, "Updated Name", conn2.Name) + + var configMap2 map[string]interface{} + err = json.Unmarshal(conn2.Config, &configMap2) + require.NoError(t, err) + assert.Equal(t, "updated-client-id", configMap2["clientID"]) +} + +func TestStaticConnectors_MultipleConnectors(t *testing.T) { + ctx := context.Background() + + tmpDir, err := os.MkdirTemp("", "dex-static-conn-multi-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + yamlContent := ` +issuer: http://localhost:5556/dex +storage: + type: sqlite3 + config: + file: ` + filepath.Join(tmpDir, "dex.db") + ` +web: + http: 127.0.0.1:5556 +enablePasswordDB: true +connectors: +- type: oidc + id: my-oidc + name: My OIDC Provider + config: + issuer: https://accounts.example.com + clientID: oidc-client-id + clientSecret: oidc-secret +- type: google + id: my-google + name: Google Login + config: + clientID: google-client-id + clientSecret: google-secret +` + configPath := filepath.Join(tmpDir, "config.yaml") + err = os.WriteFile(configPath, []byte(yamlContent), 0644) + require.NoError(t, err) + + yamlConfig, err := LoadConfig(configPath) + require.NoError(t, err) + + stor := openTestStorage(t, tmpDir) + defer stor.Close() + + err = initializeStorage(ctx, stor, yamlConfig) + require.NoError(t, err) + + allConnectors, err := stor.ListConnectors(ctx) + require.NoError(t, err) + + // Build a map for easier assertion + connByID := make(map[string]storage.Connector) + for _, c := range allConnectors { + connByID[c.ID] = c + } + + // Verify both static connectors exist + oidcConn, ok := connByID["my-oidc"] + require.True(t, ok, "oidc connector should exist") + assert.Equal(t, "My OIDC Provider", oidcConn.Name) + assert.Equal(t, "oidc", oidcConn.Type) + + var oidcConfig map[string]interface{} + err = json.Unmarshal(oidcConn.Config, &oidcConfig) + require.NoError(t, err) + assert.Equal(t, "oidc-client-id", oidcConfig["clientID"]) + + googleConn, ok := connByID["my-google"] + require.True(t, ok, "google connector should exist") + assert.Equal(t, "Google Login", googleConn.Name) + assert.Equal(t, "google", googleConn.Type) + + var googleConfig map[string]interface{} + err = json.Unmarshal(googleConn.Config, &googleConfig) + require.NoError(t, err) + assert.Equal(t, "google-client-id", googleConfig["clientID"]) + + // Verify local connector still exists alongside them (enablePasswordDB: true) + localConn, ok := connByID["local"] + require.True(t, ok, "local connector should exist") + assert.Equal(t, "local", localConn.Type) +} + +func TestStaticConnectors_EmptyList(t *testing.T) { + ctx := context.Background() + + tmpDir, err := os.MkdirTemp("", "dex-static-conn-empty-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + yamlContent := ` +issuer: http://localhost:5556/dex +storage: + type: sqlite3 + config: + file: ` + filepath.Join(tmpDir, "dex.db") + ` +web: + http: 127.0.0.1:5556 +enablePasswordDB: true +` + configPath := filepath.Join(tmpDir, "config.yaml") + err = os.WriteFile(configPath, []byte(yamlContent), 0644) + require.NoError(t, err) + + yamlConfig, err := LoadConfig(configPath) + require.NoError(t, err) + + provider, err := NewProviderFromYAML(ctx, yamlConfig) + require.NoError(t, err) + defer func() { _ = provider.Stop(ctx) }() + + // No static connectors configured, so ListConnectors should return empty + connectors, err := provider.ListConnectors(ctx) + require.NoError(t, err) + assert.Empty(t, connectors) + + // But local connector should still exist + localConn, err := provider.Storage().GetConnector(ctx, "local") + require.NoError(t, err) + assert.Equal(t, "local", localConn.ID) +} + func TestNewProvider_ContinueOnConnectorFailure(t *testing.T) { ctx := context.Background() diff --git a/infrastructure_files/getting-started-with-dex.sh b/infrastructure_files/getting-started-with-dex.sh index a14c6134e..5e605f19c 100755 --- a/infrastructure_files/getting-started-with-dex.sh +++ b/infrastructure_files/getting-started-with-dex.sh @@ -172,8 +172,11 @@ init_environment() { echo "You can access the NetBird dashboard at $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN" echo "" echo "Login with the following credentials:" - echo "Email: admin@$NETBIRD_DOMAIN" | tee .env - echo "Password: $NETBIRD_ADMIN_PASSWORD" | tee -a .env + install -m 600 /dev/null .env + printf 'Email: admin@%s\nPassword: %s\n' \ + "$NETBIRD_DOMAIN" "$NETBIRD_ADMIN_PASSWORD" >> .env + echo "Email: admin@$NETBIRD_DOMAIN" + echo "Password: $NETBIRD_ADMIN_PASSWORD" echo "" echo "Dex admin UI is not available (Dex has no built-in UI)." echo "To add more users, edit dex.yaml and restart: $DOCKER_COMPOSE_COMMAND restart dex" diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index 09c5225ad..f503cbeac 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -563,8 +563,11 @@ initEnvironment() { echo -e "\nDone!\n" echo "You can access the NetBird dashboard at $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN" echo "Login with the following credentials:" - echo "Username: $ZITADEL_ADMIN_USERNAME" | tee .env - echo "Password: $ZITADEL_ADMIN_PASSWORD" | tee -a .env + install -m 600 /dev/null .env + printf 'Username: %s\nPassword: %s\n' \ + "$ZITADEL_ADMIN_USERNAME" "$ZITADEL_ADMIN_PASSWORD" >> .env + echo "Username: $ZITADEL_ADMIN_USERNAME" + echo "Password: $ZITADEL_ADMIN_PASSWORD" } renderCaddyfile() { diff --git a/infrastructure_files/getting-started.sh b/infrastructure_files/getting-started.sh index 70088d66a..9d1b57258 100755 --- a/infrastructure_files/getting-started.sh +++ b/infrastructure_files/getting-started.sh @@ -182,6 +182,23 @@ read_enable_proxy() { return 0 } +read_enable_crowdsec() { + echo "" > /dev/stderr + echo "Do you want to enable CrowdSec IP reputation blocking?" > /dev/stderr + echo "CrowdSec checks client IPs against a community threat intelligence database" > /dev/stderr + echo "and blocks known malicious sources before they reach your services." > /dev/stderr + echo "A local CrowdSec LAPI container will be added to your deployment." > /dev/stderr + echo -n "Enable CrowdSec? [y/N]: " > /dev/stderr + read -r CHOICE < /dev/tty + + if [[ "$CHOICE" =~ ^[Yy]$ ]]; then + echo "true" + else + echo "false" + fi + return 0 +} + read_traefik_acme_email() { echo "" > /dev/stderr echo "Enter your email for Let's Encrypt certificate notifications." > /dev/stderr @@ -214,7 +231,20 @@ get_upstream_host() { wait_management_proxy() { local proxy_container="${1:-traefik}" + local use_docker_logs=false set +e + + if [[ "$proxy_container" == "detect-traefik" ]]; then + proxy_container=$(docker ps --format "{{.ID}}\t{{.Image}}\t{{.Ports}}" \ + | awk -F'\t' '$2 ~ /traefik/ && $3 ~ /:(80|443)->/ {print $1; exit}') + + if [[ -z "$proxy_container" ]]; then + echo "Warning: could not auto-detect Traefik container, log output will be skipped on timeout." > /dev/stderr + else + use_docker_logs=true + fi + fi + echo -n "Waiting for NetBird server to become ready" counter=1 while true; do @@ -225,7 +255,13 @@ wait_management_proxy() { if [[ $counter -eq 60 ]]; then echo "" echo "Taking too long. Checking logs..." - $DOCKER_COMPOSE_COMMAND logs --tail=20 "$proxy_container" + if [[ -n "$proxy_container" ]]; then + if [[ "$use_docker_logs" == "true" ]]; then + docker logs --tail=20 "$proxy_container" + else + $DOCKER_COMPOSE_COMMAND logs --tail=20 "$proxy_container" + fi + fi $DOCKER_COMPOSE_COMMAND logs --tail=20 netbird-server fi echo -n " ." @@ -297,6 +333,10 @@ initialize_default_values() { # NetBird Proxy configuration ENABLE_PROXY="false" PROXY_TOKEN="" + + # CrowdSec configuration + ENABLE_CROWDSEC="false" + CROWDSEC_BOUNCER_KEY="" return 0 } @@ -325,6 +365,9 @@ configure_reverse_proxy() { if [[ "$REVERSE_PROXY_TYPE" == "0" ]]; then TRAEFIK_ACME_EMAIL=$(read_traefik_acme_email) ENABLE_PROXY=$(read_enable_proxy) + if [[ "$ENABLE_PROXY" == "true" ]]; then + ENABLE_CROWDSEC=$(read_enable_crowdsec) + fi fi # Handle external Traefik-specific prompts (option 1) @@ -354,7 +397,7 @@ check_existing_installation() { echo "Generated files already exist, if you want to reinitialize the environment, please remove them first." echo "You can use the following commands:" echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes" - echo " rm -f docker-compose.yml dashboard.env config.yaml proxy.env traefik-dynamic.yaml nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt" + echo " rm -f docker-compose.yml dashboard.env config.yaml proxy.env traefik-dynamic.yaml nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt && rm -rf crowdsec/" echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard." exit 1 fi @@ -375,6 +418,9 @@ generate_configuration_files() { echo "NB_PROXY_TOKEN=placeholder" >> proxy.env # TCP ServersTransport for PROXY protocol v2 to the proxy backend render_traefik_dynamic > traefik-dynamic.yaml + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + mkdir -p crowdsec + fi fi ;; 1) @@ -417,8 +463,12 @@ start_services_and_show_instructions() { if [[ "$ENABLE_PROXY" == "true" ]]; then # Phase 1: Start core services (without proxy) + local core_services="traefik dashboard netbird-server" + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + core_services="$core_services crowdsec" + fi echo "Starting core services..." - $DOCKER_COMPOSE_COMMAND up -d traefik dashboard netbird-server + $DOCKER_COMPOSE_COMMAND up -d $core_services sleep 3 wait_management_proxy traefik @@ -438,7 +488,33 @@ start_services_and_show_instructions() { echo "Proxy token created successfully." - # Generate proxy.env with the token + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + echo "Registering CrowdSec bouncer..." + local cs_retries=0 + while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli lapi status >/dev/null 2>&1; do + cs_retries=$((cs_retries + 1)) + if [[ $cs_retries -ge 30 ]]; then + echo "WARNING: CrowdSec did not become ready. Skipping CrowdSec setup." > /dev/stderr + echo "You can register a bouncer manually later with:" > /dev/stderr + echo " docker exec netbird-crowdsec cscli bouncers add netbird-proxy -o raw" > /dev/stderr + ENABLE_CROWDSEC="false" + break + fi + sleep 2 + done + + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + CROWDSEC_BOUNCER_KEY=$($DOCKER_COMPOSE_COMMAND exec -T crowdsec \ + cscli bouncers add netbird-proxy -o raw 2>/dev/null) + if [[ -z "$CROWDSEC_BOUNCER_KEY" ]]; then + echo "WARNING: Failed to create CrowdSec bouncer key. Skipping CrowdSec setup." > /dev/stderr + ENABLE_CROWDSEC="false" + else + echo "CrowdSec bouncer registered." + fi + fi + fi + render_proxy_env > proxy.env # Start proxy service @@ -461,7 +537,7 @@ start_services_and_show_instructions() { $DOCKER_COMPOSE_COMMAND up -d sleep 3 - wait_management_direct + wait_management_proxy detect-traefik echo -e "$MSG_DONE" print_post_setup_instructions @@ -525,11 +601,25 @@ render_docker_compose_traefik_builtin() { # Generate proxy service section and Traefik dynamic config if enabled local proxy_service="" local proxy_volumes="" + local crowdsec_service="" + local crowdsec_volumes="" local traefik_file_provider="" local traefik_dynamic_volume="" if [[ "$ENABLE_PROXY" == "true" ]]; then traefik_file_provider=' - "--providers.file.filename=/etc/traefik/dynamic.yaml"' traefik_dynamic_volume=" - ./traefik-dynamic.yaml:/etc/traefik/dynamic.yaml:ro" + + local proxy_depends=" + netbird-server: + condition: service_started" + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + proxy_depends=" + netbird-server: + condition: service_started + crowdsec: + condition: service_healthy" + fi + proxy_service=" # NetBird Proxy - exposes internal resources to the internet proxy: @@ -539,8 +629,7 @@ render_docker_compose_traefik_builtin() { - 51820:51820/udp restart: unless-stopped networks: [netbird] - depends_on: - - netbird-server + depends_on:${proxy_depends} env_file: - ./proxy.env volumes: @@ -563,6 +652,35 @@ render_docker_compose_traefik_builtin() { " proxy_volumes=" netbird_proxy_certs:" + + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + crowdsec_service=" + crowdsec: + image: crowdsecurity/crowdsec:v1.7.7 + container_name: netbird-crowdsec + restart: unless-stopped + networks: [netbird] + environment: + COLLECTIONS: crowdsecurity/linux + volumes: + - ./crowdsec:/etc/crowdsec + - crowdsec_db:/var/lib/crowdsec/data + healthcheck: + test: ["CMD", "cscli", "lapi", "status"] + interval: 10s + timeout: 5s + retries: 15 + labels: + - traefik.enable=false + logging: + driver: \"json-file\" + options: + max-size: \"500m\" + max-file: \"2\" +" + crowdsec_volumes=" + crowdsec_db:" + fi fi cat <" + echo " Get your enrollment key at: https://app.crowdsec.net" + echo "" + fi fi return 0 } diff --git a/infrastructure_files/observability/grafana/dashboards/management.json b/infrastructure_files/observability/grafana/dashboards/management.json index 95983603f..f116a8bde 100644 --- a/infrastructure_files/observability/grafana/dashboards/management.json +++ b/infrastructure_files/observability/grafana/dashboards/management.json @@ -302,7 +302,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "rate(management_account_peer_meta_update_counter_ratio_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])", + "expr": "rate(management_account_peer_meta_update_counter_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])", "instant": false, "legendFormat": "{{cluster}}/{{environment}}/{{job}}", "range": true, @@ -410,7 +410,7 @@ }, "disableTextWrap": false, "editorMode": "code", - "expr": "histogram_quantile(0.5,sum(increase(management_account_get_peer_network_map_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", + "expr": "histogram_quantile(0.5,sum(increase(management_account_get_peer_network_map_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", "format": "heatmap", "fullMetaSearch": false, "includeNullMetadata": true, @@ -426,7 +426,7 @@ }, "disableTextWrap": false, "editorMode": "code", - "expr": "histogram_quantile(0.9,sum(increase(management_account_get_peer_network_map_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", + "expr": "histogram_quantile(0.9,sum(increase(management_account_get_peer_network_map_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", "format": "heatmap", "fullMetaSearch": false, "hide": false, @@ -443,7 +443,7 @@ }, "disableTextWrap": false, "editorMode": "code", - "expr": "histogram_quantile(0.99,sum(increase(management_account_get_peer_network_map_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", + "expr": "histogram_quantile(0.99,sum(increase(management_account_get_peer_network_map_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", "format": "heatmap", "fullMetaSearch": false, "hide": false, @@ -545,7 +545,7 @@ }, "disableTextWrap": false, "editorMode": "code", - "expr": "histogram_quantile(0.5,sum(increase(management_account_update_account_peers_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", + "expr": "histogram_quantile(0.5,sum(increase(management_account_update_account_peers_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", "format": "heatmap", "fullMetaSearch": false, "includeNullMetadata": true, @@ -561,7 +561,7 @@ }, "disableTextWrap": false, "editorMode": "code", - "expr": "histogram_quantile(0.9,sum(increase(management_account_update_account_peers_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", + "expr": "histogram_quantile(0.9,sum(increase(management_account_update_account_peers_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", "format": "heatmap", "fullMetaSearch": false, "hide": false, @@ -578,7 +578,7 @@ }, "disableTextWrap": false, "editorMode": "code", - "expr": "histogram_quantile(0.99,sum(increase(management_account_update_account_peers_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", + "expr": "histogram_quantile(0.99,sum(increase(management_account_update_account_peers_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", "format": "heatmap", "fullMetaSearch": false, "hide": false, @@ -694,7 +694,7 @@ }, "disableTextWrap": false, "editorMode": "code", - "expr": "histogram_quantile(0.5,sum(increase(management_grpc_updatechannel_queue_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", + "expr": "histogram_quantile(0.5,sum(increase(management_grpc_updatechannel_queue_length_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", "format": "heatmap", "fullMetaSearch": false, "includeNullMetadata": true, @@ -710,7 +710,7 @@ }, "disableTextWrap": false, "editorMode": "code", - "expr": "histogram_quantile(0.9,sum(increase(management_grpc_updatechannel_queue_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", + "expr": "histogram_quantile(0.9,sum(increase(management_grpc_updatechannel_queue_length_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", "format": "heatmap", "fullMetaSearch": false, "hide": false, @@ -727,7 +727,7 @@ }, "disableTextWrap": false, "editorMode": "code", - "expr": "histogram_quantile(0.99,sum(increase(management_grpc_updatechannel_queue_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", + "expr": "histogram_quantile(0.99,sum(increase(management_grpc_updatechannel_queue_length_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", "format": "heatmap", "fullMetaSearch": false, "hide": false, @@ -841,7 +841,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.50, sum(rate(management_store_persistence_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", + "expr": "histogram_quantile(0.50, sum(rate(management_store_persistence_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", "instant": false, "legendFormat": "p50", "range": true, @@ -853,7 +853,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.90, sum(rate(management_store_persistence_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", + "expr": "histogram_quantile(0.90, sum(rate(management_store_persistence_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", "hide": false, "instant": false, "legendFormat": "p90", @@ -866,7 +866,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.99, sum(rate(management_store_persistence_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", + "expr": "histogram_quantile(0.99, sum(rate(management_store_persistence_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", "hide": false, "instant": false, "legendFormat": "p99", @@ -963,7 +963,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.50, sum(rate(management_store_transaction_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", + "expr": "histogram_quantile(0.50, sum(rate(management_store_transaction_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", "instant": false, "legendFormat": "p50", "range": true, @@ -975,7 +975,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.90, sum(rate(management_store_transaction_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", + "expr": "histogram_quantile(0.90, sum(rate(management_store_transaction_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", "hide": false, "instant": false, "legendFormat": "p90", @@ -988,7 +988,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.99, sum(rate(management_store_transaction_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", + "expr": "histogram_quantile(0.99, sum(rate(management_store_transaction_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", "hide": false, "instant": false, "legendFormat": "p99", @@ -1085,7 +1085,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.50, sum(rate(management_store_global_lock_acquisition_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", + "expr": "histogram_quantile(0.50, sum(rate(management_store_global_lock_acquisition_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", "instant": false, "legendFormat": "p50", "range": true, @@ -1097,7 +1097,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.90, sum(rate(management_store_global_lock_acquisition_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", + "expr": "histogram_quantile(0.90, sum(rate(management_store_global_lock_acquisition_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", "hide": false, "instant": false, "legendFormat": "p90", @@ -1110,7 +1110,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.99, sum(rate(management_store_global_lock_acquisition_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", + "expr": "histogram_quantile(0.99, sum(rate(management_store_global_lock_acquisition_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", "hide": false, "instant": false, "legendFormat": "p99", @@ -1221,7 +1221,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "rate(management_idp_authenticate_request_counter_ratio_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])", + "expr": "rate(management_idp_authenticate_request_counter_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])", "instant": false, "legendFormat": "{{cluster}}/{{environment}}/{{job}}", "range": true, @@ -1317,7 +1317,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "rate(management_idp_get_account_counter_ratio_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])", + "expr": "rate(management_idp_get_account_counter_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])", "instant": false, "legendFormat": "{{cluster}}/{{environment}}/{{job}}", "range": true, @@ -1413,7 +1413,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "rate(management_idp_update_user_meta_counter_ratio_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])", + "expr": "rate(management_idp_update_user_meta_counter_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])", "instant": false, "legendFormat": "{{cluster}}/{{environment}}/{{job}}", "range": true, @@ -1523,7 +1523,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "sum(rate(management_http_request_counter_ratio_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",method=~\"GET|OPTIONS\"}[$__rate_interval])) by (job,method)", + "expr": "sum(rate(management_http_request_counter_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",method=~\"GET|OPTIONS\"}[$__rate_interval])) by (job,method)", "instant": false, "legendFormat": "{{method}}", "range": true, @@ -1619,7 +1619,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "sum(rate(management_http_request_counter_ratio_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",method=~\"POST|PUT|DELETE\"}[$__rate_interval])) by (job,method)", + "expr": "sum(rate(management_http_request_counter_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",method=~\"POST|PUT|DELETE\"}[$__rate_interval])) by (job,method)", "instant": false, "legendFormat": "{{method}}", "range": true, @@ -1715,7 +1715,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.50, sum(rate(management_http_request_duration_ms_total_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"read\"}[5m])) by (le))", + "expr": "histogram_quantile(0.50, sum(rate(management_http_request_duration_ms_total_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"read\"}[5m])) by (le))", "instant": false, "legendFormat": "p50", "range": true, @@ -1727,7 +1727,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.90, sum(rate(management_http_request_duration_ms_total_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"read\"}[5m])) by (le))", + "expr": "histogram_quantile(0.90, sum(rate(management_http_request_duration_ms_total_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"read\"}[5m])) by (le))", "hide": false, "instant": false, "legendFormat": "p90", @@ -1740,7 +1740,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.99, sum(rate(management_http_request_duration_ms_total_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"read\"}[5m])) by (le))", + "expr": "histogram_quantile(0.99, sum(rate(management_http_request_duration_ms_total_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"read\"}[5m])) by (le))", "hide": false, "instant": false, "legendFormat": "p99", @@ -1837,7 +1837,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.50, sum(rate(management_http_request_duration_ms_total_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"write\"}[5m])) by (le))", + "expr": "histogram_quantile(0.50, sum(rate(management_http_request_duration_ms_total_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"write\"}[5m])) by (le))", "instant": false, "legendFormat": "p50", "range": true, @@ -1849,7 +1849,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.90, sum(rate(management_http_request_duration_ms_total_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"write\"}[5m])) by (le))", + "expr": "histogram_quantile(0.90, sum(rate(management_http_request_duration_ms_total_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"write\"}[5m])) by (le))", "hide": false, "instant": false, "legendFormat": "p90", @@ -1862,7 +1862,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.99, sum(rate(management_http_request_duration_ms_total_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"write\"}[5m])) by (le))", + "expr": "histogram_quantile(0.99, sum(rate(management_http_request_duration_ms_total_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"write\"}[5m])) by (le))", "hide": false, "instant": false, "legendFormat": "p99", @@ -1963,7 +1963,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "sum(rate(management_http_request_counter_ratio_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (job,exported_endpoint,method)", + "expr": "sum(rate(management_http_request_counter_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (job,exported_endpoint,method)", "hide": false, "instant": false, "legendFormat": "{{method}}-{{exported_endpoint}}", @@ -3222,7 +3222,7 @@ }, "disableTextWrap": false, "editorMode": "builder", - "expr": "sum by(le) (increase(management_grpc_updatechannel_queue_bucket{application=\"management\", environment=\"$environment\", host=~\"$host\"}[$__rate_interval]))", + "expr": "sum by(le) (increase(management_grpc_updatechannel_queue_length_bucket{application=\"management\", environment=\"$environment\", host=~\"$host\"}[$__rate_interval]))", "format": "heatmap", "fullMetaSearch": false, "includeNullMetadata": true, @@ -3323,7 +3323,7 @@ }, "disableTextWrap": false, "editorMode": "builder", - "expr": "sum by(le) (increase(management_account_update_account_peers_duration_ms_bucket{application=\"management\", environment=\"$environment\", host=~\"$host\"}[$__rate_interval]))", + "expr": "sum by(le) (increase(management_account_update_account_peers_duration_ms_milliseconds_bucket{application=\"management\", environment=\"$environment\", host=~\"$host\"}[$__rate_interval]))", "format": "heatmap", "fullMetaSearch": false, "includeNullMetadata": true, diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index 4b414df6f..36de950e9 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -7,7 +7,6 @@ import ( "os" "slices" "strconv" - "strings" "sync" "sync/atomic" "time" @@ -16,11 +15,9 @@ import ( "golang.org/x/exp/maps" "golang.org/x/mod/semver" - nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" - "github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/account" @@ -58,13 +55,6 @@ type Controller struct { proxyController port_forwarding.Controller integratedPeerValidator integrated_validator.IntegratedValidator - - holder *types.Holder - - expNewNetworkMap bool - expNewNetworkMapAIDs map[string]struct{} - - compactedNetworkMap bool } type bufferUpdate struct { @@ -81,29 +71,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App log.Fatal(fmt.Errorf("error creating metrics: %w", err)) } - newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder)) - if err != nil { - log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err) - newNetworkMapBuilder = false - } - - compactedNetworkMap := true - compactedEnv := os.Getenv(types.EnvNewNetworkMapCompacted) - parsedCompactedNmap, err := strconv.ParseBool(compactedEnv) - if err != nil && len(compactedEnv) > 0 { - log.WithContext(ctx).Warnf("failed to parse %s, using default value true: %v", types.EnvNewNetworkMapCompacted, err) - } - if err == nil && !parsedCompactedNmap { - log.WithContext(ctx).Info("disabling compacted mode") - compactedNetworkMap = false - } - - ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",") - expIDs := make(map[string]struct{}, len(ids)) - for _, id := range ids { - expIDs[id] = struct{}{} - } - return &Controller{ repo: newRepository(store), metrics: nMetrics, @@ -117,12 +84,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App proxyController: proxyController, EphemeralPeersManager: ephemeralPeersManager, - - holder: types.NewHolder(), - expNewNetworkMap: newNetworkMapBuilder, - expNewNetworkMapAIDs: expIDs, - - compactedNetworkMap: compactedNetworkMap, } } @@ -153,17 +114,9 @@ func (c *Controller) CountStreams() int { func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error { log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) - var ( - account *types.Account - err error - ) - if c.experimentalNetworkMap(accountID) { - account = c.getAccountFromHolderOrInit(ctx, accountID) - } else { - account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf("failed to get account: %v", err) - } + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to get account: %v", err) } globalStart := time.Now() @@ -197,10 +150,6 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin routers := account.GetResourceRoutersMap() groupIDToUserIDs := account.GetActiveGroupUsers() - if c.experimentalNetworkMap(accountID) { - c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap) - } - proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers) if err != nil { log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) @@ -243,16 +192,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin c.metrics.CountCalcPostureChecksDuration(time.Since(start)) start = time.Now() - var remotePeerNetworkMap *types.NetworkMap - - switch { - case c.experimentalNetworkMap(accountID): - remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics) - case c.compactedNetworkMap: - remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) - default: - remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) - } + remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start)) @@ -317,11 +257,10 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID // UpdatePeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. -func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error { - if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return fmt.Errorf("recalculate network map cache: %v", err) +func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error { + if c.accountManagerMetrics != nil { + c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation)) } - return c.sendUpdateAccountPeers(ctx, accountID) } @@ -371,16 +310,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe return err } - var remotePeerNetworkMap *types.NetworkMap - - switch { - case c.experimentalNetworkMap(accountId): - remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics) - case c.compactedNetworkMap: - remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) - default: - remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) - } + remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { @@ -404,9 +334,13 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe return nil } -func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string) error { +func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error { log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName()) + if c.accountManagerMetrics != nil { + c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation)) + } + bufUpd, _ := c.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{}) b := bufUpd.(*bufferUpdate) @@ -421,14 +355,14 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str go func() { defer b.mu.Unlock() - _ = c.UpdateAccountPeers(ctx, accountID) + _ = c.sendUpdateAccountPeers(ctx, accountID) if !b.update.Load() { return } b.update.Store(false) if b.next == nil { b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() { - _ = c.UpdateAccountPeers(ctx, accountID) + _ = c.sendUpdateAccountPeers(ctx, accountID) }) return } @@ -451,17 +385,9 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr return peer, emptyMap, nil, 0, nil } - var ( - account *types.Account - err error - ) - if c.experimentalNetworkMap(accountID) { - account = c.getAccountFromHolderOrInit(ctx, accountID) - } else { - account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, nil, nil, 0, err - } + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, nil, nil, 0, err } account.InjectProxyPolicies(ctx) @@ -493,20 +419,10 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr return nil, nil, nil, 0, err } - var networkMap *types.NetworkMap - - if c.experimentalNetworkMap(accountID) { - networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics) - } else { - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - if c.compactedNetworkMap { - networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) - } else { - networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) - } - } + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { @@ -518,108 +434,6 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr return peer, networkMap, postureChecks, dnsFwdPort, nil } -func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) { - c.enrichAccountFromHolder(account) - account.InitNetworkMapBuilderIfNeeded(validatedPeers) -} - -func (c *Controller) getPeerNetworkMapExp( - ctx context.Context, - accountId string, - peerId string, - validatedPeers map[string]struct{}, - peersCustomZone nbdns.CustomZone, - accountZones []*zones.Zone, - metrics *telemetry.AccountManagerMetrics, -) *types.NetworkMap { - account := c.getAccountFromHolderOrInit(ctx, accountId) - if account == nil { - log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId) - return &types.NetworkMap{ - Network: &types.Network{}, - } - } - - return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics) -} - -func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) { - c.enrichAccountFromHolder(account) - account.OnPeersAddedUpdNetworkMapCache(peerIds...) -} - -func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error { - c.enrichAccountFromHolder(account) - return account.OnPeerDeletedUpdNetworkMapCache(peerId) -} - -func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) { - account := c.getAccountFromHolder(accountId) - if account == nil { - return - } - account.UpdatePeerInNetworkMapCache(peer) -} - -func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) { - account.RecalculateNetworkMapCache(validatedPeers) - c.updateAccountInHolder(account) -} - -func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error { - if c.experimentalNetworkMap(accountId) { - account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId) - if err != nil { - return err - } - validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - log.WithContext(ctx).Errorf("failed to get validate peers: %v", err) - return err - } - c.recalculateNetworkMapCache(account, validatedPeers) - } - return nil -} - -func (c *Controller) experimentalNetworkMap(accountId string) bool { - _, ok := c.expNewNetworkMapAIDs[accountId] - return c.expNewNetworkMap || ok -} - -func (c *Controller) enrichAccountFromHolder(account *types.Account) { - a := c.holder.GetAccount(account.Id) - if a == nil { - c.holder.AddAccount(account) - return - } - account.NetworkMapCache = a.NetworkMapCache - if account.NetworkMapCache == nil { - return - } - c.holder.AddAccount(account) -} - -func (c *Controller) getAccountFromHolder(accountID string) *types.Account { - return c.holder.GetAccount(accountID) -} - -func (c *Controller) getAccountFromHolderOrInit(ctx context.Context, accountID string) *types.Account { - a := c.holder.GetAccount(accountID) - if a != nil { - return a - } - account, err := c.holder.LoadOrStoreFunc(ctx, accountID, c.requestBuffer.GetAccountWithBackpressure) - if err != nil { - return nil - } - return account -} - -func (c *Controller) updateAccountInHolder(account *types.Account) { - c.holder.AddAccount(account) -} - // GetDNSDomain returns the configured dnsDomain func (c *Controller) GetDNSDomain(settings *types.Settings) string { if settings == nil { @@ -756,16 +570,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t } func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error { - peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs) - if err != nil { - return fmt.Errorf("failed to get peers by ids: %w", err) - } - - for _, peer := range peers { - c.UpdatePeerInNetworkMapCache(accountID, peer) - } - - err = c.bufferSendUpdateAccountPeers(ctx, accountID) + err := c.bufferSendUpdateAccountPeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err) } @@ -775,14 +580,6 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error { log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs) - if c.experimentalNetworkMap(accountID) { - account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return err - } - log.WithContext(ctx).Debugf("peers are ready to be added to networkmap cache: %v", peerIDs) - c.onPeersAddedUpdNetworkMapCache(account, peerIDs...) - } return c.bufferSendUpdateAccountPeers(ctx, accountID) } @@ -817,19 +614,6 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI MessageType: network_map.MessageTypeNetworkMap, }) c.peersUpdateManager.CloseChannel(ctx, peerID) - - if c.experimentalNetworkMap(accountID) { - account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err) - continue - } - err = c.onPeerDeletedUpdNetworkMapCache(account, peerID) - if err != nil { - log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err) - continue - } - } } return c.bufferSendUpdateAccountPeers(ctx, accountID) @@ -872,21 +656,11 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N return nil, err } - var networkMap *types.NetworkMap - - if c.experimentalNetworkMap(peer.AccountID) { - networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil) - } else { - account.InjectProxyPolicies(ctx) - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - if c.compactedNetworkMap { - networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) - } else { - networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) - } - } + account.InjectProxyPolicies(ctx) + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { diff --git a/management/internals/controllers/network_map/interface.go b/management/internals/controllers/network_map/interface.go index 64caac861..44d8f7d72 100644 --- a/management/internals/controllers/network_map/interface.go +++ b/management/internals/controllers/network_map/interface.go @@ -12,18 +12,15 @@ import ( ) const ( - EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP" - EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS" - DnsForwarderPort = nbdns.ForwarderServerPort OldForwarderPort = nbdns.ForwarderClientPort DnsForwarderPortMinVersion = "v0.59.0" ) type Controller interface { - UpdateAccountPeers(ctx context.Context, accountID string) error + UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error - BufferUpdateAccountPeers(ctx context.Context, accountID string) error + BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) GetDNSDomain(settings *types.Settings) string StartWarmup(context.Context) diff --git a/management/internals/controllers/network_map/interface_mock.go b/management/internals/controllers/network_map/interface_mock.go index 4e86d2973..073a75d3b 100644 --- a/management/internals/controllers/network_map/interface_mock.go +++ b/management/internals/controllers/network_map/interface_mock.go @@ -44,17 +44,17 @@ func (m *MockController) EXPECT() *MockControllerMockRecorder { } // BufferUpdateAccountPeers mocks base method. -func (m *MockController) BufferUpdateAccountPeers(ctx context.Context, accountID string) error { +func (m *MockController) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID) + ret := m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID, reason) ret0, _ := ret[0].(error) return ret0 } // BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers. -func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID any) *gomock.Call { +func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, reason any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID, reason) } // CountStreams mocks base method. @@ -238,15 +238,15 @@ func (mr *MockControllerMockRecorder) UpdateAccountPeer(ctx, accountId, peerId a } // UpdateAccountPeers mocks base method. -func (m *MockController) UpdateAccountPeers(ctx context.Context, accountID string) error { +func (m *MockController) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID) + ret := m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID, reason) ret0, _ := ret[0].(error) return ret0 } // UpdateAccountPeers indicates an expected call of UpdateAccountPeers. -func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID any) *gomock.Call { +func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID, reason any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID, reason) } diff --git a/management/internals/modules/peers/ephemeral/manager/ephemeral_test.go b/management/internals/modules/peers/ephemeral/manager/ephemeral_test.go index fc3010dd1..314e84501 100644 --- a/management/internals/modules/peers/ephemeral/manager/ephemeral_test.go +++ b/management/internals/modules/peers/ephemeral/manager/ephemeral_test.go @@ -62,7 +62,7 @@ func (a *MockAccountManager) GetDeletePeerCalls() int { return a.deletePeerCalls } -func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { +func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) { a.mu.Lock() defer a.mu.Unlock() if a.bufferUpdateCalls == nil { @@ -248,7 +248,7 @@ func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) { return err } } - mockAM.BufferUpdateAccountPeers(ctx, accountID) + mockAM.BufferUpdateAccountPeers(ctx, accountID, types.UpdateReason{}) return nil }). Times(1) diff --git a/management/internals/modules/peers/manager.go b/management/internals/modules/peers/manager.go index 7cb0f3908..c913efb92 100644 --- a/management/internals/modules/peers/manager.go +++ b/management/internals/modules/peers/manager.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -154,9 +155,11 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs return err } - eventsToStore = append(eventsToStore, func() { - m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) - }) + if !(peer.ProxyMeta.Embedded || peer.Meta.KernelVersion == "wasm") { + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) + }) + } return nil }) @@ -176,7 +179,7 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs } } - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete}) return nil } diff --git a/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go b/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go index a7f692569..f2ecfd5f9 100644 --- a/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go +++ b/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go @@ -1,6 +1,7 @@ package accesslogs import ( + "maps" "net" "net/netip" "time" @@ -37,6 +38,7 @@ type AccessLogEntry struct { BytesUpload int64 `gorm:"index"` BytesDownload int64 `gorm:"index"` Protocol AccessLogProtocol `gorm:"index"` + Metadata map[string]string `gorm:"serializer:json"` } // FromProto creates an AccessLogEntry from a proto.AccessLog @@ -55,6 +57,7 @@ func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) { a.BytesUpload = serviceLog.GetBytesUpload() a.BytesDownload = serviceLog.GetBytesDownload() a.Protocol = AccessLogProtocol(serviceLog.GetProtocol()) + a.Metadata = maps.Clone(serviceLog.GetMetadata()) if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" { if addr, err := netip.ParseAddr(sourceIP); err == nil { @@ -117,6 +120,11 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog { protocol = &p } + var metadata *map[string]string + if len(a.Metadata) > 0 { + metadata = &a.Metadata + } + return &api.ProxyAccessLog{ Id: a.ID, ServiceId: a.ServiceID, @@ -136,5 +144,6 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog { BytesUpload: a.BytesUpload, BytesDownload: a.BytesDownload, Protocol: protocol, + Metadata: metadata, } } diff --git a/management/internals/modules/reverseproxy/accesslogs/manager/manager.go b/management/internals/modules/reverseproxy/accesslogs/manager/manager.go index e8d0ce763..59d7704eb 100644 --- a/management/internals/modules/reverseproxy/accesslogs/manager/manager.go +++ b/management/internals/modules/reverseproxy/accesslogs/manager/manager.go @@ -106,13 +106,23 @@ func (m *managerImpl) CleanupOldAccessLogs(ctx context.Context, retentionDays in // StartPeriodicCleanup starts a background goroutine that periodically cleans up old access logs func (m *managerImpl) StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int) { - if retentionDays <= 0 { - log.WithContext(ctx).Debug("periodic access log cleanup disabled: retention days is 0 or negative") + if retentionDays < 0 { + log.WithContext(ctx).Debug("periodic access log cleanup disabled: retention days is negative") return } + if retentionDays == 0 { + retentionDays = 7 + log.WithContext(ctx).Debugf("no retention days specified for access log cleanup, defaulting to %d days", retentionDays) + } else { + log.WithContext(ctx).Debugf("access log retention period set to %d days", retentionDays) + } + if cleanupIntervalHours <= 0 { cleanupIntervalHours = 24 + log.WithContext(ctx).Debugf("no cleanup interval specified for access log cleanup, defaulting to %d hours", cleanupIntervalHours) + } else { + log.WithContext(ctx).Debugf("access log cleanup interval set to %d hours", cleanupIntervalHours) } cleanupCtx, cancel := context.WithCancel(ctx) diff --git a/management/internals/modules/reverseproxy/accesslogs/manager/manager_test.go b/management/internals/modules/reverseproxy/accesslogs/manager/manager_test.go index 8fadef85f..11bf60829 100644 --- a/management/internals/modules/reverseproxy/accesslogs/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/accesslogs/manager/manager_test.go @@ -121,7 +121,7 @@ func TestCleanupWithExactBoundary(t *testing.T) { } func TestStartPeriodicCleanup(t *testing.T) { - t.Run("periodic cleanup disabled with zero retention", func(t *testing.T) { + t.Run("periodic cleanup disabled with negative retention", func(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -135,7 +135,7 @@ func TestStartPeriodicCleanup(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - manager.StartPeriodicCleanup(ctx, 0, 1) + manager.StartPeriodicCleanup(ctx, -1, 1) time.Sleep(100 * time.Millisecond) diff --git a/management/internals/modules/reverseproxy/domain/domain.go b/management/internals/modules/reverseproxy/domain/domain.go index 859f1c5b2..f65e31a07 100644 --- a/management/internals/modules/reverseproxy/domain/domain.go +++ b/management/internals/modules/reverseproxy/domain/domain.go @@ -20,6 +20,9 @@ type Domain struct { // RequireSubdomain is populated at query time. When true, the domain // cannot be used bare and a subdomain label must be prepended. Not persisted. RequireSubdomain *bool `gorm:"-"` + // SupportsCrowdSec is populated at query time from proxy cluster capabilities. + // Not persisted. + SupportsCrowdSec *bool `gorm:"-"` } // EventMeta returns activity event metadata for a domain @@ -30,3 +33,8 @@ func (d *Domain) EventMeta() map[string]any { "validated": d.Validated, } } + +func (d *Domain) Copy() *Domain { + dCopy := *d + return &dCopy +} diff --git a/management/internals/modules/reverseproxy/domain/manager/api.go b/management/internals/modules/reverseproxy/domain/manager/api.go index 640ab28a5..4493ef0ad 100644 --- a/management/internals/modules/reverseproxy/domain/manager/api.go +++ b/management/internals/modules/reverseproxy/domain/manager/api.go @@ -48,6 +48,7 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain { Validated: d.Validated, SupportsCustomPorts: d.SupportsCustomPorts, RequireSubdomain: d.RequireSubdomain, + SupportsCrowdsec: d.SupportsCrowdSec, } if d.TargetCluster != "" { resp.TargetCluster = &d.TargetCluster diff --git a/management/internals/modules/reverseproxy/domain/manager/manager.go b/management/internals/modules/reverseproxy/domain/manager/manager.go index 901cdf0e3..2c4c1372e 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager.go @@ -31,19 +31,16 @@ type store interface { type proxyManager interface { GetActiveClusterAddresses(ctx context.Context) ([]string, error) -} - -type clusterCapabilities interface { - ClusterSupportsCustomPorts(clusterAddr string) *bool - ClusterRequireSubdomain(clusterAddr string) *bool + ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool + ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool + ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool } type Manager struct { - store store - validator domain.Validator - proxyManager proxyManager - clusterCapabilities clusterCapabilities - permissionsManager permissions.Manager + store store + validator domain.Validator + proxyManager proxyManager + permissionsManager permissions.Manager accountManager account.Manager } @@ -57,11 +54,6 @@ func NewManager(store store, proxyMgr proxyManager, permissionsManager permissio } } -// SetClusterCapabilities sets the cluster capabilities provider for domain queries. -func (m *Manager) SetClusterCapabilities(caps clusterCapabilities) { - m.clusterCapabilities = caps -} - func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) if err != nil { @@ -97,10 +89,9 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d Type: domain.TypeFree, Validated: true, } - if m.clusterCapabilities != nil { - d.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(cluster) - d.RequireSubdomain = m.clusterCapabilities.ClusterRequireSubdomain(cluster) - } + d.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, cluster) + d.RequireSubdomain = m.proxyManager.ClusterRequireSubdomain(ctx, cluster) + d.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, cluster) ret = append(ret, d) } @@ -114,8 +105,9 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d Type: domain.TypeCustom, Validated: d.Validated, } - if m.clusterCapabilities != nil && d.TargetCluster != "" { - cd.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(d.TargetCluster) + if d.TargetCluster != "" { + cd.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, d.TargetCluster) + cd.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, d.TargetCluster) } // Custom domains never require a subdomain by default since // the account owns them and should be able to use the bare domain. diff --git a/management/internals/modules/reverseproxy/proxy/manager.go b/management/internals/modules/reverseproxy/proxy/manager.go index 9b0de53b4..aa7cd8630 100644 --- a/management/internals/modules/reverseproxy/proxy/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager.go @@ -11,11 +11,14 @@ import ( // Manager defines the interface for proxy operations type Manager interface { - Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error + Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error Disconnect(ctx context.Context, proxyID string) error Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error GetActiveClusterAddresses(ctx context.Context) ([]string, error) GetActiveClusters(ctx context.Context) ([]Cluster, error) + ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool + ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool + ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool CleanupStale(ctx context.Context, inactivityDuration time.Duration) error } @@ -34,6 +37,4 @@ type Controller interface { RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error GetProxiesForCluster(clusterAddr string) []string - ClusterSupportsCustomPorts(clusterAddr string) *bool - ClusterRequireSubdomain(clusterAddr string) *bool } diff --git a/management/internals/modules/reverseproxy/proxy/manager/controller.go b/management/internals/modules/reverseproxy/proxy/manager/controller.go index 05a0c9048..e5b3e9886 100644 --- a/management/internals/modules/reverseproxy/proxy/manager/controller.go +++ b/management/internals/modules/reverseproxy/proxy/manager/controller.go @@ -72,17 +72,6 @@ func (c *GRPCController) UnregisterProxyFromCluster(ctx context.Context, cluster return nil } -// ClusterSupportsCustomPorts returns whether any proxy in the cluster supports custom ports. -func (c *GRPCController) ClusterSupportsCustomPorts(clusterAddr string) *bool { - return c.proxyGRPCServer.ClusterSupportsCustomPorts(clusterAddr) -} - -// ClusterRequireSubdomain returns whether the cluster requires a subdomain label. -// Returns nil when no proxy has reported the capability (defaults to false). -func (c *GRPCController) ClusterRequireSubdomain(clusterAddr string) *bool { - return c.proxyGRPCServer.ClusterRequireSubdomain(clusterAddr) -} - // GetProxiesForCluster returns all proxy IDs registered for a specific cluster. func (c *GRPCController) GetProxiesForCluster(clusterAddr string) []string { proxySet, ok := c.clusterProxies.Load(clusterAddr) diff --git a/management/internals/modules/reverseproxy/proxy/manager/manager.go b/management/internals/modules/reverseproxy/proxy/manager/manager.go index dac6d3ce3..d13334e83 100644 --- a/management/internals/modules/reverseproxy/proxy/manager/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager/manager.go @@ -16,6 +16,9 @@ type store interface { UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) + GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool + GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool + GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error } @@ -38,9 +41,14 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) { }, nil } -// Connect registers a new proxy connection in the database -func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { +// Connect registers a new proxy connection in the database. +// capabilities may be nil for old proxies that do not report them. +func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error { now := time.Now() + var caps proxy.Capabilities + if capabilities != nil { + caps = *capabilities + } p := &proxy.Proxy{ ID: proxyID, ClusterAddress: clusterAddress, @@ -48,6 +56,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress LastSeen: now, ConnectedAt: &now, Status: "connected", + Capabilities: caps, } if err := m.store.SaveProxy(ctx, p); err != nil { @@ -118,6 +127,24 @@ func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error) return clusters, nil } +// ClusterSupportsCustomPorts returns whether any active proxy in the cluster +// supports custom ports. Returns nil when no proxy has reported capabilities. +func (m Manager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { + return m.store.GetClusterSupportsCustomPorts(ctx, clusterAddr) +} + +// ClusterRequireSubdomain returns whether any active proxy in the cluster +// requires a subdomain. Returns nil when no proxy has reported capabilities. +func (m Manager) ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool { + return m.store.GetClusterRequireSubdomain(ctx, clusterAddr) +} + +// ClusterSupportsCrowdSec returns whether all active proxies in the cluster +// have CrowdSec configured (unanimous). Returns nil when no proxy has reported capabilities. +func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool { + return m.store.GetClusterSupportsCrowdSec(ctx, clusterAddr) +} + // CleanupStale removes proxies that haven't sent heartbeat in the specified duration func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error { if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil { diff --git a/management/internals/modules/reverseproxy/proxy/manager_mock.go b/management/internals/modules/reverseproxy/proxy/manager_mock.go index da3df12a2..282ca0ba5 100644 --- a/management/internals/modules/reverseproxy/proxy/manager_mock.go +++ b/management/internals/modules/reverseproxy/proxy/manager_mock.go @@ -50,18 +50,60 @@ func (mr *MockManagerMockRecorder) CleanupStale(ctx, inactivityDuration interfac return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStale", reflect.TypeOf((*MockManager)(nil).CleanupStale), ctx, inactivityDuration) } -// Connect mocks base method. -func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { +// ClusterSupportsCustomPorts mocks base method. +func (m *MockManager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress) + ret := m.ctrl.Call(m, "ClusterSupportsCustomPorts", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// ClusterSupportsCustomPorts indicates an expected call of ClusterSupportsCustomPorts. +func (mr *MockManagerMockRecorder) ClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCustomPorts", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCustomPorts), ctx, clusterAddr) +} + +// ClusterRequireSubdomain mocks base method. +func (m *MockManager) ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ClusterRequireSubdomain", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// ClusterRequireSubdomain indicates an expected call of ClusterRequireSubdomain. +func (mr *MockManagerMockRecorder) ClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterRequireSubdomain", reflect.TypeOf((*MockManager)(nil).ClusterRequireSubdomain), ctx, clusterAddr) +} + +// ClusterSupportsCrowdSec mocks base method. +func (m *MockManager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ClusterSupportsCrowdSec", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// ClusterSupportsCrowdSec indicates an expected call of ClusterSupportsCrowdSec. +func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCrowdSec", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCrowdSec), ctx, clusterAddr) +} + +// Connect mocks base method. +func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities) ret0, _ := ret[0].(error) return ret0 } // Connect indicates an expected call of Connect. -func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities) } // Disconnect mocks base method. @@ -145,34 +187,6 @@ func (m *MockController) EXPECT() *MockControllerMockRecorder { return m.recorder } -// ClusterSupportsCustomPorts mocks base method. -func (m *MockController) ClusterSupportsCustomPorts(clusterAddr string) *bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ClusterSupportsCustomPorts", clusterAddr) - ret0, _ := ret[0].(*bool) - return ret0 -} - -// ClusterSupportsCustomPorts indicates an expected call of ClusterSupportsCustomPorts. -func (mr *MockControllerMockRecorder) ClusterSupportsCustomPorts(clusterAddr interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCustomPorts", reflect.TypeOf((*MockController)(nil).ClusterSupportsCustomPorts), clusterAddr) -} - -// ClusterRequireSubdomain mocks base method. -func (m *MockController) ClusterRequireSubdomain(clusterAddr string) *bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ClusterRequireSubdomain", clusterAddr) - ret0, _ := ret[0].(*bool) - return ret0 -} - -// ClusterRequireSubdomain indicates an expected call of ClusterRequireSubdomain. -func (mr *MockControllerMockRecorder) ClusterRequireSubdomain(clusterAddr interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterRequireSubdomain", reflect.TypeOf((*MockController)(nil).ClusterRequireSubdomain), clusterAddr) -} - // GetOIDCValidationConfig mocks base method. func (m *MockController) GetOIDCValidationConfig() OIDCValidationConfig { m.ctrl.T.Helper() diff --git a/management/internals/modules/reverseproxy/proxy/proxy.go b/management/internals/modules/reverseproxy/proxy/proxy.go index 671eb109f..339c82446 100644 --- a/management/internals/modules/reverseproxy/proxy/proxy.go +++ b/management/internals/modules/reverseproxy/proxy/proxy.go @@ -2,6 +2,19 @@ package proxy import "time" +// Capabilities describes what a proxy can handle, as reported via gRPC. +// Nil fields mean the proxy never reported this capability. +type Capabilities struct { + // SupportsCustomPorts indicates whether this proxy can bind arbitrary + // ports for TCP/UDP services. TLS uses SNI routing and is not gated. + SupportsCustomPorts *bool + // RequireSubdomain indicates whether a subdomain label is required in + // front of the cluster domain. + RequireSubdomain *bool + // SupportsCrowdsec indicates whether this proxy has CrowdSec configured. + SupportsCrowdsec *bool +} + // Proxy represents a reverse proxy instance type Proxy struct { ID string `gorm:"primaryKey;type:varchar(255)"` @@ -11,6 +24,7 @@ type Proxy struct { ConnectedAt *time.Time DisconnectedAt *time.Time Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"` + Capabilities Capabilities `gorm:"embedded"` CreatedAt time.Time UpdatedAt time.Time } diff --git a/management/internals/modules/reverseproxy/service/manager/l4_port_test.go b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go index 8b652c7e1..fc91b8616 100644 --- a/management/internals/modules/reverseproxy/service/manager/l4_port_test.go +++ b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go @@ -75,16 +75,19 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor require.NoError(t, err) mockCtrl := proxy.NewMockController(ctrl) - mockCtrl.EXPECT().ClusterSupportsCustomPorts(gomock.Any()).Return(customPortsSupported).AnyTimes() - mockCtrl.EXPECT().ClusterRequireSubdomain(gomock.Any()).Return((*bool)(nil)).AnyTimes() mockCtrl.EXPECT().SendServiceUpdateToCluster(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() mockCtrl.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).AnyTimes() + mockCaps := proxy.NewMockManager(ctrl) + mockCaps.EXPECT().ClusterSupportsCustomPorts(gomock.Any(), testCluster).Return(customPortsSupported).AnyTimes() + mockCaps.EXPECT().ClusterRequireSubdomain(gomock.Any(), testCluster).Return((*bool)(nil)).AnyTimes() + mockCaps.EXPECT().ClusterSupportsCrowdSec(gomock.Any(), testCluster).Return((*bool)(nil)).AnyTimes() + accountMgr := &mock_server.MockAccountManager{ StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, - UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, - GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) { - return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID) + UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {}, + GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { + return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) }, } @@ -93,6 +96,7 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor accountManager: accountMgr, permissionsManager: permissions.NewManager(testStore), proxyController: mockCtrl, + capabilities: mockCaps, clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}}, } mgr.exposeReaper = &exposeReaper{manager: mgr} diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index ea4fa9d1e..0fb5f46ff 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -25,6 +25,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -75,22 +76,30 @@ type ClusterDeriver interface { GetClusterDomains() []string } +// CapabilityProvider queries proxy cluster capabilities from the database. +type CapabilityProvider interface { + ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool + ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool +} + type Manager struct { store store.Store accountManager account.Manager permissionsManager permissions.Manager proxyController proxy.Controller + capabilities CapabilityProvider clusterDeriver ClusterDeriver exposeReaper *exposeReaper } // NewManager creates a new service manager. -func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, clusterDeriver ClusterDeriver) *Manager { +func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver) *Manager { mgr := &Manager{ store: store, accountManager: accountManager, permissionsManager: permissionsManager, proxyController: proxyController, + capabilities: capabilities, clusterDeriver: clusterDeriver, } mgr.exposeReaper = &exposeReaper{manager: mgr} @@ -223,7 +232,7 @@ func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationCreate}) return s, nil } @@ -237,7 +246,7 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri } service.ProxyCluster = proxyCluster - if err := m.validateSubdomainRequirement(service.Domain, proxyCluster); err != nil { + if err := m.validateSubdomainRequirement(ctx, service.Domain, proxyCluster); err != nil { return err } } @@ -268,11 +277,11 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri // validateSubdomainRequirement checks whether the domain can be used bare // (without a subdomain label) on the given cluster. If the cluster reports // require_subdomain=true and the domain equals the cluster domain, it rejects. -func (m *Manager) validateSubdomainRequirement(domain, cluster string) error { +func (m *Manager) validateSubdomainRequirement(ctx context.Context, domain, cluster string) error { if domain != cluster { return nil } - requireSub := m.proxyController.ClusterRequireSubdomain(cluster) + requireSub := m.capabilities.ClusterRequireSubdomain(ctx, cluster) if requireSub != nil && *requireSub { return status.Errorf(status.InvalidArgument, "domain %s requires a subdomain label", domain) } @@ -280,6 +289,8 @@ func (m *Manager) validateSubdomainRequirement(domain, cluster string) error { } func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error { + customPorts := m.clusterCustomPorts(ctx, svc) + return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if svc.Domain != "" { if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil { @@ -287,7 +298,7 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc * } } - if err := m.ensureL4Port(ctx, transaction, svc); err != nil { + if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil { return err } @@ -307,12 +318,23 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc * }) } -// ensureL4Port auto-assigns a listen port when needed and validates cluster support. -func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service) error { +// clusterCustomPorts queries whether the cluster supports custom ports. +// Must be called before entering a transaction: the underlying query uses +// the main DB handle, which deadlocks when called inside a transaction +// that already holds the connection. +func (m *Manager) clusterCustomPorts(ctx context.Context, svc *service.Service) *bool { + if !service.IsL4Protocol(svc.Mode) { + return nil + } + return m.capabilities.ClusterSupportsCustomPorts(ctx, svc.ProxyCluster) +} + +// ensureL4Port auto-assigns a listen port when needed and validates cluster support. +// customPorts must be pre-computed via clusterCustomPorts before entering a transaction. +func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service, customPorts *bool) error { if !service.IsL4Protocol(svc.Mode) { return nil } - customPorts := m.proxyController.ClusterSupportsCustomPorts(svc.ProxyCluster) if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && (customPorts == nil || !*customPorts) { if svc.Source != service.SourceEphemeral { return status.Errorf(status.InvalidArgument, "custom ports not supported on cluster %s", svc.ProxyCluster) @@ -396,12 +418,14 @@ func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string // The count and exists queries use FOR UPDATE locking to serialize concurrent creates // for the same peer, preventing the per-peer limit from being bypassed. func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error { + customPorts := m.clusterCustomPorts(ctx, svc) + return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil { return err } - if err := m.ensureL4Port(ctx, transaction, svc); err != nil { + if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil { return err } @@ -492,7 +516,7 @@ func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, s } m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo) - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationUpdate}) return service, nil } @@ -504,28 +528,61 @@ type serviceUpdateInfo struct { } func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *service.Service) (*serviceUpdateInfo, error) { + effectiveCluster, err := m.resolveEffectiveCluster(ctx, accountID, service) + if err != nil { + return nil, err + } + + svcForCaps := *service + svcForCaps.ProxyCluster = effectiveCluster + customPorts := m.clusterCustomPorts(ctx, &svcForCaps) + var updateInfo serviceUpdateInfo - err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo) + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts) }) return &updateInfo, err } -func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo) error { +// resolveEffectiveCluster determines the cluster that will be used after the update. +// It reads the existing service without locking and derives the new cluster if the domain changed. +func (m *Manager) resolveEffectiveCluster(ctx context.Context, accountID string, svc *service.Service) (string, error) { + existing, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, svc.ID) + if err != nil { + return "", err + } + + if existing.Domain == svc.Domain { + return existing.ProxyCluster, nil + } + + if m.clusterDeriver != nil { + derived, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain) + if err != nil { + log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain) + } else { + return derived, nil + } + } + + return existing.ProxyCluster, nil +} + +func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool) error { existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID) if err != nil { return err } if existingService.Terminated { - return status.Errorf(status.PermissionDenied, "service is terminated and cannot be updated") - } + return status.Errorf(status.PermissionDenied, "service is terminated and cannot be updated") + } - if err := validateProtocolChange(existingService.Mode, service.Mode); err != nil { - return err - } + if err := validateProtocolChange(existingService.Mode, service.Mode); err != nil { + return err + } updateInfo.oldCluster = existingService.ProxyCluster updateInfo.domainChanged = existingService.Domain != service.Domain @@ -538,7 +595,7 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St service.ProxyCluster = existingService.ProxyCluster } - if err := m.validateSubdomainRequirement(service.Domain, service.ProxyCluster); err != nil { + if err := m.validateSubdomainRequirement(ctx, service.Domain, service.ProxyCluster); err != nil { return err } @@ -550,7 +607,7 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St m.preserveListenPort(service, existingService) updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled - if err := m.ensureL4Port(ctx, transaction, service); err != nil { + if err := m.ensureL4Port(ctx, transaction, service, customPorts); err != nil { return err } if err := m.checkPortConflict(ctx, transaction, service); err != nil { @@ -763,7 +820,7 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete}) return nil } @@ -804,7 +861,7 @@ func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID strin m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster) } - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete}) return nil } @@ -860,7 +917,7 @@ func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationUpdate}) return nil } @@ -1042,7 +1099,7 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s } m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster) - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationCreate}) serviceURL := "https://" + svc.Domain if service.IsL4Protocol(svc.Mode) { @@ -1063,7 +1120,7 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr } groupIDs := make([]string, 0, len(groupNames)) for _, groupName := range groupNames { - g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID) + g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID, activity.SystemInitiator) if err != nil { return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err) } @@ -1154,7 +1211,7 @@ func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serv m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster) - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete}) return nil } @@ -1205,7 +1262,7 @@ func (m *Manager) deleteExpiredPeerService(ctx context.Context, accountID, peerI meta := addPeerInfoToEventMeta(svc.EventMeta(), peer) m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activity.PeerServiceExposeExpired, meta) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster) - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete}) return nil } diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index 18e1be26e..e9403849c 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + cachestore "github.com/eko/gocache/lib/v4/store" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -18,6 +19,7 @@ import ( nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/mock_server" resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -29,6 +31,13 @@ import ( "github.com/netbirdio/netbird/shared/management/status" ) +func testCacheStore(t *testing.T) cachestore.StoreInterface { + t.Helper() + s, err := nbcache.NewStore(context.Background(), 30*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + return s +} + func TestInitializeServiceForCreate(t *testing.T) { ctx := context.Background() accountID := "test-account" @@ -422,10 +431,8 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { newProxyServer := func(t *testing.T) *nbgrpc.ProxyServiceServer { t.Helper() - tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 1*time.Hour, 10*time.Minute, 100) - require.NoError(t, err) - pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t)) + pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t)) srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) return srv } @@ -440,7 +447,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) { storedActivity = activityID.(activity.Activity) }, - UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, + UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {}, } mockStore.EXPECT(). @@ -542,7 +549,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) { storedActivity = activityID.(activity.Activity) }, - UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, + UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {}, } mockStore.EXPECT(). @@ -586,7 +593,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, meta map[string]any) { storedMeta = meta }, - UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, + UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {}, } mockStore.EXPECT(). @@ -697,16 +704,14 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) { accountMgr := &mock_server.MockAccountManager{ StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, - UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, - GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) { - return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID) + UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {}, + GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { + return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) }, } - tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100) - require.NoError(t, err) - pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t)) proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) @@ -1128,10 +1133,8 @@ func TestDeleteService_DeletesTargets(t *testing.T) { mockPerms := permissions.NewMockManager(ctrl) mockAcct := account.NewMockManager(ctrl) - tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100) - require.NoError(t, err) - pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t)) proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) @@ -1170,7 +1173,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) { mockAcct.EXPECT(). StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any()) mockAcct.EXPECT(). - UpdateAccountPeers(ctx, accountID) + UpdateAccountPeers(ctx, accountID, gomock.Any()) err = mgr.DeleteService(ctx, accountID, userID, service.ID) require.NoError(t, err) @@ -1324,11 +1327,11 @@ func TestValidateSubdomainRequirement(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ctrl := gomock.NewController(t) - mockCtrl := proxy.NewMockController(ctrl) - mockCtrl.EXPECT().ClusterRequireSubdomain(tc.cluster).Return(tc.requireSubdomain).AnyTimes() + mockCaps := proxy.NewMockManager(ctrl) + mockCaps.EXPECT().ClusterRequireSubdomain(gomock.Any(), tc.cluster).Return(tc.requireSubdomain).AnyTimes() - mgr := &Manager{proxyController: mockCtrl} - err := mgr.validateSubdomainRequirement(tc.domain, tc.cluster) + mgr := &Manager{capabilities: mockCaps} + err := mgr.validateSubdomainRequirement(context.Background(), tc.domain, tc.cluster) if tc.wantErr { require.Error(t, err) assert.Contains(t, err.Error(), "requires a subdomain label") diff --git a/management/internals/modules/reverseproxy/service/service.go b/management/internals/modules/reverseproxy/service/service.go index d956013ea..769e037bc 100644 --- a/management/internals/modules/reverseproxy/service/service.go +++ b/management/internals/modules/reverseproxy/service/service.go @@ -113,6 +113,7 @@ type AccessRestrictions struct { BlockedCIDRs []string `json:"blocked_cidrs,omitempty" gorm:"serializer:json"` AllowedCountries []string `json:"allowed_countries,omitempty" gorm:"serializer:json"` BlockedCountries []string `json:"blocked_countries,omitempty" gorm:"serializer:json"` + CrowdSecMode string `json:"crowdsec_mode,omitempty" gorm:"serializer:json"` } // Copy returns a deep copy of the AccessRestrictions. @@ -122,6 +123,7 @@ func (r AccessRestrictions) Copy() AccessRestrictions { BlockedCIDRs: slices.Clone(r.BlockedCIDRs), AllowedCountries: slices.Clone(r.AllowedCountries), BlockedCountries: slices.Clone(r.BlockedCountries), + CrowdSecMode: r.CrowdSecMode, } } @@ -555,7 +557,11 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro } if req.AccessRestrictions != nil { - s.Restrictions = restrictionsFromAPI(req.AccessRestrictions) + restrictions, err := restrictionsFromAPI(req.AccessRestrictions) + if err != nil { + return err + } + s.Restrictions = restrictions } return nil @@ -631,9 +637,9 @@ func authFromAPI(reqAuth *api.ServiceAuthConfig) AuthConfig { return auth } -func restrictionsFromAPI(r *api.AccessRestrictions) AccessRestrictions { +func restrictionsFromAPI(r *api.AccessRestrictions) (AccessRestrictions, error) { if r == nil { - return AccessRestrictions{} + return AccessRestrictions{}, nil } var res AccessRestrictions if r.AllowedCidrs != nil { @@ -648,11 +654,19 @@ func restrictionsFromAPI(r *api.AccessRestrictions) AccessRestrictions { if r.BlockedCountries != nil { res.BlockedCountries = *r.BlockedCountries } - return res + if r.CrowdsecMode != nil { + if !r.CrowdsecMode.Valid() { + return AccessRestrictions{}, fmt.Errorf("invalid crowdsec_mode %q", *r.CrowdsecMode) + } + res.CrowdSecMode = string(*r.CrowdsecMode) + } + return res, nil } func restrictionsToAPI(r AccessRestrictions) *api.AccessRestrictions { - if len(r.AllowedCIDRs) == 0 && len(r.BlockedCIDRs) == 0 && len(r.AllowedCountries) == 0 && len(r.BlockedCountries) == 0 { + if len(r.AllowedCIDRs) == 0 && len(r.BlockedCIDRs) == 0 && + len(r.AllowedCountries) == 0 && len(r.BlockedCountries) == 0 && + r.CrowdSecMode == "" { return nil } res := &api.AccessRestrictions{} @@ -668,11 +682,17 @@ func restrictionsToAPI(r AccessRestrictions) *api.AccessRestrictions { if len(r.BlockedCountries) > 0 { res.BlockedCountries = &r.BlockedCountries } + if r.CrowdSecMode != "" { + mode := api.AccessRestrictionsCrowdsecMode(r.CrowdSecMode) + res.CrowdsecMode = &mode + } return res } func restrictionsToProto(r AccessRestrictions) *proto.AccessRestrictions { - if len(r.AllowedCIDRs) == 0 && len(r.BlockedCIDRs) == 0 && len(r.AllowedCountries) == 0 && len(r.BlockedCountries) == 0 { + if len(r.AllowedCIDRs) == 0 && len(r.BlockedCIDRs) == 0 && + len(r.AllowedCountries) == 0 && len(r.BlockedCountries) == 0 && + r.CrowdSecMode == "" { return nil } return &proto.AccessRestrictions{ @@ -680,6 +700,7 @@ func restrictionsToProto(r AccessRestrictions) *proto.AccessRestrictions { BlockedCidrs: r.BlockedCIDRs, AllowedCountries: r.AllowedCountries, BlockedCountries: r.BlockedCountries, + CrowdsecMode: r.CrowdSecMode, } } @@ -787,6 +808,11 @@ func (s *Service) validateHTTPTargets() error { } func (s *Service) validateL4Target(target *Target) error { + // L4 services have a single target; per-target disable is meaningless + // (use the service-level Enabled flag instead). Force it on so that + // buildPathMappings always includes the target in the proto. + target.Enabled = true + if target.Port == 0 { return errors.New("target port is required for L4 services") } @@ -983,7 +1009,20 @@ const ( // validateAccessRestrictions validates and normalizes access restriction // entries. Country codes are uppercased in place. +func validateCrowdSecMode(mode string) error { + switch mode { + case "", "off", "enforce", "observe": + return nil + default: + return fmt.Errorf("crowdsec_mode %q is invalid", mode) + } +} + func validateAccessRestrictions(r *AccessRestrictions) error { + if err := validateCrowdSecMode(r.CrowdSecMode); err != nil { + return err + } + if len(r.AllowedCIDRs) > maxCIDREntries { return fmt.Errorf("allowed_cidrs: exceeds maximum of %d entries", maxCIDREntries) } @@ -997,35 +1036,37 @@ func validateAccessRestrictions(r *AccessRestrictions) error { return fmt.Errorf("blocked_countries: exceeds maximum of %d entries", maxCountryEntries) } - for i, raw := range r.AllowedCIDRs { + if err := validateCIDRList("allowed_cidrs", r.AllowedCIDRs); err != nil { + return err + } + if err := validateCIDRList("blocked_cidrs", r.BlockedCIDRs); err != nil { + return err + } + if err := normalizeCountryList("allowed_countries", r.AllowedCountries); err != nil { + return err + } + return normalizeCountryList("blocked_countries", r.BlockedCountries) +} + +func validateCIDRList(field string, cidrs []string) error { + for i, raw := range cidrs { prefix, err := netip.ParsePrefix(raw) if err != nil { - return fmt.Errorf("allowed_cidrs[%d]: %w", i, err) + return fmt.Errorf("%s[%d]: %w", field, i, err) } if prefix != prefix.Masked() { - return fmt.Errorf("allowed_cidrs[%d]: %q has host bits set, use %s instead", i, raw, prefix.Masked()) + return fmt.Errorf("%s[%d]: %q has host bits set, use %s instead", field, i, raw, prefix.Masked()) } } - for i, raw := range r.BlockedCIDRs { - prefix, err := netip.ParsePrefix(raw) - if err != nil { - return fmt.Errorf("blocked_cidrs[%d]: %w", i, err) - } - if prefix != prefix.Masked() { - return fmt.Errorf("blocked_cidrs[%d]: %q has host bits set, use %s instead", i, raw, prefix.Masked()) - } - } - for i, code := range r.AllowedCountries { + return nil +} + +func normalizeCountryList(field string, codes []string) error { + for i, code := range codes { if len(code) != 2 { - return fmt.Errorf("allowed_countries[%d]: %q must be a 2-letter ISO 3166-1 alpha-2 code", i, code) + return fmt.Errorf("%s[%d]: %q must be a 2-letter ISO 3166-1 alpha-2 code", field, i, code) } - r.AllowedCountries[i] = strings.ToUpper(code) - } - for i, code := range r.BlockedCountries { - if len(code) != 2 { - return fmt.Errorf("blocked_countries[%d]: %q must be a 2-letter ISO 3166-1 alpha-2 code", i, code) - } - r.BlockedCountries[i] = strings.ToUpper(code) + codes[i] = strings.ToUpper(code) } return nil } diff --git a/management/internals/modules/zones/manager/manager.go b/management/internals/modules/zones/manager/manager.go index 8548dd48c..439671e65 100644 --- a/management/internals/modules/zones/manager/manager.go +++ b/management/internals/modules/zones/manager/manager.go @@ -11,6 +11,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -144,7 +145,7 @@ func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneUpdated, zone.EventMeta()) - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZone, Operation: types.UpdateOperationUpdate}) return zone, nil } @@ -206,7 +207,7 @@ func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID event() } - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZone, Operation: types.UpdateOperationDelete}) return nil } diff --git a/management/internals/modules/zones/records/manager/manager.go b/management/internals/modules/zones/records/manager/manager.go index 5374a2ef2..7458b41db 100644 --- a/management/internals/modules/zones/records/manager/manager.go +++ b/management/internals/modules/zones/records/manager/manager.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -95,7 +96,7 @@ func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneI meta := record.EventMeta(zone.ID, zone.Name) m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordCreated, meta) - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZoneRecord, Operation: types.UpdateOperationCreate}) return record, nil } @@ -154,7 +155,7 @@ func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneI meta := record.EventMeta(zone.ID, zone.Name) m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordUpdated, meta) - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZoneRecord, Operation: types.UpdateOperationUpdate}) return record, nil } @@ -201,7 +202,7 @@ func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneI meta := record.EventMeta(zone.ID, zone.Name) m.accountManager.StoreEvent(ctx, userID, recordID, accountID, activity.DNSRecordDeleted, meta) - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZoneRecord, Operation: types.UpdateOperationDelete}) return nil } diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 88d37ca80..f2ab0a2c4 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -18,6 +18,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" + cachestore "github.com/eko/gocache/lib/v4/store" "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/encryption" @@ -26,8 +27,10 @@ import ( accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" nbContext "github.com/netbirdio/netbird/management/server/context" nbhttp "github.com/netbirdio/netbird/management/server/http" + "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" mgmtProto "github.com/netbirdio/netbird/shared/management/proto" @@ -58,6 +61,18 @@ func (s *BaseServer) Metrics() telemetry.AppMetrics { }) } +// CacheStore returns a shared cache store backed by Redis or in-memory depending on the environment. +// All consumers should reuse this store to avoid creating multiple Redis connections. +func (s *BaseServer) CacheStore() cachestore.StoreInterface { + return Create(s, func() cachestore.StoreInterface { + cs, err := nbcache.NewStore(context.Background(), nbcache.DefaultStoreMaxTimeout, nbcache.DefaultStoreCleanupInterval, nbcache.DefaultStoreMaxConn) + if err != nil { + log.Fatalf("failed to create shared cache store: %v", err) + } + return cs + }) +} + func (s *BaseServer) Store() store.Store { return Create(s, func() store.Store { store, err := store.NewStore(context.Background(), s.Config.StoreConfig.Engine, s.Config.Datadir, s.Metrics(), false) @@ -95,7 +110,7 @@ func (s *BaseServer) EventStore() activity.Store { func (s *BaseServer) APIHandler() http.Handler { return Create(s, func() http.Handler { - httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies) + httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter()) if err != nil { log.Fatalf("failed to create API handler: %v", err) } @@ -103,6 +118,15 @@ func (s *BaseServer) APIHandler() http.Handler { }) } +func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter { + return Create(s, func() *middleware.APIRateLimiter { + cfg, enabled := middleware.RateLimiterConfigFromEnv() + limiter := middleware.NewAPIRateLimiter(cfg) + limiter.SetEnabled(enabled) + return limiter + }) +} + func (s *BaseServer) GRPCServer() *grpc.Server { return Create(s, func() *grpc.Server { trustedPeers := s.Config.ReverseProxy.TrustedPeers @@ -149,7 +173,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server { } gRPCAPIHandler := grpc.NewServer(gRPCOpts...) - srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider()) + srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider(), s.SessionStore()) if err != nil { log.Fatalf("failed to create management server: %v", err) } @@ -195,10 +219,7 @@ func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig { func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore { return Create(s, func() *nbgrpc.OneTimeTokenStore { - tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 5*time.Minute, 10*time.Minute, 100) - if err != nil { - log.Fatalf("failed to create proxy token store: %v", err) - } + tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), s.CacheStore()) log.Info("One-time token store initialized for proxy authentication") return tokenStore }) @@ -206,11 +227,7 @@ func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore { func (s *BaseServer) PKCEVerifierStore() *nbgrpc.PKCEVerifierStore { return Create(s, func() *nbgrpc.PKCEVerifierStore { - pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100) - if err != nil { - log.Fatalf("failed to create PKCE verifier store: %v", err) - } - return pkceStore + return nbgrpc.NewPKCEVerifierStore(context.Background(), s.CacheStore()) }) } diff --git a/management/internals/server/config/config.go b/management/internals/server/config/config.go index 0ba393263..fb9c842b7 100644 --- a/management/internals/server/config/config.go +++ b/management/internals/server/config/config.go @@ -203,7 +203,7 @@ type ReverseProxy struct { // AccessLogRetentionDays specifies the number of days to retain access logs. // Logs older than this duration will be automatically deleted during cleanup. - // A value of 0 or negative means logs are kept indefinitely (no cleanup). + // A value of 0 will default to 7 days. Negative means logs are kept indefinitely (no cleanup). AccessLogRetentionDays int // AccessLogCleanupIntervalHours specifies how often (in hours) to run the cleanup routine. diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index 62ed659c0..89bdf0abe 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -6,6 +6,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" @@ -20,6 +21,7 @@ import ( "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" ) func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager { @@ -40,7 +42,8 @@ func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValida context.Background(), s.PeersManager(), s.SettingsManager(), - s.EventStore()) + s.EventStore(), + s.CacheStore()) if err != nil { log.Errorf("failed to create integrated peer validator: %v", err) } @@ -64,6 +67,12 @@ func (s *BaseServer) SecretsManager() grpc.SecretsManager { }) } +func (s *BaseServer) SessionStore() *auth.SessionStore { + return Create(s, func() *auth.SessionStore { + return auth.NewSessionStore(s.CacheStore()) + }) +} + func (s *BaseServer) AuthManager() auth.Manager { audiences := s.Config.GetAuthAudiences() audience := s.Config.HttpConfig.AuthAudience @@ -71,6 +80,7 @@ func (s *BaseServer) AuthManager() auth.Manager { signingKeyRefreshEnabled := s.Config.HttpConfig.IdpSignKeyRefreshEnabled issuer := s.Config.HttpConfig.AuthIssuer userIDClaim := s.Config.HttpConfig.AuthUserIDClaim + var keyFetcher nbjwt.KeyFetcher // Use embedded IdP configuration if available if oauthProvider := s.OAuthConfigProvider(); oauthProvider != nil { @@ -78,8 +88,11 @@ func (s *BaseServer) AuthManager() auth.Manager { if len(audiences) > 0 { audience = audiences[0] // Use the first client ID as the primary audience } - // Use localhost keys location for internal validation (management has embedded Dex) - keysLocation = oauthProvider.GetLocalKeysLocation() + keyFetcher = oauthProvider.GetKeyFetcher() + // Fall back to default keys location if direct key fetching is not available + if keyFetcher == nil { + keysLocation = oauthProvider.GetLocalKeysLocation() + } signingKeyRefreshEnabled = true issuer = oauthProvider.GetIssuer() userIDClaim = oauthProvider.GetUserIDClaim() @@ -92,7 +105,8 @@ func (s *BaseServer) AuthManager() auth.Manager { keysLocation, userIDClaim, audiences, - signingKeyRefreshEnabled) + signingKeyRefreshEnabled, + keyFetcher) }) } diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index a32cf6046..9b2ec2989 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -100,7 +100,7 @@ func (s *BaseServer) PeersManager() peers.Manager { func (s *BaseServer) AccountManager() account.Manager { return Create(s, func() account.Manager { - accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy) + accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy, s.CacheStore()) if err != nil { log.Fatalf("failed to create account service: %v", err) } @@ -117,9 +117,11 @@ func (s *BaseServer) IdpManager() idp.Manager { return Create(s, func() idp.Manager { var idpManager idp.Manager var err error + // Use embedded IdP service if embedded Dex is configured and enabled. // Legacy IdpManager won't be used anymore even if configured. - if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled { + embeddedEnabled := s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled + if embeddedEnabled { idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics()) if err != nil { log.Fatalf("failed to create embedded IDP service: %v", err) @@ -195,7 +197,7 @@ func (s *BaseServer) RecordsManager() records.Manager { func (s *BaseServer) ServiceManager() service.Manager { return Create(s, func() service.Manager { - return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ReverseProxyDomainManager()) + return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager()) }) } @@ -212,9 +214,6 @@ func (s *BaseServer) ProxyManager() proxy.Manager { func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager { return Create(s, func() *manager.Manager { m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager()) - s.AfterInit(func(s *BaseServer) { - m.SetClusterCapabilities(s.ServiceProxyController()) - }) return &m }) } diff --git a/management/internals/shared/grpc/onetime_token.go b/management/internals/shared/grpc/onetime_token.go index 7999407db..acfd6eafb 100644 --- a/management/internals/shared/grpc/onetime_token.go +++ b/management/internals/shared/grpc/onetime_token.go @@ -14,8 +14,6 @@ import ( "github.com/eko/gocache/lib/v4/cache" "github.com/eko/gocache/lib/v4/store" log "github.com/sirupsen/logrus" - - nbcache "github.com/netbirdio/netbird/management/server/cache" ) type tokenMetadata struct { @@ -32,17 +30,12 @@ type OneTimeTokenStore struct { ctx context.Context } -// NewOneTimeTokenStore creates a token store with automatic backend selection -func NewOneTimeTokenStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*OneTimeTokenStore, error) { - cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn) - if err != nil { - return nil, fmt.Errorf("failed to create cache store: %w", err) - } - +// NewOneTimeTokenStore creates a token store using the provided shared cache store. +func NewOneTimeTokenStore(ctx context.Context, cacheStore store.StoreInterface) *OneTimeTokenStore { return &OneTimeTokenStore{ cache: cache.New[string](cacheStore), ctx: ctx, - }, nil + } } // GenerateToken creates a new cryptographically secure one-time token diff --git a/management/internals/shared/grpc/pkce_verifier.go b/management/internals/shared/grpc/pkce_verifier.go index 441e8b051..a1325256c 100644 --- a/management/internals/shared/grpc/pkce_verifier.go +++ b/management/internals/shared/grpc/pkce_verifier.go @@ -8,8 +8,6 @@ import ( "github.com/eko/gocache/lib/v4/cache" "github.com/eko/gocache/lib/v4/store" log "github.com/sirupsen/logrus" - - nbcache "github.com/netbirdio/netbird/management/server/cache" ) // PKCEVerifierStore manages PKCE verifiers for OAuth flows. @@ -19,17 +17,12 @@ type PKCEVerifierStore struct { ctx context.Context } -// NewPKCEVerifierStore creates a PKCE verifier store with automatic backend selection -func NewPKCEVerifierStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*PKCEVerifierStore, error) { - cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn) - if err != nil { - return nil, fmt.Errorf("failed to create cache store: %w", err) - } - +// NewPKCEVerifierStore creates a PKCE verifier store using the provided shared cache store. +func NewPKCEVerifierStore(ctx context.Context, cacheStore store.StoreInterface) *PKCEVerifierStore { return &PKCEVerifierStore{ cache: cache.New[string](cacheStore), ctx: ctx, - }, nil + } } // Store saves a PKCE verifier associated with an OAuth state parameter. diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index 5fa382af0..a5e352e75 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -182,9 +182,22 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err) } - // Register proxy in database - if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo); err != nil { - log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err) + // Register proxy in database with capabilities + var caps *proxy.Capabilities + if c := req.GetCapabilities(); c != nil { + caps = &proxy.Capabilities{ + SupportsCustomPorts: c.SupportsCustomPorts, + RequireSubdomain: c.RequireSubdomain, + SupportsCrowdsec: c.SupportsCrowdsec, + } + } + if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil { + log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err) + s.connectedProxies.Delete(proxyID) + if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil { + log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr) + } + return status.Errorf(codes.Internal, "register proxy in database: %v", err) } log.WithFields(log.Fields{ @@ -297,6 +310,9 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn * } m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig()) + if !proxyAcceptsMapping(conn, m) { + continue + } mappings = append(mappings, m) } return mappings, nil @@ -445,22 +461,46 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd log.Debugf("Sending service update to cluster %s", clusterAddr) for _, proxyID := range proxyIDs { - if connVal, ok := s.connectedProxies.Load(proxyID); ok { - conn := connVal.(*proxyConnection) - msg := s.perProxyMessage(updateResponse, proxyID) - if msg == nil { - continue - } - select { - case conn.sendChan <- msg: - log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr) - default: - log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) - } + connVal, ok := s.connectedProxies.Load(proxyID) + if !ok { + continue + } + conn := connVal.(*proxyConnection) + if !proxyAcceptsMapping(conn, update) { + log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id) + continue + } + msg := s.perProxyMessage(updateResponse, proxyID) + if msg == nil { + continue + } + select { + case conn.sendChan <- msg: + log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr) + default: + log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) } } } +// proxyAcceptsMapping returns whether the proxy should receive this mapping. +// Old proxies that never reported capabilities are skipped for non-TLS L4 +// mappings with a custom listen port, since they don't understand the +// protocol. Proxies that report capabilities (even SupportsCustomPorts=false) +// are new enough to handle the mapping. TLS uses SNI routing and works on +// any proxy. Delete operations are always sent so proxies can clean up. +func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool { + if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED { + return true + } + if mapping.ListenPort == 0 || mapping.Mode == "tls" { + return true + } + // Old proxies that never reported capabilities don't understand + // custom port mappings. + return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil +} + // perProxyMessage returns a copy of update with a fresh one-time token for // create/update operations. For delete operations the original mapping is // used unchanged because proxies do not need to authenticate for removal. @@ -508,64 +548,6 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping { } } -// ClusterSupportsCustomPorts returns whether any connected proxy in the given -// cluster reports custom port support. Returns nil if no proxy has reported -// capabilities (old proxies that predate the field). -func (s *ProxyServiceServer) ClusterSupportsCustomPorts(clusterAddr string) *bool { - if s.proxyController == nil { - return nil - } - - var hasCapabilities bool - for _, pid := range s.proxyController.GetProxiesForCluster(clusterAddr) { - connVal, ok := s.connectedProxies.Load(pid) - if !ok { - continue - } - conn := connVal.(*proxyConnection) - if conn.capabilities == nil || conn.capabilities.SupportsCustomPorts == nil { - continue - } - if *conn.capabilities.SupportsCustomPorts { - return ptr(true) - } - hasCapabilities = true - } - if hasCapabilities { - return ptr(false) - } - return nil -} - -// ClusterRequireSubdomain returns whether any connected proxy in the given -// cluster reports that a subdomain is required. Returns nil if no proxy has -// reported the capability (defaults to not required). -func (s *ProxyServiceServer) ClusterRequireSubdomain(clusterAddr string) *bool { - if s.proxyController == nil { - return nil - } - - var hasCapabilities bool - for _, pid := range s.proxyController.GetProxiesForCluster(clusterAddr) { - connVal, ok := s.connectedProxies.Load(pid) - if !ok { - continue - } - conn := connVal.(*proxyConnection) - if conn.capabilities == nil || conn.capabilities.RequireSubdomain == nil { - continue - } - if *conn.capabilities.RequireSubdomain { - return ptr(true) - } - hasCapabilities = true - } - if hasCapabilities { - return ptr(false) - } - return nil -} - func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId()) if err != nil { diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index 83c99020d..de4e96d93 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -9,13 +9,22 @@ import ( "testing" "time" + cachestore "github.com/eko/gocache/lib/v4/store" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/shared/management/proto" ) +func testCacheStore(t *testing.T) cachestore.StoreInterface { + t.Helper() + s, err := nbcache.NewStore(context.Background(), 30*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + return s +} + type testProxyController struct { mu sync.Mutex clusterProxies map[string]map[string]struct{} @@ -53,14 +62,6 @@ func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, clus return nil } -func (c *testProxyController) ClusterSupportsCustomPorts(_ string) *bool { - return ptr(true) -} - -func (c *testProxyController) ClusterRequireSubdomain(_ string) *bool { - return nil -} - func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string { c.mu.Lock() defer c.mu.Unlock() @@ -122,11 +123,8 @@ func drainEmpty(ch chan *proto.GetMappingUpdateResponse) bool { func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { ctx := context.Background() - tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100) - require.NoError(t, err) - - pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) s := &ProxyServiceServer{ tokenStore: tokenStore, @@ -182,11 +180,8 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { ctx := context.Background() - tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100) - require.NoError(t, err) - - pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) s := &ProxyServiceServer{ tokenStore: tokenStore, @@ -219,11 +214,8 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) { ctx := context.Background() - tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100) - require.NoError(t, err) - - pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) s := &ProxyServiceServer{ tokenStore: tokenStore, @@ -275,8 +267,7 @@ func generateState(s *ProxyServiceServer, redirectURL string) string { func TestOAuthState_NeverTheSame(t *testing.T) { ctx := context.Background() - pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) s := &ProxyServiceServer{ oidcConfig: ProxyOIDCConfig{ @@ -304,8 +295,7 @@ func TestOAuthState_NeverTheSame(t *testing.T) { func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) { ctx := context.Background() - pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) s := &ProxyServiceServer{ oidcConfig: ProxyOIDCConfig{ @@ -315,7 +305,7 @@ func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) { } // Old format had only 2 parts: base64(url)|hmac - err = s.pkceVerifierStore.Store("base64url|hmac", "test", 10*time.Minute) + err := s.pkceVerifierStore.Store("base64url|hmac", "test", 10*time.Minute) require.NoError(t, err) _, _, err = s.ValidateState("base64url|hmac") @@ -325,8 +315,7 @@ func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) { func TestValidateState_RejectsInvalidHMAC(t *testing.T) { ctx := context.Background() - pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) s := &ProxyServiceServer{ oidcConfig: ProxyOIDCConfig{ @@ -336,7 +325,7 @@ func TestValidateState_RejectsInvalidHMAC(t *testing.T) { } // Store with tampered HMAC - err = s.pkceVerifierStore.Store("dGVzdA==|nonce|wrong-hmac", "test", 10*time.Minute) + err := s.pkceVerifierStore.Store("dGVzdA==|nonce|wrong-hmac", "test", 10*time.Minute) require.NoError(t, err) _, _, err = s.ValidateState("dGVzdA==|nonce|wrong-hmac") @@ -345,8 +334,7 @@ func TestValidateState_RejectsInvalidHMAC(t *testing.T) { } func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) { - tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := NewOneTimeTokenStore(context.Background(), testCacheStore(t)) s := &ProxyServiceServer{ tokenStore: tokenStore, @@ -355,14 +343,14 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) { const cluster = "proxy.example.com" - // Proxy A supports custom ports. - chA := registerFakeProxyWithCaps(s, "proxy-a", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}) - // Proxy B does NOT support custom ports (shared cloud proxy). - chB := registerFakeProxyWithCaps(s, "proxy-b", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + // Modern proxy reports capabilities. + chModern := registerFakeProxyWithCaps(s, "proxy-modern", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}) + // Legacy proxy never reported capabilities (nil). + chLegacy := registerFakeProxy(s, "proxy-legacy", cluster) ctx := context.Background() - // TLS passthrough works on all proxies regardless of custom port support. + // TLS passthrough with custom port: all proxies receive it (SNI routing). tlsMapping := &proto.ProxyMapping{ Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, Id: "service-tls", @@ -375,12 +363,26 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) { s.SendServiceUpdateToCluster(ctx, tlsMapping, cluster) - msgA := drainMapping(chA) - msgB := drainMapping(chB) - assert.NotNil(t, msgA, "proxy-a should receive TLS mapping") - assert.NotNil(t, msgB, "proxy-b should receive TLS mapping (passthrough works on all proxies)") + assert.NotNil(t, drainMapping(chModern), "modern proxy should receive TLS mapping") + assert.NotNil(t, drainMapping(chLegacy), "legacy proxy should receive TLS mapping (SNI works on all)") - // Send an HTTP mapping: both should receive it. + // TCP mapping with custom port: only modern proxy receives it. + tcpMapping := &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, + Id: "service-tcp", + AccountId: "account-1", + Domain: "db.example.com", + Mode: "tcp", + ListenPort: 5432, + Path: []*proto.PathMapping{{Target: "10.0.0.5:5432"}}, + } + + s.SendServiceUpdateToCluster(ctx, tcpMapping, cluster) + + assert.NotNil(t, drainMapping(chModern), "modern proxy should receive TCP custom-port mapping") + assert.Nil(t, drainMapping(chLegacy), "legacy proxy should NOT receive TCP custom-port mapping") + + // HTTP mapping (no listen port): both receive it. httpMapping := &proto.ProxyMapping{ Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, Id: "service-http", @@ -391,15 +393,20 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) { s.SendServiceUpdateToCluster(ctx, httpMapping, cluster) - msgA = drainMapping(chA) - msgB = drainMapping(chB) - assert.NotNil(t, msgA, "proxy-a should receive HTTP mapping") - assert.NotNil(t, msgB, "proxy-b should receive HTTP mapping") + assert.NotNil(t, drainMapping(chModern), "modern proxy should receive HTTP mapping") + assert.NotNil(t, drainMapping(chLegacy), "legacy proxy should receive HTTP mapping") + + // Proxy that reports SupportsCustomPorts=false still receives custom-port + // mappings because it understands the protocol (it's new enough). + chNewNoCustom := registerFakeProxyWithCaps(s, "proxy-new-no-custom", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + + s.SendServiceUpdateToCluster(ctx, tcpMapping, cluster) + + assert.NotNil(t, drainMapping(chNewNoCustom), "new proxy with SupportsCustomPorts=false should still receive mapping") } func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) { - tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := NewOneTimeTokenStore(context.Background(), testCacheStore(t)) s := &ProxyServiceServer{ tokenStore: tokenStore, @@ -408,7 +415,8 @@ func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) { const cluster = "proxy.example.com" - chShared := registerFakeProxyWithCaps(s, "proxy-shared", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + // Legacy proxy (no capabilities) still receives TLS since it uses SNI. + chLegacy := registerFakeProxy(s, "proxy-legacy", cluster) tlsMapping := &proto.ProxyMapping{ Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, @@ -421,16 +429,15 @@ func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) { s.SendServiceUpdateToCluster(context.Background(), tlsMapping, cluster) - msg := drainMapping(chShared) - assert.NotNil(t, msg, "shared proxy should receive TLS mapping even without custom port support") + msg := drainMapping(chLegacy) + assert.NotNil(t, msg, "legacy proxy should receive TLS mapping (SNI works without custom port support)") } // TestServiceModifyNotifications exercises every possible modification // scenario for an existing service, verifying the correct update types // reach the correct clusters. func TestServiceModifyNotifications(t *testing.T) { - tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := NewOneTimeTokenStore(context.Background(), testCacheStore(t)) newServer := func() (*ProxyServiceServer, map[string]chan *proto.GetMappingUpdateResponse) { s := &ProxyServiceServer{ @@ -589,7 +596,7 @@ func TestServiceModifyNotifications(t *testing.T) { s.SetProxyController(newTestProxyController()) const cluster = "proxy.example.com" chModern := registerFakeProxyWithCaps(s, "modern", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}) - chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + chLegacy := registerFakeProxy(s, "legacy", cluster) // TLS passthrough works on all proxies regardless of custom port support s.SendServiceUpdateToCluster(ctx, tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), cluster) @@ -608,7 +615,7 @@ func TestServiceModifyNotifications(t *testing.T) { } s.SetProxyController(newTestProxyController()) const cluster = "proxy.example.com" - chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + chLegacy := registerFakeProxy(s, "legacy", cluster) mapping := tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED) mapping.ListenPort = 0 // default port diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 6e8358f02..0c1611e7f 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -14,6 +14,7 @@ import ( "sync/atomic" "time" + jwtv5 "github.com/golang-jwt/jwt/v5" pb "github.com/golang/protobuf/proto" // nolint "github.com/golang/protobuf/ptypes/timestamp" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" @@ -67,6 +68,7 @@ type Server struct { appMetrics telemetry.AppMetrics peerLocks sync.Map authManager auth.Manager + sessionStore *auth.SessionStore logBlockedPeers bool blockPeersWithSameConfig bool @@ -98,6 +100,7 @@ func NewServer( integratedPeerValidator integrated_validator.IntegratedValidator, networkMapController network_map.Controller, oAuthConfigProvider idp.OAuthConfigProvider, + sessionStore *auth.SessionStore, ) (*Server, error) { if appMetrics != nil { // update gauge based on number of connected peers which is equal to open gRPC streams @@ -140,6 +143,7 @@ func NewServer( integratedPeerValidator: integratedPeerValidator, networkMapController: networkMapController, oAuthConfigProvider: oAuthConfigProvider, + sessionStore: sessionStore, loginFilter: newLoginFilter(), @@ -535,7 +539,7 @@ func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID st log.WithContext(ctx).Debugf("peer %s has been disconnected", peer.Key) } -func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) { +func (s *Server) validateToken(ctx context.Context, peerKey, jwtToken string) (string, error) { if s.authManager == nil { return "", status.Errorf(codes.Internal, "missing auth manager") } @@ -545,6 +549,10 @@ func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, er return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err) } + if err := s.claimLoginToken(ctx, peerKey, jwtToken, token); err != nil { + return "", err + } + // we need to call this method because if user is new, we will automatically add it to existing or create a new account accountId, _, err := s.accountManager.GetAccountIDFromUserAuth(ctx, userAuth) if err != nil { @@ -828,6 +836,31 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne return loginResp, nil } +func (s *Server) claimLoginToken(ctx context.Context, peerKey, jwtToken string, token *jwtv5.Token) error { + if s.sessionStore == nil || token == nil { + return nil + } + + exp, err := token.Claims.GetExpirationTime() + if err != nil || exp == nil { + log.WithContext(ctx).Warnf("JWT has no usable exp claim for peer %s", peerKey) + return status.Error(codes.Unauthenticated, "jwt token has no expiration") + } + + err = s.sessionStore.RegisterToken(ctx, jwtToken, exp.Time) + if err == nil { + return nil + } + + if errors.Is(err, auth.ErrTokenAlreadyUsed) || errors.Is(err, auth.ErrTokenExpired) { + log.WithContext(ctx).Warnf("%v for peer %s", err, peerKey) + return status.Error(codes.Unauthenticated, err.Error()) + } + + log.WithContext(ctx).Warnf("failed to claim JWT for peer %s: %v", peerKey, err) + return status.Error(codes.Unavailable, "failed to claim jwt token") +} + // processJwtToken validates the existence of a JWT token in the login request, and returns the corresponding user ID if // the token is valid. // @@ -838,7 +871,7 @@ func (s *Server) processJwtToken(ctx context.Context, loginReq *proto.LoginReque if loginReq.GetJwtToken() != "" { var err error for i := 0; i < 3; i++ { - userID, err = s.validateToken(ctx, loginReq.GetJwtToken()) + userID, err = s.validateToken(ctx, peerKey.String(), loginReq.GetJwtToken()) if err == nil { break } diff --git a/management/internals/shared/grpc/validate_session_test.go b/management/internals/shared/grpc/validate_session_test.go index 2f77de86e..d1d7fc8b7 100644 --- a/management/internals/shared/grpc/validate_session_test.go +++ b/management/internals/shared/grpc/validate_session_test.go @@ -39,11 +39,8 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup { usersManager := &testValidateSessionUsersManager{store: testStore} proxyManager := &testValidateSessionProxyManager{} - tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100) - require.NoError(t, err) - - pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager) proxyService.SetServiceManager(serviceManager) @@ -327,7 +324,7 @@ func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, type testValidateSessionProxyManager struct{} -func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string) error { +func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *proxy.Capabilities) error { return nil } @@ -335,7 +332,7 @@ func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string return nil } -func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ string) error { +func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _, _, _ string) error { return nil } @@ -351,6 +348,18 @@ func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time return nil } +func (m *testValidateSessionProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool { + return nil +} + +func (m *testValidateSessionProxyManager) ClusterRequireSubdomain(_ context.Context, _ string) *bool { + return nil +} + +func (m *testValidateSessionProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) *bool { + return nil +} + type testValidateSessionUsersManager struct { store store.Store } diff --git a/management/server/account.go b/management/server/account.go index 75db36a5f..4b71ab486 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -181,7 +181,7 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups [] return modified, newUserAutoGroups, newGroupsToCreate, nil } -// BuildManager creates a new DefaultAccountManager with a provided Store +// BuildManager creates a new DefaultAccountManager with all dependencies. func BuildManager( ctx context.Context, config *nbconfig.Config, @@ -199,6 +199,7 @@ func BuildManager( settingsManager settings.Manager, permissionsManager permissions.Manager, disableDefaultPolicy bool, + sharedCacheStore cacheStore.StoreInterface, ) (*DefaultAccountManager, error) { start := time.Now() defer func() { @@ -247,16 +248,12 @@ func BuildManager( log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", accountsCounter) } - cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn) - if err != nil { - return nil, fmt.Errorf("getting cache store: %s", err) - } - am.externalCacheManager = nbcache.NewUserDataCache(cacheStore) - am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cacheStore) + am.externalCacheManager = nbcache.NewUserDataCache(sharedCacheStore) + am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, sharedCacheStore) if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) { go func() { - err := am.warmupIDPCache(ctx, cacheStore) + err := am.warmupIDPCache(ctx, sharedCacheStore) if err != nil { log.WithContext(ctx).Warnf("failed warming up cache due to error: %v", err) // todo retry? @@ -403,7 +400,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers { - go am.UpdateAccountPeers(ctx, accountID) + go am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceAccountSettings, Operation: types.UpdateOperationUpdate}) } return newSettings, nil @@ -742,11 +739,6 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err) } - err = am.serviceManager.DeleteAllServices(ctx, accountID, userID) - if err != nil { - return status.Errorf(status.Internal, "failed to delete service %s: %v", accountID, err) - } - for _, otherUser := range account.Users { if otherUser.Id == userID { continue @@ -1589,7 +1581,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth if removedGroupAffectsPeers || newGroupsAffectsPeers { log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId) - am.BufferUpdateAccountPeers(ctx, userAuth.AccountId) + am.BufferUpdateAccountPeers(ctx, userAuth.AccountId, types.UpdateReason{Resource: types.UpdateResourceUser, Operation: types.UpdateOperationUpdate}) } return nil diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 45af63ae8..626ed222d 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -75,7 +75,7 @@ type Manager interface { GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error) GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) - GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) + GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error @@ -124,8 +124,8 @@ type Manager interface { GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error - UpdateAccountPeers(ctx context.Context, accountID string) - BufferUpdateAccountPeers(ctx context.Context, accountID string) + UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) + BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error GetStore() store.Store diff --git a/management/server/account/manager_mock.go b/management/server/account/manager_mock.go index 90700c795..8f3b22ecc 100644 --- a/management/server/account/manager_mock.go +++ b/management/server/account/manager_mock.go @@ -111,15 +111,15 @@ func (mr *MockManagerMockRecorder) ApproveUser(ctx, accountID, initiatorUserID, } // BufferUpdateAccountPeers mocks base method. -func (m *MockManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { +func (m *MockManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) { m.ctrl.T.Helper() - m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID) + m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID, reason) } // BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers. -func (mr *MockManagerMockRecorder) BufferUpdateAccountPeers(ctx, accountID interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, reason interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).BufferUpdateAccountPeers), ctx, accountID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).BufferUpdateAccountPeers), ctx, accountID, reason) } // BuildUserInfosForAccount mocks base method. @@ -736,18 +736,18 @@ func (mr *MockManagerMockRecorder) GetGroup(ctx, accountId, groupID, userID inte } // GetGroupByName mocks base method. -func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) { +func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID) + ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID, userID) ret0, _ := ret[0].(*types.Group) ret1, _ := ret[1].(error) return ret0, ret1 } // GetGroupByName indicates an expected call of GetGroupByName. -func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID, userID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID, userID) } // GetIdentityProvider mocks base method. @@ -1597,15 +1597,15 @@ func (mr *MockManagerMockRecorder) UpdateAccountOnboarding(ctx, accountID, userI } // UpdateAccountPeers mocks base method. -func (m *MockManager) UpdateAccountPeers(ctx context.Context, accountID string) { +func (m *MockManager) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) { m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID) + m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID, reason) } // UpdateAccountPeers indicates an expected call of UpdateAccountPeers. -func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID, reason interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID, reason) } // UpdateAccountSettings mocks base method. diff --git a/management/server/account/pat.go b/management/server/account/pat.go new file mode 100644 index 000000000..8e5e3e3f9 --- /dev/null +++ b/management/server/account/pat.go @@ -0,0 +1,8 @@ +package account + +const ( + // PATMinExpireDays is the minimum allowed Personal Access Token expiration in days. + PATMinExpireDays = 1 + // PATMaxExpireDays is the maximum allowed Personal Access Token expiration in days. + PATMaxExpireDays = 365 +) diff --git a/management/server/account_request_buffer.go b/management/server/account_request_buffer.go index ac53a9fa8..e1672c2d0 100644 --- a/management/server/account_request_buffer.go +++ b/management/server/account_request_buffer.go @@ -63,20 +63,11 @@ func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context, log.WithContext(ctx).Tracef("requesting account %s with backpressure", accountID) startTime := time.Now() + ac.getAccountRequestCh <- req - select { - case <-ctx.Done(): - return nil, ctx.Err() - case ac.getAccountRequestCh <- req: - } - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case result := <-req.ResultChan: - log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime)) - return result.Account, result.Err - } + result := <-req.ResultChan + log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime)) + return result.Account, result.Err } func (ac *AccountRequestBuffer) processGetAccountBatch(ctx context.Context, accountID string) { diff --git a/management/server/account_test.go b/management/server/account_test.go index fdec43617..e259856e3 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -15,7 +15,6 @@ import ( "time" "github.com/golang/mock/gomock" - "github.com/netbirdio/netbird/shared/management/status" "github.com/prometheus/client_golang/prometheus/push" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" @@ -23,6 +22,9 @@ import ( "go.opentelemetry.io/otel/metric/noop" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" + "github.com/netbirdio/netbird/shared/management/status" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" @@ -406,7 +408,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { } customZone := account.GetPeersCustomZone(context.Background(), "netbird.io") - networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) + networkMap := account.GetPeerNetworkMapFromComponents(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) } @@ -1169,11 +1171,6 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"])) } -func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") - testAccountManager_NetworkUpdates_SaveGroup(t) -} - func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { testAccountManager_NetworkUpdates_SaveGroup(t) } @@ -1229,11 +1226,6 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { wg.Wait() } -func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") - testAccountManager_NetworkUpdates_DeletePolicy(t) -} - func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { testAccountManager_NetworkUpdates_DeletePolicy(t) } @@ -1272,11 +1264,6 @@ func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { wg.Wait() } -func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") - testAccountManager_NetworkUpdates_SavePolicy(t) -} - func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { testAccountManager_NetworkUpdates_SavePolicy(t) } @@ -1330,11 +1317,6 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { wg.Wait() } -func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") - testAccountManager_NetworkUpdates_DeletePeer(t) -} - func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { testAccountManager_NetworkUpdates_DeletePeer(t) } @@ -1395,11 +1377,6 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { wg.Wait() } -func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") - testAccountManager_NetworkUpdates_DeleteGroup(t) -} - func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { testAccountManager_NetworkUpdates_DeleteGroup(t) } @@ -1631,75 +1608,6 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) { assert.Contains(t, routeIDs, route.ID("route-2")) } -func TestAccount_GetRoutesToSync(t *testing.T) { - _, prefix, err := route.ParseNetwork("192.168.64.0/24") - if err != nil { - t.Fatal(err) - } - _, prefix2, err := route.ParseNetwork("192.168.0.0/24") - if err != nil { - t.Fatal(err) - } - account := &types.Account{ - Peers: map[string]*nbpeer.Peer{ - "peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, - }, - Groups: map[string]*types.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, - Routes: map[route.ID]*route.Route{ - "route-1": { - ID: "route-1", - Network: prefix, - NetID: "network-1", - Description: "network-1", - Peer: "peer-1", - NetworkType: 0, - Masquerade: false, - Metric: 999, - Enabled: true, - Groups: []string{"group1"}, - }, - "route-2": { - ID: "route-2", - Network: prefix2, - NetID: "network-2", - Description: "network-2", - Peer: "peer-2", - NetworkType: 0, - Masquerade: false, - Metric: 999, - Enabled: true, - Groups: []string{"group1"}, - }, - "route-3": { - ID: "route-3", - Network: prefix, - NetID: "network-1", - Description: "network-1", - Peer: "peer-2", - NetworkType: 0, - Masquerade: false, - Metric: 999, - Enabled: true, - Groups: []string{"group1"}, - }, - }, - } - - routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}, account.GetPeerGroups("peer-2")) - - assert.Len(t, routes, 2) - routeIDs := make(map[route.ID]struct{}, 2) - for _, r := range routes { - routeIDs[r.ID] = struct{}{} - } - assert.Contains(t, routeIDs, route.ID("route-2")) - assert.Contains(t, routeIDs, route.ID("route-3")) - - emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}, account.GetPeerGroups("peer-3")) - - assert.Len(t, emptyRoutes, 0) -} - func TestAccount_Copy(t *testing.T) { account := &types.Account{ Id: "account1", @@ -1815,9 +1723,14 @@ func TestAccount_Copy(t *testing.T) { Targets: []*service.Target{}, }, }, - NetworkMapCache: &types.NetworkMapBuilder{}, + Domains: []*domain.Domain{ + { + ID: "domain1", + Domain: "test.com", + AccountID: "account1", + }, + }, } - account.InitOnce() err := hasNilField(account) if err != nil { t.Fatal(err) @@ -1848,7 +1761,7 @@ func hasNilField(x interface{}) error { if f := rv.Field(i); f.IsValid() { k := f.Kind() switch k { - case reflect.Ptr: + case reflect.Pointer: if f.IsNil() { return fmt.Errorf("field %s is nil", f.String()) } @@ -2302,6 +2215,29 @@ func TestAccount_GetExpiredPeers(t *testing.T) { } } +func TestGetExpiredPeers_SkipsAlreadyExpired(t *testing.T) { + ctx := context.Background() + + testStore, cleanUp, err := store.NewTestStoreFromSQL(ctx, "testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanUp) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + // Verify the already-expired peer is excluded at the store level + peers, err := testStore.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID) + require.NoError(t, err) + + for _, peer := range peers { + assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should be excluded by the store query") + assert.False(t, peer.Status.LoginExpired, "returned peers should not already be marked as login expired") + } + + // Only the non-expired peer with expiration enabled should be returned + require.Len(t, peers, 1) + assert.Equal(t, "notexpired01", peers[0].ID) +} + func TestAccount_GetInactivePeers(t *testing.T) { type test struct { name string @@ -3125,10 +3061,15 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU ctx := context.Background() + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, nil, err + } + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) - manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) if err != nil { return nil, nil, err } @@ -3138,7 +3079,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU if err != nil { return nil, nil, err } - manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, nil)) + manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, proxyManager, nil)) return manager, updateManager, nil } @@ -3216,6 +3157,13 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel. return manager, updateManager, account, peer1, peer2, peer3 } +// peerUpdateTimeout bounds how long peerShouldReceiveUpdate and its outer +// wrappers wait for an expected update message. Sized for slow CI runners +// (MySQL, FreeBSD, loaded sqlite) where the channel publish can take +// seconds. Only runs down on failure; passing tests return immediately +// when the channel delivers. +const peerUpdateTimeout = 5 * time.Second + func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) { t.Helper() select { @@ -3234,7 +3182,7 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.Upd if msg == nil { t.Errorf("Received nil update message, expected valid message") } - case <-time.After(500 * time.Millisecond): + case <-time.After(peerUpdateTimeout): t.Error("Timed out waiting for update message") } } diff --git a/management/server/activity/store/sql_store_idp_migration.go b/management/server/activity/store/sql_store_idp_migration.go new file mode 100644 index 000000000..1b3a9ecd9 --- /dev/null +++ b/management/server/activity/store/sql_store_idp_migration.go @@ -0,0 +1,61 @@ +package store + +// This file contains migration-only methods on Store. +// They satisfy the migration.MigrationEventStore interface via duck typing. +// Delete this file when migration tooling is no longer needed. + +import ( + "context" + "fmt" + + "gorm.io/gorm" + + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/idp/migration" +) + +// CheckSchema verifies that all tables and columns required by the migration exist in the event database. +func (store *Store) CheckSchema(checks []migration.SchemaCheck) []migration.SchemaError { + migrator := store.db.Migrator() + var errs []migration.SchemaError + + for _, check := range checks { + if !migrator.HasTable(check.Table) { + errs = append(errs, migration.SchemaError{Table: check.Table}) + continue + } + for _, col := range check.Columns { + if !migrator.HasColumn(check.Table, col) { + errs = append(errs, migration.SchemaError{Table: check.Table, Column: col}) + } + } + } + + return errs +} + +// UpdateUserID updates all references to oldUserID in events and deleted_users tables. +func (store *Store) UpdateUserID(ctx context.Context, oldUserID, newUserID string) error { + return store.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := tx.Model(&activity.Event{}). + Where("initiator_id = ?", oldUserID). + Update("initiator_id", newUserID).Error; err != nil { + return fmt.Errorf("update events.initiator_id: %w", err) + } + + if err := tx.Model(&activity.Event{}). + Where("target_id = ?", oldUserID). + Update("target_id", newUserID).Error; err != nil { + return fmt.Errorf("update events.target_id: %w", err) + } + + // Raw exec: GORM can't update a PK via Model().Update() + if err := tx.Exec( + "UPDATE deleted_users SET id = ? WHERE id = ?", newUserID, oldUserID, + ).Error; err != nil { + return fmt.Errorf("update deleted_users.id: %w", err) + } + + return nil + }) +} diff --git a/management/server/activity/store/sql_store_idp_migration_test.go b/management/server/activity/store/sql_store_idp_migration_test.go new file mode 100644 index 000000000..98b6e1327 --- /dev/null +++ b/management/server/activity/store/sql_store_idp_migration_test.go @@ -0,0 +1,161 @@ +package store + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/util/crypt" +) + +func TestUpdateUserID(t *testing.T) { + ctx := context.Background() + + newStore := func(t *testing.T) *Store { + t.Helper() + key, _ := crypt.GenerateKey() + s, err := NewSqlStore(ctx, t.TempDir(), key) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { s.Close(ctx) }) //nolint + return s + } + + t.Run("updates initiator_id in events", func(t *testing.T) { + store := newStore(t) + accountID := "account_1" + + _, err := store.Save(ctx, &activity.Event{ + Timestamp: time.Now().UTC(), + Activity: activity.PeerAddedByUser, + InitiatorID: "old-user", + TargetID: "some-peer", + AccountID: accountID, + }) + assert.NoError(t, err) + + err = store.UpdateUserID(ctx, "old-user", "new-user") + assert.NoError(t, err) + + result, err := store.Get(ctx, accountID, 0, 10, false) + assert.NoError(t, err) + assert.Len(t, result, 1) + assert.Equal(t, "new-user", result[0].InitiatorID) + }) + + t.Run("updates target_id in events", func(t *testing.T) { + store := newStore(t) + accountID := "account_1" + + _, err := store.Save(ctx, &activity.Event{ + Timestamp: time.Now().UTC(), + Activity: activity.PeerAddedByUser, + InitiatorID: "some-admin", + TargetID: "old-user", + AccountID: accountID, + }) + assert.NoError(t, err) + + err = store.UpdateUserID(ctx, "old-user", "new-user") + assert.NoError(t, err) + + result, err := store.Get(ctx, accountID, 0, 10, false) + assert.NoError(t, err) + assert.Len(t, result, 1) + assert.Equal(t, "new-user", result[0].TargetID) + }) + + t.Run("updates deleted_users id", func(t *testing.T) { + store := newStore(t) + accountID := "account_1" + + // Save an event with email/name meta to create a deleted_users row for "old-user" + _, err := store.Save(ctx, &activity.Event{ + Timestamp: time.Now().UTC(), + Activity: activity.PeerAddedByUser, + InitiatorID: "admin", + TargetID: "old-user", + AccountID: accountID, + Meta: map[string]any{ + "email": "user@example.com", + "name": "Test User", + }, + }) + assert.NoError(t, err) + + err = store.UpdateUserID(ctx, "old-user", "new-user") + assert.NoError(t, err) + + // Save another event referencing new-user with email/name meta. + // This should upsert (not conflict) because the PK was already migrated. + _, err = store.Save(ctx, &activity.Event{ + Timestamp: time.Now().UTC(), + Activity: activity.PeerAddedByUser, + InitiatorID: "admin", + TargetID: "new-user", + AccountID: accountID, + Meta: map[string]any{ + "email": "user@example.com", + "name": "Test User", + }, + }) + assert.NoError(t, err) + + // The deleted user info should be retrievable via Get (joined on target_id) + result, err := store.Get(ctx, accountID, 0, 10, false) + assert.NoError(t, err) + assert.Len(t, result, 2) + for _, ev := range result { + assert.Equal(t, "new-user", ev.TargetID) + } + }) + + t.Run("no-op when old user ID does not exist", func(t *testing.T) { + store := newStore(t) + + err := store.UpdateUserID(ctx, "nonexistent-user", "new-user") + assert.NoError(t, err) + }) + + t.Run("only updates matching user leaves others unchanged", func(t *testing.T) { + store := newStore(t) + accountID := "account_1" + + _, err := store.Save(ctx, &activity.Event{ + Timestamp: time.Now().UTC(), + Activity: activity.PeerAddedByUser, + InitiatorID: "user-a", + TargetID: "peer-1", + AccountID: accountID, + }) + assert.NoError(t, err) + + _, err = store.Save(ctx, &activity.Event{ + Timestamp: time.Now().UTC(), + Activity: activity.PeerAddedByUser, + InitiatorID: "user-b", + TargetID: "peer-2", + AccountID: accountID, + }) + assert.NoError(t, err) + + err = store.UpdateUserID(ctx, "user-a", "user-a-new") + assert.NoError(t, err) + + result, err := store.Get(ctx, accountID, 0, 10, false) + assert.NoError(t, err) + assert.Len(t, result, 2) + + for _, ev := range result { + if ev.TargetID == "peer-1" { + assert.Equal(t, "user-a-new", ev.InitiatorID) + } else { + assert.Equal(t, "user-b", ev.InitiatorID) + } + } + }) +} diff --git a/management/server/auth/manager.go b/management/server/auth/manager.go index 76cc750b6..27346a604 100644 --- a/management/server/auth/manager.go +++ b/management/server/auth/manager.go @@ -33,15 +33,20 @@ type manager struct { extractor *nbjwt.ClaimsExtractor } -func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool) Manager { - // @note if invalid/missing parameters are sent the validator will instantiate - // but it will fail when validating and parsing the token - jwtValidator := nbjwt.NewValidator( - issuer, - allAudiences, - keysLocation, - idpRefreshKeys, - ) +func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool, keyFetcher nbjwt.KeyFetcher) Manager { + var jwtValidator *nbjwt.Validator + if keyFetcher != nil { + jwtValidator = nbjwt.NewValidatorWithKeyFetcher(issuer, allAudiences, keyFetcher) + } else { + // @note if invalid/missing parameters are sent the validator will instantiate + // but it will fail when validating and parsing the token + jwtValidator = nbjwt.NewValidator( + issuer, + allAudiences, + keysLocation, + idpRefreshKeys, + ) + } claimsExtractor := nbjwt.NewClaimsExtractor( nbjwt.WithAudience(audience), diff --git a/management/server/auth/manager_test.go b/management/server/auth/manager_test.go index b9f091b1e..469737f47 100644 --- a/management/server/auth/manager_test.go +++ b/management/server/auth/manager_test.go @@ -52,7 +52,7 @@ func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } - manager := auth.NewManager(store, "", "", "", "", []string{}, false) + manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil) user, pat, _, _, err := manager.GetPATInfo(context.Background(), token) if err != nil { @@ -92,7 +92,7 @@ func TestAuthManager_MarkPATUsed(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } - manager := auth.NewManager(store, "", "", "", "", []string{}, false) + manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil) err = manager.MarkPATUsed(context.Background(), "tokenId") if err != nil { @@ -142,7 +142,7 @@ func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) { // these tests only assert groups are parsed from token as per account settings token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}}) - manager := auth.NewManager(store, "", "", "", "", []string{}, false) + manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil) t.Run("JWT groups disabled", func(t *testing.T) { userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token) @@ -225,7 +225,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) { keyId := "test-key" // note, we can use a nil store because ValidateAndParseToken does not use it in it's flow - manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false) + manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false, nil) customClaim := func(name string) string { return fmt.Sprintf("%s/%s", audience, name) diff --git a/management/server/auth/session.go b/management/server/auth/session.go new file mode 100644 index 000000000..7621a1c10 --- /dev/null +++ b/management/server/auth/session.go @@ -0,0 +1,61 @@ +package auth + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "time" + + "github.com/eko/gocache/lib/v4/cache" + "github.com/eko/gocache/lib/v4/store" +) + +const ( + usedTokenKeyPrefix = "jwt-used:" + usedTokenMarker = "1" +) + +var ( + ErrTokenAlreadyUsed = errors.New("JWT already used") + ErrTokenExpired = errors.New("JWT expired") +) + +type SessionStore struct { + cache *cache.Cache[string] +} + +func NewSessionStore(cacheStore store.StoreInterface) *SessionStore { + return &SessionStore{cache: cache.New[string](cacheStore)} +} + +// RegisterToken records a JWT until its exp time and rejects reuse. +func (s *SessionStore) RegisterToken(ctx context.Context, token string, expiresAt time.Time) error { + ttl := time.Until(expiresAt) + if ttl <= 0 { + return ErrTokenExpired + } + + key := usedTokenKeyPrefix + hashToken(token) + _, err := s.cache.Get(ctx, key) + if err == nil { + return ErrTokenAlreadyUsed + } + + var notFound *store.NotFound + if !errors.As(err, ¬Found) { + return fmt.Errorf("failed to lookup used token entry: %w", err) + } + + if err := s.cache.Set(ctx, key, usedTokenMarker, store.WithExpiration(ttl)); err != nil { + return fmt.Errorf("failed to store used token entry: %w", err) + } + + return nil +} + +func hashToken(token string) string { + sum := sha256.Sum256([]byte(token)) + return hex.EncodeToString(sum[:]) +} diff --git a/management/server/auth/session_test.go b/management/server/auth/session_test.go new file mode 100644 index 000000000..3a7d85f4c --- /dev/null +++ b/management/server/auth/session_test.go @@ -0,0 +1,82 @@ +package auth + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbcache "github.com/netbirdio/netbird/management/server/cache" +) + +func newTestSessionStore(t *testing.T) *SessionStore { + t.Helper() + cacheStore, err := nbcache.NewStore(context.Background(), time.Hour, time.Hour, 100) + require.NoError(t, err) + return NewSessionStore(cacheStore) +} + +func TestSessionStore_FirstRegisterSucceeds(t *testing.T) { + s := newTestSessionStore(t) + ctx := context.Background() + + require.NoError(t, s.RegisterToken(ctx, "token", time.Now().Add(time.Hour))) +} + +func TestSessionStore_RegisterSameTokenTwiceIsRejected(t *testing.T) { + s := newTestSessionStore(t) + ctx := context.Background() + token := "token" + exp := time.Now().Add(time.Hour) + + require.NoError(t, s.RegisterToken(ctx, token, exp)) + + err := s.RegisterToken(ctx, token, exp) + require.Error(t, err) + assert.ErrorIs(t, err, ErrTokenAlreadyUsed) +} + +func TestSessionStore_RegisterDifferentTokensAreIndependent(t *testing.T) { + s := newTestSessionStore(t) + ctx := context.Background() + exp := time.Now().Add(time.Hour) + + require.NoError(t, s.RegisterToken(ctx, "tokenA", exp)) + require.NoError(t, s.RegisterToken(ctx, "tokenB", exp)) +} + +func TestSessionStore_RegisterWithPastExpiryIsRejected(t *testing.T) { + s := newTestSessionStore(t) + ctx := context.Background() + token := "token" + + err := s.RegisterToken(ctx, token, time.Now().Add(-time.Second)) + require.Error(t, err) + assert.ErrorIs(t, err, ErrTokenExpired) +} + +func TestSessionStore_EntryEvictsAtTTLAndAllowsReRegistration(t *testing.T) { + s := newTestSessionStore(t) + ctx := context.Background() + token := "token" + + require.NoError(t, s.RegisterToken(ctx, token, time.Now().Add(50*time.Millisecond))) + + err := s.RegisterToken(ctx, token, time.Now().Add(50*time.Millisecond)) + assert.ErrorIs(t, err, ErrTokenAlreadyUsed) + + time.Sleep(120 * time.Millisecond) + + require.NoError(t, s.RegisterToken(ctx, token, time.Now().Add(time.Hour))) +} + +func TestHashToken_StableAndDoesNotLeak(t *testing.T) { + a := hashToken("tokenA") + b := hashToken("tokenB") + assert.Equal(t, a, hashToken("tokenA"), "hash must be deterministic") + assert.NotEqual(t, a, b, "different tokens must hash differently") + assert.Len(t, a, 64, "sha256 hex must be 64 chars") + assert.NotContains(t, a, "tokenA", "raw token must not appear in hash") +} diff --git a/management/server/cache/store.go b/management/server/cache/store.go index 54b0242de..2ca8e8603 100644 --- a/management/server/cache/store.go +++ b/management/server/cache/store.go @@ -17,12 +17,24 @@ import ( // RedisStoreEnvVar is the environment variable that determines if a redis store should be used. // The value should follow redis URL format. https://github.com/redis/redis-specifications/blob/master/uri/redis.txt -const RedisStoreEnvVar = "NB_IDP_CACHE_REDIS_ADDRESS" +const RedisStoreEnvVar = "NB_CACHE_REDIS_ADDRESS" + +// legacyIdPCacheRedisEnvVar is the previous environment variable used for IDP cache. +const legacyIdPCacheRedisEnvVar = "NB_IDP_CACHE_REDIS_ADDRESS" + +const ( + // DefaultStoreMaxTimeout is the default max timeout for the shared cache store. + DefaultStoreMaxTimeout = 7 * 24 * time.Hour + // DefaultStoreCleanupInterval is the default cleanup interval for the shared cache store. + DefaultStoreCleanupInterval = 30 * time.Minute + // DefaultStoreMaxConn is the default max connections for the shared cache store. + DefaultStoreMaxConn = 1000 +) // NewStore creates a new cache store with the given max timeout and cleanup interval. It checks for the environment Variable RedisStoreEnvVar // to determine if a redis store should be used. If the environment variable is set, it will attempt to connect to the redis store. func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (store.StoreInterface, error) { - redisAddr := os.Getenv(RedisStoreEnvVar) + redisAddr := GetAddrFromEnv() if redisAddr != "" { return getRedisStore(ctx, redisAddr, maxConn) } @@ -30,6 +42,15 @@ func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, ma return gocache_store.NewGoCache(goc), nil } +// GetAddrFromEnv returns the redis address from the environment variable RedisStoreEnvVar or its legacy counterpart. +func GetAddrFromEnv() string { + addr := os.Getenv(RedisStoreEnvVar) + if addr == "" { + addr = os.Getenv(legacyIdPCacheRedisEnvVar) + } + return addr +} + func getRedisStore(ctx context.Context, redisEnvAddr string, maxConn int) (store.StoreInterface, error) { options, err := redis.ParseURL(redisEnvAddr) if err != nil { diff --git a/management/server/dns.go b/management/server/dns.go index baf6debc3..c62fa5185 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -86,7 +86,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceDNSSettings, Operation: types.UpdateOperationUpdate}) } return nil diff --git a/management/server/dns_test.go b/management/server/dns_test.go index bd0755d0d..c443223c6 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/peers" ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/netbird/management/server/permissions" @@ -225,11 +226,17 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { peersManager := peers.NewManager(store, permissionsManager) ctx := context.Background() + + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, err + } + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) - return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) } func createDNSStore(t *testing.T) (store.Store, error) { @@ -451,7 +458,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -471,7 +478,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -511,7 +518,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) diff --git a/management/server/geolocation/geolocation.go b/management/server/geolocation/geolocation.go index 30fd493e8..0af3ce2f6 100644 --- a/management/server/geolocation/geolocation.go +++ b/management/server/geolocation/geolocation.go @@ -130,6 +130,10 @@ func (gl *geolocationImpl) Lookup(ip net.IP) (*Record, error) { gl.mux.RLock() defer gl.mux.RUnlock() + if gl.db == nil { + return nil, fmt.Errorf("geolocation database is not available") + } + var record Record err := gl.db.Lookup(ip, &record) if err != nil { @@ -173,8 +177,14 @@ func (gl *geolocationImpl) GetCitiesByCountry(countryISOCode string) ([]City, er func (gl *geolocationImpl) Stop() error { close(gl.stopCh) - if gl.db != nil { - if err := gl.db.Close(); err != nil { + + gl.mux.Lock() + db := gl.db + gl.db = nil + gl.mux.Unlock() + + if db != nil { + if err := db.Close(); err != nil { return err } } diff --git a/management/server/group.go b/management/server/group.go index 326b167cf..e1d05171e 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -61,7 +61,10 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us } // GetGroupByName filters all groups in an account by name and returns the one with the most peers -func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) { +func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { + if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { + return nil, err + } return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) } @@ -114,7 +117,7 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate}) } return nil @@ -182,7 +185,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate}) } return nil @@ -250,7 +253,7 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate}) } return globalErr @@ -318,7 +321,7 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate}) } return globalErr @@ -490,7 +493,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate}) } return nil @@ -528,7 +531,7 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate}) } return nil @@ -556,7 +559,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate}) } return nil @@ -594,7 +597,7 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate}) } return nil diff --git a/management/server/group_test.go b/management/server/group_test.go index fa818e532..5821b90a3 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -620,7 +620,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -638,7 +638,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -656,7 +656,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -689,7 +689,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -730,7 +730,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -757,7 +757,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -804,7 +804,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) diff --git a/management/server/http/handler.go b/management/server/http/handler.go index ad36b9d46..b9ea605d3 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -5,9 +5,6 @@ import ( "fmt" "net/http" "net/netip" - "os" - "strconv" - "time" "github.com/gorilla/mux" "github.com/rs/cors" @@ -65,15 +62,10 @@ import ( "github.com/netbirdio/netbird/management/server/telemetry" ) -const ( - apiPrefix = "/api" - rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED" - rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST" - rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM" -) +const apiPrefix = "/api" // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. -func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) { +func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) { // Register bypass paths for unauthenticated endpoints if err := bypass.AddBypassPath("/api/instance"); err != nil { @@ -94,34 +86,10 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks return nil, fmt.Errorf("failed to add bypass path: %w", err) } - var rateLimitingConfig *middleware.RateLimiterConfig - if os.Getenv(rateLimitingEnabledKey) == "true" { - rpm := 6 - if v := os.Getenv(rateLimitingRPMKey); v != "" { - value, err := strconv.Atoi(v) - if err != nil { - log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm) - } else { - rpm = value - } - } - - burst := 500 - if v := os.Getenv(rateLimitingBurstKey); v != "" { - value, err := strconv.Atoi(v) - if err != nil { - log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst) - } else { - burst = value - } - } - - rateLimitingConfig = &middleware.RateLimiterConfig{ - RequestsPerMinute: float64(rpm), - Burst: burst, - CleanupInterval: 6 * time.Hour, - LimiterTTL: 24 * time.Hour, - } + if rateLimiter == nil { + log.Warn("NewAPIHandler: nil rate limiter, rate limiting disabled") + rateLimiter = middleware.NewAPIRateLimiter(nil) + rateLimiter.SetEnabled(false) } authMiddleware := middleware.NewAuthMiddleware( @@ -129,7 +97,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks accountManager.GetAccountIDFromUserAuth, accountManager.SyncUserJWTGroups, accountManager.GetUserFromUserAuth, - rateLimitingConfig, + rateLimiter, appMetrics.GetMeter(), ) @@ -171,7 +139,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks zonesManager.RegisterEndpoints(router, zManager) recordsManager.RegisterEndpoints(router, rManager) idp.AddEndpoints(accountManager, router) - instance.AddEndpoints(instanceManager, router) + instance.AddEndpoints(instanceManager, accountManager, router) instance.AddVersionEndpoint(instanceManager, router) if serviceManager != nil && reverseProxyDomainManager != nil { reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router) diff --git a/management/server/http/handlers/groups/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go index 56ccc9d0b..f8d161a87 100644 --- a/management/server/http/handlers/groups/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -52,7 +52,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) { groupName := r.URL.Query().Get("name") if groupName != "" { // Get single group by name - group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID) + group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -118,7 +118,7 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { return } - allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID) + allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/groups/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go index 458a15c11..c7b4cbcdd 100644 --- a/management/server/http/handlers/groups/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -71,7 +71,7 @@ func initGroupTestData(initGroups ...*types.Group) *handler { return groups, nil }, - GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) { + GetGroupByNameFunc: func(ctx context.Context, groupName, _, _ string) (*types.Group, error) { if groupName == "All" { return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil } diff --git a/management/server/http/handlers/instance/instance_handler.go b/management/server/http/handlers/instance/instance_handler.go index cd9fae6b8..e98ce4d7c 100644 --- a/management/server/http/handlers/instance/instance_handler.go +++ b/management/server/http/handlers/instance/instance_handler.go @@ -7,6 +7,7 @@ import ( "github.com/gorilla/mux" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/account" nbinstance "github.com/netbirdio/netbird/management/server/instance" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" @@ -15,13 +16,15 @@ import ( // handler handles the instance setup HTTP endpoints type handler struct { instanceManager nbinstance.Manager + setupManager *nbinstance.SetupService } // AddEndpoints registers the instance setup endpoints. // These endpoints bypass authentication for initial setup. -func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) { +func AddEndpoints(instanceManager nbinstance.Manager, accountManager account.Manager, router *mux.Router) { h := &handler{ instanceManager: instanceManager, + setupManager: nbinstance.NewSetupService(instanceManager, accountManager), } router.HandleFunc("/instance", h.getInstanceStatus).Methods("GET", "OPTIONS") @@ -55,24 +58,36 @@ func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) { // setup creates the initial admin user for the instance. // This endpoint is unauthenticated but only works when setup is required. func (h *handler) setup(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + var req api.SetupRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("invalid request body", http.StatusBadRequest, w) return } - userData, err := h.instanceManager.CreateOwnerUser(r.Context(), req.Email, req.Password, req.Name) + result, err := h.setupManager.SetupOwner(ctx, req.Email, req.Password, req.Name, nbinstance.SetupOptions{ + CreatePAT: req.CreatePat != nil && *req.CreatePat, + PATExpireInDays: req.PatExpireIn, + }) if err != nil { - util.WriteError(r.Context(), err, w) + util.WriteError(ctx, err, w) return } - log.WithContext(r.Context()).Infof("instance setup completed: created user %s", req.Email) + log.WithContext(ctx).Infof("instance setup completed: created user %s", req.Email) - util.WriteJSONObject(r.Context(), w, api.SetupResponse{ - UserId: userData.ID, - Email: userData.Email, - }) + resp := api.SetupResponse{ + UserId: result.User.ID, + Email: result.User.Email, + } + + if result.PATPlainToken != "" { + resp.PersonalAccessToken = &result.PATPlainToken + } + + w.Header().Set("Cache-Control", "no-store") + util.WriteJSONObject(ctx, w, resp) } // getVersionInfo returns version information for NetBird components. diff --git a/management/server/http/handlers/instance/instance_handler_test.go b/management/server/http/handlers/instance/instance_handler_test.go index 470079c85..711e01964 100644 --- a/management/server/http/handlers/instance/instance_handler_test.go +++ b/management/server/http/handlers/instance/instance_handler_test.go @@ -10,12 +10,18 @@ import ( "net/mail" "testing" + "github.com/golang/mock/gomock" "github.com/gorilla/mux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/idp" nbinstance "github.com/netbirdio/netbird/management/server/instance" + "github.com/netbirdio/netbird/management/server/mock_server" + nbstore "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" ) @@ -25,6 +31,7 @@ type mockInstanceManager struct { isSetupRequired bool isSetupRequiredFn func(ctx context.Context) (bool, error) createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error) + rollbackSetupFn func(ctx context.Context, userID string) error getVersionInfoFn func(ctx context.Context) (*nbinstance.VersionInfo, error) } @@ -67,6 +74,13 @@ func (m *mockInstanceManager) CreateOwnerUser(ctx context.Context, email, passwo }, nil } +func (m *mockInstanceManager) RollbackSetup(ctx context.Context, userID string) error { + if m.rollbackSetupFn != nil { + return m.rollbackSetupFn(ctx, userID) + } + return nil +} + func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.VersionInfo, error) { if m.getVersionInfoFn != nil { return m.getVersionInfoFn(ctx) @@ -82,8 +96,12 @@ func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.V var _ nbinstance.Manager = (*mockInstanceManager)(nil) func setupTestRouter(manager nbinstance.Manager) *mux.Router { + return setupTestRouterWithPAT(manager, nil) +} + +func setupTestRouterWithPAT(manager nbinstance.Manager, accountManager account.Manager) *mux.Router { router := mux.NewRouter() - AddEndpoints(manager, router) + AddEndpoints(manager, accountManager, router) return router } @@ -161,6 +179,7 @@ func TestSetup_Success(t *testing.T) { router.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "no-store", rec.Header().Get("Cache-Control")) var response api.SetupResponse err := json.NewDecoder(rec.Body).Decode(&response) @@ -293,6 +312,239 @@ func TestSetup_ManagerError(t *testing.T) { assert.Equal(t, http.StatusInternalServerError, rec.Code) } +func TestSetup_PAT_FeatureDisabled_IgnoresCreatePAT(t *testing.T) { + t.Setenv(nbinstance.SetupPATEnabledEnvKey, "false") + + manager := &mockInstanceManager{isSetupRequired: true} + // NB_SETUP_PAT_ENABLED=false: request fields must be silently ignored + router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{}) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var response api.SetupResponse + require.NoError(t, json.NewDecoder(rec.Body).Decode(&response)) + assert.Nil(t, response.PersonalAccessToken) +} + +func TestSetup_PAT_FlagOmitted_NoPAT(t *testing.T) { + t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true") + + manager := &mockInstanceManager{isSetupRequired: true} + router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{}) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin"}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var response api.SetupResponse + require.NoError(t, json.NewDecoder(rec.Body).Decode(&response)) + assert.Nil(t, response.PersonalAccessToken) +} + +func TestSetup_PAT_MissingExpireIn_DefaultsToOneDay(t *testing.T) { + t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true") + + createCalled := false + manager := &mockInstanceManager{ + isSetupRequired: true, + createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) { + createCalled = true + return &idp.UserData{ID: "u1", Email: email, Name: name}, nil + }, + } + accountMgr := &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) { + assert.Equal(t, "u1", userAuth.UserId) + return "acc-1", nil + }, + CreatePATFunc: func(_ context.Context, accountID, initiator, target, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { + assert.Equal(t, "acc-1", accountID) + assert.Equal(t, "u1", initiator) + assert.Equal(t, "u1", target) + assert.Equal(t, "setup-token", name) + assert.Equal(t, 1, expiresIn) + return &types.PersonalAccessTokenGenerated{PlainToken: "nbp_plain"}, nil + }, + } + router := setupTestRouterWithPAT(manager, accountMgr) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "no-store", rec.Header().Get("Cache-Control")) + assert.True(t, createCalled) + var response api.SetupResponse + require.NoError(t, json.NewDecoder(rec.Body).Decode(&response)) + require.NotNil(t, response.PersonalAccessToken) + assert.Equal(t, "nbp_plain", *response.PersonalAccessToken) +} + +func TestSetup_PAT_ExpireOutOfRange(t *testing.T) { + t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true") + + manager := &mockInstanceManager{isSetupRequired: true} + router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{}) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 0}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusUnprocessableEntity, rec.Code) +} + +func TestSetup_PAT_Success(t *testing.T) { + t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true") + + manager := &mockInstanceManager{ + isSetupRequired: true, + createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) { + return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil + }, + } + + gotAccountArgs := struct { + userID string + email string + }{} + accountMgr := &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) { + gotAccountArgs.userID = userAuth.UserId + gotAccountArgs.email = userAuth.Email + return "acc-1", nil + }, + CreatePATFunc: func(_ context.Context, accountID, initiator, target, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { + assert.Equal(t, "acc-1", accountID) + assert.Equal(t, "owner-id", initiator) + assert.Equal(t, "owner-id", target) + assert.Equal(t, "setup-token", name) + assert.Equal(t, 30, expiresIn) + return &types.PersonalAccessTokenGenerated{PlainToken: "nbp_plain"}, nil + }, + } + + router := setupTestRouterWithPAT(manager, accountMgr) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "no-store", rec.Header().Get("Cache-Control")) + var response api.SetupResponse + require.NoError(t, json.NewDecoder(rec.Body).Decode(&response)) + assert.Equal(t, "owner-id", response.UserId) + require.NotNil(t, response.PersonalAccessToken) + assert.Equal(t, "nbp_plain", *response.PersonalAccessToken) + assert.Equal(t, "owner-id", gotAccountArgs.userID) + assert.Equal(t, "admin@example.com", gotAccountArgs.email) +} + +func TestSetup_PAT_AccountCreationFails_Rollback(t *testing.T) { + t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true") + + ctrl := gomock.NewController(t) + accountStore := nbstore.NewMockStore(ctrl) + accountStore.EXPECT().GetAccountIDByUserID(gomock.Any(), nbstore.LockingStrengthNone, "owner-id").Return("", status.NewAccountNotFoundError("owner-id")) + + rolledBackFor := "" + manager := &mockInstanceManager{ + isSetupRequired: true, + createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) { + return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil + }, + rollbackSetupFn: func(_ context.Context, userID string) error { + rolledBackFor = userID + return nil + }, + } + accountMgr := &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) { + return "", errors.New("db down") + }, + GetStoreFunc: func() nbstore.Store { + return accountStore + }, + } + + router := setupTestRouterWithPAT(manager, accountMgr) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Equal(t, "owner-id", rolledBackFor, "RollbackSetup must be called with the created user id") +} + +func TestSetup_PAT_CreatePATFails_Rollback(t *testing.T) { + t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true") + + ctrl := gomock.NewController(t) + accountStore := nbstore.NewMockStore(ctrl) + account := &types.Account{Id: "acc-1"} + accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil) + accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(nil) + + rolledBackFor := "" + manager := &mockInstanceManager{ + isSetupRequired: true, + createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) { + return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil + }, + rollbackSetupFn: func(_ context.Context, userID string) error { + rolledBackFor = userID + return nil + }, + } + accountMgr := &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) { + return "acc-1", nil + }, + CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) { + return nil, status.Errorf(status.Internal, "token store unavailable") + }, + GetStoreFunc: func() nbstore.Store { + return accountStore + }, + } + + router := setupTestRouterWithPAT(manager, accountMgr) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Equal(t, "owner-id", rolledBackFor, "RollbackSetup must be called when CreatePAT fails") +} + func TestGetVersionInfo_Success(t *testing.T) { manager := &mockInstanceManager{} router := mux.NewRouter() diff --git a/management/server/http/handlers/networks/routers_handler.go b/management/server/http/handlers/networks/routers_handler.go index c311a29fe..ce9efb78d 100644 --- a/management/server/http/handlers/networks/routers_handler.go +++ b/management/server/http/handlers/networks/routers_handler.go @@ -105,6 +105,12 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) { router.NetworkID = networkID router.AccountID = accountID router.Enabled = true + + if err := router.Validate(); err != nil { + util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w) + return + } + router, err = h.routersManager.CreateRouter(r.Context(), userID, router) if err != nil { util.WriteError(r.Context(), err, w) @@ -157,6 +163,11 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) { router.ID = mux.Vars(r)["routerId"] router.AccountID = accountID + if err := router.Validate(); err != nil { + util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w) + return + } + router, err = h.routersManager.UpdateRouter(r.Context(), userID, router) if err != nil { util.WriteError(r.Context(), err, w) diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 6b9a69f04..bf6937a49 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -417,7 +417,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.networkMapController.GetDNSDomain(account.Settings) - netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) + netMap := account.GetPeerNetworkMapFromComponents(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) } diff --git a/management/server/http/handlers/proxy/auth_callback_integration_test.go b/management/server/http/handlers/proxy/auth_callback_integration_test.go index 922bf4352..c99acab63 100644 --- a/management/server/http/handlers/proxy/auth_callback_integration_test.go +++ b/management/server/http/handlers/proxy/auth_callback_integration_test.go @@ -22,6 +22,7 @@ import ( nbproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" @@ -191,11 +192,11 @@ func setupAuthCallbackTest(t *testing.T) *testSetup { oidcServer := newFakeOIDCServer() - tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100) + cacheStore, err := nbcache.NewStore(ctx, 30*time.Minute, 10*time.Minute, 100) require.NoError(t, err) - pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore) + pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore) usersManager := users.NewManager(testStore) diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 63be672e6..6d075d9c2 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -12,6 +12,7 @@ import ( "go.opentelemetry.io/otel/metric" "github.com/netbirdio/management-integrations/integrations" + serverauth "github.com/netbirdio/netbird/management/server/auth" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" @@ -42,14 +43,9 @@ func NewAuthMiddleware( ensureAccount EnsureAccountFunc, syncUserJWTGroups SyncUserJWTGroupsFunc, getUserFromUserAuth GetUserFromUserAuthFunc, - rateLimiterConfig *RateLimiterConfig, + rateLimiter *APIRateLimiter, meter metric.Meter, ) *AuthMiddleware { - var rateLimiter *APIRateLimiter - if rateLimiterConfig != nil { - rateLimiter = NewAPIRateLimiter(rateLimiterConfig) - } - var patUsageTracker *PATUsageTracker if meter != nil { var err error @@ -87,17 +83,14 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { switch authType { case "bearer": - request, err := m.checkJWTFromRequest(r, authHeader) - if err != nil { + if err := m.checkJWTFromRequest(r, authHeader); err != nil { log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error()) util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) return } - - h.ServeHTTP(w, request) + h.ServeHTTP(w, r) case "token": - request, err := m.checkPATFromRequest(r, authHeader) - if err != nil { + if err := m.checkPATFromRequest(r, authHeader); err != nil { log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error()) // Check if it's a status error, otherwise default to Unauthorized if _, ok := status.FromError(err); !ok { @@ -106,7 +99,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { util.WriteError(r.Context(), err, w) return } - h.ServeHTTP(w, request) + h.ServeHTTP(w, r) default: util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w) return @@ -115,19 +108,19 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { } // CheckJWTFromRequest checks if the JWT is valid -func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) { +func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) error { token, err := getTokenFromJWTRequest(authHeaderParts) // If an error occurs, call the error handler and return an error if err != nil { - return r, fmt.Errorf("error extracting token: %w", err) + return fmt.Errorf("error extracting token: %w", err) } ctx := r.Context() userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token) if err != nil { - return r, err + return err } if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 { @@ -143,7 +136,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts [] // we need to call this method because if user is new, we will automatically add it to existing or create a new account accountId, _, err := m.ensureAccount(ctx, userAuth) if err != nil { - return r, err + return err } if userAuth.AccountId != accountId { @@ -153,7 +146,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts [] userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken) if err != nil { - return r, err + return err } err = m.syncUserJWTGroups(ctx, userAuth) @@ -164,41 +157,41 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts [] _, err = m.getUserFromUserAuth(ctx, userAuth) if err != nil { log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err) - return r, err + return err } - return nbcontext.SetUserAuthInRequest(r, userAuth), nil + // propagates ctx change to upstream middleware + *r = *nbcontext.SetUserAuthInRequest(r, userAuth) + return nil } // CheckPATFromRequest checks if the PAT is valid -func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) { +func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) error { token, err := getTokenFromPATRequest(authHeaderParts) if err != nil { - return r, fmt.Errorf("error extracting token: %w", err) + return fmt.Errorf("error extracting token: %w", err) } if m.patUsageTracker != nil { m.patUsageTracker.IncrementUsage(token) } - if m.rateLimiter != nil && !isTerraformRequest(r) { - if !m.rateLimiter.Allow(token) { - return r, status.Errorf(status.TooManyRequests, "too many requests") - } + if !isTerraformRequest(r) && !m.rateLimiter.Allow(token) { + return status.Errorf(status.TooManyRequests, "too many requests") } ctx := r.Context() user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token) if err != nil { - return r, fmt.Errorf("invalid Token: %w", err) + return fmt.Errorf("invalid Token: %w", err) } if time.Now().After(pat.GetExpirationDate()) { - return r, fmt.Errorf("token expired") + return fmt.Errorf("token expired") } err = m.authManager.MarkPATUsed(ctx, pat.ID) if err != nil { - return r, err + return err } userAuth := auth.UserAuth{ @@ -216,7 +209,9 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts [] } } - return nbcontext.SetUserAuthInRequest(r, userAuth), nil + // propagates ctx change to upstream middleware + *r = *nbcontext.SetUserAuthInRequest(r, userAuth) + return nil } func isTerraformRequest(r *http.Request) bool { diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index f397c63a4..8f736fbfd 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -196,6 +196,8 @@ func TestAuthMiddleware_Handler(t *testing.T) { GetPATInfoFunc: mockGetAccountInfoFromPAT, } + disabledLimiter := NewAPIRateLimiter(nil) + disabledLimiter.SetEnabled(false) authMiddleware := NewAuthMiddleware( mockAuth, func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { @@ -207,7 +209,7 @@ func TestAuthMiddleware_Handler(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - nil, + disabledLimiter, nil, ) @@ -266,7 +268,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -318,7 +320,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -361,7 +363,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -405,7 +407,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -469,7 +471,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -528,7 +530,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -583,7 +585,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -670,6 +672,8 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { GetPATInfoFunc: mockGetAccountInfoFromPAT, } + disabledLimiter := NewAPIRateLimiter(nil) + disabledLimiter.SetEnabled(false) authMiddleware := NewAuthMiddleware( mockAuth, func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { @@ -681,7 +685,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - nil, + disabledLimiter, nil, ) diff --git a/management/server/http/middleware/rate_limiter.go b/management/server/http/middleware/rate_limiter.go index 936b34319..bfd44afee 100644 --- a/management/server/http/middleware/rate_limiter.go +++ b/management/server/http/middleware/rate_limiter.go @@ -4,14 +4,27 @@ import ( "context" "net" "net/http" + "os" + "strconv" "sync" + "sync/atomic" "time" + log "github.com/sirupsen/logrus" "golang.org/x/time/rate" "github.com/netbirdio/netbird/shared/management/http/util" ) +const ( + RateLimitingEnabledEnv = "NB_API_RATE_LIMITING_ENABLED" + RateLimitingBurstEnv = "NB_API_RATE_LIMITING_BURST" + RateLimitingRPMEnv = "NB_API_RATE_LIMITING_RPM" + + defaultAPIRPM = 6 + defaultAPIBurst = 500 +) + // RateLimiterConfig holds configuration for the API rate limiter type RateLimiterConfig struct { // RequestsPerMinute defines the rate at which tokens are replenished @@ -34,6 +47,43 @@ func DefaultRateLimiterConfig() *RateLimiterConfig { } } +func RateLimiterConfigFromEnv() (cfg *RateLimiterConfig, enabled bool) { + rpm := defaultAPIRPM + if v := os.Getenv(RateLimitingRPMEnv); v != "" { + value, err := strconv.Atoi(v) + if err != nil { + log.Warnf("parsing %s env var: %v, using default %d", RateLimitingRPMEnv, err, rpm) + } else { + rpm = value + } + } + if rpm <= 0 { + log.Warnf("%s=%d is non-positive, using default %d", RateLimitingRPMEnv, rpm, defaultAPIRPM) + rpm = defaultAPIRPM + } + + burst := defaultAPIBurst + if v := os.Getenv(RateLimitingBurstEnv); v != "" { + value, err := strconv.Atoi(v) + if err != nil { + log.Warnf("parsing %s env var: %v, using default %d", RateLimitingBurstEnv, err, burst) + } else { + burst = value + } + } + if burst <= 0 { + log.Warnf("%s=%d is non-positive, using default %d", RateLimitingBurstEnv, burst, defaultAPIBurst) + burst = defaultAPIBurst + } + + return &RateLimiterConfig{ + RequestsPerMinute: float64(rpm), + Burst: burst, + CleanupInterval: 6 * time.Hour, + LimiterTTL: 24 * time.Hour, + }, os.Getenv(RateLimitingEnabledEnv) == "true" +} + // limiterEntry holds a rate limiter and its last access time type limiterEntry struct { limiter *rate.Limiter @@ -46,6 +96,7 @@ type APIRateLimiter struct { limiters map[string]*limiterEntry mu sync.RWMutex stopChan chan struct{} + enabled atomic.Bool } // NewAPIRateLimiter creates a new API rate limiter with the given configuration @@ -59,14 +110,53 @@ func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter { limiters: make(map[string]*limiterEntry), stopChan: make(chan struct{}), } + rl.enabled.Store(true) go rl.cleanupLoop() return rl } +func (rl *APIRateLimiter) SetEnabled(enabled bool) { + rl.enabled.Store(enabled) +} + +func (rl *APIRateLimiter) Enabled() bool { + return rl.enabled.Load() +} + +func (rl *APIRateLimiter) UpdateConfig(config *RateLimiterConfig) { + if config == nil { + return + } + if config.RequestsPerMinute <= 0 || config.Burst <= 0 { + log.Warnf("UpdateConfig: ignoring invalid rpm=%v burst=%d", config.RequestsPerMinute, config.Burst) + return + } + + newRPS := rate.Limit(config.RequestsPerMinute / 60.0) + newBurst := config.Burst + + rl.mu.Lock() + rl.config.RequestsPerMinute = config.RequestsPerMinute + rl.config.Burst = newBurst + snapshot := make([]*rate.Limiter, 0, len(rl.limiters)) + for _, entry := range rl.limiters { + snapshot = append(snapshot, entry.limiter) + } + rl.mu.Unlock() + + for _, l := range snapshot { + l.SetLimit(newRPS) + l.SetBurst(newBurst) + } +} + // Allow checks if a request for the given key (token) is allowed func (rl *APIRateLimiter) Allow(key string) bool { + if !rl.enabled.Load() { + return true + } limiter := rl.getLimiter(key) return limiter.Allow() } @@ -74,6 +164,9 @@ func (rl *APIRateLimiter) Allow(key string) bool { // Wait blocks until the rate limiter allows another request for the given key // Returns an error if the context is canceled func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error { + if !rl.enabled.Load() { + return nil + } limiter := rl.getLimiter(key) return limiter.Wait(ctx) } @@ -153,6 +246,10 @@ func (rl *APIRateLimiter) Reset(key string) { // Returns 429 Too Many Requests if the rate limit is exceeded. func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !rl.enabled.Load() { + next.ServeHTTP(w, r) + return + } clientIP := getClientIP(r) if !rl.Allow(clientIP) { util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w) diff --git a/management/server/http/middleware/rate_limiter_test.go b/management/server/http/middleware/rate_limiter_test.go index 68f804e57..4b97d1874 100644 --- a/management/server/http/middleware/rate_limiter_test.go +++ b/management/server/http/middleware/rate_limiter_test.go @@ -1,8 +1,10 @@ package middleware import ( + "fmt" "net/http" "net/http/httptest" + "sync" "testing" "time" @@ -156,3 +158,172 @@ func TestAPIRateLimiter_Reset(t *testing.T) { // Should be allowed again assert.True(t, rl.Allow("test-key")) } + +func TestAPIRateLimiter_SetEnabled(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 1, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + assert.True(t, rl.Allow("key")) + assert.False(t, rl.Allow("key"), "burst exhausted while enabled") + + rl.SetEnabled(false) + assert.False(t, rl.Enabled()) + for i := 0; i < 5; i++ { + assert.True(t, rl.Allow("key"), "disabled limiter must always allow") + } + + rl.SetEnabled(true) + assert.True(t, rl.Enabled()) + assert.False(t, rl.Allow("key"), "re-enabled limiter retains prior bucket state") +} + +func TestAPIRateLimiter_UpdateConfig(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 2, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + assert.True(t, rl.Allow("k1")) + assert.True(t, rl.Allow("k1")) + assert.False(t, rl.Allow("k1"), "burst=2 exhausted") + + rl.UpdateConfig(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 10, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + + // New burst applies to existing keys in place; bucket refills up to new burst over time, + // but importantly newly-added keys use the updated config immediately. + assert.True(t, rl.Allow("k2")) + for i := 0; i < 9; i++ { + assert.True(t, rl.Allow("k2")) + } + assert.False(t, rl.Allow("k2"), "new burst=10 exhausted") +} + +func TestAPIRateLimiter_UpdateConfig_NilIgnored(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 1, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + rl.UpdateConfig(nil) // must not panic or zero the config + + assert.True(t, rl.Allow("k")) + assert.False(t, rl.Allow("k")) +} + +func TestAPIRateLimiter_UpdateConfig_NonPositiveIgnored(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 1, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + assert.True(t, rl.Allow("k")) + assert.False(t, rl.Allow("k")) + + rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 0, Burst: 0, CleanupInterval: time.Minute, LimiterTTL: time.Minute}) + rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: -1, Burst: 5, CleanupInterval: time.Minute, LimiterTTL: time.Minute}) + rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 60, Burst: -1, CleanupInterval: time.Minute, LimiterTTL: time.Minute}) + + rl.Reset("k") + assert.True(t, rl.Allow("k")) + assert.False(t, rl.Allow("k"), "burst should still be 1 — invalid UpdateConfig calls were ignored") +} + +func TestAPIRateLimiter_ConcurrentAllowAndUpdate(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 600, + Burst: 10, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + var wg sync.WaitGroup + stop := make(chan struct{}) + + for i := 0; i < 8; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + key := fmt.Sprintf("k%d", id) + for { + select { + case <-stop: + return + default: + rl.Allow(key) + } + } + }(i) + } + + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 200; i++ { + select { + case <-stop: + return + default: + rl.UpdateConfig(&RateLimiterConfig{ + RequestsPerMinute: float64(30 + (i % 90)), + Burst: 1 + (i % 20), + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + rl.SetEnabled(i%2 == 0) + } + } + }() + + time.Sleep(100 * time.Millisecond) + close(stop) + wg.Wait() +} + +func TestRateLimiterConfigFromEnv(t *testing.T) { + t.Setenv(RateLimitingEnabledEnv, "true") + t.Setenv(RateLimitingRPMEnv, "42") + t.Setenv(RateLimitingBurstEnv, "7") + + cfg, enabled := RateLimiterConfigFromEnv() + assert.True(t, enabled) + assert.Equal(t, float64(42), cfg.RequestsPerMinute) + assert.Equal(t, 7, cfg.Burst) + + t.Setenv(RateLimitingEnabledEnv, "false") + _, enabled = RateLimiterConfigFromEnv() + assert.False(t, enabled) + + t.Setenv(RateLimitingEnabledEnv, "") + t.Setenv(RateLimitingRPMEnv, "") + t.Setenv(RateLimitingBurstEnv, "") + cfg, enabled = RateLimiterConfigFromEnv() + assert.False(t, enabled) + assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute) + assert.Equal(t, defaultAPIBurst, cfg.Burst) + + t.Setenv(RateLimitingRPMEnv, "0") + t.Setenv(RateLimitingBurstEnv, "-5") + cfg, _ = RateLimiterConfigFromEnv() + assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute, "non-positive rpm must fall back to default") + assert.Equal(t, defaultAPIBurst, cfg.Burst, "non-positive burst must fall back to default") +} diff --git a/management/server/http/testing/integration/networks_handler_integration_test.go b/management/server/http/testing/integration/networks_handler_integration_test.go index 4cb6b268b..54f204a8f 100644 --- a/management/server/http/testing/integration/networks_handler_integration_test.go +++ b/management/server/http/testing/integration/networks_handler_integration_test.go @@ -1170,13 +1170,17 @@ func Test_NetworkRouters_Create(t *testing.T) { Metric: 100, Enabled: true, }, - expectedStatus: http.StatusOK, - verifyResponse: func(t *testing.T, router *api.NetworkRouter) { - t.Helper() - assert.NotEmpty(t, router.Id) - assert.Equal(t, peerID, *router.Peer) - assert.Equal(t, 1, len(*router.PeerGroups)) + expectedStatus: http.StatusBadRequest, + }, + { + name: "Create router without peer and peer_groups", + networkId: "testNetworkId", + requestBody: &api.NetworkRouterRequest{ + Masquerade: true, + Metric: 100, + Enabled: true, }, + expectedStatus: http.StatusBadRequest, }, { name: "Create router in non-existing network", @@ -1341,13 +1345,18 @@ func Test_NetworkRouters_Update(t *testing.T) { Metric: 100, Enabled: true, }, - expectedStatus: http.StatusOK, - verifyResponse: func(t *testing.T, router *api.NetworkRouter) { - t.Helper() - assert.Equal(t, "testRouterId", router.Id) - assert.Equal(t, peerID, *router.Peer) - assert.Equal(t, 1, len(*router.PeerGroups)) + expectedStatus: http.StatusBadRequest, + }, + { + name: "Update router without peer and peer_groups", + networkId: "testNetworkId", + routerId: "testRouterId", + requestBody: &api.NetworkRouterRequest{ + Masquerade: true, + Metric: 100, + Enabled: true, }, + expectedStatus: http.StatusBadRequest, }, } diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 55095bbb7..1a8b83c7e 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -35,6 +35,7 @@ import ( "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" serverauth "github.com/netbirdio/netbird/management/server/auth" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/groups" http2 "github.com/netbirdio/netbird/management/server/http" @@ -87,22 +88,22 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee jobManager := job.NewJobManager(nil, store, peersManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + t.Fatalf("Failed to create cache store: %v", err) + } + requestBuffer := server.NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManager), &config.Config{}) - am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) + am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false, cacheStore) if err != nil { t.Fatalf("Failed to create manager: %v", err) } accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil) - proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100) - if err != nil { - t.Fatalf("Failed to create proxy token store: %v", err) - } - pkceverifierStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - if err != nil { - t.Fatalf("Failed to create PKCE verifier store: %v", err) - } + proxyTokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore) + pkceverifierStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore) noopMeter := noop.NewMeterProvider().Meter("") proxyMgr, err := proxymanager.NewManager(store, noopMeter) if err != nil { @@ -114,13 +115,12 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee if err != nil { t.Fatalf("Failed to create proxy controller: %v", err) } - domainManager.SetClusterCapabilities(serviceProxyController) - serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, domainManager) + serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, proxyMgr, domainManager) proxyServiceServer.SetServiceManager(serviceManager) am.SetServiceManager(serviceManager) // @note this is required so that PAT's validate from store, but JWT's are mocked - authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false) + authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false, nil) authManagerMock := &serverauth.MockManager{ ValidateAndParseTokenFunc: mockValidateAndParseToken, EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups, @@ -135,7 +135,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil) + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } @@ -217,22 +217,22 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin jobManager := job.NewJobManager(nil, store, peersManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + t.Fatalf("Failed to create cache store: %v", err) + } + requestBuffer := server.NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManager), &config.Config{}) - am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) + am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false, cacheStore) if err != nil { t.Fatalf("Failed to create manager: %v", err) } accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil) - proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100) - if err != nil { - t.Fatalf("Failed to create proxy token store: %v", err) - } - pkceverifierStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - if err != nil { - t.Fatalf("Failed to create PKCE verifier store: %v", err) - } + proxyTokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore) + pkceverifierStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore) noopMeter := noop.NewMeterProvider().Meter("") proxyMgr, err := proxymanager.NewManager(store, noopMeter) if err != nil { @@ -244,13 +244,12 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin if err != nil { t.Fatalf("Failed to create proxy controller: %v", err) } - domainManager.SetClusterCapabilities(serviceProxyController) - serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, domainManager) + serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, proxyMgr, domainManager) proxyServiceServer.SetServiceManager(serviceManager) am.SetServiceManager(serviceManager) // @note this is required so that PAT's validate from store, but JWT's are mocked - authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false) + authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false, nil) authManagerMock := &serverauth.MockManager{ ValidateAndParseTokenFunc: mockValidateAndParseToken, EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups, @@ -265,7 +264,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil) + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } diff --git a/management/server/identity_provider.go b/management/server/identity_provider.go index 8fd96c238..f965f36b8 100644 --- a/management/server/identity_provider.go +++ b/management/server/identity_provider.go @@ -274,7 +274,7 @@ func identityProviderToConnectorConfig(idpConfig *types.IdentityProvider) *dex.C } // generateIdentityProviderID generates a unique ID for an identity provider. -// For specific provider types (okta, zitadel, entra, google, pocketid, microsoft), +// For specific provider types (okta, zitadel, entra, google, pocketid, microsoft, adfs), // the ID is prefixed with the type name. Generic OIDC providers get no prefix. func generateIdentityProviderID(idpType types.IdentityProviderType) string { id := xid.New().String() @@ -296,6 +296,8 @@ func generateIdentityProviderID(idpType types.IdentityProviderType) string { return "authentik-" + id case types.IdentityProviderTypeKeycloak: return "keycloak-" + id + case types.IdentityProviderTypeADFS: + return "adfs-" + id default: // Generic OIDC - no prefix return id diff --git a/management/server/identity_provider_test.go b/management/server/identity_provider_test.go index 9fce6b9c0..d51254c55 100644 --- a/management/server/identity_provider_test.go +++ b/management/server/identity_provider_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "path/filepath" "testing" + "time" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" @@ -19,6 +20,7 @@ import ( ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" @@ -83,10 +85,15 @@ func createManagerWithEmbeddedIdP(t testing.TB) (*DefaultAccountManager, *update permissionsManager := permissions.NewManager(testStore) peersManager := peers.NewManager(testStore, permissionsManager) + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, nil, err + } + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, testStore) networkMapController := controller.NewController(ctx, testStore, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(testStore, peersManager), &config.Config{}) - manager, err := BuildManager(ctx, &config.Config{}, testStore, networkMapController, job.NewJobManager(nil, testStore, peersManager), idpManager, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + manager, err := BuildManager(ctx, &config.Config{}, testStore, networkMapController, job.NewJobManager(nil, testStore, peersManager), idpManager, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) if err != nil { return nil, nil, err } diff --git a/management/server/idp/embedded.go b/management/server/idp/embedded.go index 2cc7b9743..48d3221cc 100644 --- a/management/server/idp/embedded.go +++ b/management/server/idp/embedded.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/idp/dex" "github.com/netbirdio/netbird/management/server/telemetry" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" ) const ( @@ -48,6 +49,8 @@ type EmbeddedIdPConfig struct { // Existing local users are preserved and will be able to login again if re-enabled. // Cannot be enabled if no external identity provider connectors are configured. LocalAuthDisabled bool + // StaticConnectors are additional connectors to seed during initialization + StaticConnectors []dex.Connector } // EmbeddedStorageConfig holds storage configuration for the embedded IdP. @@ -157,6 +160,7 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) { RedirectURIs: cliRedirectURIs, }, }, + StaticConnectors: c.StaticConnectors, } // Add owner user if provided @@ -193,6 +197,9 @@ type OAuthConfigProvider interface { // Management server has embedded Dex and can validate tokens via localhost, // avoiding external network calls and DNS resolution issues during startup. GetLocalKeysLocation() string + // GetKeyFetcher returns a KeyFetcher that reads keys directly from the IDP storage, + // or nil if direct key fetching is not supported (falls back to HTTP). + GetKeyFetcher() nbjwt.KeyFetcher GetClientIDs() []string GetUserIDClaim() string GetTokenEndpoint() string @@ -593,6 +600,11 @@ func (m *EmbeddedIdPManager) GetCLIRedirectURLs() []string { return m.config.CLIRedirectURIs } +// GetKeyFetcher returns a KeyFetcher that reads keys directly from Dex storage. +func (m *EmbeddedIdPManager) GetKeyFetcher() nbjwt.KeyFetcher { + return m.provider.GetJWKS +} + // GetKeysLocation returns the JWKS endpoint URL for token validation. func (m *EmbeddedIdPManager) GetKeysLocation() string { return m.provider.GetKeysLocation() diff --git a/management/server/idp/google_workspace.go b/management/server/idp/google_workspace.go index 48e4f3000..dadbfd83e 100644 --- a/management/server/idp/google_workspace.go +++ b/management/server/idp/google_workspace.go @@ -66,14 +66,14 @@ func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClient } // Create a new Admin SDK Directory service client - adminCredentials, err := getGoogleCredentials(ctx, config.ServiceAccountKey) + credentialsOption, err := getGoogleCredentialsOption(ctx, config.ServiceAccountKey) if err != nil { return nil, err } service, err := admin.NewService(context.Background(), option.WithScopes(admin.AdminDirectoryUserReadonlyScope), - option.WithCredentials(adminCredentials), + credentialsOption, ) if err != nil { return nil, err @@ -218,39 +218,32 @@ func (gm *GoogleWorkspaceManager) DeleteUser(_ context.Context, userID string) e return nil } -// getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey. -// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it. -// If that fails, it falls back to using the default Google credentials path. -// It returns the retrieved credentials or an error if unsuccessful. -func getGoogleCredentials(ctx context.Context, serviceAccountKey string) (*google.Credentials, error) { +// getGoogleCredentialsOption returns the google.golang.org/api option carrying +// Google credentials derived from the provided serviceAccountKey. +// It decodes the base64-encoded serviceAccountKey and uses it as the credentials JSON. +// If the key is empty, it falls back to the default Google credentials path. +func getGoogleCredentialsOption(ctx context.Context, serviceAccountKey string) (option.ClientOption, error) { log.WithContext(ctx).Debug("retrieving google credentials from the base64 encoded service account key") decodeKey, err := base64.StdEncoding.DecodeString(serviceAccountKey) if err != nil { return nil, fmt.Errorf("failed to decode service account key: %w", err) } - creds, err := google.CredentialsFromJSON( - context.Background(), - decodeKey, - admin.AdminDirectoryUserReadonlyScope, - ) - if err == nil { - // No need to fallback to the default Google credentials path - return creds, nil + if len(decodeKey) > 0 { + return option.WithAuthCredentialsJSON(option.ServiceAccount, decodeKey), nil } - log.WithContext(ctx).Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err) - log.WithContext(ctx).Debug("falling back to default google credentials location") + log.WithContext(ctx).Debug("no service account key provided, falling back to default google credentials location") - creds, err = google.FindDefaultCredentials( - context.Background(), + creds, err := google.FindDefaultCredentials( + ctx, admin.AdminDirectoryUserReadonlyScope, ) if err != nil { return nil, err } - return creds, nil + return option.WithCredentials(creds), nil } // parseGoogleWorkspaceUser parse google user to UserData. diff --git a/management/server/idp/migration/migration.go b/management/server/idp/migration/migration.go new file mode 100644 index 000000000..01cadb86d --- /dev/null +++ b/management/server/idp/migration/migration.go @@ -0,0 +1,235 @@ +// Package migration provides utility functions for migrating from the external IdP solution in pre v0.62.0 +// to the new embedded IdP manager (Dex based), which is the default in v0.62.0 and later. +// It includes functions to seed connectors and migrate existing users to use these connectors. +package migration + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "os" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/idp/dex" + "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/types" +) + +// Server is the dependency interface that migration functions use to access +// the main data store and the activity event store. +type Server interface { + Store() Store + EventStore() EventStore // may return nil +} + +const idpSeedInfoKey = "IDP_SEED_INFO" +const dryRunEnvKey = "NB_IDP_MIGRATION_DRY_RUN" + +func isDryRun() bool { + return os.Getenv(dryRunEnvKey) == "true" +} + +var ErrNoSeedInfo = errors.New("no seed info found in environment") + +// SeedConnectorFromEnv reads the IDP_SEED_INFO env var, base64-decodes it, +// and JSON-unmarshals it into a dex.Connector. Returns nil if not set. +func SeedConnectorFromEnv() (*dex.Connector, error) { + val, ok := os.LookupEnv(idpSeedInfoKey) + if !ok || val == "" { + return nil, ErrNoSeedInfo + } + + decoded, err := base64.StdEncoding.DecodeString(val) + if err != nil { + return nil, fmt.Errorf("base64 decode: %w", err) + } + + var conn dex.Connector + if err := json.Unmarshal(decoded, &conn); err != nil { + return nil, fmt.Errorf("json unmarshal: %w", err) + } + + return &conn, nil +} + +// MigrateUsersToStaticConnectors re-keys every user ID in the main store (and +// the activity store, if present) so that it encodes the given connector ID, +// skipping users that have already been migrated. Set NB_IDP_MIGRATION_DRY_RUN=true +// to log what would happen without writing any changes. +func MigrateUsersToStaticConnectors(s Server, conn *dex.Connector) error { + ctx := context.Background() + + if isDryRun() { + log.Info("[DRY RUN] migration dry-run mode enabled, no changes will be written") + } + + users, err := s.Store().ListUsers(ctx) + if err != nil { + return fmt.Errorf("failed to list users: %w", err) + } + + // Reconciliation pass: fix activity store for users already migrated in main DB + // but whose activity references may still use old IDs (from a previous partial failure). + if s.EventStore() != nil && !isDryRun() { + if err := reconcileActivityStore(ctx, s.EventStore(), users); err != nil { + return err + } + } + + var migratedCount, skippedCount int + + for _, user := range users { + _, _, decErr := dex.DecodeDexUserID(user.Id) + if decErr == nil { + skippedCount++ + continue + } + + newUserID := dex.EncodeDexUserID(user.Id, conn.ID) + + if isDryRun() { + log.Infof("[DRY RUN] would migrate user %s -> %s (account: %s)", user.Id, newUserID, user.AccountID) + migratedCount++ + continue + } + + if err := migrateUser(ctx, s, user.Id, user.AccountID, newUserID); err != nil { + return err + } + + migratedCount++ + } + + if isDryRun() { + log.Infof("[DRY RUN] migration summary: %d users would be migrated, %d already migrated", migratedCount, skippedCount) + } else { + log.Infof("migration complete: %d users migrated, %d already migrated", migratedCount, skippedCount) + } + + return nil +} + +// reconcileActivityStore updates activity store references for users already migrated +// in the main DB whose activity entries may still use old IDs from a previous partial failure. +func reconcileActivityStore(ctx context.Context, eventStore EventStore, users []*types.User) error { + for _, user := range users { + originalID, _, err := dex.DecodeDexUserID(user.Id) + if err != nil { + // skip users that aren't migrated, they will be handled in the main migration loop + continue + } + if err := eventStore.UpdateUserID(ctx, originalID, user.Id); err != nil { + return fmt.Errorf("reconcile activity store for user %s: %w", user.Id, err) + } + } + return nil +} + +// migrateUser updates a single user's ID in both the main store and the activity store. +func migrateUser(ctx context.Context, s Server, oldID, accountID, newID string) error { + if err := s.Store().UpdateUserID(ctx, accountID, oldID, newID); err != nil { + return fmt.Errorf("failed to update user ID for user %s: %w", oldID, err) + } + + if s.EventStore() == nil { + return nil + } + + if err := s.EventStore().UpdateUserID(ctx, oldID, newID); err != nil { + return fmt.Errorf("failed to update activity store user ID for user %s: %w", oldID, err) + } + + return nil +} + +// PopulateUserInfo fetches user email and name from the external IDP and updates +// the store for users that are missing this information. +func PopulateUserInfo(s Server, idpManager idp.Manager, dryRun bool) error { + ctx := context.Background() + + users, err := s.Store().ListUsers(ctx) + if err != nil { + return fmt.Errorf("failed to list users: %w", err) + } + + // Build a map of IDP user ID -> UserData from the external IDP + allAccounts, err := idpManager.GetAllAccounts(ctx) + if err != nil { + return fmt.Errorf("failed to fetch accounts from IDP: %w", err) + } + + idpUsers := make(map[string]*idp.UserData) + for _, accountUsers := range allAccounts { + for _, userData := range accountUsers { + idpUsers[userData.ID] = userData + } + } + + log.Infof("fetched %d users from IDP", len(idpUsers)) + + var updatedCount, skippedCount, notFoundCount int + + for _, user := range users { + if user.IsServiceUser { + skippedCount++ + continue + } + + if user.Email != "" && user.Name != "" { + skippedCount++ + continue + } + + // The user ID in the store may be the original IDP ID or a Dex-encoded ID. + // Try to decode the Dex format first to get the original IDP ID. + lookupID := user.Id + if originalID, _, decErr := dex.DecodeDexUserID(user.Id); decErr == nil { + lookupID = originalID + } + + idpUser, found := idpUsers[lookupID] + if !found { + notFoundCount++ + log.Debugf("user %s (lookup: %s) not found in IDP, skipping", user.Id, lookupID) + continue + } + + email := user.Email + name := user.Name + if email == "" && idpUser.Email != "" { + email = idpUser.Email + } + if name == "" && idpUser.Name != "" { + name = idpUser.Name + } + + if email == user.Email && name == user.Name { + skippedCount++ + continue + } + + if dryRun { + log.Infof("[DRY RUN] would update user %s: email=%q, name=%q", user.Id, email, name) + updatedCount++ + continue + } + + if err := s.Store().UpdateUserInfo(ctx, user.Id, email, name); err != nil { + return fmt.Errorf("failed to update user info for %s: %w", user.Id, err) + } + + log.Infof("updated user %s: email=%q, name=%q", user.Id, email, name) + updatedCount++ + } + + if dryRun { + log.Infof("[DRY RUN] user info summary: %d would be updated, %d skipped, %d not found in IDP", updatedCount, skippedCount, notFoundCount) + } else { + log.Infof("user info population complete: %d updated, %d skipped, %d not found in IDP", updatedCount, skippedCount, notFoundCount) + } + + return nil +} diff --git a/management/server/idp/migration/migration_test.go b/management/server/idp/migration/migration_test.go new file mode 100644 index 000000000..2ff71347e --- /dev/null +++ b/management/server/idp/migration/migration_test.go @@ -0,0 +1,828 @@ +package migration + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/idp/dex" + "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/types" +) + +// testStore is a hand-written mock for MigrationStore. +type testStore struct { + listUsersFunc func(ctx context.Context) ([]*types.User, error) + updateUserIDFunc func(ctx context.Context, accountID, oldUserID, newUserID string) error + updateUserInfoFunc func(ctx context.Context, userID, email, name string) error + checkSchemaFunc func(checks []SchemaCheck) []SchemaError + updateCalls []updateUserIDCall + updateInfoCalls []updateUserInfoCall +} + +type updateUserIDCall struct { + AccountID string + OldUserID string + NewUserID string +} + +type updateUserInfoCall struct { + UserID string + Email string + Name string +} + +func (s *testStore) ListUsers(ctx context.Context) ([]*types.User, error) { + return s.listUsersFunc(ctx) +} + +func (s *testStore) UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error { + s.updateCalls = append(s.updateCalls, updateUserIDCall{accountID, oldUserID, newUserID}) + return s.updateUserIDFunc(ctx, accountID, oldUserID, newUserID) +} + +func (s *testStore) UpdateUserInfo(ctx context.Context, userID, email, name string) error { + s.updateInfoCalls = append(s.updateInfoCalls, updateUserInfoCall{userID, email, name}) + if s.updateUserInfoFunc != nil { + return s.updateUserInfoFunc(ctx, userID, email, name) + } + return nil +} + +func (s *testStore) CheckSchema(checks []SchemaCheck) []SchemaError { + if s.checkSchemaFunc != nil { + return s.checkSchemaFunc(checks) + } + return nil +} + +type testServer struct { + store Store + eventStore EventStore +} + +func (s *testServer) Store() Store { return s.store } +func (s *testServer) EventStore() EventStore { return s.eventStore } + +func TestSeedConnectorFromEnv(t *testing.T) { + t.Run("returns ErrNoSeedInfo when env var is not set", func(t *testing.T) { + os.Unsetenv(idpSeedInfoKey) + + conn, err := SeedConnectorFromEnv() + assert.ErrorIs(t, err, ErrNoSeedInfo) + assert.Nil(t, conn) + }) + + t.Run("returns ErrNoSeedInfo when env var is empty", func(t *testing.T) { + t.Setenv(idpSeedInfoKey, "") + + conn, err := SeedConnectorFromEnv() + assert.ErrorIs(t, err, ErrNoSeedInfo) + assert.Nil(t, conn) + }) + + t.Run("returns error on invalid base64", func(t *testing.T) { + t.Setenv(idpSeedInfoKey, "not-valid-base64!!!") + + conn, err := SeedConnectorFromEnv() + assert.NotErrorIs(t, err, ErrNoSeedInfo) + assert.Error(t, err) + assert.Nil(t, conn) + assert.Contains(t, err.Error(), "base64 decode") + }) + + t.Run("returns error on invalid JSON", func(t *testing.T) { + encoded := base64.StdEncoding.EncodeToString([]byte("not json")) + t.Setenv(idpSeedInfoKey, encoded) + + conn, err := SeedConnectorFromEnv() + assert.NotErrorIs(t, err, ErrNoSeedInfo) + assert.Error(t, err) + assert.Nil(t, conn) + assert.Contains(t, err.Error(), "json unmarshal") + }) + + t.Run("successfully decodes valid connector", func(t *testing.T) { + expected := dex.Connector{ + Type: "oidc", + Name: "Test Provider", + ID: "test-provider", + Config: map[string]any{ + "issuer": "https://example.com", + "clientID": "my-client-id", + "clientSecret": "my-secret", + }, + } + + data, err := json.Marshal(expected) + require.NoError(t, err) + + encoded := base64.StdEncoding.EncodeToString(data) + t.Setenv(idpSeedInfoKey, encoded) + + conn, err := SeedConnectorFromEnv() + assert.NoError(t, err) + require.NotNil(t, conn) + assert.Equal(t, expected.Type, conn.Type) + assert.Equal(t, expected.Name, conn.Name) + assert.Equal(t, expected.ID, conn.ID) + assert.Equal(t, expected.Config["issuer"], conn.Config["issuer"]) + }) +} + +func TestMigrateUsersToStaticConnectors(t *testing.T) { + connector := &dex.Connector{ + Type: "oidc", + Name: "Test Provider", + ID: "test-connector", + } + + t.Run("succeeds with no users", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { return nil, nil }, + updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { return nil }, + } + + srv := &testServer{store: ms} + err := MigrateUsersToStaticConnectors(srv, connector) + assert.NoError(t, err) + }) + + t.Run("returns error when ListUsers fails", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return nil, fmt.Errorf("db error") + }, + updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { return nil }, + } + + srv := &testServer{store: ms} + err := MigrateUsersToStaticConnectors(srv, connector) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to list users") + }) + + t.Run("migrates single user with correct encoded ID", func(t *testing.T) { + user := &types.User{Id: "user-1", AccountID: "account-1"} + expectedNewID := dex.EncodeDexUserID("user-1", "test-connector") + + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{user}, nil + }, + updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { + return nil + }, + } + + srv := &testServer{store: ms} + err := MigrateUsersToStaticConnectors(srv, connector) + assert.NoError(t, err) + require.Len(t, ms.updateCalls, 1) + assert.Equal(t, "account-1", ms.updateCalls[0].AccountID) + assert.Equal(t, "user-1", ms.updateCalls[0].OldUserID) + assert.Equal(t, expectedNewID, ms.updateCalls[0].NewUserID) + }) + + t.Run("migrates multiple users", func(t *testing.T) { + users := []*types.User{ + {Id: "user-1", AccountID: "account-1"}, + {Id: "user-2", AccountID: "account-1"}, + {Id: "user-3", AccountID: "account-2"}, + } + + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return users, nil + }, + updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { + return nil + }, + } + + srv := &testServer{store: ms} + err := MigrateUsersToStaticConnectors(srv, connector) + assert.NoError(t, err) + assert.Len(t, ms.updateCalls, 3) + }) + + t.Run("returns error when UpdateUserID fails", func(t *testing.T) { + users := []*types.User{ + {Id: "user-1", AccountID: "account-1"}, + {Id: "user-2", AccountID: "account-1"}, + } + + callCount := 0 + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return users, nil + }, + updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { + callCount++ + if callCount == 2 { + return fmt.Errorf("update failed") + } + return nil + }, + } + + srv := &testServer{store: ms} + err := MigrateUsersToStaticConnectors(srv, connector) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to update user ID for user user-2") + }) + + t.Run("stops on first UpdateUserID error", func(t *testing.T) { + users := []*types.User{ + {Id: "user-1", AccountID: "account-1"}, + {Id: "user-2", AccountID: "account-1"}, + } + + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return users, nil + }, + updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { + return fmt.Errorf("update failed") + }, + } + + srv := &testServer{store: ms} + err := MigrateUsersToStaticConnectors(srv, connector) + assert.Error(t, err) + assert.Len(t, ms.updateCalls, 1) // stopped after first error + }) + + t.Run("skips already migrated users", func(t *testing.T) { + alreadyMigratedID := dex.EncodeDexUserID("user-1", "test-connector") + users := []*types.User{ + {Id: alreadyMigratedID, AccountID: "account-1"}, + } + + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return users, nil + }, + updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { + return nil + }, + } + + srv := &testServer{store: ms} + err := MigrateUsersToStaticConnectors(srv, connector) + assert.NoError(t, err) + assert.Len(t, ms.updateCalls, 0) + }) + + t.Run("migrates only non-migrated users in mixed state", func(t *testing.T) { + alreadyMigratedID := dex.EncodeDexUserID("user-1", "test-connector") + users := []*types.User{ + {Id: alreadyMigratedID, AccountID: "account-1"}, + {Id: "user-2", AccountID: "account-1"}, + {Id: "user-3", AccountID: "account-2"}, + } + + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return users, nil + }, + updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { + return nil + }, + } + + srv := &testServer{store: ms} + err := MigrateUsersToStaticConnectors(srv, connector) + assert.NoError(t, err) + // Only user-2 and user-3 should be migrated + assert.Len(t, ms.updateCalls, 2) + assert.Equal(t, "user-2", ms.updateCalls[0].OldUserID) + assert.Equal(t, "user-3", ms.updateCalls[1].OldUserID) + }) + + t.Run("dry run does not call UpdateUserID", func(t *testing.T) { + t.Setenv(dryRunEnvKey, "true") + + users := []*types.User{ + {Id: "user-1", AccountID: "account-1"}, + {Id: "user-2", AccountID: "account-1"}, + } + + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return users, nil + }, + updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { + t.Fatal("UpdateUserID should not be called in dry-run mode") + return nil + }, + } + + srv := &testServer{store: ms} + err := MigrateUsersToStaticConnectors(srv, connector) + assert.NoError(t, err) + assert.Len(t, ms.updateCalls, 0) + }) + + t.Run("dry run skips already migrated users", func(t *testing.T) { + t.Setenv(dryRunEnvKey, "true") + + alreadyMigratedID := dex.EncodeDexUserID("user-1", "test-connector") + users := []*types.User{ + {Id: alreadyMigratedID, AccountID: "account-1"}, + {Id: "user-2", AccountID: "account-1"}, + } + + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return users, nil + }, + updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { + t.Fatal("UpdateUserID should not be called in dry-run mode") + return nil + }, + } + + srv := &testServer{store: ms} + err := MigrateUsersToStaticConnectors(srv, connector) + assert.NoError(t, err) + }) + + t.Run("dry run disabled by default", func(t *testing.T) { + user := &types.User{Id: "user-1", AccountID: "account-1"} + + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{user}, nil + }, + updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { + return nil + }, + } + + srv := &testServer{store: ms} + err := MigrateUsersToStaticConnectors(srv, connector) + assert.NoError(t, err) + assert.Len(t, ms.updateCalls, 1) // proves it's not in dry-run + }) +} + +func TestPopulateUserInfo(t *testing.T) { + noopUpdateID := func(ctx context.Context, accountID, oldUserID, newUserID string) error { return nil } + + t.Run("succeeds with no users", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { return nil, nil }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{}, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.NoError(t, err) + assert.Empty(t, ms.updateInfoCalls) + }) + + t.Run("returns error when ListUsers fails", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return nil, fmt.Errorf("db error") + }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{} + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to list users") + }) + + t.Run("returns error when GetAllAccounts fails", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{{Id: "user-1", AccountID: "acc-1"}}, nil + }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return nil, fmt.Errorf("idp error") + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to fetch accounts from IDP") + }) + + t.Run("updates user with missing email and name", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: "user-1", AccountID: "acc-1", Email: "", Name: ""}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": { + {ID: "user-1", Email: "user1@example.com", Name: "User One"}, + }, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.NoError(t, err) + require.Len(t, ms.updateInfoCalls, 1) + assert.Equal(t, "user-1", ms.updateInfoCalls[0].UserID) + assert.Equal(t, "user1@example.com", ms.updateInfoCalls[0].Email) + assert.Equal(t, "User One", ms.updateInfoCalls[0].Name) + }) + + t.Run("updates only missing email when name exists", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: "user-1", AccountID: "acc-1", Email: "", Name: "Existing Name"}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": {{ID: "user-1", Email: "user1@example.com", Name: "IDP Name"}}, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.NoError(t, err) + require.Len(t, ms.updateInfoCalls, 1) + assert.Equal(t, "user1@example.com", ms.updateInfoCalls[0].Email) + assert.Equal(t, "Existing Name", ms.updateInfoCalls[0].Name) + }) + + t.Run("updates only missing name when email exists", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: "user-1", AccountID: "acc-1", Email: "existing@example.com", Name: ""}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": {{ID: "user-1", Email: "idp@example.com", Name: "IDP Name"}}, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.NoError(t, err) + require.Len(t, ms.updateInfoCalls, 1) + assert.Equal(t, "existing@example.com", ms.updateInfoCalls[0].Email) + assert.Equal(t, "IDP Name", ms.updateInfoCalls[0].Name) + }) + + t.Run("skips users that already have both email and name", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: "user-1", AccountID: "acc-1", Email: "user1@example.com", Name: "User One"}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": {{ID: "user-1", Email: "different@example.com", Name: "Different Name"}}, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.NoError(t, err) + assert.Empty(t, ms.updateInfoCalls) + }) + + t.Run("skips service users", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: "svc-1", AccountID: "acc-1", Email: "", Name: "", IsServiceUser: true}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": {{ID: "svc-1", Email: "svc@example.com", Name: "Service"}}, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.NoError(t, err) + assert.Empty(t, ms.updateInfoCalls) + }) + + t.Run("skips users not found in IDP", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: "user-1", AccountID: "acc-1", Email: "", Name: ""}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": {{ID: "different-user", Email: "other@example.com", Name: "Other"}}, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.NoError(t, err) + assert.Empty(t, ms.updateInfoCalls) + }) + + t.Run("looks up dex-encoded user IDs by original ID", func(t *testing.T) { + dexEncodedID := dex.EncodeDexUserID("original-idp-id", "my-connector") + + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: dexEncodedID, AccountID: "acc-1", Email: "", Name: ""}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": {{ID: "original-idp-id", Email: "user@example.com", Name: "User"}}, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.NoError(t, err) + require.Len(t, ms.updateInfoCalls, 1) + assert.Equal(t, dexEncodedID, ms.updateInfoCalls[0].UserID) + assert.Equal(t, "user@example.com", ms.updateInfoCalls[0].Email) + assert.Equal(t, "User", ms.updateInfoCalls[0].Name) + }) + + t.Run("handles multiple users across multiple accounts", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: "user-1", AccountID: "acc-1", Email: "", Name: ""}, + {Id: "user-2", AccountID: "acc-1", Email: "already@set.com", Name: "Already Set"}, + {Id: "user-3", AccountID: "acc-2", Email: "", Name: ""}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": { + {ID: "user-1", Email: "u1@example.com", Name: "User 1"}, + {ID: "user-2", Email: "u2@example.com", Name: "User 2"}, + }, + "acc-2": { + {ID: "user-3", Email: "u3@example.com", Name: "User 3"}, + }, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.NoError(t, err) + require.Len(t, ms.updateInfoCalls, 2) + assert.Equal(t, "user-1", ms.updateInfoCalls[0].UserID) + assert.Equal(t, "u1@example.com", ms.updateInfoCalls[0].Email) + assert.Equal(t, "user-3", ms.updateInfoCalls[1].UserID) + assert.Equal(t, "u3@example.com", ms.updateInfoCalls[1].Email) + }) + + t.Run("returns error when UpdateUserInfo fails", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: "user-1", AccountID: "acc-1", Email: "", Name: ""}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + updateUserInfoFunc: func(ctx context.Context, userID, email, name string) error { + return fmt.Errorf("db write error") + }, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": {{ID: "user-1", Email: "u1@example.com", Name: "User 1"}}, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to update user info for user-1") + }) + + t.Run("stops on first UpdateUserInfo error", func(t *testing.T) { + callCount := 0 + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: "user-1", AccountID: "acc-1", Email: "", Name: ""}, + {Id: "user-2", AccountID: "acc-1", Email: "", Name: ""}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + updateUserInfoFunc: func(ctx context.Context, userID, email, name string) error { + callCount++ + return fmt.Errorf("db write error") + }, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": { + {ID: "user-1", Email: "u1@example.com", Name: "U1"}, + {ID: "user-2", Email: "u2@example.com", Name: "U2"}, + }, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.Error(t, err) + assert.Equal(t, 1, callCount) + }) + + t.Run("dry run does not call UpdateUserInfo", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: "user-1", AccountID: "acc-1", Email: "", Name: ""}, + {Id: "user-2", AccountID: "acc-1", Email: "", Name: ""}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + updateUserInfoFunc: func(ctx context.Context, userID, email, name string) error { + t.Fatal("UpdateUserInfo should not be called in dry-run mode") + return nil + }, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": { + {ID: "user-1", Email: "u1@example.com", Name: "U1"}, + {ID: "user-2", Email: "u2@example.com", Name: "U2"}, + }, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, true) + assert.NoError(t, err) + assert.Empty(t, ms.updateInfoCalls) + }) + + t.Run("skips user when IDP has empty email and name too", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: "user-1", AccountID: "acc-1", Email: "", Name: ""}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": {{ID: "user-1", Email: "", Name: ""}}, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.NoError(t, err) + assert.Empty(t, ms.updateInfoCalls) + }) +} + +func TestSchemaError_String(t *testing.T) { + t.Run("missing table", func(t *testing.T) { + e := SchemaError{Table: "jobs"} + assert.Equal(t, `table "jobs" is missing`, e.String()) + }) + + t.Run("missing column", func(t *testing.T) { + e := SchemaError{Table: "users", Column: "email"} + assert.Equal(t, `column "email" on table "users" is missing`, e.String()) + }) +} + +func TestRequiredSchema(t *testing.T) { + // Verify RequiredSchema covers all the tables touched by UpdateUserID and UpdateUserInfo. + expectedTables := []string{ + "users", + "personal_access_tokens", + "peers", + "accounts", + "user_invites", + "proxy_access_tokens", + "jobs", + } + + schemaTableNames := make([]string, len(RequiredSchema)) + for i, s := range RequiredSchema { + schemaTableNames[i] = s.Table + } + + for _, expected := range expectedTables { + assert.Contains(t, schemaTableNames, expected, "RequiredSchema should include table %q", expected) + } +} + +func TestCheckSchema_MockStore(t *testing.T) { + t.Run("returns nil when all schema exists", func(t *testing.T) { + ms := &testStore{ + checkSchemaFunc: func(checks []SchemaCheck) []SchemaError { + return nil + }, + } + errs := ms.CheckSchema(RequiredSchema) + assert.Empty(t, errs) + }) + + t.Run("returns errors for missing tables", func(t *testing.T) { + ms := &testStore{ + checkSchemaFunc: func(checks []SchemaCheck) []SchemaError { + return []SchemaError{ + {Table: "jobs"}, + {Table: "proxy_access_tokens"}, + } + }, + } + errs := ms.CheckSchema(RequiredSchema) + require.Len(t, errs, 2) + assert.Equal(t, "jobs", errs[0].Table) + assert.Equal(t, "", errs[0].Column) + assert.Equal(t, "proxy_access_tokens", errs[1].Table) + }) + + t.Run("returns errors for missing columns", func(t *testing.T) { + ms := &testStore{ + checkSchemaFunc: func(checks []SchemaCheck) []SchemaError { + return []SchemaError{ + {Table: "users", Column: "email"}, + {Table: "users", Column: "name"}, + } + }, + } + errs := ms.CheckSchema(RequiredSchema) + require.Len(t, errs, 2) + assert.Equal(t, "users", errs[0].Table) + assert.Equal(t, "email", errs[0].Column) + }) +} diff --git a/management/server/idp/migration/store.go b/management/server/idp/migration/store.go new file mode 100644 index 000000000..e7cc54a41 --- /dev/null +++ b/management/server/idp/migration/store.go @@ -0,0 +1,82 @@ +package migration + +import ( + "context" + "fmt" + + "github.com/netbirdio/netbird/management/server/types" +) + +// SchemaCheck represents a table and the columns required on it. +type SchemaCheck struct { + Table string + Columns []string +} + +// RequiredSchema lists all tables and columns that the migration tool needs. +// If any are missing, the user must upgrade their management server first so +// that the automatic GORM migrations create them. +var RequiredSchema = []SchemaCheck{ + {Table: "users", Columns: []string{"id", "email", "name", "account_id"}}, + {Table: "personal_access_tokens", Columns: []string{"user_id", "created_by"}}, + {Table: "peers", Columns: []string{"user_id"}}, + {Table: "accounts", Columns: []string{"created_by"}}, + {Table: "user_invites", Columns: []string{"created_by"}}, + {Table: "proxy_access_tokens", Columns: []string{"created_by"}}, + {Table: "jobs", Columns: []string{"triggered_by"}}, +} + +// SchemaError describes a single missing table or column. +type SchemaError struct { + Table string + Column string // empty when the whole table is missing +} + +func (e SchemaError) String() string { + if e.Column == "" { + return fmt.Sprintf("table %q is missing", e.Table) + } + return fmt.Sprintf("column %q on table %q is missing", e.Column, e.Table) +} + +// Store defines the data store operations required for IdP user migration. +// This interface is separate from the main store.Store interface because these methods +// are only used during one-time migration and should be removed once migration tooling +// is no longer needed. +// +// The SQL store implementations (SqlStore) already have these methods on their concrete +// types, so they satisfy this interface via Go's structural typing with zero code changes. +type Store interface { + // ListUsers returns all users across all accounts. + ListUsers(ctx context.Context) ([]*types.User, error) + + // UpdateUserID atomically updates a user's ID and all foreign key references + // across the database (peers, groups, policies, PATs, etc.). + UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error + + // UpdateUserInfo updates a user's email and name in the store. + UpdateUserInfo(ctx context.Context, userID, email, name string) error + + // CheckSchema verifies that all tables and columns required by the migration + // exist in the database. Returns a list of problems; an empty slice means OK. + CheckSchema(checks []SchemaCheck) []SchemaError +} + +// RequiredEventSchema lists all tables and columns that the migration tool needs +// in the activity/event store. +var RequiredEventSchema = []SchemaCheck{ + {Table: "events", Columns: []string{"initiator_id", "target_id"}}, + {Table: "deleted_users", Columns: []string{"id"}}, +} + +// EventStore defines the activity event store operations required for migration. +// Like Store, this is a temporary interface for migration tooling only. +type EventStore interface { + // CheckSchema verifies that all tables and columns required by the migration + // exist in the event database. Returns a list of problems; an empty slice means OK. + CheckSchema(checks []SchemaCheck) []SchemaError + + // UpdateUserID updates all event references (initiator_id, target_id) and + // deleted_users records to use the new user ID format. + UpdateUserID(ctx context.Context, oldUserID, newUserID string) error +} diff --git a/management/server/instance/manager.go b/management/server/instance/manager.go index 19e3abdc0..2c355bb3b 100644 --- a/management/server/instance/manager.go +++ b/management/server/instance/manager.go @@ -12,6 +12,7 @@ import ( "sync" "time" + "github.com/dexidp/dex/storage" goversion "github.com/hashicorp/go-version" log "github.com/sirupsen/logrus" @@ -60,14 +61,31 @@ type Manager interface { // This should only be called when IsSetupRequired returns true. CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) + // RollbackSetup reverses a successful CreateOwnerUser by deleting the user + // from the embedded IDP and reloading setupRequired from persistent state, so + // /api/setup can be retried only when no accounts or local users remain. Used + // when post-user steps (account or PAT creation) fail and the caller wants a + // clean slate. + RollbackSetup(ctx context.Context, userID string) error + // GetVersionInfo returns version information for NetBird components. GetVersionInfo(ctx context.Context) (*VersionInfo, error) } +type instanceStore interface { + GetAccountsCounter(ctx context.Context) (int64, error) +} + +type embeddedIdP interface { + CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error) + DeleteUser(ctx context.Context, userID string) error + GetAllAccounts(ctx context.Context) (map[string][]*idp.UserData, error) +} + // DefaultManager is the default implementation of Manager. type DefaultManager struct { - store store.Store - embeddedIdpManager *idp.EmbeddedIdPManager + store instanceStore + embeddedIdpManager embeddedIdP setupRequired bool setupMu sync.RWMutex @@ -82,18 +100,18 @@ type DefaultManager struct { // NewManager creates a new instance manager. // If idpManager is not an EmbeddedIdPManager, setup-related operations will return appropriate defaults. func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager) (Manager, error) { - embeddedIdp, _ := idpManager.(*idp.EmbeddedIdPManager) + embeddedIdp, ok := idpManager.(*idp.EmbeddedIdPManager) m := &DefaultManager{ - store: store, - embeddedIdpManager: embeddedIdp, - setupRequired: false, + store: store, + setupRequired: false, httpClient: &http.Client{ Timeout: httpTimeout, }, } - if embeddedIdp != nil { + if ok && embeddedIdp != nil { + m.embeddedIdpManager = embeddedIdp err := m.loadSetupRequired(ctx) if err != nil { return nil, err @@ -143,36 +161,106 @@ func (m *DefaultManager) IsSetupRequired(_ context.Context) (bool, error) { // CreateOwnerUser creates the initial owner user in the embedded IDP. func (m *DefaultManager) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) { - if err := m.validateSetupInfo(email, password, name); err != nil { - return nil, err - } - if m.embeddedIdpManager == nil { return nil, errors.New("embedded IDP is not enabled") } - m.setupMu.RLock() - setupRequired := m.setupRequired - m.setupMu.RUnlock() + if err := m.validateSetupInfo(email, password, name); err != nil { + return nil, err + } - if !setupRequired { + m.setupMu.Lock() + defer m.setupMu.Unlock() + + if !m.setupRequired { return nil, status.Errorf(status.PreconditionFailed, "setup already completed") } + if err := m.checkSetupRequiredFromDB(ctx); err != nil { + var sErr *status.Error + if errors.As(err, &sErr) && sErr.Type() == status.PreconditionFailed { + m.setupRequired = false + } + return nil, err + } + userData, err := m.embeddedIdpManager.CreateUserWithPassword(ctx, email, password, name) if err != nil { return nil, fmt.Errorf("failed to create user in embedded IdP: %w", err) } - m.setupMu.Lock() m.setupRequired = false - m.setupMu.Unlock() log.WithContext(ctx).Infof("created owner user %s in embedded IdP", email) return userData, nil } +// RollbackSetup undoes a successful CreateOwnerUser: deletes the user from the +// embedded IDP and reloads setupRequired from persistent state. +func (m *DefaultManager) RollbackSetup(ctx context.Context, userID string) error { + if m.embeddedIdpManager == nil { + return errors.New("embedded IDP is not enabled") + } + + var deleteErr error + if err := m.embeddedIdpManager.DeleteUser(ctx, userID); err != nil { + if isNotFoundError(err) { + log.WithContext(ctx).Debugf("setup rollback user %s already deleted", userID) + } else { + deleteErr = fmt.Errorf("failed to delete user from embedded IdP: %w", err) + } + } + + if err := m.loadSetupRequired(ctx); err != nil { + reloadErr := fmt.Errorf("failed to reload setup state after rollback: %w", err) + if deleteErr != nil { + return errors.Join(deleteErr, reloadErr) + } + return reloadErr + } + + if deleteErr != nil { + return deleteErr + } + + log.WithContext(ctx).Infof("rolled back setup for user %s", userID) + return nil +} + +func isNotFoundError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, storage.ErrNotFound) { + return true + } + if s, ok := status.FromError(err); ok { + return s.Type() == status.NotFound + } + return false +} + +func (m *DefaultManager) checkSetupRequiredFromDB(ctx context.Context) error { + numAccounts, err := m.store.GetAccountsCounter(ctx) + if err != nil { + return fmt.Errorf("failed to check accounts: %w", err) + } + if numAccounts > 0 { + return status.Errorf(status.PreconditionFailed, "setup already completed") + } + + users, err := m.embeddedIdpManager.GetAllAccounts(ctx) + if err != nil { + return fmt.Errorf("failed to check IdP users: %w", err) + } + if len(users) > 0 { + return status.Errorf(status.PreconditionFailed, "setup already completed") + } + + return nil +} + func (m *DefaultManager) validateSetupInfo(email, password, name string) error { if email == "" { return status.Errorf(status.InvalidArgument, "email is required") @@ -189,6 +277,9 @@ func (m *DefaultManager) validateSetupInfo(email, password, name string) error { if len(password) < 8 { return status.Errorf(status.InvalidArgument, "password must be at least 8 characters") } + if len(password) > 72 { + return status.Errorf(status.InvalidArgument, "password must be at most 72 characters") + } return nil } diff --git a/management/server/instance/manager_test.go b/management/server/instance/manager_test.go index 35d0ff53c..5ffb016de 100644 --- a/management/server/instance/manager_test.go +++ b/management/server/instance/manager_test.go @@ -3,181 +3,309 @@ package instance import ( "context" "errors" + "fmt" + "net/http" + "sync" + "sync/atomic" "testing" + "time" + "github.com/dexidp/dex/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/shared/management/status" ) -// mockStore implements a minimal store.Store for testing +type mockIdP struct { + mu sync.Mutex + createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error) + deleteUserFunc func(ctx context.Context, userID string) error + users map[string][]*idp.UserData + getAllAccountsErr error +} + +func (m *mockIdP) CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error) { + if m.createUserFunc != nil { + return m.createUserFunc(ctx, email, password, name) + } + return &idp.UserData{ID: "test-user-id", Email: email, Name: name}, nil +} + +func (m *mockIdP) DeleteUser(ctx context.Context, userID string) error { + if m.deleteUserFunc != nil { + return m.deleteUserFunc(ctx, userID) + } + return nil +} + +func (m *mockIdP) GetAllAccounts(_ context.Context) (map[string][]*idp.UserData, error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.getAllAccountsErr != nil { + return nil, m.getAllAccountsErr + } + return m.users, nil +} + type mockStore struct { accountsCount int64 err error } -func (m *mockStore) GetAccountsCounter(ctx context.Context) (int64, error) { +func (m *mockStore) GetAccountsCounter(_ context.Context) (int64, error) { if m.err != nil { return 0, m.err } return m.accountsCount, nil } -// mockEmbeddedIdPManager wraps the real EmbeddedIdPManager for testing -type mockEmbeddedIdPManager struct { - createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error) -} - -func (m *mockEmbeddedIdPManager) CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error) { - if m.createUserFunc != nil { - return m.createUserFunc(ctx, email, password, name) +func newTestManager(idpMock *mockIdP, storeMock *mockStore) *DefaultManager { + return &DefaultManager{ + store: storeMock, + embeddedIdpManager: idpMock, + setupRequired: true, + httpClient: &http.Client{Timeout: httpTimeout}, } - return &idp.UserData{ - ID: "test-user-id", - Email: email, - Name: name, - }, nil -} - -// testManager is a test implementation that accepts our mock types -type testManager struct { - store *mockStore - embeddedIdpManager *mockEmbeddedIdPManager -} - -func (m *testManager) IsSetupRequired(ctx context.Context) (bool, error) { - if m.embeddedIdpManager == nil { - return false, nil - } - - count, err := m.store.GetAccountsCounter(ctx) - if err != nil { - return false, err - } - - return count == 0, nil -} - -func (m *testManager) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) { - if m.embeddedIdpManager == nil { - return nil, errors.New("embedded IDP is not enabled") - } - - return m.embeddedIdpManager.CreateUserWithPassword(ctx, email, password, name) -} - -func TestIsSetupRequired_EmbeddedIdPDisabled(t *testing.T) { - manager := &testManager{ - store: &mockStore{accountsCount: 0}, - embeddedIdpManager: nil, // No embedded IDP - } - - required, err := manager.IsSetupRequired(context.Background()) - require.NoError(t, err) - assert.False(t, required, "setup should not be required when embedded IDP is disabled") -} - -func TestIsSetupRequired_NoAccounts(t *testing.T) { - manager := &testManager{ - store: &mockStore{accountsCount: 0}, - embeddedIdpManager: &mockEmbeddedIdPManager{}, - } - - required, err := manager.IsSetupRequired(context.Background()) - require.NoError(t, err) - assert.True(t, required, "setup should be required when no accounts exist") -} - -func TestIsSetupRequired_AccountsExist(t *testing.T) { - manager := &testManager{ - store: &mockStore{accountsCount: 1}, - embeddedIdpManager: &mockEmbeddedIdPManager{}, - } - - required, err := manager.IsSetupRequired(context.Background()) - require.NoError(t, err) - assert.False(t, required, "setup should not be required when accounts exist") -} - -func TestIsSetupRequired_MultipleAccounts(t *testing.T) { - manager := &testManager{ - store: &mockStore{accountsCount: 5}, - embeddedIdpManager: &mockEmbeddedIdPManager{}, - } - - required, err := manager.IsSetupRequired(context.Background()) - require.NoError(t, err) - assert.False(t, required, "setup should not be required when multiple accounts exist") -} - -func TestIsSetupRequired_StoreError(t *testing.T) { - manager := &testManager{ - store: &mockStore{err: errors.New("database error")}, - embeddedIdpManager: &mockEmbeddedIdPManager{}, - } - - _, err := manager.IsSetupRequired(context.Background()) - assert.Error(t, err, "should return error when store fails") } func TestCreateOwnerUser_Success(t *testing.T) { - expectedEmail := "admin@example.com" - expectedName := "Admin User" - expectedPassword := "securepassword123" + idpMock := &mockIdP{} + mgr := newTestManager(idpMock, &mockStore{}) - manager := &testManager{ - store: &mockStore{accountsCount: 0}, - embeddedIdpManager: &mockEmbeddedIdPManager{ - createUserFunc: func(ctx context.Context, email, password, name string) (*idp.UserData, error) { - assert.Equal(t, expectedEmail, email) - assert.Equal(t, expectedPassword, password) - assert.Equal(t, expectedName, name) - return &idp.UserData{ - ID: "created-user-id", - Email: email, - Name: name, - }, nil - }, - }, - } - - userData, err := manager.CreateOwnerUser(context.Background(), expectedEmail, expectedPassword, expectedName) + userData, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") require.NoError(t, err) - assert.Equal(t, "created-user-id", userData.ID) - assert.Equal(t, expectedEmail, userData.Email) - assert.Equal(t, expectedName, userData.Name) + assert.Equal(t, "admin@example.com", userData.Email) + + _, err = mgr.CreateOwnerUser(context.Background(), "admin2@example.com", "password123", "Admin2") + require.Error(t, err) + assert.Contains(t, err.Error(), "setup already completed") +} + +func TestCreateOwnerUser_SetupAlreadyCompleted(t *testing.T) { + mgr := newTestManager(&mockIdP{}, &mockStore{}) + mgr.setupRequired = false + + _, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") + require.Error(t, err) + assert.Contains(t, err.Error(), "setup already completed") } func TestCreateOwnerUser_EmbeddedIdPDisabled(t *testing.T) { - manager := &testManager{ - store: &mockStore{accountsCount: 0}, - embeddedIdpManager: nil, - } + mgr := &DefaultManager{setupRequired: true} - _, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") - assert.Error(t, err, "should return error when embedded IDP is disabled") + _, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") + require.Error(t, err) assert.Contains(t, err.Error(), "embedded IDP is not enabled") } func TestCreateOwnerUser_IdPError(t *testing.T) { - manager := &testManager{ - store: &mockStore{accountsCount: 0}, - embeddedIdpManager: &mockEmbeddedIdPManager{ - createUserFunc: func(ctx context.Context, email, password, name string) (*idp.UserData, error) { - return nil, errors.New("user already exists") - }, + idpMock := &mockIdP{ + createUserFunc: func(_ context.Context, _, _, _ string) (*idp.UserData, error) { + return nil, errors.New("provider error") + }, + } + mgr := newTestManager(idpMock, &mockStore{}) + + _, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") + require.Error(t, err) + assert.Contains(t, err.Error(), "provider error") + + required, _ := mgr.IsSetupRequired(context.Background()) + assert.True(t, required, "setup should still be required after IdP error") +} + +func TestCreateOwnerUser_TransientDBError_DoesNotBlockSetup(t *testing.T) { + mgr := newTestManager(&mockIdP{}, &mockStore{err: errors.New("connection refused")}) + + _, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") + require.Error(t, err) + assert.Contains(t, err.Error(), "connection refused") + + required, _ := mgr.IsSetupRequired(context.Background()) + assert.True(t, required, "setup should still be required after transient DB error") + + mgr.store = &mockStore{} + userData, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") + require.NoError(t, err) + assert.Equal(t, "admin@example.com", userData.Email) +} + +func TestCreateOwnerUser_TransientIdPError_DoesNotBlockSetup(t *testing.T) { + idpMock := &mockIdP{getAllAccountsErr: errors.New("connection reset")} + mgr := newTestManager(idpMock, &mockStore{}) + + _, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") + require.Error(t, err) + assert.Contains(t, err.Error(), "connection reset") + + required, _ := mgr.IsSetupRequired(context.Background()) + assert.True(t, required, "setup should still be required after transient IdP error") + + idpMock.getAllAccountsErr = nil + userData, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") + require.NoError(t, err) + assert.Equal(t, "admin@example.com", userData.Email) +} + +func TestCreateOwnerUser_DBCheckBlocksConcurrent(t *testing.T) { + idpMock := &mockIdP{ + users: map[string][]*idp.UserData{ + "acc1": {{ID: "existing-user"}}, + }, + } + mgr := newTestManager(idpMock, &mockStore{}) + + _, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") + require.Error(t, err) + assert.Contains(t, err.Error(), "setup already completed") +} + +func TestCreateOwnerUser_DBCheckBlocksWhenAccountsExist(t *testing.T) { + mgr := newTestManager(&mockIdP{}, &mockStore{accountsCount: 1}) + + _, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") + require.Error(t, err) + assert.Contains(t, err.Error(), "setup already completed") +} + +func TestCreateOwnerUser_ConcurrentRequests(t *testing.T) { + var idpCallCount atomic.Int32 + var successCount atomic.Int32 + var failCount atomic.Int32 + + idpMock := &mockIdP{ + createUserFunc: func(_ context.Context, email, _, _ string) (*idp.UserData, error) { + idpCallCount.Add(1) + time.Sleep(50 * time.Millisecond) + return &idp.UserData{ID: "user-1", Email: email, Name: "Owner"}, nil + }, + } + mgr := newTestManager(idpMock, &mockStore{}) + + var wg sync.WaitGroup + for i := range 10 { + wg.Add(1) + go func(idx int) { + defer wg.Done() + _, err := mgr.CreateOwnerUser( + context.Background(), + fmt.Sprintf("owner%d@example.com", idx), + "password1234", + fmt.Sprintf("Owner%d", idx), + ) + if err != nil { + failCount.Add(1) + } else { + successCount.Add(1) + } + }(i) + } + wg.Wait() + + assert.Equal(t, int32(1), successCount.Load(), "exactly one concurrent setup request should succeed") + assert.Equal(t, int32(9), failCount.Load(), "remaining concurrent requests should fail") + assert.Equal(t, int32(1), idpCallCount.Load(), "IdP CreateUser should be called exactly once") +} + +func TestIsSetupRequired_EmbeddedIdPDisabled(t *testing.T) { + mgr := &DefaultManager{} + + required, err := mgr.IsSetupRequired(context.Background()) + require.NoError(t, err) + assert.False(t, required) +} + +func TestIsSetupRequired_ReturnsFlag(t *testing.T) { + mgr := newTestManager(&mockIdP{}, &mockStore{}) + + required, err := mgr.IsSetupRequired(context.Background()) + require.NoError(t, err) + assert.True(t, required) + + mgr.setupMu.Lock() + mgr.setupRequired = false + mgr.setupMu.Unlock() + + required, err = mgr.IsSetupRequired(context.Background()) + require.NoError(t, err) + assert.False(t, required) +} + +func TestRollbackSetup_UserAlreadyDeletedIsSuccess(t *testing.T) { + tests := []struct { + name string + err error + }{ + { + name: "management status not found", + err: status.NewUserNotFoundError("owner-id"), + }, + { + name: "dex storage not found", + err: fmt.Errorf("failed to get user for deletion: %w", storage.ErrNotFound), }, } - _, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") - assert.Error(t, err, "should return error when IDP fails") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + idpMock := &mockIdP{ + deleteUserFunc: func(_ context.Context, userID string) error { + assert.Equal(t, "owner-id", userID) + return tt.err + }, + } + mgr := newTestManager(idpMock, &mockStore{}) + mgr.setupRequired = false + + err := mgr.RollbackSetup(context.Background(), "owner-id") + require.NoError(t, err) + + required, err := mgr.IsSetupRequired(context.Background()) + require.NoError(t, err) + assert.True(t, required, "setup should be required when no accounts or local users remain") + }) + } +} + +func TestRollbackSetup_RecomputesSetupStateWhenAccountStillExists(t *testing.T) { + idpMock := &mockIdP{ + deleteUserFunc: func(_ context.Context, _ string) error { + return status.NewUserNotFoundError("owner-id") + }, + } + mgr := newTestManager(idpMock, &mockStore{accountsCount: 1}) + mgr.setupRequired = true + + err := mgr.RollbackSetup(context.Background(), "owner-id") + require.NoError(t, err) + + required, err := mgr.IsSetupRequired(context.Background()) + require.NoError(t, err) + assert.False(t, required, "setup should not be required while an account still exists") +} + +func TestRollbackSetup_ReturnsDeleteErrorButReloadsSetupState(t *testing.T) { + idpMock := &mockIdP{ + deleteUserFunc: func(_ context.Context, _ string) error { + return errors.New("idp unavailable") + }, + } + mgr := newTestManager(idpMock, &mockStore{}) + mgr.setupRequired = false + + err := mgr.RollbackSetup(context.Background(), "owner-id") + require.Error(t, err) + assert.Contains(t, err.Error(), "idp unavailable") + + required, err := mgr.IsSetupRequired(context.Background()) + require.NoError(t, err) + assert.True(t, required, "setup state should be reloaded even when user deletion fails") } func TestDefaultManager_ValidateSetupRequest(t *testing.T) { - manager := &DefaultManager{ - setupRequired: true, - } + manager := &DefaultManager{setupRequired: true} tests := []struct { name string @@ -188,11 +316,10 @@ func TestDefaultManager_ValidateSetupRequest(t *testing.T) { errorMsg string }{ { - name: "valid request", - email: "admin@example.com", - password: "password123", - userName: "Admin User", - expectError: false, + name: "valid request", + email: "admin@example.com", + password: "password123", + userName: "Admin User", }, { name: "empty email", @@ -235,11 +362,24 @@ func TestDefaultManager_ValidateSetupRequest(t *testing.T) { errorMsg: "password must be at least 8 characters", }, { - name: "password exactly 8 characters", + name: "password exactly 8 characters", + email: "admin@example.com", + password: "12345678", + userName: "Admin User", + }, + { + name: "password exactly 72 characters", + email: "admin@example.com", + password: "aaaaaaaabbbbbbbbccccccccddddddddeeeeeeeeffffffffgggggggghhhhhhhhiiiiiiii", + userName: "Admin User", + }, + { + name: "password too long", email: "admin@example.com", - password: "12345678", + password: "aaaaaaaabbbbbbbbccccccccddddddddeeeeeeeeffffffffgggggggghhhhhhhhiiiiiiiij", userName: "Admin User", - expectError: false, + expectError: true, + errorMsg: "password must be at most 72 characters", }, } @@ -255,14 +395,3 @@ func TestDefaultManager_ValidateSetupRequest(t *testing.T) { }) } } - -func TestDefaultManager_CreateOwnerUser_SetupAlreadyCompleted(t *testing.T) { - manager := &DefaultManager{ - setupRequired: false, - embeddedIdpManager: &idp.EmbeddedIdPManager{}, - } - - _, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") - require.Error(t, err) - assert.Contains(t, err.Error(), "setup already completed") -} diff --git a/management/server/instance/setup_service.go b/management/server/instance/setup_service.go new file mode 100644 index 000000000..92a4923be --- /dev/null +++ b/management/server/instance/setup_service.go @@ -0,0 +1,216 @@ +package instance + +import ( + "context" + "fmt" + "os" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/status" +) + +const ( + setupPATTokenName = "setup-token" + + // SetupPATEnabledEnvKey enables setup-time Personal Access Token creation. + SetupPATEnabledEnvKey = "NB_SETUP_PAT_ENABLED" + + setupPATDefaultExpireDays = 1 +) + +// SetupOptions controls optional work performed during initial instance setup. +type SetupOptions struct { + // CreatePAT requests creation of a setup Personal Access Token. It is honored + // only when SetupPATEnabledEnvKey is set to "true". + CreatePAT bool + // PATExpireInDays defaults to 1 day when CreatePAT is requested and setup PAT + // creation is enabled. + PATExpireInDays *int +} + +// SetupResult contains resources created during initial instance setup. +type SetupResult struct { + User *idp.UserData + PATPlainToken string +} + +// SetupService orchestrates the initial setup use case across the instance and +// account bounded contexts and owns the compensation logic when a later step +// fails. +type SetupService struct { + instanceManager Manager + accountManager account.Manager + setupPATEnabled bool +} + +// NewSetupService creates a setup use-case service. +func NewSetupService(instanceManager Manager, accountManager account.Manager) *SetupService { + return &SetupService{ + instanceManager: instanceManager, + accountManager: accountManager, + setupPATEnabled: os.Getenv(SetupPATEnabledEnvKey) == "true", + } +} + +func normalizeSetupOptions(opts SetupOptions, setupPATEnabled bool) (SetupOptions, error) { + if !opts.CreatePAT { + return opts, nil + } + + if !setupPATEnabled { + opts.CreatePAT = false + opts.PATExpireInDays = nil + return opts, nil + } + + if opts.PATExpireInDays == nil { + defaultExpireInDays := setupPATDefaultExpireDays + opts.PATExpireInDays = &defaultExpireInDays + } + + if *opts.PATExpireInDays < account.PATMinExpireDays || *opts.PATExpireInDays > account.PATMaxExpireDays { + return opts, status.Errorf(status.InvalidArgument, "pat_expire_in must be between %d and %d", account.PATMinExpireDays, account.PATMaxExpireDays) + } + + return opts, nil +} + +// SetupOwner creates the initial owner user and, when requested and enabled by +// SetupPATEnabledEnvKey, provisions the account and a setup Personal Access +// Token. If account or PAT provisioning fails, created resources are rolled +// back so setup can be retried. If account rollback fails, user rollback is +// skipped to avoid leaving an account without its owner user. +func (m *SetupService) SetupOwner(ctx context.Context, email, password, name string, opts SetupOptions) (*SetupResult, error) { + opts, err := normalizeSetupOptions(opts, m.setupPATEnabled) + if err != nil { + return nil, err + } + + if opts.CreatePAT && m.accountManager == nil { + return nil, fmt.Errorf("account manager is required to create setup PAT") + } + + userData, err := m.instanceManager.CreateOwnerUser(ctx, email, password, name) + if err != nil { + return nil, err + } + + result := &SetupResult{User: userData} + if !opts.CreatePAT { + return result, nil + } + + userAuth := auth.UserAuth{ + UserId: userData.ID, + Email: userData.Email, + Name: userData.Name, + } + + accountID, err := m.accountManager.GetAccountIDByUserID(ctx, userAuth) + if err != nil { + err = fmt.Errorf("create account for setup user: %w", err) + if rollbackErr := m.rollbackSetup(ctx, userData.ID, "account provisioning failed", err, ""); rollbackErr != nil { + return nil, fmt.Errorf("%w; failed to roll back setup resources: %v", err, rollbackErr) + } + return nil, err + } + + pat, err := m.accountManager.CreatePAT(ctx, accountID, userData.ID, userData.ID, setupPATTokenName, *opts.PATExpireInDays) + if err != nil { + err = fmt.Errorf("create setup PAT: %w", err) + if rollbackErr := m.rollbackSetup(ctx, userData.ID, "setup PAT provisioning failed", err, accountID); rollbackErr != nil { + return nil, fmt.Errorf("%w; failed to roll back setup resources: %v", err, rollbackErr) + } + return nil, err + } + + result.PATPlainToken = pat.PlainToken + return result, nil +} + +func (m *SetupService) rollbackSetup(ctx context.Context, userID, reason string, origErr error, accountID string) error { + if accountID == "" { + resolvedAccountID, err := m.lookupSetupAccountIDForRollback(ctx, userID) + if err != nil { + rollbackErr := fmt.Errorf("resolve setup account for rollback: %w", err) + log.WithContext(ctx).Errorf("failed to resolve setup account for user %s after %s: original error: %v, rollback error: %v", userID, reason, origErr, rollbackErr) + return rollbackErr + } + accountID = resolvedAccountID + } + + if accountID != "" { + if err := m.rollbackSetupAccount(ctx, accountID); err != nil { + rollbackErr := fmt.Errorf("roll back setup account %s: %w", accountID, err) + log.WithContext(ctx).Errorf("failed to roll back setup account %s for user %s after %s: original error: %v, rollback error: %v", accountID, userID, reason, origErr, rollbackErr) + return rollbackErr + } + log.WithContext(ctx).Warnf("rolled back setup account %s for user %s after %s: %v", accountID, userID, reason, origErr) + } + + if err := m.instanceManager.RollbackSetup(ctx, userID); err != nil { + rollbackErr := fmt.Errorf("roll back setup user %s: %w", userID, err) + log.WithContext(ctx).Errorf("failed to roll back setup user %s after %s: original error: %v, rollback error: %v", userID, reason, origErr, rollbackErr) + return rollbackErr + } + log.WithContext(ctx).Warnf("rolled back setup user %s after %s: %v", userID, reason, origErr) + return nil +} + +func (m *SetupService) lookupSetupAccountIDForRollback(ctx context.Context, userID string) (string, error) { + if m.accountManager == nil { + return "", fmt.Errorf("account manager is required to resolve setup account") + } + + accountStore := m.accountManager.GetStore() + if accountStore == nil { + return "", fmt.Errorf("account store is unavailable") + } + + accountID, err := accountStore.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID) + if err != nil { + if isNotFoundError(err) { + return "", nil + } + return "", fmt.Errorf("get setup account ID for rollback: %w", err) + } + + return accountID, nil +} + +// rollbackSetupAccount removes only the setup-created account data from the +// store. It intentionally avoids accountManager.DeleteAccount because the normal +// account deletion path also deletes users from the IdP; embedded IdP cleanup is +// owned by instanceManager.RollbackSetup. +func (m *SetupService) rollbackSetupAccount(ctx context.Context, accountID string) error { + if m.accountManager == nil { + return fmt.Errorf("account manager is required to roll back setup account") + } + + accountStore := m.accountManager.GetStore() + if accountStore == nil { + return fmt.Errorf("account store is unavailable") + } + + account, err := accountStore.GetAccount(ctx, accountID) + if err != nil { + if isNotFoundError(err) { + return nil + } + return fmt.Errorf("get setup account for rollback: %w", err) + } + + if err := accountStore.DeleteAccount(ctx, account); err != nil { + if isNotFoundError(err) { + return nil + } + return fmt.Errorf("delete setup account for rollback: %w", err) + } + + return nil +} diff --git a/management/server/instance/setup_service_test.go b/management/server/instance/setup_service_test.go new file mode 100644 index 000000000..12ec7d0fa --- /dev/null +++ b/management/server/instance/setup_service_test.go @@ -0,0 +1,318 @@ +package instance + +import ( + "context" + "errors" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/mock_server" + nbstore "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/status" +) + +type setupInstanceManagerMock struct { + createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error) + rollbackSetupFn func(ctx context.Context, userID string) error +} + +func (m *setupInstanceManagerMock) IsSetupRequired(context.Context) (bool, error) { + return true, nil +} + +func (m *setupInstanceManagerMock) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) { + if m.createOwnerUserFn != nil { + return m.createOwnerUserFn(ctx, email, password, name) + } + return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil +} + +func (m *setupInstanceManagerMock) RollbackSetup(ctx context.Context, userID string) error { + if m.rollbackSetupFn != nil { + return m.rollbackSetupFn(ctx, userID) + } + return nil +} + +func (m *setupInstanceManagerMock) GetVersionInfo(context.Context) (*VersionInfo, error) { + return &VersionInfo{}, nil +} + +var _ Manager = (*setupInstanceManagerMock)(nil) + +func intPtr(v int) *int { + return &v +} + +func TestSetupOwner_PATFeatureDisabled_IgnoresCreatePAT(t *testing.T) { + t.Setenv(SetupPATEnabledEnvKey, "false") + + createCalls := 0 + setupManager := NewSetupService( + &setupInstanceManagerMock{ + createOwnerUserFn: func(_ context.Context, email, _, name string) (*idp.UserData, error) { + createCalls++ + return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil + }, + }, + &mock_server.MockAccountManager{}, + ) + + result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{ + CreatePAT: true, + }) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, "owner-id", result.User.ID) + assert.Empty(t, result.PATPlainToken) + assert.Equal(t, 1, createCalls) +} + +func TestSetupOwner_PATFeatureEnabled_MissingExpireDefaultsToOneDay(t *testing.T) { + t.Setenv(SetupPATEnabledEnvKey, "true") + + createCalled := false + setupManager := NewSetupService( + &setupInstanceManagerMock{ + createOwnerUserFn: func(_ context.Context, email, _, name string) (*idp.UserData, error) { + createCalled = true + return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil + }, + }, + &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) { + assert.Equal(t, "owner-id", userAuth.UserId) + return "acc-1", nil + }, + CreatePATFunc: func(_ context.Context, accountID, initiatorUserID, targetUserID, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { + assert.Equal(t, "acc-1", accountID) + assert.Equal(t, "owner-id", initiatorUserID) + assert.Equal(t, "owner-id", targetUserID) + assert.Equal(t, setupPATTokenName, tokenName) + assert.Equal(t, setupPATDefaultExpireDays, expiresIn) + return &types.PersonalAccessTokenGenerated{PlainToken: "nbp_plain"}, nil + }, + }, + ) + + result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{ + CreatePAT: true, + }) + + require.NoError(t, err) + require.NotNil(t, result) + assert.True(t, createCalled) + assert.Equal(t, "nbp_plain", result.PATPlainToken) +} + +func TestSetupOwner_PATFeatureEnabled_MissingAccountManagerFailsBeforeCreateUser(t *testing.T) { + t.Setenv(SetupPATEnabledEnvKey, "true") + + createCalled := false + rollbackCalled := false + setupManager := NewSetupService( + &setupInstanceManagerMock{ + createOwnerUserFn: func(_ context.Context, email, _, name string) (*idp.UserData, error) { + createCalled = true + return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil + }, + rollbackSetupFn: func(_ context.Context, _ string) error { + rollbackCalled = true + return nil + }, + }, + nil, + ) + + result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{ + CreatePAT: true, + }) + + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "account manager is required") + assert.False(t, createCalled) + assert.False(t, rollbackCalled) +} + +func TestSetupOwner_AccountProvisioningFails_RollsBackSideEffectAccountAndUser(t *testing.T) { + t.Setenv(SetupPATEnabledEnvKey, "true") + + ctrl := gomock.NewController(t) + accountStore := nbstore.NewMockStore(ctrl) + account := &types.Account{Id: "acc-1"} + accountStore.EXPECT().GetAccountIDByUserID(gomock.Any(), nbstore.LockingStrengthNone, "owner-id").Return("acc-1", nil) + accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil) + accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(nil) + + rolledBackFor := "" + rollbackCalls := 0 + setupManager := NewSetupService( + &setupInstanceManagerMock{ + rollbackSetupFn: func(_ context.Context, userID string) error { + rollbackCalls++ + rolledBackFor = userID + return nil + }, + }, + &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) { + assert.Equal(t, "owner-id", userAuth.UserId) + return "", errors.New("metadata update failed") + }, + GetStoreFunc: func() nbstore.Store { + return accountStore + }, + }, + ) + + result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{ + CreatePAT: true, + PATExpireInDays: intPtr(30), + }) + + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "create account for setup user") + assert.Equal(t, "owner-id", rolledBackFor) + assert.Equal(t, 1, rollbackCalls) +} + +func TestSetupOwner_CreatePATFails_RollsBackSetupAccountAndUser(t *testing.T) { + t.Setenv(SetupPATEnabledEnvKey, "true") + + ctrl := gomock.NewController(t) + accountStore := nbstore.NewMockStore(ctrl) + account := &types.Account{Id: "acc-1"} + accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil) + accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(nil) + + rollbackCalls := 0 + setupManager := NewSetupService( + &setupInstanceManagerMock{ + rollbackSetupFn: func(_ context.Context, userID string) error { + rollbackCalls++ + assert.Equal(t, "owner-id", userID) + return nil + }, + }, + &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) { + assert.Equal(t, "owner-id", userAuth.UserId) + return "acc-1", nil + }, + CreatePATFunc: func(_ context.Context, accountID, initiatorUserID, targetUserID, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { + assert.Equal(t, "acc-1", accountID) + assert.Equal(t, "owner-id", initiatorUserID) + assert.Equal(t, "owner-id", targetUserID) + assert.Equal(t, setupPATTokenName, tokenName) + assert.Equal(t, 30, expiresIn) + return nil, status.Errorf(status.Internal, "token store unavailable") + }, + GetStoreFunc: func() nbstore.Store { + return accountStore + }, + }, + ) + + result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{ + CreatePAT: true, + PATExpireInDays: intPtr(30), + }) + + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "create setup PAT") + assert.Equal(t, 1, rollbackCalls) +} + +func TestSetupOwner_CreatePATFails_AccountAlreadyGoneStillRollsBackUser(t *testing.T) { + t.Setenv(SetupPATEnabledEnvKey, "true") + + ctrl := gomock.NewController(t) + accountStore := nbstore.NewMockStore(ctrl) + accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(nil, status.NewAccountNotFoundError("acc-1")) + + rolledBackFor := "" + rollbackCalls := 0 + setupManager := NewSetupService( + &setupInstanceManagerMock{ + rollbackSetupFn: func(_ context.Context, userID string) error { + rollbackCalls++ + rolledBackFor = userID + return nil + }, + }, + &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) { + return "acc-1", nil + }, + CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) { + return nil, errors.New("token failure") + }, + GetStoreFunc: func() nbstore.Store { + return accountStore + }, + }, + ) + + result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{ + CreatePAT: true, + PATExpireInDays: intPtr(30), + }) + + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "create setup PAT") + assert.Equal(t, "owner-id", rolledBackFor) + assert.Equal(t, 1, rollbackCalls) +} + +func TestSetupOwner_CreatePATFails_AccountRollbackFailureStopsBeforeUserRollback(t *testing.T) { + t.Setenv(SetupPATEnabledEnvKey, "true") + + ctrl := gomock.NewController(t) + accountStore := nbstore.NewMockStore(ctrl) + account := &types.Account{Id: "acc-1"} + accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil) + accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(errors.New("delete failed")) + + rollbackCalls := 0 + setupManager := NewSetupService( + &setupInstanceManagerMock{ + rollbackSetupFn: func(_ context.Context, userID string) error { + rollbackCalls++ + return nil + }, + }, + &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) { + return "acc-1", nil + }, + CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) { + return nil, errors.New("token failure") + }, + GetStoreFunc: func() nbstore.Store { + return accountStore + }, + }, + ) + + result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{ + CreatePAT: true, + PATExpireInDays: intPtr(30), + }) + + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "create setup PAT") + assert.Contains(t, err.Error(), "failed to roll back setup resources") + assert.Equal(t, 0, rollbackCalls) +} diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 090c99877..1b77ea335 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -29,6 +29,7 @@ import ( "github.com/netbirdio/netbird/management/internals/server/config" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" @@ -266,8 +267,8 @@ func Test_SyncProtocol(t *testing.T) { } // expired peers come separately. - if len(networkMap.GetOfflinePeers()) != 1 { - t.Fatal("expecting SyncResponse to have NetworkMap with 1 offline peer") + if len(networkMap.GetOfflinePeers()) != 2 { + t.Fatal("expecting SyncResponse to have NetworkMap with 2 offline peer") } expiredPeerPubKey := "RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=" @@ -369,9 +370,15 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config requestBuffer := NewAccountRequestBuffer(ctx, store) ephemeralMgr := manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)) + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + cleanup() + return nil, nil, "", cleanup, err + } + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeralMgr, config) accountManager, err := BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", - eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) if err != nil { cleanup() @@ -384,7 +391,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config return nil, nil, "", cleanup, err } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil, nil) if err != nil { return nil, nil, "", cleanup, err } diff --git a/management/server/management_test.go b/management/server/management_test.go index de02855bf..f1d49193c 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -28,6 +28,7 @@ import ( nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" @@ -207,6 +208,12 @@ func startServer( jobManager := job.NewJobManager(nil, str, peersManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + t.Fatalf("failed creating cache store: %v", err) + } + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := server.NewAccountRequestBuffer(ctx, str) networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(str, peers.NewManager(str, permissionsManager)), config) @@ -227,7 +234,8 @@ func startServer( port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, - false) + false, + cacheStore) if err != nil { t.Fatalf("failed creating an account manager: %v", err) } @@ -248,6 +256,7 @@ func startServer( server.MockIntegratedValidator{}, networkMapController, nil, + nil, ) if err != nil { t.Fatalf("failed creating management server: %v", err) diff --git a/management/server/migration/migration.go b/management/server/migration/migration.go index 29555ed0c..7a51cc200 100644 --- a/management/server/migration/migration.go +++ b/management/server/migration/migration.go @@ -489,6 +489,102 @@ func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName stri return nil } +// hasForeignKey checks whether a foreign key constraint exists on the given table and column. +func hasForeignKey(db *gorm.DB, table, column string) bool { + var count int64 + + switch db.Name() { + case "postgres": + db.Raw(` + SELECT COUNT(*) FROM information_schema.key_column_usage kcu + JOIN information_schema.table_constraints tc + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' + AND kcu.table_name = ? + AND kcu.column_name = ? + `, table, column).Scan(&count) + case "mysql": + db.Raw(` + SELECT COUNT(*) FROM information_schema.key_column_usage + WHERE table_schema = DATABASE() + AND table_name = ? + AND column_name = ? + AND referenced_table_name IS NOT NULL + `, table, column).Scan(&count) + default: // sqlite + type fkInfo struct { + From string + } + var fks []fkInfo + db.Raw(fmt.Sprintf("PRAGMA foreign_key_list(%s)", table)).Scan(&fks) + for _, fk := range fks { + if fk.From == column { + return true + } + } + return false + } + + return count > 0 +} + +// CleanupOrphanedResources deletes rows from the table of model T where the foreign +// key column (fkColumn) references a row in the table of model R that no longer exists. +func CleanupOrphanedResources[T any, R any](ctx context.Context, db *gorm.DB, fkColumn string) error { + var model T + var refModel R + + if !db.Migrator().HasTable(&model) { + log.WithContext(ctx).Debugf("table for %T does not exist, no cleanup needed", model) + return nil + } + + if !db.Migrator().HasTable(&refModel) { + log.WithContext(ctx).Debugf("referenced table for %T does not exist, no cleanup needed", refModel) + return nil + } + + stmtT := &gorm.Statement{DB: db} + if err := stmtT.Parse(&model); err != nil { + return fmt.Errorf("parse model %T: %w", model, err) + } + childTable := stmtT.Schema.Table + + stmtR := &gorm.Statement{DB: db} + if err := stmtR.Parse(&refModel); err != nil { + return fmt.Errorf("parse reference model %T: %w", refModel, err) + } + parentTable := stmtR.Schema.Table + + if !db.Migrator().HasColumn(&model, fkColumn) { + log.WithContext(ctx).Debugf("column %s does not exist in table %s, no cleanup needed", fkColumn, childTable) + return nil + } + + // If a foreign key constraint already exists on the column, the DB itself + // enforces referential integrity and orphaned rows cannot exist. + if hasForeignKey(db, childTable, fkColumn) { + log.WithContext(ctx).Debugf("foreign key constraint for %s already exists on %s, no cleanup needed", fkColumn, childTable) + return nil + } + + result := db.Exec( + fmt.Sprintf( + "DELETE FROM %s WHERE %s NOT IN (SELECT id FROM %s)", + childTable, fkColumn, parentTable, + ), + ) + if result.Error != nil { + return fmt.Errorf("cleanup orphaned rows in %s: %w", childTable, result.Error) + } + + log.WithContext(ctx).Infof("Cleaned up %d orphaned rows from %s where %s had no matching row in %s", + result.RowsAffected, childTable, fkColumn, parentTable) + + return nil +} + func RemoveDuplicatePeerKeys(ctx context.Context, db *gorm.DB) error { if !db.Migrator().HasTable("peers") { log.WithContext(ctx).Debug("peers table does not exist, skipping duplicate key cleanup") diff --git a/management/server/migration/migration_test.go b/management/server/migration/migration_test.go index c1be8a3a3..5e00976c2 100644 --- a/management/server/migration/migration_test.go +++ b/management/server/migration/migration_test.go @@ -441,3 +441,197 @@ func TestRemoveDuplicatePeerKeys_NoTable(t *testing.T) { err := migration.RemoveDuplicatePeerKeys(context.Background(), db) require.NoError(t, err, "Should not fail when table does not exist") } + +type testParent struct { + ID string `gorm:"primaryKey"` +} + +func (testParent) TableName() string { + return "test_parents" +} + +type testChild struct { + ID string `gorm:"primaryKey"` + ParentID string +} + +func (testChild) TableName() string { + return "test_children" +} + +type testChildWithFK struct { + ID string `gorm:"primaryKey"` + ParentID string `gorm:"index"` + Parent *testParent `gorm:"foreignKey:ParentID"` +} + +func (testChildWithFK) TableName() string { + return "test_children" +} + +func setupOrphanTestDB(t *testing.T, models ...any) *gorm.DB { + t.Helper() + db := setupDatabase(t) + for _, m := range models { + _ = db.Migrator().DropTable(m) + } + err := db.AutoMigrate(models...) + require.NoError(t, err, "Failed to auto-migrate tables") + return db +} + +func TestCleanupOrphanedResources_NoChildTable(t *testing.T) { + db := setupDatabase(t) + _ = db.Migrator().DropTable(&testChild{}) + _ = db.Migrator().DropTable(&testParent{}) + + err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id") + require.NoError(t, err, "Should not fail when child table does not exist") +} + +func TestCleanupOrphanedResources_NoParentTable(t *testing.T) { + db := setupDatabase(t) + _ = db.Migrator().DropTable(&testParent{}) + _ = db.Migrator().DropTable(&testChild{}) + + err := db.AutoMigrate(&testChild{}) + require.NoError(t, err) + + err = migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id") + require.NoError(t, err, "Should not fail when parent table does not exist") +} + +func TestCleanupOrphanedResources_EmptyTables(t *testing.T) { + db := setupOrphanTestDB(t, &testParent{}, &testChild{}) + + err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id") + require.NoError(t, err, "Should not fail on empty tables") + + var count int64 + db.Model(&testChild{}).Count(&count) + assert.Equal(t, int64(0), count) +} + +func TestCleanupOrphanedResources_NoOrphans(t *testing.T) { + db := setupOrphanTestDB(t, &testParent{}, &testChild{}) + + require.NoError(t, db.Create(&testParent{ID: "p1"}).Error) + require.NoError(t, db.Create(&testParent{ID: "p2"}).Error) + require.NoError(t, db.Create(&testChild{ID: "c1", ParentID: "p1"}).Error) + require.NoError(t, db.Create(&testChild{ID: "c2", ParentID: "p2"}).Error) + + err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id") + require.NoError(t, err) + + var count int64 + db.Model(&testChild{}).Count(&count) + assert.Equal(t, int64(2), count, "All children should remain when no orphans") +} + +func TestCleanupOrphanedResources_AllOrphans(t *testing.T) { + db := setupOrphanTestDB(t, &testParent{}, &testChild{}) + + require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c1", "gone1").Error) + require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c2", "gone2").Error) + require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c3", "gone3").Error) + + err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id") + require.NoError(t, err) + + var count int64 + db.Model(&testChild{}).Count(&count) + assert.Equal(t, int64(0), count, "All orphaned children should be deleted") +} + +func TestCleanupOrphanedResources_MixedValidAndOrphaned(t *testing.T) { + db := setupOrphanTestDB(t, &testParent{}, &testChild{}) + + require.NoError(t, db.Create(&testParent{ID: "p1"}).Error) + require.NoError(t, db.Create(&testParent{ID: "p2"}).Error) + + require.NoError(t, db.Create(&testChild{ID: "c1", ParentID: "p1"}).Error) + require.NoError(t, db.Create(&testChild{ID: "c2", ParentID: "p2"}).Error) + require.NoError(t, db.Create(&testChild{ID: "c3", ParentID: "p1"}).Error) + + require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c4", "gone1").Error) + require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c5", "gone2").Error) + + err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id") + require.NoError(t, err) + + var remaining []testChild + require.NoError(t, db.Order("id").Find(&remaining).Error) + + assert.Len(t, remaining, 3, "Only valid children should remain") + assert.Equal(t, "c1", remaining[0].ID) + assert.Equal(t, "c2", remaining[1].ID) + assert.Equal(t, "c3", remaining[2].ID) +} + +func TestCleanupOrphanedResources_Idempotent(t *testing.T) { + db := setupOrphanTestDB(t, &testParent{}, &testChild{}) + + require.NoError(t, db.Create(&testParent{ID: "p1"}).Error) + require.NoError(t, db.Create(&testChild{ID: "c1", ParentID: "p1"}).Error) + require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c2", "gone").Error) + + ctx := context.Background() + + err := migration.CleanupOrphanedResources[testChild, testParent](ctx, db, "parent_id") + require.NoError(t, err) + + var count int64 + db.Model(&testChild{}).Count(&count) + assert.Equal(t, int64(1), count) + + err = migration.CleanupOrphanedResources[testChild, testParent](ctx, db, "parent_id") + require.NoError(t, err) + + db.Model(&testChild{}).Count(&count) + assert.Equal(t, int64(1), count, "Count should remain the same after second run") +} + +func TestCleanupOrphanedResources_SkipsWhenForeignKeyExists(t *testing.T) { + engine := os.Getenv("NETBIRD_STORE_ENGINE") + if engine != "postgres" && engine != "mysql" { + t.Skip("FK constraint early-exit test requires postgres or mysql") + } + + db := setupDatabase(t) + _ = db.Migrator().DropTable(&testChildWithFK{}) + _ = db.Migrator().DropTable(&testParent{}) + + err := db.AutoMigrate(&testParent{}, &testChildWithFK{}) + require.NoError(t, err) + + require.NoError(t, db.Create(&testParent{ID: "p1"}).Error) + require.NoError(t, db.Create(&testParent{ID: "p2"}).Error) + require.NoError(t, db.Create(&testChildWithFK{ID: "c1", ParentID: "p1"}).Error) + require.NoError(t, db.Create(&testChildWithFK{ID: "c2", ParentID: "p2"}).Error) + + switch engine { + case "postgres": + require.NoError(t, db.Exec("ALTER TABLE test_children DROP CONSTRAINT fk_test_children_parent").Error) + require.NoError(t, db.Exec("DELETE FROM test_parents WHERE id = ?", "p2").Error) + require.NoError(t, db.Exec( + "ALTER TABLE test_children ADD CONSTRAINT fk_test_children_parent "+ + "FOREIGN KEY (parent_id) REFERENCES test_parents(id) NOT VALID", + ).Error) + case "mysql": + require.NoError(t, db.Exec("SET FOREIGN_KEY_CHECKS = 0").Error) + require.NoError(t, db.Exec("ALTER TABLE test_children DROP FOREIGN KEY fk_test_children_parent").Error) + require.NoError(t, db.Exec("DELETE FROM test_parents WHERE id = ?", "p2").Error) + require.NoError(t, db.Exec( + "ALTER TABLE test_children ADD CONSTRAINT fk_test_children_parent "+ + "FOREIGN KEY (parent_id) REFERENCES test_parents(id)", + ).Error) + require.NoError(t, db.Exec("SET FOREIGN_KEY_CHECKS = 1").Error) + } + + err = migration.CleanupOrphanedResources[testChildWithFK, testParent](context.Background(), db, "parent_id") + require.NoError(t, err) + + var count int64 + db.Model(&testChildWithFK{}).Count(&count) + assert.Equal(t, int64(2), count, "Both rows should survive — migration must skip when FK constraint exists") +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index afd2021ac..ac4d0c6d6 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -46,7 +46,7 @@ type MockAccountManager struct { AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error) - GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error) + GetGroupByNameFunc func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error @@ -128,8 +128,8 @@ type MockAccountManager struct { GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) AllowSyncFunc func(string, uint64) bool - UpdateAccountPeersFunc func(ctx context.Context, accountID string) - BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) + UpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason) + BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason) RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error GetIdentityProviderFunc func(ctx context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error) @@ -200,15 +200,15 @@ func (am *MockAccountManager) UpdateGroups(ctx context.Context, accountID, userI return status.Errorf(codes.Unimplemented, "method UpdateGroups is not implemented") } -func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { +func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) { if am.UpdateAccountPeersFunc != nil { - am.UpdateAccountPeersFunc(ctx, accountID) + am.UpdateAccountPeersFunc(ctx, accountID, reason) } } -func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { +func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) { if am.BufferUpdateAccountPeersFunc != nil { - am.BufferUpdateAccountPeersFunc(ctx, accountID) + am.BufferUpdateAccountPeersFunc(ctx, accountID, reason) } } @@ -406,9 +406,9 @@ func (am *MockAccountManager) AddPeer( } // GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface -func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*types.Group, error) { +func (am *MockAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { if am.GetGroupByNameFunc != nil { - return am.GetGroupByNameFunc(ctx, accountID, groupName) + return am.GetGroupByNameFunc(ctx, groupName, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetGroupByName is not implemented") } diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 3d8c78912..5859bfb0d 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -82,7 +82,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationCreate}) } return newNSGroup.Copy(), nil @@ -133,7 +133,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationUpdate}) } return nil @@ -176,7 +176,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationDelete}) } return nil diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 90b4b9687..b2c8300d6 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -17,6 +17,7 @@ import ( ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -794,11 +795,17 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { peersManager := peers.NewManager(store, permissionsManager) ctx := context.Background() + + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, err + } + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) - return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) } func createNSStore(t *testing.T) (store.Store, error) { @@ -1080,7 +1087,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1098,7 +1105,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go index b6706ca45..c96b60bb2 100644 --- a/management/server/networks/manager.go +++ b/management/server/networks/manager.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" + serverTypes "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -177,7 +178,7 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw event() } - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetwork, Operation: serverTypes.UpdateOperationDelete}) return nil } diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 86f9b6579..5a0e26533 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -162,7 +162,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc event() } - go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) + go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationCreate}) return resource, nil } @@ -270,7 +270,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc } }() - go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) + go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationUpdate}) return resource, nil } @@ -352,7 +352,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net event() } - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationDelete}) return nil } diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go index 82cac424a..c7c3f2ff4 100644 --- a/management/server/networks/routers/manager.go +++ b/management/server/networks/routers/manager.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" + serverTypes "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -119,7 +120,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network)) - go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) + go m.accountManager.UpdateAccountPeers(ctx, router.AccountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationCreate}) return router, nil } @@ -183,7 +184,7 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network)) - go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) + go m.accountManager.UpdateAccountPeers(ctx, router.AccountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationUpdate}) return router, nil } @@ -217,7 +218,7 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo event() - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationDelete}) return nil } diff --git a/management/server/networks/routers/types/router.go b/management/server/networks/routers/types/router.go index e90c61a97..1293a9934 100644 --- a/management/server/networks/routers/types/router.go +++ b/management/server/networks/routers/types/router.go @@ -21,11 +21,7 @@ type NetworkRouter struct { } func NewNetworkRouter(accountID string, networkID string, peer string, peerGroups []string, masquerade bool, metric int, enabled bool) (*NetworkRouter, error) { - if peer != "" && len(peerGroups) > 0 { - return nil, errors.New("peer and peerGroups cannot be set at the same time") - } - - return &NetworkRouter{ + r := &NetworkRouter{ ID: xid.New().String(), AccountID: accountID, NetworkID: networkID, @@ -34,7 +30,25 @@ func NewNetworkRouter(accountID string, networkID string, peer string, peerGroup Masquerade: masquerade, Metric: metric, Enabled: enabled, - }, nil + } + + if err := r.Validate(); err != nil { + return nil, err + } + + return r, nil +} + +func (n *NetworkRouter) Validate() error { + if n.Peer != "" && len(n.PeerGroups) > 0 { + return errors.New("peer and peer_groups cannot be set at the same time") + } + + if n.Peer == "" && len(n.PeerGroups) == 0 { + return errors.New("either peer or peer_groups must be provided") + } + + return nil } func (n *NetworkRouter) ToAPIResponse() *api.NetworkRouter { diff --git a/management/server/networks/routers/types/router_test.go b/management/server/networks/routers/types/router_test.go index 5801e3bfa..a2f2fe6e3 100644 --- a/management/server/networks/routers/types/router_test.go +++ b/management/server/networks/routers/types/router_test.go @@ -38,7 +38,7 @@ func TestNewNetworkRouter(t *testing.T) { expectedError: false, }, { - name: "Valid with no peer or peerGroups", + name: "Invalid with no peer or peerGroups", networkID: "network-3", accountID: "account-3", peer: "", @@ -46,7 +46,18 @@ func TestNewNetworkRouter(t *testing.T) { masquerade: true, metric: 300, enabled: true, - expectedError: false, + expectedError: true, + }, + { + name: "Invalid with empty peerGroups slice", + networkID: "network-5", + accountID: "account-5", + peer: "", + peerGroups: []string{}, + masquerade: true, + metric: 500, + enabled: true, + expectedError: true, }, // Invalid cases diff --git a/management/server/peer.go b/management/server/peer.go index f7cb6a0f1..25c6ecd8c 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -33,8 +33,8 @@ import ( const remoteJobsMinVer = "0.64.0" -// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if -// the current user is not an admin. +// GetPeers returns peers visible to the user within an account. +// Users with "peers:read" see all peers. Otherwise, users see only their own peers, or none if restricted by account settings. func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) { user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) if err != nil { @@ -46,14 +46,8 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return nil, status.NewPermissionValidationError(err) } - accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, nameFilter, ipFilter) - if err != nil { - return nil, err - } - - // @note if the user has permission to read peers it shows all account peers if allowed { - return accountPeers, nil + return am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, nameFilter, ipFilter) } settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) @@ -65,41 +59,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return []*nbpeer.Peer{}, nil } - // @note if it does not have permission read peers then only display it's own peers - peers := make([]*nbpeer.Peer, 0) - peersMap := make(map[string]*nbpeer.Peer) - - for _, peer := range accountPeers { - if user.Id != peer.UserID { - continue - } - peers = append(peers, peer) - peersMap[peer.ID] = peer - } - - return am.getUserAccessiblePeers(ctx, accountID, peersMap, peers) -} - -func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, accountID string, peersMap map[string]*nbpeer.Peer, peers []*nbpeer.Peer) ([]*nbpeer.Peer, error) { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, err - } - - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - return nil, err - } - - // fetch all the peers that have access to the user's peers - for _, peer := range peers { - aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap, account.GetActiveGroupUsers()) - for _, p := range aclPeers { - peersMap[p.ID] = p - } - } - - return maps.Values(peersMap), nil + return am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID) } // MarkPeerConnected marks peer as connected (true) or disconnected (false) @@ -858,8 +818,13 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe if !addedByUser { opEvent.Meta["setup_key_name"] = peerAddConfig.SetupKeyName } + if newPeer.Status != nil && newPeer.Status.RequiresApproval { + opEvent.Meta["pending_approval"] = true + } - am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) + if !temporary { + am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) + } if err := am.networkMapController.OnPeersAdded(ctx, accountID, []string{newPeer.ID}); err != nil { log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err) @@ -1228,7 +1193,8 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *types.Se return false } -// GetPeer for a given accountID, peerID and userID error if not found. +// GetPeer returns a peer visible to the user within an account. +// Users with "peers:read" permission can access any peer. Otherwise, users can access only their own peer. func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) if err != nil { @@ -1253,47 +1219,17 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, return peer, nil } - return am.checkIfUserOwnsPeer(ctx, accountID, userID, peer) -} - -func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, err - } - - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - return nil, err - } - - // it is also possible that user doesn't own the peer but some of his peers have access to it, - // this is a valid case, show the peer as well. - userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID) - if err != nil { - return nil, err - } - - for _, p := range userPeers { - aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap, account.GetActiveGroupUsers()) - for _, aclPeer := range aclPeers { - if aclPeer.ID == peer.ID { - return peer, nil - } - } - } - return nil, status.Errorf(status.Internal, "user %s has no access to peer %s under account %s", userID, peer.ID, accountID) } // UpdateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. -func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { - _ = am.networkMapController.UpdateAccountPeers(ctx, accountID) +func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) { + _ = am.networkMapController.UpdateAccountPeers(ctx, accountID, reason) } -func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { - _ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID) +func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) { + _ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID, reason) } // UpdateAccountPeer updates a single peer that belongs to an account. @@ -1403,6 +1339,10 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID var peers []*nbpeer.Peer for _, peer := range peersWithExpiry { + if peer.Status.LoginExpired { + continue + } + expired, _ := peer.LoginExpired(settings.PeerLoginExpiration) if expired { peers = append(peers, peer) @@ -1480,9 +1420,11 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil { return nil, err } - peerDeletedEvents = append(peerDeletedEvents, func() { - am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) - }) + if !(peer.ProxyMeta.Embedded || peer.Meta.KernelVersion == "wasm") { + peerDeletedEvents = append(peerDeletedEvents, func() { + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) + }) + } } return peerDeletedEvents, nil diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 51c16d730..36809d354 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -32,6 +32,7 @@ import ( ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/shared/grpc" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" @@ -178,11 +179,6 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { testGetNetworkMapGeneral(t) } -func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") - testGetNetworkMapGeneral(t) -} - func testGetNetworkMapGeneral(t *testing.T) { manager, _, err := createManager(t) if err != nil { @@ -563,25 +559,9 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { } assert.NotNil(t, peer) - // the user can see peer2 because peer1 of the user has access to peer2 due to the All group and the default rule 0 all-to-all access - peer, err = manager.GetPeer(context.Background(), accountID, peer2.ID, someUser) - if err != nil { - t.Fatal(err) - return - } - assert.NotNil(t, peer) - - // delete the all-to-all policy so that user's peer1 has no access to peer2 - for _, policy := range account.Policies { - err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser) - if err != nil { - t.Fatal(err) - return - } - } - - // at this point the user can't see the details of peer2 - peer, err = manager.GetPeer(context.Background(), accountID, peer2.ID, someUser) //nolint + // the user can NOT see peer2 because it is not owned by them. + // Regular users only see peers they directly own. + _, err = manager.GetPeer(context.Background(), accountID, peer2.ID, someUser) assert.Error(t, err) // admin users can always access all the peers @@ -995,7 +975,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { start := time.Now() for i := 0; i < b.N; i++ { - manager.UpdateAccountPeers(ctx, account.Id) + manager.UpdateAccountPeers(ctx, account.Id, types.UpdateReason{}) } duration := time.Since(start) @@ -1015,11 +995,6 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { } } -func TestUpdateAccountPeers_Experimental(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") - testUpdateAccountPeers(t) -} - func TestUpdateAccountPeers(t *testing.T) { testUpdateAccountPeers(t) } @@ -1058,7 +1033,7 @@ func testUpdateAccountPeers(t *testing.T) { peerChannels[peerID] = updateManager.CreateChannel(ctx, peerID) } - manager.UpdateAccountPeers(ctx, account.Id) + manager.UpdateAccountPeers(ctx, account.Id, types.UpdateReason{}) for _, channel := range peerChannels { update := <-channel @@ -1294,11 +1269,15 @@ func Test_RegisterPeerByUser(t *testing.T) { peersManager := peers.NewManager(s, permissionsManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + require.NoError(t, err) + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) - am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1380,11 +1359,15 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { peersManager := peers.NewManager(s, permissionsManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + require.NoError(t, err) + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) - am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1534,11 +1517,15 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { peersManager := peers.NewManager(s, permissionsManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + require.NoError(t, err) + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) - am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1587,7 +1574,6 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { } func Test_LoginPeer(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } @@ -1615,11 +1601,15 @@ func Test_LoginPeer(t *testing.T) { peersManager := peers.NewManager(s, permissionsManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + require.NoError(t, err) + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) - am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1890,7 +1880,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1912,7 +1902,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1977,7 +1967,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1995,7 +1985,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2041,7 +2031,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2059,7 +2049,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2096,7 +2086,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2114,7 +2104,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) diff --git a/management/server/policy.go b/management/server/policy.go index 3e84c3d10..40f3908e3 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -5,6 +5,7 @@ import ( _ "embed" "github.com/rs/xid" + "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -46,25 +47,40 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user var isUpdate = policy.ID != "" var updateAccountPeers bool var action = activity.PolicyAdded + var unchanged bool err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - if err = validatePolicy(ctx, transaction, accountID, policy); err != nil { - return err - } - - updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate) + existingPolicy, err := validatePolicy(ctx, transaction, accountID, policy) if err != nil { return err } - saveFunc := transaction.CreatePolicy if isUpdate { - action = activity.PolicyUpdated - saveFunc = transaction.SavePolicy - } + if policy.Equal(existingPolicy) { + logrus.WithContext(ctx).Tracef("policy update skipped because equal to stored one - policy id %s", policy.ID) + unchanged = true + return nil + } - if err = saveFunc(ctx, policy); err != nil { - return err + action = activity.PolicyUpdated + + updateAccountPeers, err = arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy) + if err != nil { + return err + } + + if err = transaction.SavePolicy(ctx, policy); err != nil { + return err + } + } else { + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy) + if err != nil { + return err + } + + if err = transaction.CreatePolicy(ctx, policy); err != nil { + return err + } } return transaction.IncrementNetworkSerial(ctx, accountID) @@ -73,10 +89,18 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return nil, err } + if unchanged { + return policy, nil + } + am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + policyOp := types.UpdateOperationCreate + if isUpdate { + policyOp = types.UpdateOperationUpdate + } + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePolicy, Operation: policyOp}) } return policy, nil @@ -101,7 +125,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po return err } - updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false) + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy) if err != nil { return err } @@ -119,7 +143,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePolicy, Operation: types.UpdateOperationDelete}) } return nil @@ -138,34 +162,37 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) } -// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. -func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) { - if isUpdate { - existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) - if err != nil { - return false, err - } - - if !policy.Enabled && !existingPolicy.Enabled { - return false, nil - } - - for _, rule := range existingPolicy.Rules { - if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { - return true, nil - } - } - - hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups()) - if err != nil { - return false, err - } - - if hasPeers { +// arePolicyChangesAffectPeers checks if a policy (being created or deleted) will affect any associated peers. +func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, policy *types.Policy) (bool, error) { + for _, rule := range policy.Rules { + if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { return true, nil } } + return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups()) +} + +func arePolicyChangesAffectPeersWithExisting(ctx context.Context, transaction store.Store, policy *types.Policy, existingPolicy *types.Policy) (bool, error) { + if !policy.Enabled && !existingPolicy.Enabled { + return false, nil + } + + for _, rule := range existingPolicy.Rules { + if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { + return true, nil + } + } + + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups()) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + for _, rule := range policy.Rules { if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { return true, nil @@ -175,12 +202,15 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups()) } -// validatePolicy validates the policy and its rules. -func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { +// validatePolicy validates the policy and its rules. For updates it returns +// the existing policy loaded from the store so callers can avoid a second read. +func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) (*types.Policy, error) { + var existingPolicy *types.Policy if policy.ID != "" { - existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) + var err error + existingPolicy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) if err != nil { - return err + return nil, err } // TODO: Refactor to support multiple rules per policy @@ -191,7 +221,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri for _, rule := range policy.Rules { if rule.ID != "" && !existingRuleIDs[rule.ID] { - return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID) + return nil, status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID) } } } else { @@ -201,12 +231,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups()) if err != nil { - return err + return nil, err } postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks) if err != nil { - return err + return nil, err } for i, rule := range policy.Rules { @@ -225,7 +255,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks) } - return nil + return existingPolicy, nil } // getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. diff --git a/management/server/policy_test.go b/management/server/policy_test.go index a3f987732..a553b7d05 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -1231,7 +1231,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1263,7 +1263,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1294,7 +1294,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1314,7 +1314,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1355,7 +1355,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1373,7 +1373,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } @@ -1393,7 +1393,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) diff --git a/management/server/posture/network.go b/management/server/posture/network.go index f78744143..4b4b3ccaa 100644 --- a/management/server/posture/network.go +++ b/management/server/posture/network.go @@ -17,19 +17,48 @@ type PeerNetworkRangeCheck struct { var _ Check = (*PeerNetworkRangeCheck)(nil) +// prefixContains reports whether outer fully contains inner (equal counts as contained). +// Requires the same address family, that outer is no more specific than inner (its +// netmask is shorter or equal), and that inner's network address falls inside outer. +// This is stricter than netip.Prefix.Contains(Addr) — a peer's /24 NIC will not match a +// configured /32 rule, since the rule covers a single host but the NIC describes a whole +// subnet whose host bits are unknown. +func prefixContains(outer, inner netip.Prefix) bool { + outer = outer.Masked() + inner = inner.Masked() + return outer.Bits() <= inner.Bits() && + outer.Addr().BitLen() == inner.Addr().BitLen() && // same family + outer.Contains(inner.Addr()) +} + +// Check evaluates configured ranges against the peer's local network interface prefixes +// and its public connection IP (as a /32 or /128). A configured range matches when it +// fully contains one of those prefixes, so operators can target both private subnets +// and public CIDRs (e.g. 1.0.0.0/24, 2.2.2.2/32). Including the connection IP is what +// lets a public-range posture check work — peer.Meta.NetworkAddresses only carries +// local NIC addresses. func (p *PeerNetworkRangeCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) { - if len(peer.Meta.NetworkAddresses) == 0 { + peerPrefixes := make([]netip.Prefix, 0, len(peer.Meta.NetworkAddresses)+1) + for _, peerNetAddr := range peer.Meta.NetworkAddresses { + peerPrefixes = append(peerPrefixes, peerNetAddr.NetIP) + } + // Unmap collapses 4-in-6 forms (::ffff:a.b.c.d) so an IPv4 range matches. + if connIP := peer.Location.ConnectionIP; len(connIP) > 0 { + if addr, ok := netip.AddrFromSlice(connIP); ok { + addr = addr.Unmap() + peerPrefixes = append(peerPrefixes, netip.PrefixFrom(addr, addr.BitLen())) + } + } + + if len(peerPrefixes) == 0 { return false, fmt.Errorf("peer's does not contain peer network range addresses") } - maskedPrefixes := make([]netip.Prefix, 0, len(p.Ranges)) - for _, prefix := range p.Ranges { - maskedPrefixes = append(maskedPrefixes, prefix.Masked()) - } - - for _, peerNetAddr := range peer.Meta.NetworkAddresses { - peerMaskedPrefix := peerNetAddr.NetIP.Masked() - if slices.Contains(maskedPrefixes, peerMaskedPrefix) { + for _, peerPrefix := range peerPrefixes { + for _, rangePrefix := range p.Ranges { + if !prefixContains(rangePrefix, peerPrefix) { + continue + } switch p.Action { case CheckActionDeny: return false, nil diff --git a/management/server/posture/network_test.go b/management/server/posture/network_test.go index a841bbe08..4af394c62 100644 --- a/management/server/posture/network_test.go +++ b/management/server/posture/network_test.go @@ -2,6 +2,7 @@ package posture import ( "context" + "net" "net/netip" "testing" @@ -134,6 +135,205 @@ func TestPeerNetworkRangeCheck_Check(t *testing.T) { wantErr: true, isValid: false, }, + { + name: "Peer connection IP matches the denied /32", + check: PeerNetworkRangeCheck{ + Action: CheckActionDeny, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("109.41.115.194/32"), + }, + }, + peer: nbpeer.Peer{ + Meta: nbpeer.PeerSystemMeta{ + NetworkAddresses: []nbpeer.NetworkAddress{ + {NetIP: netip.MustParsePrefix("192.168.0.123/24")}, + }, + }, + Location: nbpeer.Location{ConnectionIP: net.ParseIP("109.41.115.194")}, + }, + wantErr: false, + isValid: false, + }, + { + name: "Peer connection IP does not match the denied /32", + check: PeerNetworkRangeCheck{ + Action: CheckActionDeny, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("109.41.115.194/32"), + }, + }, + peer: nbpeer.Peer{ + Meta: nbpeer.PeerSystemMeta{ + NetworkAddresses: []nbpeer.NetworkAddress{ + {NetIP: netip.MustParsePrefix("192.168.0.123/24")}, + }, + }, + Location: nbpeer.Location{ConnectionIP: net.ParseIP("8.8.8.8")}, + }, + wantErr: false, + isValid: true, + }, + { + name: "Peer connection IP matches the allowed /32 with no NetworkAddresses", + check: PeerNetworkRangeCheck{ + Action: CheckActionAllow, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("109.41.115.194/32"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("109.41.115.194")}, + }, + wantErr: false, + isValid: true, + }, + { + name: "IPv6 connection IP matches the denied /128", + check: PeerNetworkRangeCheck{ + Action: CheckActionDeny, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("2001:db8::1/128"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("2001:db8::1")}, + }, + wantErr: false, + isValid: false, + }, + { + name: "IPv6 connection IP does not match the denied /128", + check: PeerNetworkRangeCheck{ + Action: CheckActionDeny, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("2001:db8::1/128"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("2001:db8::2")}, + }, + wantErr: false, + isValid: true, + }, + { + name: "IPv4-mapped IPv6 connection IP matches IPv4 /32", + check: PeerNetworkRangeCheck{ + Action: CheckActionDeny, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("109.41.115.194/32"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("::ffff:109.41.115.194")}, + }, + wantErr: false, + isValid: false, + }, + { + name: "Connection IP falls inside an allowed /24 range", + check: PeerNetworkRangeCheck{ + Action: CheckActionAllow, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("1.0.0.0/24"), + netip.MustParsePrefix("2.2.2.2/32"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("1.0.0.55")}, + }, + wantErr: false, + isValid: true, + }, + { + name: "Connection IP falls inside an allowed /23 range", + check: PeerNetworkRangeCheck{ + Action: CheckActionAllow, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("3.0.0.0/23"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("3.0.1.200")}, + }, + wantErr: false, + isValid: true, + }, + { + name: "Connection IP outside the allowed /24 range", + check: PeerNetworkRangeCheck{ + Action: CheckActionAllow, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("1.0.0.0/24"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("1.0.1.5")}, + }, + wantErr: false, + isValid: false, + }, + { + name: "Connection IP inside a denied /24 range", + check: PeerNetworkRangeCheck{ + Action: CheckActionDeny, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("1.0.0.0/24"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("1.0.0.7")}, + }, + wantErr: false, + isValid: false, + }, + { + name: "Local NIC /24 does not match a /32 rule even if host bit lines up", + check: PeerNetworkRangeCheck{ + Action: CheckActionAllow, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.5/32"), + }, + }, + peer: nbpeer.Peer{ + Meta: nbpeer.PeerSystemMeta{ + NetworkAddresses: []nbpeer.NetworkAddress{ + {NetIP: netip.MustParsePrefix("192.168.0.5/24")}, + }, + }, + }, + wantErr: false, + isValid: false, + }, + { + name: "Local NIC address inside an allowed /16 range", + check: PeerNetworkRangeCheck{ + Action: CheckActionAllow, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + peer: nbpeer.Peer{ + Meta: nbpeer.PeerSystemMeta{ + NetworkAddresses: []nbpeer.NetworkAddress{ + {NetIP: netip.MustParsePrefix("192.168.5.7/24")}, + }, + }, + }, + wantErr: false, + isValid: true, + }, + { + name: "Empty NetworkAddresses and empty ConnectionIP still errors", + check: PeerNetworkRangeCheck{ + Action: CheckActionDeny, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("109.41.115.194/32"), + }, + }, + peer: nbpeer.Peer{}, + wantErr: true, + isValid: false, + }, } for _, tt := range tests { diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index ba901c771..1e3ce4b8a 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -11,6 +11,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -76,7 +77,11 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + postureOp := types.UpdateOperationCreate + if isUpdate { + postureOp = types.UpdateOperationUpdate + } + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePostureCheck, Operation: postureOp}) } return postureChecks, nil @@ -84,7 +89,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI // DeletePostureChecks deletes a posture check by ID. func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) } diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 7f0a48dc7..394f0d896 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -244,7 +244,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -273,7 +273,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -292,7 +292,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -395,7 +395,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -438,7 +438,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) diff --git a/management/server/route.go b/management/server/route.go index 2b4f11d05..a9561faf0 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -191,7 +191,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationCreate}) } return newRoute, nil @@ -245,7 +245,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) if oldRouteAffectsPeers || newRouteAffectsPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationUpdate}) } return nil @@ -288,7 +288,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationDelete}) } return nil diff --git a/management/server/route_test.go b/management/server/route_test.go index d4882eff8..d0caf4b9b 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -2,10 +2,8 @@ package server import ( "context" - "fmt" "net" "net/netip" - "sort" "testing" "time" @@ -20,6 +18,7 @@ import ( ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -1293,11 +1292,17 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel. peersManager := peers.NewManager(store, permissionsManager) ctx := context.Background() + + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, nil, err + } + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) - am, err := BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) if err != nil { return nil, nil, err } @@ -1833,11 +1838,6 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }, } - validatedPeers := make(map[string]struct{}) - for p := range account.Peers { - validatedPeers[p] = struct{}{} - } - t.Run("check applied policies for the route", func(t *testing.T) { route1 := account.Routes["route1"] policies := types.GetAllRoutePoliciesFromGroups(account, route1.AccessControlGroups) @@ -1851,116 +1851,6 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { policies = types.GetAllRoutePoliciesFromGroups(account, route3.AccessControlGroups) assert.Len(t, policies, 0) }) - - t.Run("check peer routes firewall rules", func(t *testing.T) { - routesFirewallRules := account.GetPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers) - assert.Len(t, routesFirewallRules, 4) - - expectedRoutesFirewallRules := []*types.RouteFirewallRule{ - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerCIp), - fmt.Sprintf(types.AllowedIPsFormat, peerHIp), - fmt.Sprintf(types.AllowedIPsFormat, peerBIp), - }, - Action: "accept", - Destination: "192.168.0.0/16", - Protocol: "all", - Port: 80, - RouteID: "route1:peerA", - }, - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerCIp), - fmt.Sprintf(types.AllowedIPsFormat, peerHIp), - fmt.Sprintf(types.AllowedIPsFormat, peerBIp), - }, - Action: "accept", - Destination: "192.168.0.0/16", - Protocol: "all", - Port: 320, - RouteID: "route1:peerA", - }, - } - additionalFirewallRule := []*types.RouteFirewallRule{ - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerJIp), - }, - Action: "accept", - Destination: "192.168.10.0/16", - Protocol: "tcp", - Port: 80, - RouteID: "route4:peerA", - }, - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerKIp), - }, - Action: "accept", - Destination: "192.168.10.0/16", - Protocol: "all", - RouteID: "route4:peerA", - }, - } - - assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(append(expectedRoutesFirewallRules, additionalFirewallRule...))) - - // peerD is also the routing peer for route1, should contain same routes firewall rules as peerA - routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) - assert.Len(t, routesFirewallRules, 2) - for _, rule := range expectedRoutesFirewallRules { - rule.RouteID = "route1:peerD" - } - assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) - - // peerE is a single routing peer for route 2 and route 3 - routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers) - assert.Len(t, routesFirewallRules, 3) - - expectedRoutesFirewallRules = []*types.RouteFirewallRule{ - { - SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"}, - Action: "accept", - Destination: existingNetwork.String(), - Protocol: "tcp", - PortRange: types.RulePortRange{Start: 80, End: 350}, - RouteID: "route2", - }, - { - SourceRanges: []string{"0.0.0.0/0"}, - Action: "accept", - Destination: "192.0.2.0/32", - Protocol: "all", - Domains: domain.List{"example.com"}, - IsDynamic: true, - RouteID: "route3", - }, - { - SourceRanges: []string{"::/0"}, - Action: "accept", - Destination: "192.0.2.0/32", - Protocol: "all", - Domains: domain.List{"example.com"}, - IsDynamic: true, - RouteID: "route3", - }, - } - assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) - - // peerC is part of route1 distribution groups but should not receive the routes firewall rules - routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) - assert.Len(t, routesFirewallRules, 0) - }) - -} - -// orderList is a helper function to sort a list of strings -func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFirewallRule { - for _, rule := range ruleList { - sort.Strings(rule.SourceRanges) - } - return ruleList } func TestRouteAccountPeersUpdate(t *testing.T) { @@ -2063,7 +1953,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } @@ -2100,7 +1990,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2120,7 +2010,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2138,7 +2028,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2178,7 +2068,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2218,7 +2108,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2658,11 +2548,6 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { }, } - validatedPeers := make(map[string]struct{}) - for p := range account.Peers { - validatedPeers[p] = struct{}{} - } - t.Run("validate applied policies for different network resources", func(t *testing.T) { // Test case: Resource1 is directly applied to the policy (policyResource1) policies := account.GetPoliciesForNetworkResource("resource1") @@ -2686,127 +2571,4 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { policies = account.GetPoliciesForNetworkResource("resource6") assert.Len(t, policies, 1, "resource6 should have exactly 1 policy applied via access control groups") }) - - t.Run("validate routing peer firewall rules for network resources", func(t *testing.T) { - resourcePoliciesMap := account.GetResourcePoliciesMap() - resourceRoutersMap := account.GetResourceRoutersMap() - _, routes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), "peerA", resourcePoliciesMap, resourceRoutersMap) - firewallRules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerA"], validatedPeers, routes, resourcePoliciesMap) - assert.Len(t, firewallRules, 4) - assert.Len(t, sourcePeers, 5) - - expectedFirewallRules := []*types.RouteFirewallRule{ - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerCIp), - fmt.Sprintf(types.AllowedIPsFormat, peerHIp), - fmt.Sprintf(types.AllowedIPsFormat, peerBIp), - }, - Action: "accept", - Destination: "192.168.0.0/16", - Protocol: "all", - Port: 80, - RouteID: "resource2:peerA", - }, - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerCIp), - fmt.Sprintf(types.AllowedIPsFormat, peerHIp), - fmt.Sprintf(types.AllowedIPsFormat, peerBIp), - }, - Action: "accept", - Destination: "192.168.0.0/16", - Protocol: "all", - Port: 320, - RouteID: "resource2:peerA", - }, - } - - additionalFirewallRules := []*types.RouteFirewallRule{ - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerJIp), - }, - Action: "accept", - Destination: "192.0.2.0/32", - Protocol: "tcp", - Port: 80, - Domains: domain.List{"example.com"}, - IsDynamic: true, - RouteID: "resource4:peerA", - }, - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerKIp), - }, - Action: "accept", - Destination: "192.0.2.0/32", - Protocol: "all", - Domains: domain.List{"example.com"}, - IsDynamic: true, - RouteID: "resource4:peerA", - }, - } - assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(append(expectedFirewallRules, additionalFirewallRules...))) - - // peerD is also the routing peer for resource2 - _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerD", resourcePoliciesMap, resourceRoutersMap) - firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerD"], validatedPeers, routes, resourcePoliciesMap) - assert.Len(t, firewallRules, 2) - for _, rule := range expectedFirewallRules { - rule.RouteID = "resource2:peerD" - } - assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) - assert.Len(t, sourcePeers, 3) - - // peerE is a single routing peer for resource1 and resource3 - // PeerE should only receive rules for resource1 since resource3 has no applied policy - _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerE", resourcePoliciesMap, resourceRoutersMap) - firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerE"], validatedPeers, routes, resourcePoliciesMap) - assert.Len(t, firewallRules, 1) - assert.Len(t, sourcePeers, 2) - - expectedFirewallRules = []*types.RouteFirewallRule{ - { - SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"}, - Action: "accept", - Destination: "10.10.10.0/24", - Protocol: "tcp", - PortRange: types.RulePortRange{Start: 80, End: 350}, - RouteID: "resource1:peerE", - }, - } - assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) - - // peerC is part of distribution groups for resource2 but should not receive the firewall rules - firewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) - assert.Len(t, firewallRules, 0) - - // peerL is the single routing peer for resource5 - _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerL", resourcePoliciesMap, resourceRoutersMap) - assert.Len(t, routes, 1) - firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerL"], validatedPeers, routes, resourcePoliciesMap) - assert.Len(t, firewallRules, 1) - assert.Len(t, sourcePeers, 1) - - expectedFirewallRules = []*types.RouteFirewallRule{ - { - SourceRanges: []string{"100.65.29.67/32"}, - Action: "accept", - Destination: "10.12.12.1/32", - Protocol: "tcp", - Port: 8080, - RouteID: "resource5:peerL", - }, - } - assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) - - _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerM", resourcePoliciesMap, resourceRoutersMap) - assert.Len(t, routes, 1) - assert.Len(t, sourcePeers, 0) - - _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerN", resourcePoliciesMap, resourceRoutersMap) - assert.Len(t, routes, 1) - assert.Len(t, sourcePeers, 2) - }) } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index cf030f51e..0a716d08d 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -396,6 +396,11 @@ func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) er return result.Error } + result = tx.Select(clause.Associations).Delete(account.Services, "account_id = ?", account.Id) + if result.Error != nil { + return result.Error + } + result = tx.Select(clause.Associations).Delete(account) if result.Error != nil { return result.Error @@ -1012,10 +1017,10 @@ func (s *SqlStore) GetAccountsCounter(ctx context.Context) (int64, error) { // GetCustomDomainsCounts returns the total and validated custom domain counts. func (s *SqlStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) { var total, validated int64 - if err := s.db.WithContext(ctx).Model(&domain.Domain{}).Count(&total).Error; err != nil { + if err := s.db.Model(&domain.Domain{}).Count(&total).Error; err != nil { return 0, 0, err } - if err := s.db.WithContext(ctx).Model(&domain.Domain{}).Where("validated = ?", true).Count(&validated).Error; err != nil { + if err := s.db.Model(&domain.Domain{}).Where("validated = ?", true).Count(&validated).Error; err != nil { return 0, 0, err } return total, validated, nil @@ -1191,7 +1196,6 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types account.NameServerGroups[ns.ID] = &ns } account.NameServerGroupsG = nil - account.InitOnce() return &account, nil } @@ -1630,7 +1634,6 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc if sExtraIntegratedValidatorGroups.Valid { _ = json.Unmarshal([]byte(sExtraIntegratedValidatorGroups.String), &account.Settings.Extra.IntegratedValidatorGroups) } - account.InitOnce() return &account, nil } @@ -2080,7 +2083,8 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpservice.Service, error) { const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth, meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster, - pass_host_header, rewrite_redirects, session_private_key, session_public_key + pass_host_header, rewrite_redirects, session_private_key, session_public_key, + mode, listen_port, port_auto_assigned, source, source_peer, terminated FROM services WHERE account_id = $1` const targetsQuery = `SELECT id, account_id, service_id, path, host, port, protocol, @@ -2097,6 +2101,9 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv var auth []byte var createdAt, certIssuedAt sql.NullTime var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString + var mode, source, sourcePeer sql.NullString + var terminated, portAutoAssigned sql.NullBool + var listenPort sql.NullInt64 err := row.Scan( &s.ID, &s.AccountID, @@ -2112,6 +2119,12 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv &s.RewriteRedirects, &sessionPrivateKey, &sessionPublicKey, + &mode, + &listenPort, + &portAutoAssigned, + &source, + &sourcePeer, + &terminated, ) if err != nil { return nil, err @@ -2143,7 +2156,24 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv if sessionPublicKey.Valid { s.SessionPublicKey = sessionPublicKey.String } - + if mode.Valid { + s.Mode = mode.String + } + if source.Valid { + s.Source = source.String + } + if sourcePeer.Valid { + s.SourcePeer = sourcePeer.String + } + if terminated.Valid { + s.Terminated = terminated.Bool + } + if portAutoAssigned.Valid { + s.PortAutoAssigned = portAutoAssigned.Bool + } + if listenPort.Valid { + s.ListenPort = uint16(listenPort.Int64) + } s.Targets = []*rpservice.Target{} return &s, nil }) @@ -3278,7 +3308,7 @@ func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStreng var peers []*nbpeer.Peer result := tx. - Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true). + Where("login_expiration_enabled = ? AND peer_status_login_expired != ? AND user_id IS NOT NULL AND user_id != ''", true, true). Find(&peers, accountIDCondition, accountID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to get peers with expiration from the store: %s", result.Error) @@ -4410,7 +4440,7 @@ func (s *SqlStore) DeletePAT(ctx context.Context, userID, patID string) error { // GetProxyAccessTokenByHashedToken retrieves a proxy access token by its hashed value. func (s *SqlStore) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error) { - tx := s.db.WithContext(ctx) + tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } @@ -4429,7 +4459,7 @@ func (s *SqlStore) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStr // GetAllProxyAccessTokens retrieves all proxy access tokens. func (s *SqlStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error) { - tx := s.db.WithContext(ctx) + tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } @@ -4445,7 +4475,7 @@ func (s *SqlStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength Loc // SaveProxyAccessToken saves a proxy access token to the database. func (s *SqlStore) SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error { - if result := s.db.WithContext(ctx).Create(token); result.Error != nil { + if result := s.db.Create(token); result.Error != nil { return status.Errorf(status.Internal, "save proxy access token: %v", result.Error) } return nil @@ -4453,7 +4483,7 @@ func (s *SqlStore) SaveProxyAccessToken(ctx context.Context, token *types.ProxyA // RevokeProxyAccessToken revokes a proxy access token by its ID. func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error { - result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).Where(idQueryCondition, tokenID).Update("revoked", true) + result := s.db.Model(&types.ProxyAccessToken{}).Where(idQueryCondition, tokenID).Update("revoked", true) if result.Error != nil { return status.Errorf(status.Internal, "revoke proxy access token: %v", result.Error) } @@ -4467,7 +4497,7 @@ func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) e // MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token. func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error { - result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}). + result := s.db.Model(&types.ProxyAccessToken{}). Where(idQueryCondition, tokenID). Update("last_used", time.Now().UTC()) if result.Error != nil { @@ -5136,7 +5166,7 @@ func (s *SqlStore) EphemeralServiceExists(ctx context.Context, lockStrength Lock // GetServicesByClusterAndPort returns services matching the given proxy cluster, mode, and listen port. func (s *SqlStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster string, mode string, listenPort uint16) ([]*rpservice.Service, error) { - tx := s.db.WithContext(ctx) + tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } @@ -5152,7 +5182,7 @@ func (s *SqlStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength // GetServicesByCluster returns all services for the given proxy cluster. func (s *SqlStore) GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*rpservice.Service, error) { - tx := s.db.WithContext(ctx) + tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } @@ -5262,7 +5292,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin var logs []*accesslogs.AccessLogEntry var totalCount int64 - baseQuery := s.db.WithContext(ctx). + baseQuery := s.db. Model(&accesslogs.AccessLogEntry{}). Where(accountIDCondition, accountID) @@ -5273,7 +5303,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin return nil, 0, status.Errorf(status.Internal, "failed to count access logs") } - query := s.db.WithContext(ctx). + query := s.db. Where(accountIDCondition, accountID) query = s.applyAccessLogFilters(query, filter) @@ -5310,7 +5340,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin // DeleteOldAccessLogs deletes all access logs older than the specified time func (s *SqlStore) DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error) { - result := s.db.WithContext(ctx). + result := s.db. Where("timestamp < ?", olderThan). Delete(&accesslogs.AccessLogEntry{}) @@ -5399,7 +5429,7 @@ func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength // SaveProxy saves or updates a proxy in the database func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error { - result := s.db.WithContext(ctx).Save(p) + result := s.db.Save(p) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error) return status.Errorf(status.Internal, "failed to save proxy") @@ -5411,7 +5441,7 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error { func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { now := time.Now() - result := s.db.WithContext(ctx). + result := s.db. Model(&proxy.Proxy{}). Where("id = ? AND status = ?", proxyID, "connected"). Update("last_seen", now) @@ -5430,7 +5460,7 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd ConnectedAt: &now, Status: "connected", } - if err := s.db.WithContext(ctx).Save(p).Error; err != nil { + if err := s.db.Save(p).Error; err != nil { log.WithContext(ctx).Errorf("failed to create proxy on heartbeat: %v", err) return status.Errorf(status.Internal, "failed to create proxy on heartbeat") } @@ -5443,9 +5473,9 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) { var addresses []string - result := s.db.WithContext(ctx). + result := s.db. Model(&proxy.Proxy{}). - Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)). + Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)). Distinct("cluster_address"). Pluck("cluster_address", &addresses) @@ -5463,7 +5493,7 @@ func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, result := s.db.Model(&proxy.Proxy{}). Select("cluster_address as address, COUNT(*) as connected_proxies"). - Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)). + Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)). Group("cluster_address"). Scan(&clusters) @@ -5475,11 +5505,122 @@ func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, return clusters, nil } +// proxyActiveThreshold is the maximum age of a heartbeat for a proxy to be +// considered active. Must be at least 2x the heartbeat interval (1 min). +const proxyActiveThreshold = 2 * time.Minute + +var validCapabilityColumns = map[string]struct{}{ + "supports_custom_ports": {}, + "require_subdomain": {}, + "supports_crowdsec": {}, +} + +// GetClusterSupportsCustomPorts returns whether any active proxy in the cluster +// supports custom ports. Returns nil when no proxy reported the capability. +func (s *SqlStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { + return s.getClusterCapability(ctx, clusterAddr, "supports_custom_ports") +} + +// GetClusterRequireSubdomain returns whether any active proxy in the cluster +// requires a subdomain. Returns nil when no proxy reported the capability. +func (s *SqlStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool { + return s.getClusterCapability(ctx, clusterAddr, "require_subdomain") +} + +// GetClusterSupportsCrowdSec returns whether all active proxies in the cluster +// have CrowdSec configured. Returns nil when no proxy reported the capability. +// Unlike other capabilities that use ANY-true (for rolling upgrades), CrowdSec +// requires unanimous support: a single unconfigured proxy would let requests +// bypass reputation checks. +func (s *SqlStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool { + return s.getClusterUnanimousCapability(ctx, clusterAddr, "supports_crowdsec") +} + +// getClusterUnanimousCapability returns an aggregated boolean capability +// requiring all active proxies in the cluster to report true. +func (s *SqlStore) getClusterUnanimousCapability(ctx context.Context, clusterAddr, column string) *bool { + if _, ok := validCapabilityColumns[column]; !ok { + log.WithContext(ctx).Errorf("invalid capability column: %s", column) + return nil + } + + var result struct { + Total int64 + Reported int64 + AllTrue bool + } + + // All active proxies must have reported the capability (no NULLs) and all + // must report true. A single unreported or false proxy means the cluster + // does not unanimously support the capability. + err := s.db.WithContext(ctx). + Model(&proxy.Proxy{}). + Select("COUNT(*) AS total, "+ + "COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) AS reported, "+ + "COUNT(*) > 0 AND COUNT(*) = COUNT(CASE WHEN "+column+" = true THEN 1 END) AS all_true"). + Where("cluster_address = ? AND status = ? AND last_seen > ?", + clusterAddr, "connected", time.Now().Add(-proxyActiveThreshold)). + Scan(&result).Error + + if err != nil { + log.WithContext(ctx).Errorf("query cluster capability %s for %s: %v", column, clusterAddr, err) + return nil + } + + if result.Total == 0 || result.Reported == 0 { + return nil + } + + // If any proxy has not reported (NULL), we can't confirm unanimous support. + if result.Reported < result.Total { + v := false + return &v + } + + return &result.AllTrue +} + +// getClusterCapability returns an aggregated boolean capability for the given +// cluster. It checks active (connected, recently seen) proxies and returns: +// - *true if any proxy in the cluster has the capability set to true, +// - *false if at least one proxy reported but none set it to true, +// - nil if no proxy reported the capability at all. +func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column string) *bool { + if _, ok := validCapabilityColumns[column]; !ok { + log.WithContext(ctx).Errorf("invalid capability column: %s", column) + return nil + } + + var result struct { + HasCapability bool + AnyTrue bool + } + + err := s.db. + Model(&proxy.Proxy{}). + Select("COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) > 0 AS has_capability, "+ + "COALESCE(MAX(CASE WHEN "+column+" = true THEN 1 ELSE 0 END), 0) = 1 AS any_true"). + Where("cluster_address = ? AND status = ? AND last_seen > ?", + clusterAddr, "connected", time.Now().Add(-proxyActiveThreshold)). + Scan(&result).Error + + if err != nil { + log.WithContext(ctx).Errorf("query cluster capability %s for %s: %v", column, clusterAddr, err) + return nil + } + + if !result.HasCapability { + return nil + } + + return &result.AnyTrue +} + // CleanupStaleProxies deletes proxies that haven't sent heartbeat in the specified duration func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error { cutoffTime := time.Now().Add(-inactivityDuration) - result := s.db.WithContext(ctx). + result := s.db. Where("last_seen < ?", cutoffTime). Delete(&proxy.Proxy{}) diff --git a/management/server/store/sql_store_idp_migration.go b/management/server/store/sql_store_idp_migration.go new file mode 100644 index 000000000..64962845b --- /dev/null +++ b/management/server/store/sql_store_idp_migration.go @@ -0,0 +1,177 @@ +package store + +// This file contains migration-only methods on SqlStore. +// They satisfy the migration.Store interface via duck typing. +// Delete this file when migration tooling is no longer needed. + +import ( + "context" + "fmt" + + log "github.com/sirupsen/logrus" + "gorm.io/gorm" + + "github.com/netbirdio/netbird/management/server/idp/migration" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/status" +) + +func (s *SqlStore) CheckSchema(checks []migration.SchemaCheck) []migration.SchemaError { + migrator := s.db.Migrator() + var errs []migration.SchemaError + + for _, check := range checks { + if !migrator.HasTable(check.Table) { + errs = append(errs, migration.SchemaError{Table: check.Table}) + continue + } + for _, col := range check.Columns { + if !migrator.HasColumn(check.Table, col) { + errs = append(errs, migration.SchemaError{Table: check.Table, Column: col}) + } + } + } + + return errs +} + +func (s *SqlStore) ListUsers(ctx context.Context) ([]*types.User, error) { + tx := s.db + var users []*types.User + result := tx.Find(&users) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when listing users from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "issue listing users from store") + } + + for _, user := range users { + if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil { + return nil, fmt.Errorf("decrypt user: %w", err) + } + } + + return users, nil +} + +// txDeferFKConstraints defers foreign key constraint checks for the duration of the transaction. +// MySQL is already handled by s.transaction (SET FOREIGN_KEY_CHECKS = 0). +func (s *SqlStore) txDeferFKConstraints(tx *gorm.DB) error { + if s.storeEngine == types.SqliteStoreEngine { + return tx.Exec("PRAGMA defer_foreign_keys = ON").Error + } + + if s.storeEngine != types.PostgresStoreEngine { + return nil + } + + // GORM creates FK constraints as NOT DEFERRABLE by default, so + // SET CONSTRAINTS ALL DEFERRED is a no-op unless we ALTER them first. + err := tx.Exec(` + DO $$ DECLARE r RECORD; + BEGIN + FOR r IN SELECT conname, conrelid::regclass AS tbl + FROM pg_constraint WHERE contype = 'f' AND NOT condeferrable + LOOP + EXECUTE format('ALTER TABLE %s ALTER CONSTRAINT %I DEFERRABLE INITIALLY IMMEDIATE', r.tbl, r.conname); + END LOOP; + END $$ + `).Error + if err != nil { + return fmt.Errorf("make FK constraints deferrable: %w", err) + } + return tx.Exec("SET CONSTRAINTS ALL DEFERRED").Error +} + +// txRestoreFKConstraints reverts FK constraints back to NOT DEFERRABLE after the +// deferred updates are done but before the transaction commits. +func (s *SqlStore) txRestoreFKConstraints(tx *gorm.DB) error { + if s.storeEngine != types.PostgresStoreEngine { + return nil + } + + return tx.Exec(` + DO $$ DECLARE r RECORD; + BEGIN + FOR r IN SELECT conname, conrelid::regclass AS tbl + FROM pg_constraint WHERE contype = 'f' AND condeferrable + LOOP + EXECUTE format('ALTER TABLE %s ALTER CONSTRAINT %I NOT DEFERRABLE', r.tbl, r.conname); + END LOOP; + END $$ + `).Error +} + +func (s *SqlStore) UpdateUserInfo(ctx context.Context, userID, email, name string) error { + user := &types.User{Email: email, Name: name} + if err := user.EncryptSensitiveData(s.fieldEncrypt); err != nil { + return fmt.Errorf("encrypt user info: %w", err) + } + + result := s.db.Model(&types.User{}).Where("id = ?", userID).Updates(map[string]any{ + "email": user.Email, + "name": user.Name, + }) + if result.Error != nil { + log.WithContext(ctx).Errorf("error updating user info for %s: %s", userID, result.Error) + return status.Errorf(status.Internal, "failed to update user info") + } + + return nil +} + +func (s *SqlStore) UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error { + type fkUpdate struct { + model any + column string + where string + } + + updates := []fkUpdate{ + {&types.PersonalAccessToken{}, "user_id", "user_id = ?"}, + {&types.PersonalAccessToken{}, "created_by", "created_by = ?"}, + {&nbpeer.Peer{}, "user_id", "user_id = ?"}, + {&types.UserInviteRecord{}, "created_by", "created_by = ?"}, + {&types.Account{}, "created_by", "created_by = ?"}, + {&types.ProxyAccessToken{}, "created_by", "created_by = ?"}, + {&types.Job{}, "triggered_by", "triggered_by = ?"}, + } + + log.Info("Updating user ID in the store") + err := s.transaction(func(tx *gorm.DB) error { + if err := s.txDeferFKConstraints(tx); err != nil { + return err + } + + for _, u := range updates { + if err := tx.Model(u.model).Where(u.where, oldUserID).Update(u.column, newUserID).Error; err != nil { + return fmt.Errorf("update %s: %w", u.column, err) + } + } + + if err := tx.Model(&types.User{}).Where(accountAndIDQueryCondition, accountID, oldUserID).Update("id", newUserID).Error; err != nil { + return fmt.Errorf("update users: %w", err) + } + + return nil + }) + if err != nil { + log.WithContext(ctx).Errorf("failed to update user ID in the store: %s", err) + return status.Errorf(status.Internal, "failed to update user ID in store") + } + + log.Info("Restoring FK constraints") + err = s.transaction(func(tx *gorm.DB) error { + if err := s.txRestoreFKConstraints(tx); err != nil { + return fmt.Errorf("restore FK constraints: %w", err) + } + + return nil + }) + if err != nil { + log.WithContext(ctx).Errorf("failed to restore FK constraints after user ID update: %s", err) + return status.Errorf(status.Internal, "failed to restore FK constraints after user ID update") + } + + return nil +} diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index bafa63580..5a5616abc 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -22,6 +22,8 @@ import ( "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/dns" + proxydomain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones/records" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -350,6 +352,35 @@ func TestSqlite_DeleteAccount(t *testing.T) { }, } + account.Services = []*rpservice.Service{ + { + ID: "service_id", + AccountID: account.Id, + Name: "test service", + Domain: "svc.example.com", + Enabled: true, + Targets: []*rpservice.Target{ + { + AccountID: account.Id, + ServiceID: "service_id", + Host: "localhost", + Port: 8080, + Protocol: "http", + Enabled: true, + }, + }, + }, + } + + account.Domains = []*proxydomain.Domain{ + { + ID: "domain_id", + Domain: "custom.example.com", + AccountID: account.Id, + Validated: true, + }, + } + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) @@ -411,6 +442,20 @@ func TestSqlite_DeleteAccount(t *testing.T) { require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for network resources") require.Len(t, resources, 0, "expecting no network resources to be found after DeleteAccount") } + + domains, err := store.ListCustomDomains(context.Background(), account.Id) + require.NoError(t, err, "expecting no error after DeleteAccount when searching for custom domains") + require.Len(t, domains, 0, "expecting no custom domains to be found after DeleteAccount") + + var services []*rpservice.Service + err = store.(*SqlStore).db.Model(&rpservice.Service{}).Find(&services, "account_id = ?", account.Id).Error + require.NoError(t, err, "expecting no error after DeleteAccount when searching for services") + require.Len(t, services, 0, "expecting no services to be found after DeleteAccount") + + var targets []*rpservice.Target + err = store.(*SqlStore).db.Model(&rpservice.Target{}).Find(&targets, "account_id = ?", account.Id).Error + require.NoError(t, err, "expecting no error after DeleteAccount when searching for service targets") + require.Len(t, targets, 0, "expecting no service targets to be found after DeleteAccount") } func Test_GetAccount(t *testing.T) { @@ -2684,7 +2729,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) { { name: "should retrieve peers for an existing account ID", accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", - expectedCount: 4, + expectedCount: 5, }, { name: "should return no peers for a non-existing account ID", @@ -2706,7 +2751,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) { name: "should filter peers by partial name", accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", nameFilter: "host", - expectedCount: 3, + expectedCount: 4, }, { name: "should filter peers by ip", @@ -2732,14 +2777,16 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) { require.NoError(t, err) tests := []struct { - name string - accountID string - expectedCount int + name string + accountID string + expectedCount int + expectedPeerIDs []string }{ { - name: "should retrieve peers with expiration for an existing account ID", - accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", - expectedCount: 1, + name: "should retrieve only non-expired peers with expiration enabled", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + expectedPeerIDs: []string{"notexpired01"}, }, { name: "should return no peers with expiration for a non-existing account ID", @@ -2758,10 +2805,30 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) { peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, tt.accountID) require.NoError(t, err) require.Len(t, peers, tt.expectedCount) + for i, peer := range peers { + assert.Equal(t, tt.expectedPeerIDs[i], peer.ID) + } }) } } +func TestSqlStore_GetAccountPeersWithExpiration_ExcludesAlreadyExpired(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, accountID) + require.NoError(t, err) + + // Verify the already-expired peer (cg05lnblo1hkg2j514p0) is not returned + for _, peer := range peers { + assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should not be returned") + assert.False(t, peer.Status.LoginExpired, "returned peers should not have LoginExpired set") + } +} + func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) { store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) t.Cleanup(cleanup) @@ -2842,7 +2909,7 @@ func TestSqlStore_GetUserPeers(t *testing.T) { name: "should retrieve peers for another valid account ID and user ID", accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", userID: "edafee4e-63fb-11ec-90d6-0242ac120003", - expectedCount: 2, + expectedCount: 3, }, { name: "should return no peers for existing account ID with empty user ID", diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go index f2abafceb..81c4b33ae 100644 --- a/management/server/store/sqlstore_bench_test.go +++ b/management/server/store/sqlstore_bench_test.go @@ -20,6 +20,7 @@ import ( "github.com/stretchr/testify/assert" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" @@ -265,6 +266,7 @@ func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) { &nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{}, &service.Service{}, &service.Target{}, + &domain.Domain{}, } for i := len(models) - 1; i >= 0; i-- { diff --git a/management/server/store/store.go b/management/server/store/store.go index d00dcde38..0d8b0678a 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -121,7 +121,7 @@ type Store interface { GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) - GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error) + GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error UpdateGroups(ctx context.Context, accountID string, groups []*types.Group) error @@ -287,6 +287,9 @@ type Store interface { UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) + GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool + GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool + GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error) @@ -446,6 +449,12 @@ func getMigrationsPreAuto(ctx context.Context) []migrationFunc { func(db *gorm.DB) error { return migration.RemoveDuplicatePeerKeys(ctx, db) }, + func(db *gorm.DB) error { + return migration.CleanupOrphanedResources[rpservice.Service, types.Account](ctx, db, "account_id") + }, + func(db *gorm.DB) error { + return migration.CleanupOrphanedResources[domain.Domain, types.Account](ctx, db, "account_id") + }, } } diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index 235405861..beee13d96 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -165,6 +165,19 @@ func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration int return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration) } +// GetClusterSupportsCrowdSec mocks base method. +func (m *MockStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClusterSupportsCrowdSec", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// GetClusterSupportsCrowdSec indicates an expected call of GetClusterSupportsCrowdSec. +func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr) +} // Close mocks base method. func (m *MockStore) Close(ctx context.Context) error { m.ctrl.T.Helper() @@ -1361,6 +1374,34 @@ func (mr *MockStoreMockRecorder) GetAnyAccountID(ctx interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnyAccountID", reflect.TypeOf((*MockStore)(nil).GetAnyAccountID), ctx) } +// GetClusterRequireSubdomain mocks base method. +func (m *MockStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClusterRequireSubdomain", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// GetClusterRequireSubdomain indicates an expected call of GetClusterRequireSubdomain. +func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr) +} + +// GetClusterSupportsCustomPorts mocks base method. +func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClusterSupportsCustomPorts", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// GetClusterSupportsCustomPorts indicates an expected call of GetClusterSupportsCustomPorts. +func (mr *MockStoreMockRecorder) GetClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCustomPorts", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCustomPorts), ctx, clusterAddr) +} + // GetCustomDomain mocks base method. func (m *MockStore) GetCustomDomain(ctx context.Context, accountID, domainID string) (*domain.Domain, error) { m.ctrl.T.Helper() @@ -1438,18 +1479,18 @@ func (mr *MockStoreMockRecorder) GetGroupByID(ctx, lockStrength, accountID, grou } // GetGroupByName mocks base method. -func (m *MockStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types2.Group, error) { +func (m *MockStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types2.Group, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetGroupByName", ctx, lockStrength, groupName, accountID) + ret := m.ctrl.Call(m, "GetGroupByName", ctx, lockStrength, accountID, groupName) ret0, _ := ret[0].(*types2.Group) ret1, _ := ret[1].(error) return ret0, ret1 } // GetGroupByName indicates an expected call of GetGroupByName. -func (mr *MockStoreMockRecorder) GetGroupByName(ctx, lockStrength, groupName, accountID interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) GetGroupByName(ctx, lockStrength, accountID, groupName interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockStore)(nil).GetGroupByName), ctx, lockStrength, groupName, accountID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockStore)(nil).GetGroupByName), ctx, lockStrength, accountID, groupName) } // GetGroupsByIDs mocks base method. @@ -1946,6 +1987,21 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRouteByID", reflect.TypeOf((*MockStore)(nil).GetRouteByID), ctx, lockStrength, accountID, routeID) } +// GetRoutingPeerNetworks mocks base method. +func (m *MockStore) GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRoutingPeerNetworks", ctx, accountID, peerID) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRoutingPeerNetworks indicates an expected call of GetRoutingPeerNetworks. +func (mr *MockStoreMockRecorder) GetRoutingPeerNetworks(ctx, accountID, peerID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoutingPeerNetworks", reflect.TypeOf((*MockStore)(nil).GetRoutingPeerNetworks), ctx, accountID, peerID) +} + // GetServiceByDomain mocks base method. func (m *MockStore) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) { m.ctrl.T.Helper() @@ -2333,21 +2389,6 @@ func (mr *MockStoreMockRecorder) IncrementSetupKeyUsage(ctx, setupKeyID interfac return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementSetupKeyUsage", reflect.TypeOf((*MockStore)(nil).IncrementSetupKeyUsage), ctx, setupKeyID) } -// GetRoutingPeerNetworks mocks base method. -func (m *MockStore) GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetRoutingPeerNetworks", ctx, accountID, peerID) - ret0, _ := ret[0].([]string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetRoutingPeerNetworks indicates an expected call of GetRoutingPeerNetworks. -func (mr *MockStoreMockRecorder) GetRoutingPeerNetworks(ctx, accountID, peerID interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoutingPeerNetworks", reflect.TypeOf((*MockStore)(nil).GetRoutingPeerNetworks), ctx, accountID, peerID) -} - // IsPrimaryAccount mocks base method. func (m *MockStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) { m.ctrl.T.Helper() diff --git a/management/server/telemetry/accountmanager_metrics.go b/management/server/telemetry/accountmanager_metrics.go index 3b1e078eb..518aae7eb 100644 --- a/management/server/telemetry/accountmanager_metrics.go +++ b/management/server/telemetry/accountmanager_metrics.go @@ -4,6 +4,7 @@ import ( "context" "time" + "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" ) @@ -11,6 +12,7 @@ import ( type AccountManagerMetrics struct { ctx context.Context updateAccountPeersDurationMs metric.Float64Histogram + updateAccountPeersCounter metric.Int64Counter getPeerNetworkMapDurationMs metric.Float64Histogram networkMapObjectCount metric.Int64Histogram peerMetaUpdateCount metric.Int64Counter @@ -48,6 +50,13 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account return nil, err } + updateAccountPeersCounter, err := meter.Int64Counter("management.account.update.account.peers.counter", + metric.WithUnit("1"), + metric.WithDescription("Number of account peers updates triggered, labeled by resource and operation")) + if err != nil { + return nil, err + } + peerMetaUpdateCount, err := meter.Int64Counter("management.account.peer.meta.update.counter", metric.WithUnit("1"), metric.WithDescription("Number of updates with new meta data from the peers")) @@ -59,6 +68,7 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account ctx: ctx, getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs, updateAccountPeersDurationMs: updateAccountPeersDurationMs, + updateAccountPeersCounter: updateAccountPeersCounter, networkMapObjectCount: networkMapObjectCount, peerMetaUpdateCount: peerMetaUpdateCount, }, nil @@ -80,6 +90,16 @@ func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) { metrics.networkMapObjectCount.Record(metrics.ctx, count) } +// CountUpdateAccountPeersTriggered increments the counter for account peers updates with resource and operation labels. +func (metrics *AccountManagerMetrics) CountUpdateAccountPeersTriggered(resource, operation string) { + metrics.updateAccountPeersCounter.Add(metrics.ctx, 1, + metric.WithAttributes( + attribute.String("resource", resource), + attribute.String("operation", operation), + ), + ) +} + // CountPeerMetUpdate counts the number of peer meta updates func (metrics *AccountManagerMetrics) CountPeerMetUpdate() { metrics.peerMetaUpdateCount.Add(metrics.ctx, 1) diff --git a/management/server/telemetry/http_api_metrics.go b/management/server/telemetry/http_api_metrics.go index c50ed1e51..e48e6d64a 100644 --- a/management/server/telemetry/http_api_metrics.go +++ b/management/server/telemetry/http_api_metrics.go @@ -183,19 +183,22 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler { w := WrapResponseWriter(rw) - h.ServeHTTP(w, r.WithContext(ctx)) + handlerDone := make(chan struct{}) + context.AfterFunc(ctx, func() { + select { + case <-handlerDone: + default: + log.Debugf("HTTP request context canceled mid-flight: %v %v (reqID=%s, after %v, cause: %v)", + r.Method, r.URL.Path, reqID, time.Since(reqStart), context.Cause(ctx)) + } + }) - userAuth, err := nbContext.GetUserAuthFromContext(r.Context()) - if err == nil { - if userAuth.AccountId != "" { - //nolint - ctx = context.WithValue(ctx, nbContext.AccountIDKey, userAuth.AccountId) - } - if userAuth.UserId != "" { - //nolint - ctx = context.WithValue(ctx, nbContext.UserIDKey, userAuth.UserId) - } - } + // Hold on to req so auth's in-place ctx update is visible after ServeHTTP. + req := r.WithContext(ctx) + h.ServeHTTP(w, req) + close(handlerDone) + + ctx = req.Context() if w.Status() > 399 { log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status()) diff --git a/management/server/testdata/store_with_expired_peers.sql b/management/server/testdata/store_with_expired_peers.sql index dfcaeee6f..189bd1262 100644 --- a/management/server/testdata/store_with_expired_peers.sql +++ b/management/server/testdata/store_with_expired_peers.sql @@ -31,6 +31,7 @@ INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-3465300 INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','nVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HX=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('notexpired01','bf1c8084-ba50-4ce7-9439-34653001fc3b','oVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HY=','','"100.64.117.98"','activehost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'activehost','activehost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,1,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO installations VALUES(1,''); diff --git a/management/server/types/account.go b/management/server/types/account.go index 269fc7a88..e7c1e2dce 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -8,7 +8,6 @@ import ( "slices" "strconv" "strings" - "sync" "time" "github.com/hashicorp/go-multierror" @@ -18,6 +17,7 @@ import ( "github.com/netbirdio/netbird/client/ssh/auth" nbdns "github.com/netbirdio/netbird/dns" + proxydomain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones/records" @@ -26,7 +26,6 @@ import ( networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" @@ -101,6 +100,7 @@ type Account struct { DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"` Services []*service.Service `gorm:"foreignKey:AccountID;references:id"` + Domains []*proxydomain.Domain `gorm:"foreignKey:AccountID;references:id"` // Settings is a dictionary of Account settings Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"` @@ -108,16 +108,9 @@ type Account struct { NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"` Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"` - NetworkMapCache *NetworkMapBuilder `gorm:"-"` - nmapInitOnce *sync.Once `gorm:"-"` - ReverseProxyFreeDomainNonce string } -func (a *Account) InitOnce() { - a.nmapInitOnce = &sync.Once{} -} - // this class is used by gorm only type PrimaryAccountInfo struct { IsDomainPrimaryAccount bool @@ -153,108 +146,6 @@ func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool { o.SignupFormPending == onboarding.SignupFormPending } -// GetRoutesToSync returns the enabled routes for the peer ID and the routes -// from the ACL peers that have distribution groups associated with the peer ID. -// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. -func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route { - routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID) - peerRoutesMembership := make(LookupMap) - for _, r := range append(routes, peerDisabledRoutes...) { - peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} - } - - for _, peer := range aclPeers { - activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID) - groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, peerGroups) - filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) - routes = append(routes, filteredRoutes...) - } - - return routes -} - -// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership -func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route { - var filteredRoutes []*route.Route - for _, r := range routes { - _, found := peerMemberships[string(r.GetHAUniqueID())] - if !found { - filteredRoutes = append(filteredRoutes, r) - } - } - return filteredRoutes -} - -// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map -func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route { - var filteredRoutes []*route.Route - for _, r := range routes { - for _, groupID := range r.Groups { - _, found := groupListMap[groupID] - if found { - filteredRoutes = append(filteredRoutes, r) - break - } - } - } - return filteredRoutes -} - -// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves -// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. -// If the given is not a routing peer, then the lists are empty. -func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { - - peer := a.GetPeer(peerID) - if peer == nil { - log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id) - return enabledRoutes, disabledRoutes - } - - seenRoute := make(map[route.ID]struct{}) - - takeRoute := func(r *route.Route, id string) { - if _, ok := seenRoute[r.ID]; ok { - return - } - seenRoute[r.ID] = struct{}{} - - if r.Enabled { - r.Peer = peer.Key - enabledRoutes = append(enabledRoutes, r) - return - } - disabledRoutes = append(disabledRoutes, r) - } - - for _, r := range a.Routes { - for _, groupID := range r.PeerGroups { - group := a.GetGroup(groupID) - if group == nil { - log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id) - continue - } - for _, id := range group.Peers { - if id != peerID { - continue - } - - newPeerRoute := r.Copy() - newPeerRoute.Peer = id - newPeerRoute.PeerGroups = nil - newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map - takeRoute(newPeerRoute, id) - break - } - } - if r.Peer == peerID { - takeRoute(r.Copy(), peerID) - } - } - - return enabledRoutes, disabledRoutes -} - // GetRoutesByPrefixOrDomains return list of routes by account and route prefix func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route { var routes []*route.Route @@ -274,106 +165,6 @@ func (a *Account) GetGroup(groupID string) *Group { return a.Groups[groupID] } -// GetPeerNetworkMap returns the networkmap for the given peer ID. -func (a *Account) GetPeerNetworkMap( - ctx context.Context, - peerID string, - peersCustomZone nbdns.CustomZone, - accountZones []*zones.Zone, - validatedPeersMap map[string]struct{}, - resourcePolicies map[string][]*Policy, - routers map[string]map[string]*routerTypes.NetworkRouter, - metrics *telemetry.AccountManagerMetrics, - groupIDToUserIDs map[string][]string, -) *NetworkMap { - start := time.Now() - peer := a.Peers[peerID] - if peer == nil { - return &NetworkMap{ - Network: a.Network.Copy(), - } - } - - if _, ok := validatedPeersMap[peerID]; !ok { - return &NetworkMap{ - Network: a.Network.Copy(), - } - } - - peerGroups := a.GetPeerGroups(peerID) - - aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs) - // exclude expired peers - var peersToConnect []*nbpeer.Peer - var expiredPeers []*nbpeer.Peer - for _, p := range aclPeers { - expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration) - if a.Settings.PeerLoginExpirationEnabled && expired { - expiredPeers = append(expiredPeers, p) - continue - } - peersToConnect = append(peersToConnect, p) - } - - routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect, peerGroups) - routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap) - isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers) - var networkResourcesFirewallRules []*RouteFirewallRule - if isRouter { - networkResourcesFirewallRules = a.GetPeerNetworkResourceFirewallRules(ctx, peer, validatedPeersMap, networkResourcesRoutes, resourcePolicies) - } - peersToConnectIncludingRouters := a.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, isRouter, sourcePeers) - - dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) - dnsUpdate := nbdns.Config{ - ServiceEnable: dnsManagementStatus, - } - - if dnsManagementStatus { - var zones []nbdns.CustomZone - - if peersCustomZone.Domain != "" { - records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers) - zones = append(zones, nbdns.CustomZone{ - Domain: peersCustomZone.Domain, - Records: records, - }) - } - - filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups) - zones = append(zones, filteredAccountZones...) - - dnsUpdate.CustomZones = zones - dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) - } - - nm := &NetworkMap{ - Peers: peersToConnectIncludingRouters, - Network: a.Network.Copy(), - Routes: slices.Concat(networkResourcesRoutes, routesUpdate), - DNSConfig: dnsUpdate, - OfflinePeers: expiredPeers, - FirewallRules: firewallRules, - RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules), - AuthorizedUsers: authorizedUsers, - EnableSSH: enableSSH, - } - - if metrics != nil { - objectCount := int64(len(peersToConnectIncludingRouters) + len(expiredPeers) + len(routesUpdate) + len(networkResourcesRoutes) + len(firewallRules) + +len(networkResourcesFirewallRules) + len(routesFirewallRules)) - metrics.CountNetworkMapObjects(objectCount) - metrics.CountGetPeerNetworkMapDuration(time.Since(start)) - - if objectCount > 5000 { - log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+ - "peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d, network resources routes: %d, network resources firewall rules: %d, routes firewall rules: %d", - a.Id, objectCount, len(peersToConnectIncludingRouters), len(expiredPeers), len(routesUpdate), len(firewallRules), len(networkResourcesRoutes), len(networkResourcesFirewallRules), len(routesFirewallRules)) - } - } - - return nm -} - func (a *Account) addNetworksRoutingPeers( networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, @@ -419,39 +210,6 @@ func (a *Account) addNetworksRoutingPeers( return peersToConnect } -func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup { - groupList := account.GetPeerGroups(peerID) - - var peerNSGroups []*nbdns.NameServerGroup - - for _, nsGroup := range account.NameServerGroups { - if !nsGroup.Enabled { - continue - } - for _, gID := range nsGroup.Groups { - _, found := groupList[gID] - if found { - if !peerIsNameserver(account.GetPeer(peerID), nsGroup) { - peerNSGroups = append(peerNSGroups, nsGroup.Copy()) - break - } - } - } - } - - return peerNSGroups -} - -// peerIsNameserver returns true if the peer is a nameserver for a nsGroup -func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool { - for _, ns := range nsGroup.NameServers { - if peer.IP.Equal(ns.IP.AsSlice()) { - return true - } - } - return false -} - func AddPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels LookupMap) { for _, peer := range account.Peers { label, err := GetPeerHostLabel(peer.Name, peerLabels) @@ -798,19 +556,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string { return grps } -func (a *Account) getPeerDNSManagementStatus(peerID string) bool { - peerGroups := a.GetPeerGroups(peerID) - enabled := true - for _, groupID := range a.DNSSettings.DisabledManagementGroups { - _, found := peerGroups[groupID] - if found { - enabled = false - break - } - } - return enabled -} - func (a *Account) GetPeerGroups(peerID string) LookupMap { groupList := make(LookupMap) for groupID, group := range a.Groups { @@ -911,6 +656,11 @@ func (a *Account) Copy() *Account { services = append(services, svc.Copy()) } + domains := []*proxydomain.Domain{} + for _, domain := range a.Domains { + domains = append(domains, domain.Copy()) + } + return &Account{ Id: a.Id, CreatedBy: a.CreatedBy, @@ -934,8 +684,7 @@ func (a *Account) Copy() *Account { NetworkResources: networkResources, Services: services, Onboarding: a.Onboarding, - NetworkMapCache: a.NetworkMapCache, - nmapInitOnce: a.nmapInitOnce, + Domains: domains, } } @@ -1296,31 +1045,6 @@ func (a *Account) GetPostureChecks(postureChecksID string) *posture.Checks { return nil } -// GetPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account. -func (a *Account) GetPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule { - routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes)) - - enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID) - for _, route := range enabledRoutes { - // If no access control groups are specified, accept all traffic. - if len(route.AccessControlGroups) == 0 { - defaultPermit := getDefaultPermit(route) - routesFirewallRules = append(routesFirewallRules, defaultPermit...) - continue - } - - distributionPeers := a.getDistributionGroupsPeers(route) - - for _, accessGroup := range route.AccessControlGroups { - policies := GetAllRoutePoliciesFromGroups(a, []string{accessGroup}) - rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers) - routesFirewallRules = append(routesFirewallRules, rules...) - } - } - - return routesFirewallRules -} - func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule { var fwRules []*RouteFirewallRule for _, policy := range policies { @@ -1379,50 +1103,6 @@ func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID return distributionGroupPeers } -func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} { - distPeers := make(map[string]struct{}) - for _, id := range route.Groups { - group := a.Groups[id] - if group == nil { - continue - } - - for _, pID := range group.Peers { - distPeers[pID] = struct{}{} - } - } - return distPeers -} - -func getDefaultPermit(route *route.Route) []*RouteFirewallRule { - var rules []*RouteFirewallRule - - sources := []string{"0.0.0.0/0"} - if route.Network.Addr().Is6() { - sources = []string{"::/0"} - } - rule := RouteFirewallRule{ - SourceRanges: sources, - Action: string(PolicyTrafficActionAccept), - Destination: route.Network.String(), - Protocol: string(PolicyRuleProtocolALL), - Domains: route.Domains, - IsDynamic: route.IsDynamic(), - RouteID: route.ID, - } - - rules = append(rules, &rule) - - // dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally - if route.IsDynamic() { - ruleV6 := rule - ruleV6.SourceRanges = []string{"::/0"} - rules = append(rules, &ruleV6) - } - - return rules -} - // GetAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups // and returns a list of policies that have rules with destinations matching the specified groups. func GetAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy { @@ -1500,65 +1180,6 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy { return resourcePolicies } -// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers. -func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) { - var isRoutingPeer bool - var routes []*route.Route - allSourcePeers := make(map[string]struct{}, len(a.Peers)) - - for _, resource := range a.NetworkResources { - if !resource.Enabled { - continue - } - - var addSourcePeers bool - - networkRoutingPeers, exists := routers[resource.NetworkID] - if exists { - if router, ok := networkRoutingPeers[peerID]; ok { - isRoutingPeer, addSourcePeers = true, true - routes = append(routes, a.getNetworkResourcesRoutes(resource, peerID, router, resourcePolicies)...) - } - } - - addedResourceRoute := false - for _, policy := range resourcePolicies[resource.ID] { - var peers []string - if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { - peers = []string{policy.Rules[0].SourceResource.ID} - } else { - peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) - } - if addSourcePeers { - for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) { - allSourcePeers[pID] = struct{}{} - } - } else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { - // add routes for the resource if the peer is in the distribution group - for peerId, router := range networkRoutingPeers { - routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...) - } - addedResourceRoute = true - } - if addedResourceRoute { - break - } - } - } - - return isRoutingPeer, routes, allSourcePeers -} - -func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string { - var dest []string - for _, peerID := range inputPeers { - if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) { - dest = append(dest, peerID) - } - } - return dest -} - func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string { peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity for _, groupID := range groups { @@ -1650,22 +1271,6 @@ func (a *Account) GetPoliciesAppliedInNetwork(networkID string) []string { return result } -// getNetworkResourcesRoutes convert the network resources list to routes list. -func (a *Account) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerId string, router *routerTypes.NetworkRouter, resourcePolicies map[string][]*Policy) []*route.Route { - resourceAppliedPolicies := resourcePolicies[resource.ID] - - var routes []*route.Route - // distribute the resource routes only if there is policy applied to it - if len(resourceAppliedPolicies) > 0 { - peer := a.GetPeer(peerId) - if peer != nil { - routes = append(routes, resource.ToRoute(peer, router)) - } - } - - return routes -} - func (a *Account) GetResourceRoutersMap() map[string]map[string]*routerTypes.NetworkRouter { routers := make(map[string]map[string]*routerTypes.NetworkRouter) diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index af2896216..9b1c9e31d 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -4,8 +4,6 @@ import ( "context" "fmt" "net" - "net/netip" - "slices" "testing" "github.com/miekg/dns" @@ -19,7 +17,6 @@ import ( routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/route" ) @@ -84,6 +81,12 @@ func setupTestAccount() *Account { }, }, Groups: map[string]*Group{ + "groupAll": { + ID: "groupAll", + Name: "All", + Peers: []string{"peer1", "peer2", "peer3", "peer11", "peer12", "peer21", "peer31", "peer32", "peer41", "peer51", "peer61"}, + Issued: GroupIssuedAPI, + }, "group1": { ID: "group1", Peers: []string{"peer11", "peer12"}, @@ -445,402 +448,6 @@ func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) { require.Len(t, result, 0) } -const ( - accID = "accountID" - network1ID = "network1ID" - group1ID = "group1" - accNetResourcePeer1ID = "peer1" - accNetResourcePeer2ID = "peer2" - accNetResourceRouter1ID = "router1" - accNetResource1ID = "resource1ID" - accNetResourceRestrictPostureCheckID = "restrictPostureCheck" - accNetResourceRelaxedPostureCheckID = "relaxedPostureCheck" - accNetResourceLockedPostureCheckID = "lockedPostureCheck" - accNetResourceLinuxPostureCheckID = "linuxPostureCheck" -) - -var ( - accNetResourcePeer1IP = net.IP{192, 168, 1, 1} - accNetResourcePeer2IP = net.IP{192, 168, 1, 2} - accNetResourceRouter1IP = net.IP{192, 168, 1, 3} - accNetResourceValidPeers = map[string]struct{}{accNetResourcePeer1ID: {}, accNetResourcePeer2ID: {}} -) - -func getBasicAccountsWithResource() *Account { - return &Account{ - Id: accID, - Peers: map[string]*nbpeer.Peer{ - accNetResourcePeer1ID: { - ID: accNetResourcePeer1ID, - AccountID: accID, - Key: "peer1Key", - IP: accNetResourcePeer1IP, - Meta: nbpeer.PeerSystemMeta{ - GoOS: "linux", - WtVersion: "0.35.1", - KernelVersion: "4.4.0", - }, - }, - accNetResourcePeer2ID: { - ID: accNetResourcePeer2ID, - AccountID: accID, - Key: "peer2Key", - IP: accNetResourcePeer2IP, - Meta: nbpeer.PeerSystemMeta{ - GoOS: "windows", - WtVersion: "0.34.1", - KernelVersion: "4.4.0", - }, - }, - accNetResourceRouter1ID: { - ID: accNetResourceRouter1ID, - AccountID: accID, - Key: "router1Key", - IP: accNetResourceRouter1IP, - Meta: nbpeer.PeerSystemMeta{ - GoOS: "linux", - WtVersion: "0.35.1", - KernelVersion: "4.4.0", - }, - }, - }, - Groups: map[string]*Group{ - group1ID: { - ID: group1ID, - Peers: []string{accNetResourcePeer1ID, accNetResourcePeer2ID}, - }, - }, - Networks: []*networkTypes.Network{ - { - ID: network1ID, - AccountID: accID, - Name: "network1", - }, - }, - NetworkRouters: []*routerTypes.NetworkRouter{ - { - ID: accNetResourceRouter1ID, - NetworkID: network1ID, - AccountID: accID, - Peer: accNetResourceRouter1ID, - PeerGroups: []string{}, - Masquerade: false, - Metric: 100, - Enabled: true, - }, - }, - NetworkResources: []*resourceTypes.NetworkResource{ - { - ID: accNetResource1ID, - AccountID: accID, - NetworkID: network1ID, - Address: "10.10.10.0/24", - Prefix: netip.MustParsePrefix("10.10.10.0/24"), - Type: resourceTypes.NetworkResourceType("subnet"), - Enabled: true, - }, - }, - Policies: []*Policy{ - { - ID: "policy1ID", - AccountID: accID, - Enabled: true, - Rules: []*PolicyRule{ - { - ID: "rule1ID", - Enabled: true, - Sources: []string{group1ID}, - DestinationResource: Resource{ - ID: accNetResource1ID, - Type: "Host", - }, - Protocol: PolicyRuleProtocolTCP, - Ports: []string{"80"}, - Action: PolicyTrafficActionAccept, - }, - }, - SourcePostureChecks: nil, - }, - }, - PostureChecks: []*posture.Checks{ - { - ID: accNetResourceRestrictPostureCheckID, - Name: accNetResourceRestrictPostureCheckID, - Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{ - MinVersion: "0.35.0", - }, - }, - }, - { - ID: accNetResourceRelaxedPostureCheckID, - Name: accNetResourceRelaxedPostureCheckID, - Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{ - MinVersion: "0.0.1", - }, - }, - }, - { - ID: accNetResourceLockedPostureCheckID, - Name: accNetResourceLockedPostureCheckID, - Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{ - MinVersion: "7.7.7", - }, - }, - }, - { - ID: accNetResourceLinuxPostureCheckID, - Name: accNetResourceLinuxPostureCheckID, - Checks: posture.ChecksDefinition{ - OSVersionCheck: &posture.OSVersionCheck{ - Linux: &posture.MinKernelVersionCheck{ - MinKernelVersion: "0.0.0"}, - }, - }, - }, - }, - } -} - -func Test_NetworksNetMapGenWithNoPostureChecks(t *testing.T) { - account := getBasicAccountsWithResource() - - // all peers should match the policy - - // validate for peer1 - isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate for peer2 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate routes for router1 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.True(t, isRouter, "should be router") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 2, "expected source peers don't match") - assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") - assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match") - - // validate rules for router1 - rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) - assert.Len(t, rules, 1, "expected rules count don't match") - assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") - assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") - if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { - t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) - } - if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { - t.Errorf("%s should have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) - } -} - -func Test_NetworksNetMapGenWithPostureChecks(t *testing.T) { - account := getBasicAccountsWithResource() - - // should allow peer1 to match the policy - policy := account.Policies[0] - policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID} - - // validate for peer1 - isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate for peer2 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate routes for router1 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.True(t, isRouter, "should be router") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 1, "expected source peers don't match") - assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") - - // validate rules for router1 - rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) - assert.Len(t, rules, 1, "expected rules count don't match") - assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") - assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") - if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { - t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) - } - if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { - t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) - } -} - -func Test_NetworksNetMapGenWithNoMatchedPostureChecks(t *testing.T) { - account := getBasicAccountsWithResource() - - // should not match any peer - policy := account.Policies[0] - policy.SourcePostureChecks = []string{accNetResourceLockedPostureCheckID} - - // validate for peer1 - isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate for peer2 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate routes for router1 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.True(t, isRouter, "should be router") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate rules for router1 - rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) - assert.Len(t, rules, 0, "expected rules count don't match") -} - -func Test_NetworksNetMapGenWithTwoPoliciesAndPostureChecks(t *testing.T) { - account := getBasicAccountsWithResource() - - // should allow peer1 to match the policy - policy := account.Policies[0] - policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID} - - // should allow peer1 and peer2 to match the policy - newPolicy := &Policy{ - ID: "policy2ID", - AccountID: accID, - Enabled: true, - Rules: []*PolicyRule{ - { - ID: "policy2ID", - Enabled: true, - Sources: []string{group1ID}, - DestinationResource: Resource{ - ID: accNetResource1ID, - Type: "Host", - }, - Protocol: PolicyRuleProtocolTCP, - Ports: []string{"22"}, - Action: PolicyTrafficActionAccept, - }, - }, - SourcePostureChecks: []string{accNetResourceRelaxedPostureCheckID}, - } - - account.Policies = append(account.Policies, newPolicy) - - // validate for peer1 - isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate for peer2 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate routes for router1 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.True(t, isRouter, "should be router") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 2, "expected source peers don't match") - assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") - assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match") - - // validate rules for router1 - rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) - assert.Len(t, rules, 2, "expected rules count don't match") - assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") - assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") - if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { - t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) - } - if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { - t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) - } - - assert.Equal(t, uint16(22), rules[1].Port, "should have port 22") - assert.Equal(t, "tcp", rules[1].Protocol, "should have protocol tcp") - if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer1IP.String()+"/32") { - t.Errorf("%s should have source range of peer1 %s", rules[1].SourceRanges, accNetResourcePeer1IP.String()) - } - if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer2IP.String()+"/32") { - t.Errorf("%s should have source range of peer2 %s", rules[1].SourceRanges, accNetResourcePeer2IP.String()) - } -} - -func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) { - account := getBasicAccountsWithResource() - - // two posture checks should match only the peers that match both checks - policy := account.Policies[0] - policy.SourcePostureChecks = []string{accNetResourceRelaxedPostureCheckID, accNetResourceLinuxPostureCheckID} - - // validate for peer1 - isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate for peer2 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate routes for router1 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.True(t, isRouter, "should be router") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 1, "expected source peers don't match") - assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") - - // validate rules for router1 - rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) - assert.Len(t, rules, 1, "expected rules count don't match") - assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") - assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") - if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { - t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) - } - if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { - t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) - } -} - -func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) { - account := getBasicAccountsWithResource() - - account.Peers["router2Id"] = &nbpeer.Peer{Key: "router2Key", ID: "router2Id", AccountID: accID, IP: net.IP{192, 168, 1, 4}} - account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{ - ID: "router2Id", - NetworkID: network1ID, - AccountID: accID, - Peer: "router2Id", - }) - - // validate routes for router1 - isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.True(t, isRouter, "should be router") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 2, "expected source peers don't match") -} - func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) { tests := []struct { name string diff --git a/management/server/types/holder.go b/management/server/types/holder.go deleted file mode 100644 index de8ac8110..000000000 --- a/management/server/types/holder.go +++ /dev/null @@ -1,47 +0,0 @@ -package types - -import ( - "context" - "sync" -) - -type Holder struct { - mu sync.RWMutex - accounts map[string]*Account -} - -func NewHolder() *Holder { - return &Holder{ - accounts: make(map[string]*Account), - } -} - -func (h *Holder) GetAccount(id string) *Account { - h.mu.RLock() - defer h.mu.RUnlock() - return h.accounts[id] -} - -func (h *Holder) AddAccount(account *Account) { - h.mu.Lock() - defer h.mu.Unlock() - a := h.accounts[account.Id] - if a != nil && a.Network.CurrentSerial() >= account.Network.CurrentSerial() { - return - } - h.accounts[account.Id] = account -} - -func (h *Holder) LoadOrStoreFunc(ctx context.Context, id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) { - h.mu.Lock() - defer h.mu.Unlock() - if acc, ok := h.accounts[id]; ok { - return acc, nil - } - account, err := accGetter(ctx, id) - if err != nil { - return nil, err - } - h.accounts[id] = account - return account, nil -} diff --git a/management/server/types/identity_provider.go b/management/server/types/identity_provider.go index c4498e4d4..0c1f9509c 100644 --- a/management/server/types/identity_provider.go +++ b/management/server/types/identity_provider.go @@ -39,6 +39,8 @@ const ( IdentityProviderTypeAuthentik IdentityProviderType = "authentik" // IdentityProviderTypeKeycloak is the Keycloak identity provider IdentityProviderTypeKeycloak IdentityProviderType = "keycloak" + // IdentityProviderTypeADFS is the Microsoft AD FS identity provider + IdentityProviderTypeADFS IdentityProviderType = "adfs" ) // IdentityProvider represents an identity provider configuration @@ -112,7 +114,8 @@ func (t IdentityProviderType) IsValid() bool { switch t { case IdentityProviderTypeOIDC, IdentityProviderTypeZitadel, IdentityProviderTypeEntra, IdentityProviderTypeGoogle, IdentityProviderTypeOkta, IdentityProviderTypePocketID, - IdentityProviderTypeMicrosoft, IdentityProviderTypeAuthentik, IdentityProviderTypeKeycloak: + IdentityProviderTypeMicrosoft, IdentityProviderTypeAuthentik, IdentityProviderTypeKeycloak, + IdentityProviderTypeADFS: return true } return false diff --git a/management/server/types/networkmap.go b/management/server/types/networkmap.go deleted file mode 100644 index 68c988a93..000000000 --- a/management/server/types/networkmap.go +++ /dev/null @@ -1,67 +0,0 @@ -package types - -import ( - "context" - - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/modules/zones" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/telemetry" -) - -func (a *Account) initNetworkMapBuilder(validatedPeers map[string]struct{}) { - if a.NetworkMapCache != nil { - return - } - a.nmapInitOnce.Do(func() { - a.NetworkMapCache = NewNetworkMapBuilder(a, validatedPeers) - }) -} - -func (a *Account) InitNetworkMapBuilderIfNeeded(validatedPeers map[string]struct{}) { - a.initNetworkMapBuilder(validatedPeers) -} - -func (a *Account) GetPeerNetworkMapExp( - ctx context.Context, - peerID string, - peersCustomZone nbdns.CustomZone, - accountZones []*zones.Zone, - validatedPeers map[string]struct{}, - metrics *telemetry.AccountManagerMetrics, -) *NetworkMap { - a.initNetworkMapBuilder(validatedPeers) - return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, accountZones, validatedPeers, metrics) -} - -func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error { - if a.NetworkMapCache == nil { - return nil - } - return a.NetworkMapCache.OnPeerAddedIncremental(a, peerId) -} - -func (a *Account) OnPeersAddedUpdNetworkMapCache(peerIds ...string) { - if a.NetworkMapCache == nil { - return - } - a.NetworkMapCache.EnqueuePeersForIncrementalAdd(a, peerIds...) -} - -func (a *Account) OnPeerDeletedUpdNetworkMapCache(peerId string) error { - if a.NetworkMapCache == nil { - return nil - } - return a.NetworkMapCache.OnPeerDeleted(a, peerId) -} - -func (a *Account) UpdatePeerInNetworkMapCache(peer *nbpeer.Peer) { - if a.NetworkMapCache == nil { - return - } - a.NetworkMapCache.UpdatePeer(peer) -} - -func (a *Account) RecalculateNetworkMapCache(validatedPeers map[string]struct{}) { - a.initNetworkMapBuilder(validatedPeers) -} diff --git a/management/server/types/networkmap_benchmark_test.go b/management/server/types/networkmap_benchmark_test.go new file mode 100644 index 000000000..38272e7b0 --- /dev/null +++ b/management/server/types/networkmap_benchmark_test.go @@ -0,0 +1,217 @@ +package types_test + +import ( + "context" + "fmt" + "os" + "testing" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/types" +) + +type benchmarkScale struct { + name string + peers int + groups int +} + +var defaultScales = []benchmarkScale{ + {"100peers_5groups", 100, 5}, + {"500peers_20groups", 500, 20}, + {"1000peers_50groups", 1000, 50}, + {"5000peers_100groups", 5000, 100}, + {"10000peers_200groups", 10000, 200}, + {"20000peers_200groups", 20000, 200}, + {"30000peers_300groups", 30000, 300}, +} + +func skipCIBenchmark(b *testing.B) { + if os.Getenv("CI") == "true" { + b.Skip("Skipping benchmark in CI") + } +} + +// ────────────────────────────────────────────────────────────────────────────── +// Single Peer Network Map Generation +// ────────────────────────────────────────────────────────────────────────────── + +// BenchmarkNetworkMapGeneration_Components benchmarks the components-based approach for a single peer. +func BenchmarkNetworkMapGeneration_Components(b *testing.B) { + skipCIBenchmark(b) + for _, scale := range defaultScales { + b.Run(scale.name, func(b *testing.B) { + account, validatedPeers := scalableTestAccount(scale.peers, scale.groups) + ctx := context.Background() + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) + } + }) + } +} + +// ────────────────────────────────────────────────────────────────────────────── +// All Peers (UpdateAccountPeers hot path) +// ────────────────────────────────────────────────────────────────────────────── + +// BenchmarkNetworkMapGeneration_AllPeers benchmarks generating network maps for ALL peers. +func BenchmarkNetworkMapGeneration_AllPeers(b *testing.B) { + skipCIBenchmark(b) + scales := []benchmarkScale{ + {"100peers_5groups", 100, 5}, + {"500peers_20groups", 500, 20}, + {"1000peers_50groups", 1000, 50}, + {"5000peers_100groups", 5000, 100}, + } + + for _, scale := range scales { + account, validatedPeers := scalableTestAccount(scale.peers, scale.groups) + ctx := context.Background() + + peerIDs := make([]string, 0, len(account.Peers)) + for peerID := range account.Peers { + peerIDs = append(peerIDs, peerID) + } + + b.Run("components/"+scale.name, func(b *testing.B) { + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + b.ReportAllocs() + b.ResetTimer() + for range b.N { + for _, peerID := range peerIDs { + _ = account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) + } + } + }) + } +} + +// ────────────────────────────────────────────────────────────────────────────── +// Sub-operations +// ────────────────────────────────────────────────────────────────────────────── + +// BenchmarkNetworkMapGeneration_ComponentsCreation benchmarks components extraction. +func BenchmarkNetworkMapGeneration_ComponentsCreation(b *testing.B) { + skipCIBenchmark(b) + for _, scale := range defaultScales { + b.Run(scale.name, func(b *testing.B) { + account, validatedPeers := scalableTestAccount(scale.peers, scale.groups) + ctx := context.Background() + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = account.GetPeerNetworkMapComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs) + } + }) + } +} + +// BenchmarkNetworkMapGeneration_ComponentsCalculation benchmarks calculation from pre-built components. +func BenchmarkNetworkMapGeneration_ComponentsCalculation(b *testing.B) { + skipCIBenchmark(b) + for _, scale := range defaultScales { + b.Run(scale.name, func(b *testing.B) { + account, validatedPeers := scalableTestAccount(scale.peers, scale.groups) + ctx := context.Background() + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + components := account.GetPeerNetworkMapComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs) + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = types.CalculateNetworkMapFromComponents(ctx, components) + } + }) + } +} + +// BenchmarkNetworkMapGeneration_PrecomputeMaps benchmarks precomputed map costs. +func BenchmarkNetworkMapGeneration_PrecomputeMaps(b *testing.B) { + skipCIBenchmark(b) + for _, scale := range defaultScales { + b.Run("ResourcePoliciesMap/"+scale.name, func(b *testing.B) { + account, _ := scalableTestAccount(scale.peers, scale.groups) + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = account.GetResourcePoliciesMap() + } + }) + b.Run("ResourceRoutersMap/"+scale.name, func(b *testing.B) { + account, _ := scalableTestAccount(scale.peers, scale.groups) + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = account.GetResourceRoutersMap() + } + }) + b.Run("ActiveGroupUsers/"+scale.name, func(b *testing.B) { + account, _ := scalableTestAccount(scale.peers, scale.groups) + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = account.GetActiveGroupUsers() + } + }) + } +} + +// ────────────────────────────────────────────────────────────────────────────── +// Scaling Analysis +// ────────────────────────────────────────────────────────────────────────────── + +// BenchmarkNetworkMapGeneration_GroupScaling tests group count impact on performance. +func BenchmarkNetworkMapGeneration_GroupScaling(b *testing.B) { + skipCIBenchmark(b) + groupCounts := []int{1, 5, 20, 50, 100, 200, 500} + for _, numGroups := range groupCounts { + b.Run(fmt.Sprintf("components_%dgroups", numGroups), func(b *testing.B) { + account, validatedPeers := scalableTestAccount(1000, numGroups) + ctx := context.Background() + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) + } + }) + } +} + +// BenchmarkNetworkMapGeneration_PeerScaling tests peer count impact on performance. +func BenchmarkNetworkMapGeneration_PeerScaling(b *testing.B) { + skipCIBenchmark(b) + peerCounts := []int{50, 100, 500, 1000, 2000, 5000, 10000, 20000, 30000} + for _, numPeers := range peerCounts { + numGroups := numPeers / 20 + if numGroups < 1 { + numGroups = 1 + } + b.Run(fmt.Sprintf("components_%dpeers", numPeers), func(b *testing.B) { + account, validatedPeers := scalableTestAccount(numPeers, numGroups) + ctx := context.Background() + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) + } + }) + } +} diff --git a/management/server/types/networkmap_comparison_test.go b/management/server/types/networkmap_comparison_test.go deleted file mode 100644 index c5844cca0..000000000 --- a/management/server/types/networkmap_comparison_test.go +++ /dev/null @@ -1,592 +0,0 @@ -package types - -import ( - "context" - "encoding/json" - "fmt" - "net" - "net/netip" - "os" - "path/filepath" - "sort" - "testing" - "time" - - "github.com/stretchr/testify/require" - - nbdns "github.com/netbirdio/netbird/dns" - resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" - routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" - networkTypes "github.com/netbirdio/netbird/management/server/networks/types" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/route" -) - -func TestNetworkMapComponents_CompareWithLegacy(t *testing.T) { - account := createTestAccount() - ctx := context.Background() - - peerID := testingPeerID - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - pid := fmt.Sprintf("peer-%d", i) - if pid == offlinePeerID { - continue - } - validatedPeersMap[pid] = struct{}{} - } - - peersCustomZone := nbdns.CustomZone{} - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - - legacyNetworkMap := account.GetPeerNetworkMap( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - nil, - groupIDToUserIDs, - ) - - components := account.GetPeerNetworkMapComponents( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - groupIDToUserIDs, - ) - - if components == nil { - t.Fatal("GetPeerNetworkMapComponents returned nil") - } - - newNetworkMap := CalculateNetworkMapFromComponents(ctx, components) - - if newNetworkMap == nil { - t.Fatal("CalculateNetworkMapFromComponents returned nil") - } - - compareNetworkMaps(t, legacyNetworkMap, newNetworkMap) -} - -func TestNetworkMapComponents_GoldenFileComparison(t *testing.T) { - account := createTestAccount() - ctx := context.Background() - - peerID := testingPeerID - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - pid := fmt.Sprintf("peer-%d", i) - if pid == offlinePeerID { - continue - } - validatedPeersMap[pid] = struct{}{} - } - - peersCustomZone := nbdns.CustomZone{} - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - - legacyNetworkMap := account.GetPeerNetworkMap( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - nil, - groupIDToUserIDs, - ) - - components := account.GetPeerNetworkMapComponents( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - groupIDToUserIDs, - ) - - require.NotNil(t, components, "GetPeerNetworkMapComponents returned nil") - - newNetworkMap := CalculateNetworkMapFromComponents(ctx, components) - require.NotNil(t, newNetworkMap, "CalculateNetworkMapFromComponents returned nil") - - normalizeAndSortNetworkMap(legacyNetworkMap) - normalizeAndSortNetworkMap(newNetworkMap) - - componentsJSON, err := json.MarshalIndent(components, "", " ") - require.NoError(t, err, "error marshaling components to JSON") - - legacyJSON, err := json.MarshalIndent(legacyNetworkMap, "", " ") - require.NoError(t, err, "error marshaling legacy network map to JSON") - - newJSON, err := json.MarshalIndent(newNetworkMap, "", " ") - require.NoError(t, err, "error marshaling new network map to JSON") - - goldenDir := filepath.Join("testdata", "comparison") - err = os.MkdirAll(goldenDir, 0755) - require.NoError(t, err) - - legacyGoldenPath := filepath.Join(goldenDir, "legacy_networkmap.json") - err = os.WriteFile(legacyGoldenPath, legacyJSON, 0644) - require.NoError(t, err, "error writing legacy golden file") - - newGoldenPath := filepath.Join(goldenDir, "components_networkmap.json") - err = os.WriteFile(newGoldenPath, newJSON, 0644) - require.NoError(t, err, "error writing components golden file") - - componentsPath := filepath.Join(goldenDir, "components.json") - err = os.WriteFile(componentsPath, componentsJSON, 0644) - require.NoError(t, err, "error writing components golden file") - - require.JSONEq(t, string(legacyJSON), string(newJSON), - "NetworkMaps from legacy and components approaches do not match.\n"+ - "Legacy JSON saved to: %s\n"+ - "Components JSON saved to: %s", - legacyGoldenPath, newGoldenPath) - - t.Logf("✅ NetworkMaps are identical") - t.Logf(" Legacy NetworkMap: %s", legacyGoldenPath) - t.Logf(" Components NetworkMap: %s", newGoldenPath) -} - -func normalizeAndSortNetworkMap(nm *NetworkMap) { - if nm == nil { - return - } - - sort.Slice(nm.Peers, func(i, j int) bool { - return nm.Peers[i].ID < nm.Peers[j].ID - }) - - sort.Slice(nm.OfflinePeers, func(i, j int) bool { - return nm.OfflinePeers[i].ID < nm.OfflinePeers[j].ID - }) - - sort.Slice(nm.Routes, func(i, j int) bool { - return string(nm.Routes[i].ID) < string(nm.Routes[j].ID) - }) - - sort.Slice(nm.FirewallRules, func(i, j int) bool { - if nm.FirewallRules[i].PeerIP != nm.FirewallRules[j].PeerIP { - return nm.FirewallRules[i].PeerIP < nm.FirewallRules[j].PeerIP - } - if nm.FirewallRules[i].Direction != nm.FirewallRules[j].Direction { - return nm.FirewallRules[i].Direction < nm.FirewallRules[j].Direction - } - if nm.FirewallRules[i].Protocol != nm.FirewallRules[j].Protocol { - return nm.FirewallRules[i].Protocol < nm.FirewallRules[j].Protocol - } - if nm.FirewallRules[i].Port != nm.FirewallRules[j].Port { - return nm.FirewallRules[i].Port < nm.FirewallRules[j].Port - } - return nm.FirewallRules[i].PolicyID < nm.FirewallRules[j].PolicyID - }) - - for i := range nm.RoutesFirewallRules { - sort.Strings(nm.RoutesFirewallRules[i].SourceRanges) - } - - sort.Slice(nm.RoutesFirewallRules, func(i, j int) bool { - if nm.RoutesFirewallRules[i].Destination != nm.RoutesFirewallRules[j].Destination { - return nm.RoutesFirewallRules[i].Destination < nm.RoutesFirewallRules[j].Destination - } - - minLen := len(nm.RoutesFirewallRules[i].SourceRanges) - if len(nm.RoutesFirewallRules[j].SourceRanges) < minLen { - minLen = len(nm.RoutesFirewallRules[j].SourceRanges) - } - for k := 0; k < minLen; k++ { - if nm.RoutesFirewallRules[i].SourceRanges[k] != nm.RoutesFirewallRules[j].SourceRanges[k] { - return nm.RoutesFirewallRules[i].SourceRanges[k] < nm.RoutesFirewallRules[j].SourceRanges[k] - } - } - if len(nm.RoutesFirewallRules[i].SourceRanges) != len(nm.RoutesFirewallRules[j].SourceRanges) { - return len(nm.RoutesFirewallRules[i].SourceRanges) < len(nm.RoutesFirewallRules[j].SourceRanges) - } - - if string(nm.RoutesFirewallRules[i].RouteID) != string(nm.RoutesFirewallRules[j].RouteID) { - return string(nm.RoutesFirewallRules[i].RouteID) < string(nm.RoutesFirewallRules[j].RouteID) - } - - if nm.RoutesFirewallRules[i].PolicyID != nm.RoutesFirewallRules[j].PolicyID { - return nm.RoutesFirewallRules[i].PolicyID < nm.RoutesFirewallRules[j].PolicyID - } - - if nm.RoutesFirewallRules[i].Port != nm.RoutesFirewallRules[j].Port { - return nm.RoutesFirewallRules[i].Port < nm.RoutesFirewallRules[j].Port - } - - return nm.RoutesFirewallRules[i].Protocol < nm.RoutesFirewallRules[j].Protocol - }) - - if nm.DNSConfig.CustomZones != nil { - for i := range nm.DNSConfig.CustomZones { - sort.Slice(nm.DNSConfig.CustomZones[i].Records, func(a, b int) bool { - return nm.DNSConfig.CustomZones[i].Records[a].Name < nm.DNSConfig.CustomZones[i].Records[b].Name - }) - } - } - - if len(nm.DNSConfig.NameServerGroups) != 0 { - sort.Slice(nm.DNSConfig.NameServerGroups, func(a, b int) bool { - return nm.DNSConfig.NameServerGroups[a].Name < nm.DNSConfig.NameServerGroups[b].Name - }) - } -} - -func compareNetworkMaps(t *testing.T, legacy, current *NetworkMap) { - t.Helper() - - if legacy.Network.Serial != current.Network.Serial { - t.Errorf("Network Serial mismatch: legacy=%d, current=%d", legacy.Network.Serial, current.Network.Serial) - } - - if len(legacy.Peers) != len(current.Peers) { - t.Errorf("Peers count mismatch: legacy=%d, current=%d", len(legacy.Peers), len(current.Peers)) - } - - legacyPeerIDs := make(map[string]bool) - for _, p := range legacy.Peers { - legacyPeerIDs[p.ID] = true - } - - for _, p := range current.Peers { - if !legacyPeerIDs[p.ID] { - t.Errorf("Current NetworkMap contains peer %s not in legacy", p.ID) - } - } - - if len(legacy.OfflinePeers) != len(current.OfflinePeers) { - t.Errorf("OfflinePeers count mismatch: legacy=%d, current=%d", len(legacy.OfflinePeers), len(current.OfflinePeers)) - } - - if len(legacy.FirewallRules) != len(current.FirewallRules) { - t.Logf("FirewallRules count mismatch: legacy=%d, current=%d", len(legacy.FirewallRules), len(current.FirewallRules)) - } - - if len(legacy.Routes) != len(current.Routes) { - t.Logf("Routes count mismatch: legacy=%d, current=%d", len(legacy.Routes), len(current.Routes)) - } - - if len(legacy.RoutesFirewallRules) != len(current.RoutesFirewallRules) { - t.Logf("RoutesFirewallRules count mismatch: legacy=%d, current=%d", len(legacy.RoutesFirewallRules), len(current.RoutesFirewallRules)) - } - - if legacy.DNSConfig.ServiceEnable != current.DNSConfig.ServiceEnable { - t.Errorf("DNSConfig.ServiceEnable mismatch: legacy=%v, current=%v", legacy.DNSConfig.ServiceEnable, current.DNSConfig.ServiceEnable) - } -} - -const ( - numPeers = 100 - devGroupID = "group-dev" - opsGroupID = "group-ops" - allGroupID = "group-all" - routeID = route.ID("route-main") - routeHA1ID = route.ID("route-ha-1") - routeHA2ID = route.ID("route-ha-2") - policyIDDevOps = "policy-dev-ops" - policyIDAll = "policy-all" - policyIDPosture = "policy-posture" - policyIDDrop = "policy-drop" - postureCheckID = "posture-check-ver" - networkResourceID = "res-database" - networkID = "net-database" - networkRouterID = "router-database" - nameserverGroupID = "ns-group-main" - testingPeerID = "peer-60" - expiredPeerID = "peer-98" - offlinePeerID = "peer-99" - routingPeerID = "peer-95" - testAccountID = "account-comparison-test" -) - -func createTestAccount() *Account { - peers := make(map[string]*nbpeer.Peer) - devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{} - - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - ip := net.IP{100, 64, 0, byte(i + 1)} - wtVersion := "0.25.0" - if i%2 == 0 { - wtVersion = "0.40.0" - } - - p := &nbpeer.Peer{ - ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1), - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, - UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"}, - } - - if peerID == expiredPeerID { - p.LoginExpirationEnabled = true - pastTimestamp := time.Now().Add(-2 * time.Hour) - p.LastLogin = &pastTimestamp - } - - peers[peerID] = p - allGroupPeers = append(allGroupPeers, peerID) - if i < numPeers/2 { - devGroupPeers = append(devGroupPeers, peerID) - } else { - opsGroupPeers = append(opsGroupPeers, peerID) - } - } - - groups := map[string]*Group{ - allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers}, - devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers}, - opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers}, - } - - policies := []*Policy{ - { - ID: policyIDAll, Name: "Default-Allow", Enabled: true, - Rules: []*PolicyRule{{ - ID: policyIDAll, Name: "Allow All", Enabled: true, Action: PolicyTrafficActionAccept, - Protocol: PolicyRuleProtocolALL, Bidirectional: true, - Sources: []string{allGroupID}, Destinations: []string{allGroupID}, - }}, - }, - { - ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true, - Rules: []*PolicyRule{{ - ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: PolicyTrafficActionAccept, - Protocol: PolicyRuleProtocolTCP, Bidirectional: false, - PortRanges: []RulePortRange{{Start: 8080, End: 8090}}, - Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, - }}, - }, - { - ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true, - Rules: []*PolicyRule{{ - ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: PolicyTrafficActionDrop, - Protocol: PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true, - Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, - }}, - }, - { - ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true, - SourcePostureChecks: []string{postureCheckID}, - Rules: []*PolicyRule{{ - ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: PolicyTrafficActionAccept, - Protocol: PolicyRuleProtocolALL, Bidirectional: true, - Sources: []string{opsGroupID}, DestinationResource: Resource{ID: networkResourceID}, - }}, - }, - } - - routes := map[route.ID]*route.Route{ - routeID: { - ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"), - Peer: peers["peer-75"].Key, - PeerID: "peer-75", - Description: "Route to internal resource", Enabled: true, - PeerGroups: []string{devGroupID, opsGroupID}, - Groups: []string{devGroupID, opsGroupID}, - AccessControlGroups: []string{devGroupID}, - }, - routeHA1ID: { - ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"), - Peer: peers["peer-80"].Key, - PeerID: "peer-80", - Description: "HA Route 1", Enabled: true, Metric: 1000, - PeerGroups: []string{allGroupID}, - Groups: []string{allGroupID}, - AccessControlGroups: []string{allGroupID}, - }, - routeHA2ID: { - ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"), - Peer: peers["peer-90"].Key, - PeerID: "peer-90", - Description: "HA Route 2", Enabled: true, Metric: 900, - PeerGroups: []string{devGroupID, opsGroupID}, - Groups: []string{devGroupID, opsGroupID}, - AccessControlGroups: []string{allGroupID}, - }, - } - - account := &Account{ - Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes, - Network: &Network{ - Identifier: "net-comparison-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1, - }, - DNSSettings: DNSSettings{DisabledManagementGroups: []string{opsGroupID}}, - NameServerGroups: map[string]*nbdns.NameServerGroup{ - nameserverGroupID: { - ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID}, - NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}}, - }, - }, - PostureChecks: []*posture.Checks{ - {ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"}, - }}, - }, - NetworkResources: []*resourceTypes.NetworkResource{ - {ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"}, - }, - Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}}, - NetworkRouters: []*routerTypes.NetworkRouter{ - {ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID}, - }, - Settings: &Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour}, - } - - for _, p := range account.Policies { - p.AccountID = account.Id - } - for _, r := range account.Routes { - r.AccountID = account.Id - } - - return account -} - -func BenchmarkLegacyNetworkMap(b *testing.B) { - account := createTestAccount() - ctx := context.Background() - peerID := testingPeerID - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - pid := fmt.Sprintf("peer-%d", i) - if pid != offlinePeerID { - validatedPeersMap[pid] = struct{}{} - } - } - - peersCustomZone := nbdns.CustomZone{} - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = account.GetPeerNetworkMap( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - nil, - groupIDToUserIDs, - ) - } -} - -func BenchmarkComponentsNetworkMap(b *testing.B) { - account := createTestAccount() - ctx := context.Background() - peerID := testingPeerID - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - pid := fmt.Sprintf("peer-%d", i) - if pid != offlinePeerID { - validatedPeersMap[pid] = struct{}{} - } - } - - peersCustomZone := nbdns.CustomZone{} - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - components := account.GetPeerNetworkMapComponents( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - groupIDToUserIDs, - ) - _ = CalculateNetworkMapFromComponents(ctx, components) - } -} - -func BenchmarkComponentsCreation(b *testing.B) { - account := createTestAccount() - ctx := context.Background() - peerID := testingPeerID - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - pid := fmt.Sprintf("peer-%d", i) - if pid != offlinePeerID { - validatedPeersMap[pid] = struct{}{} - } - } - - peersCustomZone := nbdns.CustomZone{} - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = account.GetPeerNetworkMapComponents( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - groupIDToUserIDs, - ) - } -} - -func BenchmarkCalculationFromComponents(b *testing.B) { - account := createTestAccount() - ctx := context.Background() - peerID := testingPeerID - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - pid := fmt.Sprintf("peer-%d", i) - if pid != offlinePeerID { - validatedPeersMap[pid] = struct{}{} - } - } - - peersCustomZone := nbdns.CustomZone{} - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - - components := account.GetPeerNetworkMapComponents( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - groupIDToUserIDs, - ) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = CalculateNetworkMapFromComponents(ctx, components) - } -} diff --git a/management/server/types/networkmap_components.go b/management/server/types/networkmap_components.go index 23d84a994..6f84c8d30 100644 --- a/management/server/types/networkmap_components.go +++ b/management/server/types/networkmap_components.go @@ -19,8 +19,6 @@ import ( "github.com/netbirdio/netbird/shared/management/domain" ) -const EnvNewNetworkMapCompacted = "NB_NETWORK_MAP_COMPACTED" - type NetworkMapComponents struct { PeerID string diff --git a/management/server/types/networkmap_components_correctness_test.go b/management/server/types/networkmap_components_correctness_test.go new file mode 100644 index 000000000..5cd41ff10 --- /dev/null +++ b/management/server/types/networkmap_components_correctness_test.go @@ -0,0 +1,1192 @@ +package types_test + +import ( + "context" + "fmt" + "net" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" +) + +// scalableTestAccountWithoutDefaultPolicy creates an account without the blanket "Allow All" policy. +// Use this for tests that need to verify feature-specific connectivity in isolation. +func scalableTestAccountWithoutDefaultPolicy(numPeers, numGroups int) (*types.Account, map[string]struct{}) { + return buildScalableTestAccount(numPeers, numGroups, false) +} + +// scalableTestAccount creates a realistic account with a blanket "Allow All" policy +// plus per-group policies, routes, network resources, posture checks, and DNS settings. +func scalableTestAccount(numPeers, numGroups int) (*types.Account, map[string]struct{}) { + return buildScalableTestAccount(numPeers, numGroups, true) +} + +// buildScalableTestAccount is the core builder. When withDefaultPolicy is true it adds +// a blanket group-all <-> group-all allow rule; when false the only policies are the +// per-group ones, so tests can verify feature-specific connectivity in isolation. +func buildScalableTestAccount(numPeers, numGroups int, withDefaultPolicy bool) (*types.Account, map[string]struct{}) { + peers := make(map[string]*nbpeer.Peer, numPeers) + allGroupPeers := make([]string, 0, numPeers) + + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + ip := net.IP{100, byte(64 + i/65536), byte((i / 256) % 256), byte(i % 256)} + wtVersion := "0.25.0" + if i%2 == 0 { + wtVersion = "0.40.0" + } + + p := &nbpeer.Peer{ + ID: peerID, + IP: ip, + Key: fmt.Sprintf("key-%s", peerID), + DNSLabel: fmt.Sprintf("peer%d", i), + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"}, + } + + if i == numPeers-2 { + p.LoginExpirationEnabled = true + pastTimestamp := time.Now().Add(-2 * time.Hour) + p.LastLogin = &pastTimestamp + } + + peers[peerID] = p + allGroupPeers = append(allGroupPeers, peerID) + } + + groups := make(map[string]*types.Group, numGroups+1) + groups["group-all"] = &types.Group{ID: "group-all", Name: "All", Peers: allGroupPeers} + + peersPerGroup := numPeers / numGroups + if peersPerGroup < 1 { + peersPerGroup = 1 + } + + for g := range numGroups { + groupID := fmt.Sprintf("group-%d", g) + groupPeers := make([]string, 0, peersPerGroup) + start := g * peersPerGroup + end := start + peersPerGroup + if end > numPeers { + end = numPeers + } + for i := start; i < end; i++ { + groupPeers = append(groupPeers, fmt.Sprintf("peer-%d", i)) + } + groups[groupID] = &types.Group{ID: groupID, Name: fmt.Sprintf("Group %d", g), Peers: groupPeers} + } + + policies := make([]*types.Policy, 0, numGroups+2) + if withDefaultPolicy { + policies = append(policies, &types.Policy{ + ID: "policy-all", Name: "Default-Allow", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: "rule-all", Name: "Allow All", Enabled: true, Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{"group-all"}, Destinations: []string{"group-all"}, + }}, + }) + } + + for g := range numGroups { + groupID := fmt.Sprintf("group-%d", g) + dstGroup := fmt.Sprintf("group-%d", (g+1)%numGroups) + policies = append(policies, &types.Policy{ + ID: fmt.Sprintf("policy-%d", g), Name: fmt.Sprintf("Policy %d", g), Enabled: true, + Rules: []*types.PolicyRule{{ + ID: fmt.Sprintf("rule-%d", g), Name: fmt.Sprintf("Rule %d", g), Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, + Ports: []string{"8080"}, + Sources: []string{groupID}, Destinations: []string{dstGroup}, + }}, + }) + } + + if numGroups >= 2 { + policies = append(policies, &types.Policy{ + ID: "policy-drop", Name: "Drop DB traffic", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: "rule-drop", Name: "Drop DB", Enabled: true, Action: types.PolicyTrafficActionDrop, + Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + }}, + }) + } + + numRoutes := numGroups + if numRoutes > 20 { + numRoutes = 20 + } + routes := make(map[route.ID]*route.Route, numRoutes) + for r := range numRoutes { + routeID := route.ID(fmt.Sprintf("route-%d", r)) + peerIdx := (numPeers / 2) + r + if peerIdx >= numPeers { + peerIdx = numPeers - 1 + } + routePeerID := fmt.Sprintf("peer-%d", peerIdx) + groupID := fmt.Sprintf("group-%d", r%numGroups) + routes[routeID] = &route.Route{ + ID: routeID, + Network: netip.MustParsePrefix(fmt.Sprintf("10.%d.0.0/16", r)), + Peer: peers[routePeerID].Key, + PeerID: routePeerID, + Description: fmt.Sprintf("Route %d", r), + Enabled: true, + PeerGroups: []string{groupID}, + Groups: []string{"group-all"}, + AccessControlGroups: []string{groupID}, + AccountID: "test-account", + } + } + + numResources := numGroups / 2 + if numResources < 1 { + numResources = 1 + } + if numResources > 50 { + numResources = 50 + } + + networkResources := make([]*resourceTypes.NetworkResource, 0, numResources) + networksList := make([]*networkTypes.Network, 0, numResources) + networkRouters := make([]*routerTypes.NetworkRouter, 0, numResources) + + routingPeerStart := numPeers * 3 / 4 + for nr := range numResources { + netID := fmt.Sprintf("net-%d", nr) + resID := fmt.Sprintf("res-%d", nr) + routerPeerIdx := routingPeerStart + nr + if routerPeerIdx >= numPeers { + routerPeerIdx = numPeers - 1 + } + routerPeerID := fmt.Sprintf("peer-%d", routerPeerIdx) + + networksList = append(networksList, &networkTypes.Network{ID: netID, Name: fmt.Sprintf("Network %d", nr), AccountID: "test-account"}) + networkResources = append(networkResources, &resourceTypes.NetworkResource{ + ID: resID, NetworkID: netID, AccountID: "test-account", Enabled: true, + Address: fmt.Sprintf("svc-%d.netbird.cloud", nr), + }) + networkRouters = append(networkRouters, &routerTypes.NetworkRouter{ + ID: fmt.Sprintf("router-%d", nr), NetworkID: netID, Peer: routerPeerID, + Enabled: true, AccountID: "test-account", + }) + + policies = append(policies, &types.Policy{ + ID: fmt.Sprintf("policy-res-%d", nr), Name: fmt.Sprintf("Resource Policy %d", nr), Enabled: true, + SourcePostureChecks: []string{"posture-check-ver"}, + Rules: []*types.PolicyRule{{ + ID: fmt.Sprintf("rule-res-%d", nr), Name: fmt.Sprintf("Allow Resource %d", nr), Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{fmt.Sprintf("group-%d", nr%numGroups)}, + DestinationResource: types.Resource{ID: resID}, + }}, + }) + } + + account := &types.Account{ + Id: "test-account", + Peers: peers, + Groups: groups, + Policies: policies, + Routes: routes, + Users: map[string]*types.User{ + "user-admin": {Id: "user-admin", Role: types.UserRoleAdmin, IsServiceUser: false, AccountID: "test-account"}, + }, + Network: &types.Network{ + Identifier: "net-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}, Serial: 1, + }, + DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{}}, + NameServerGroups: map[string]*nbdns.NameServerGroup{ + "ns-group-main": { + ID: "ns-group-main", Name: "Main NS", Enabled: true, Groups: []string{"group-all"}, + NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}}, + }, + }, + PostureChecks: []*posture.Checks{ + {ID: "posture-check-ver", Name: "Check version", Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"}, + }}, + }, + NetworkResources: networkResources, + Networks: networksList, + NetworkRouters: networkRouters, + Settings: &types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour}, + } + + for _, p := range account.Policies { + p.AccountID = account.Id + } + for _, r := range account.Routes { + r.AccountID = account.Id + } + + validatedPeers := make(map[string]struct{}, numPeers) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if i != numPeers-1 { + validatedPeers[peerID] = struct{}{} + } + } + + return account, validatedPeers +} + +// componentsNetworkMap is a convenience wrapper for GetPeerNetworkMapFromComponents. +func componentsNetworkMap(account *types.Account, peerID string, validatedPeers map[string]struct{}) *types.NetworkMap { + return account.GetPeerNetworkMapFromComponents( + context.Background(), peerID, nbdns.CustomZone{}, nil, + validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), + nil, account.GetActiveGroupUsers(), + ) +} + +// ────────────────────────────────────────────────────────────────────────────── +// 1. PEER VISIBILITY & GROUPS +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_PeerVisibility(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.Equal(t, len(validatedPeers)-1-len(nm.OfflinePeers), len(nm.Peers), "peer should see all other validated non-expired peers") +} + +func TestComponents_PeerDoesNotSeeItself(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + for _, p := range nm.Peers { + assert.NotEqual(t, "peer-0", p.ID, "peer should not see itself") + } +} + +func TestComponents_IntraGroupConnectivity(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + peerIDs := make(map[string]bool, len(nm.Peers)) + for _, p := range nm.Peers { + peerIDs[p.ID] = true + } + assert.True(t, peerIDs["peer-5"], "peer-0 should see peer-5 from same group") +} + +func TestComponents_CrossGroupConnectivity(t *testing.T) { + // Without default policy, only per-group policies provide connectivity + account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(20, 2) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + peerIDs := make(map[string]bool, len(nm.Peers)) + for _, p := range nm.Peers { + peerIDs[p.ID] = true + } + assert.True(t, peerIDs["peer-10"], "peer-0 should see peer-10 from cross-group policy") +} + +func TestComponents_BidirectionalPolicy(t *testing.T) { + // Without default policy so bidirectional visibility comes only from per-group policies + account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(100, 5) + nm0 := componentsNetworkMap(account, "peer-0", validatedPeers) + nm20 := componentsNetworkMap(account, "peer-20", validatedPeers) + require.NotNil(t, nm0) + require.NotNil(t, nm20) + + peer0SeesPeer20 := false + for _, p := range nm0.Peers { + if p.ID == "peer-20" { + peer0SeesPeer20 = true + } + } + peer20SeesPeer0 := false + for _, p := range nm20.Peers { + if p.ID == "peer-0" { + peer20SeesPeer0 = true + } + } + assert.True(t, peer0SeesPeer20, "peer-0 should see peer-20 via bidirectional policy") + assert.True(t, peer20SeesPeer0, "peer-20 should see peer-0 via bidirectional policy") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 2. PEER EXPIRATION & ACCOUNT SETTINGS +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_ExpiredPeerInOfflineList(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + offlineIDs := make(map[string]bool, len(nm.OfflinePeers)) + for _, p := range nm.OfflinePeers { + offlineIDs[p.ID] = true + } + assert.True(t, offlineIDs["peer-98"], "expired peer should be in OfflinePeers") + for _, p := range nm.Peers { + assert.NotEqual(t, "peer-98", p.ID, "expired peer should not be in active Peers") + } +} + +func TestComponents_ExpirationDisabledSetting(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + account.Settings.PeerLoginExpirationEnabled = false + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + peerIDs := make(map[string]bool, len(nm.Peers)) + for _, p := range nm.Peers { + peerIDs[p.ID] = true + } + assert.True(t, peerIDs["peer-98"], "with expiration disabled, peer-98 should be in active Peers") +} + +func TestComponents_LoginExpiration_PeerLevel(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + account.Settings.PeerLoginExpirationEnabled = true + account.Settings.PeerLoginExpiration = 1 * time.Hour + + pastLogin := time.Now().Add(-2 * time.Hour) + account.Peers["peer-5"].LastLogin = &pastLogin + account.Peers["peer-5"].LoginExpirationEnabled = true + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + offlineIDs := make(map[string]bool, len(nm.OfflinePeers)) + for _, p := range nm.OfflinePeers { + offlineIDs[p.ID] = true + } + assert.True(t, offlineIDs["peer-5"], "login-expired peer should be in OfflinePeers") + for _, p := range nm.Peers { + assert.NotEqual(t, "peer-5", p.ID, "login-expired peer should not be in active Peers") + } +} + +func TestComponents_NetworkSerial(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 5) + account.Network.Serial = 42 + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.Equal(t, uint64(42), nm.Network.Serial, "network serial should match") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 3. NON-VALIDATED PEERS +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_NonValidatedPeerExcluded(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + for _, p := range nm.Peers { + assert.NotEqual(t, "peer-99", p.ID, "non-validated peer should not appear in Peers") + } + for _, p := range nm.OfflinePeers { + assert.NotEqual(t, "peer-99", p.ID, "non-validated peer should not appear in OfflinePeers") + } +} + +func TestComponents_NonValidatedTargetPeerGetsEmptyMap(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-99", validatedPeers) + require.NotNil(t, nm) + assert.Empty(t, nm.Peers) + assert.Empty(t, nm.FirewallRules) +} + +func TestComponents_NonExistentPeerGetsEmptyMap(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-does-not-exist", validatedPeers) + require.NotNil(t, nm) + assert.Empty(t, nm.Peers) + assert.Empty(t, nm.FirewallRules) +} + +// ────────────────────────────────────────────────────────────────────────────── +// 4. POLICIES & FIREWALL RULES +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_FirewallRulesGenerated(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.NotEmpty(t, nm.FirewallRules, "should have firewall rules from policies") +} + +func TestComponents_DropPolicyGeneratesDropRules(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + hasDropRule := false + for _, rule := range nm.FirewallRules { + if rule.Action == string(types.PolicyTrafficActionDrop) { + hasDropRule = true + break + } + } + assert.True(t, hasDropRule, "should have at least one drop firewall rule") +} + +func TestComponents_DisabledPolicyIgnored(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 2) + for _, p := range account.Policies { + p.Enabled = false + } + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.Empty(t, nm.Peers, "disabled policies should yield no peers") + assert.Empty(t, nm.FirewallRules, "disabled policies should yield no firewall rules") +} + +func TestComponents_PortPolicy(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 2) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + has8080, has5432 := false, false + for _, rule := range nm.FirewallRules { + if rule.Port == "8080" { + has8080 = true + } + if rule.Port == "5432" { + has5432 = true + } + } + assert.True(t, has8080, "should have firewall rule for port 8080") + assert.True(t, has5432, "should have firewall rule for port 5432 (drop policy)") +} + +func TestComponents_PortRangePolicy(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 2) + account.Peers["peer-0"].Meta.WtVersion = "0.50.0" + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-port-range", Name: "Port Range", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{{ + ID: "rule-port-range", Name: "Port Range Rule", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, + PortRanges: []types.RulePortRange{{Start: 8000, End: 9000}}, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + }}, + }) + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + hasPortRange := false + for _, rule := range nm.FirewallRules { + if rule.PortRange.Start == 8000 && rule.PortRange.End == 9000 { + hasPortRange = true + break + } + } + assert.True(t, hasPortRange, "should have firewall rule with port range 8000-9000") +} + +func TestComponents_FirewallRuleDirection(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 2) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + hasIn, hasOut := false, false + for _, rule := range nm.FirewallRules { + if rule.Direction == types.FirewallRuleDirectionIN { + hasIn = true + } + if rule.Direction == types.FirewallRuleDirectionOUT { + hasOut = true + } + } + assert.True(t, hasIn, "should have inbound firewall rules") + assert.True(t, hasOut, "should have outbound firewall rules") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 5. ROUTES +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_RoutesIncluded(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.NotEmpty(t, nm.Routes, "should have routes") +} + +func TestComponents_DisabledRouteExcluded(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 2) + for _, r := range account.Routes { + r.Enabled = false + } + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + for _, r := range nm.Routes { + assert.True(t, r.Enabled, "only enabled routes should appear") + } +} + +func TestComponents_RoutesFirewallRulesForACG(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.NotEmpty(t, nm.RoutesFirewallRules, "should have route firewall rules for access-controlled routes") +} + +func TestComponents_HARouteDeduplication(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 5) + + haNetwork := netip.MustParsePrefix("172.16.0.0/16") + account.Routes["route-ha-1"] = &route.Route{ + ID: "route-ha-1", Network: haNetwork, PeerID: "peer-10", + Peer: account.Peers["peer-10"].Key, Enabled: true, Metric: 100, + Groups: []string{"group-all"}, PeerGroups: []string{"group-0"}, AccountID: "test-account", + } + account.Routes["route-ha-2"] = &route.Route{ + ID: "route-ha-2", Network: haNetwork, PeerID: "peer-20", + Peer: account.Peers["peer-20"].Key, Enabled: true, Metric: 200, + Groups: []string{"group-all"}, PeerGroups: []string{"group-1"}, AccountID: "test-account", + } + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + haRoutes := 0 + for _, r := range nm.Routes { + if r.Network == haNetwork { + haRoutes++ + } + } + // Components deduplicates HA routes with the same HA unique ID, returning one entry per HA group + assert.Equal(t, 1, haRoutes, "HA routes with same network should be deduplicated into one entry") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 6. NETWORK RESOURCES & ROUTERS +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_NetworkResourceRoutes_RouterPeer(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + + var routerPeerID string + for _, nr := range account.NetworkRouters { + routerPeerID = nr.Peer + break + } + require.NotEmpty(t, routerPeerID) + + nm := componentsNetworkMap(account, routerPeerID, validatedPeers) + require.NotNil(t, nm) + assert.NotEmpty(t, nm.Peers, "router peer should see source peers") +} + +func TestComponents_NetworkResourceRoutes_SourcePeerSeesRouterPeer(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + + var routerPeerID string + for _, nr := range account.NetworkRouters { + routerPeerID = nr.Peer + break + } + require.NotEmpty(t, routerPeerID) + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + peerIDs := make(map[string]bool, len(nm.Peers)) + for _, p := range nm.Peers { + peerIDs[p.ID] = true + } + assert.True(t, peerIDs[routerPeerID], "source peer should see router peer for network resource") +} + +func TestComponents_DisabledNetworkResourceIgnored(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 5) + for _, nr := range account.NetworkResources { + nr.Enabled = false + } + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.NotNil(t, nm.Network) +} + +// ────────────────────────────────────────────────────────────────────────────── +// 7. POSTURE CHECKS +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_PostureCheckFiltering_PassingPeer(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.NotEmpty(t, nm.Routes, "passing peer should have routes including resource routes") +} + +func TestComponents_PostureCheckFiltering_FailingPeer(t *testing.T) { + // peer-0 has version 0.40.0 (passes posture check >= 0.26.0) + // peer-1 has version 0.25.0 (fails posture check >= 0.26.0) + // Resource policies require posture-check-ver, so the failing peer + // should not see the router peer for those resources. + account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(100, 5) + + nm0 := componentsNetworkMap(account, "peer-0", validatedPeers) + nm1 := componentsNetworkMap(account, "peer-1", validatedPeers) + require.NotNil(t, nm0) + require.NotNil(t, nm1) + + // The passing peer should have more peers visible (including resource router peers) + // than the failing peer, because the failing peer is excluded from resource policies. + assert.Greater(t, len(nm0.Peers), len(nm1.Peers), + "passing peer (0.40.0) should see more peers than failing peer (0.25.0) due to posture-gated resource policies") +} + +func TestComponents_MultiplePostureChecks(t *testing.T) { + account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(50, 2) + + // Keep only the posture-gated policy — remove per-group policies so connectivity is isolated + account.Policies = []*types.Policy{} + + // Set kernel version on peers so the OS posture check can evaluate + for _, p := range account.Peers { + p.Meta.KernelVersion = "5.15.0" + } + + account.PostureChecks = append(account.PostureChecks, &posture.Checks{ + ID: "posture-check-os", Name: "Check OS", + Checks: posture.ChecksDefinition{ + OSVersionCheck: &posture.OSVersionCheck{Linux: &posture.MinKernelVersionCheck{MinKernelVersion: "0.0.1"}}, + }, + }) + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-multi-posture", Name: "Multi Posture", Enabled: true, AccountID: "test-account", + SourcePostureChecks: []string{"posture-check-ver", "posture-check-os"}, + Rules: []*types.PolicyRule{{ + ID: "rule-multi-posture", Name: "Multi Check Rule", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Bidirectional: true, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + }}, + }) + + // peer-0 (0.40.0, kernel 5.15.0) passes both checks, should see group-1 peers + nm0 := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm0) + assert.NotEmpty(t, nm0.Peers, "peer passing both posture checks should see destination peers") + + // peer-1 (0.25.0, kernel 5.15.0) fails version check, should NOT see group-1 peers + nm1 := componentsNetworkMap(account, "peer-1", validatedPeers) + require.NotNil(t, nm1) + assert.Empty(t, nm1.Peers, + "peer failing posture check should see no peers when posture-gated policy is the only connectivity") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 8. DNS +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_DNSConfigEnabled(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.True(t, nm.DNSConfig.ServiceEnable, "DNS should be enabled") + assert.NotEmpty(t, nm.DNSConfig.NameServerGroups, "should have nameserver groups") +} + +func TestComponents_DNSDisabledByManagementGroup(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + account.DNSSettings.DisabledManagementGroups = []string{"group-all"} + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.False(t, nm.DNSConfig.ServiceEnable, "DNS should be disabled for peer in disabled group") +} + +func TestComponents_DNSNameServerGroupDistribution(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + account.NameServerGroups["ns-group-0"] = &nbdns.NameServerGroup{ + ID: "ns-group-0", Name: "Group 0 NS", Enabled: true, Groups: []string{"group-0"}, + NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53}}, + } + + nm0 := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm0) + hasGroup0NS := false + for _, ns := range nm0.DNSConfig.NameServerGroups { + if ns.ID == "ns-group-0" { + hasGroup0NS = true + } + } + assert.True(t, hasGroup0NS, "peer-0 in group-0 should receive ns-group-0") + + nm10 := componentsNetworkMap(account, "peer-10", validatedPeers) + require.NotNil(t, nm10) + hasGroup0NSForPeer10 := false + for _, ns := range nm10.DNSConfig.NameServerGroups { + if ns.ID == "ns-group-0" { + hasGroup0NSForPeer10 = true + } + } + assert.False(t, hasGroup0NSForPeer10, "peer-10 in group-1 should NOT receive ns-group-0") +} + +func TestComponents_DNSCustomZone(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + customZone := nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer0.netbird.cloud.", Type: 1, Class: "IN", TTL: 300, RData: account.Peers["peer-0"].IP.String()}, + {Name: "peer1.netbird.cloud.", Type: 1, Class: "IN", TTL: 300, RData: account.Peers["peer-1"].IP.String()}, + }, + } + + nm := account.GetPeerNetworkMapFromComponents( + context.Background(), "peer-0", customZone, nil, + validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), + nil, account.GetActiveGroupUsers(), + ) + require.NotNil(t, nm) + assert.True(t, nm.DNSConfig.ServiceEnable) +} + +// ────────────────────────────────────────────────────────────────────────────── +// 9. SSH +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_SSHPolicy(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + account.Groups["ssh-users"] = &types.Group{ID: "ssh-users", Name: "SSH Users", Peers: []string{}} + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-ssh", Name: "SSH Access", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{{ + ID: "rule-ssh", Name: "Allow SSH", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolNetbirdSSH, + Bidirectional: false, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + AuthorizedGroups: map[string][]string{"ssh-users": {"root"}}, + }}, + }) + + nm := componentsNetworkMap(account, "peer-10", validatedPeers) + require.NotNil(t, nm) + assert.True(t, nm.EnableSSH, "SSH should be enabled for destination peer of SSH policy") +} + +func TestComponents_SSHNotEnabledWithoutPolicy(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.False(t, nm.EnableSSH, "SSH should not be enabled without SSH policy") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 10. CROSS-PEER CONSISTENCY +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_AllPeersGetValidMaps verifies that every validated peer gets a +// non-nil map with a consistent network serial and non-empty peer list. +func TestComponents_AllPeersGetValidMaps(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 5) + for peerID := range account.Peers { + if _, validated := validatedPeers[peerID]; !validated { + continue + } + nm := componentsNetworkMap(account, peerID, validatedPeers) + require.NotNil(t, nm, "network map should not be nil for %s", peerID) + assert.Equal(t, account.Network.Serial, nm.Network.Serial, "serial mismatch for %s", peerID) + assert.NotEmpty(t, nm.Peers, "validated peer %s should see other peers", peerID) + } +} + +// TestComponents_LargeScaleMapGeneration verifies that components can generate maps +// at larger scales without errors and with consistent output. +func TestComponents_LargeScaleMapGeneration(t *testing.T) { + scales := []struct{ peers, groups int }{ + {500, 20}, + {1000, 50}, + } + for _, s := range scales { + t.Run(fmt.Sprintf("%dpeers_%dgroups", s.peers, s.groups), func(t *testing.T) { + account, validatedPeers := scalableTestAccount(s.peers, s.groups) + testPeers := []string{"peer-0", fmt.Sprintf("peer-%d", s.peers/4), fmt.Sprintf("peer-%d", s.peers/2)} + for _, peerID := range testPeers { + nm := componentsNetworkMap(account, peerID, validatedPeers) + require.NotNil(t, nm, "network map should not be nil for %s", peerID) + assert.NotEmpty(t, nm.Peers, "peer %s should see other peers at scale", peerID) + assert.NotEmpty(t, nm.Routes, "peer %s should have routes at scale", peerID) + assert.Equal(t, account.Network.Serial, nm.Network.Serial, "serial mismatch for %s", peerID) + } + }) + } +} + +// ────────────────────────────────────────────────────────────────────────────── +// 11. PEER-AS-RESOURCE POLICIES +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_PeerAsSourceResource verifies that a policy with SourceResource.Type=Peer +// targets only that specific peer as the source. +func TestComponents_PeerAsSourceResource(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-peer-src", Name: "Peer Source Resource", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{{ + ID: "rule-peer-src", Name: "Peer Source Rule", Enabled: true, + Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, + Ports: []string{"443"}, + SourceResource: types.Resource{ID: "peer-0", Type: types.ResourceTypePeer}, + Destinations: []string{"group-1"}, + }}, + }) + + // peer-0 is the source resource, should see group-1 peers + nm0 := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm0) + + has443 := false + for _, rule := range nm0.FirewallRules { + if rule.Port == "443" { + has443 = true + break + } + } + assert.True(t, has443, "peer-0 as source resource should have port 443 rule") +} + +// TestComponents_PeerAsDestinationResource verifies that a policy with DestinationResource.Type=Peer +// targets only that specific peer as the destination. +func TestComponents_PeerAsDestinationResource(t *testing.T) { + account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(20, 2) + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-peer-dst", Name: "Peer Dest Resource", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{{ + ID: "rule-peer-dst", Name: "Peer Dest Rule", Enabled: true, + Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, + Ports: []string{"443"}, + Sources: []string{"group-0"}, + DestinationResource: types.Resource{ID: "peer-15", Type: types.ResourceTypePeer}, + }}, + }) + + // peer-0 is in group-0 (source), should see peer-15 as destination + nm0 := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm0) + + peerIDs := make(map[string]bool, len(nm0.Peers)) + for _, p := range nm0.Peers { + peerIDs[p.ID] = true + } + assert.True(t, peerIDs["peer-15"], "peer-0 should see peer-15 via peer-as-destination-resource policy") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 12. MULTIPLE RULES PER POLICY +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_MultipleRulesPerPolicy verifies a policy with multiple rules generates +// firewall rules for each. +func TestComponents_MultipleRulesPerPolicy(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-multi-rule", Name: "Multi Rule Policy", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{ + { + ID: "rule-http", Name: "Allow HTTP", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, Ports: []string{"80"}, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + }, + { + ID: "rule-https", Name: "Allow HTTPS", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, Ports: []string{"443"}, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + }, + }, + }) + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + has80, has443 := false, false + for _, rule := range nm.FirewallRules { + if rule.Port == "80" { + has80 = true + } + if rule.Port == "443" { + has443 = true + } + } + assert.True(t, has80, "should have firewall rule for port 80 from first rule") + assert.True(t, has443, "should have firewall rule for port 443 from second rule") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 13. SSH AUTHORIZED USERS CONTENT +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_SSHAuthorizedUsersContent verifies that SSH policies populate +// the AuthorizedUsers map with the correct users and machine mappings. +func TestComponents_SSHAuthorizedUsersContent(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + + account.Users["user-dev"] = &types.User{Id: "user-dev", Role: types.UserRoleUser, AccountID: "test-account", AutoGroups: []string{"ssh-users"}} + account.Groups["ssh-users"] = &types.Group{ID: "ssh-users", Name: "SSH Users", Peers: []string{}} + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-ssh", Name: "SSH Access", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{{ + ID: "rule-ssh", Name: "Allow SSH", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolNetbirdSSH, + Bidirectional: false, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + AuthorizedGroups: map[string][]string{"ssh-users": {"root", "admin"}}, + }}, + }) + + // peer-10 is in group-1 (destination) + nm := componentsNetworkMap(account, "peer-10", validatedPeers) + require.NotNil(t, nm) + assert.True(t, nm.EnableSSH, "SSH should be enabled") + assert.NotNil(t, nm.AuthorizedUsers, "AuthorizedUsers should not be nil") + assert.NotEmpty(t, nm.AuthorizedUsers, "AuthorizedUsers should have entries") + + // Check that "root" machine user mapping exists + _, hasRoot := nm.AuthorizedUsers["root"] + _, hasAdmin := nm.AuthorizedUsers["admin"] + assert.True(t, hasRoot || hasAdmin, "AuthorizedUsers should contain 'root' or 'admin' machine user mapping") +} + +// TestComponents_SSHLegacyImpliedSSH verifies that a non-SSH ALL protocol policy with +// SSHEnabled peer implies legacy SSH access. +func TestComponents_SSHLegacyImpliedSSH(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + + // Enable SSH on the destination peer + account.Peers["peer-10"].SSHEnabled = true + + // The default "Allow All" policy with Protocol=ALL + SSHEnabled peer should imply SSH + nm := componentsNetworkMap(account, "peer-10", validatedPeers) + require.NotNil(t, nm) + assert.True(t, nm.EnableSSH, "SSH should be implied by ALL protocol policy with SSHEnabled peer") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 14. ROUTE DEFAULT PERMIT (no AccessControlGroups) +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_RouteDefaultPermit verifies that a route without AccessControlGroups +// generates default permit firewall rules (0.0.0.0/0 source). +func TestComponents_RouteDefaultPermit(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + + // Add a route without ACGs — this peer is the routing peer + routingPeerID := "peer-5" + account.Routes["route-no-acg"] = &route.Route{ + ID: "route-no-acg", Network: netip.MustParsePrefix("192.168.99.0/24"), + PeerID: routingPeerID, Peer: account.Peers[routingPeerID].Key, + Enabled: true, Groups: []string{"group-all"}, PeerGroups: []string{"group-0"}, + AccessControlGroups: []string{}, + AccountID: "test-account", + } + + // The routing peer should get default permit route firewall rules + nm := componentsNetworkMap(account, routingPeerID, validatedPeers) + require.NotNil(t, nm) + + hasDefaultPermit := false + for _, rfr := range nm.RoutesFirewallRules { + for _, src := range rfr.SourceRanges { + if src == "0.0.0.0/0" || src == "::/0" { + hasDefaultPermit = true + break + } + } + } + assert.True(t, hasDefaultPermit, "route without ACG should have default permit rule with 0.0.0.0/0 source") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 15. MULTIPLE ROUTERS PER NETWORK +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_MultipleRoutersPerNetwork verifies that a network resource +// with multiple routers provides routes through all available routers. +func TestComponents_MultipleRoutersPerNetwork(t *testing.T) { + account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(20, 2) + + netID := "net-multi-router" + resID := "res-multi-router" + account.Networks = append(account.Networks, &networkTypes.Network{ID: netID, Name: "Multi Router Network", AccountID: "test-account"}) + account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{ + ID: resID, NetworkID: netID, AccountID: "test-account", Enabled: true, + Address: "multi-svc.netbird.cloud", + }) + account.NetworkRouters = append(account.NetworkRouters, + &routerTypes.NetworkRouter{ID: "router-a", NetworkID: netID, Peer: "peer-5", Enabled: true, AccountID: "test-account", Metric: 100}, + &routerTypes.NetworkRouter{ID: "router-b", NetworkID: netID, Peer: "peer-15", Enabled: true, AccountID: "test-account", Metric: 200}, + ) + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-multi-router-res", Name: "Multi Router Resource", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{{ + ID: "rule-multi-router-res", Name: "Allow Multi Router", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{"group-0"}, DestinationResource: types.Resource{ID: resID}, + }}, + }) + + // peer-0 is in group-0 (source), should see both router peers + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + peerIDs := make(map[string]bool, len(nm.Peers)) + for _, p := range nm.Peers { + peerIDs[p.ID] = true + } + assert.True(t, peerIDs["peer-5"], "source peer should see router-a (peer-5)") + assert.True(t, peerIDs["peer-15"], "source peer should see router-b (peer-15)") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 16. PEER-AS-NAMESERVER EXCLUSION +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_PeerIsNameserverExcludedFromNSGroup verifies that a peer serving +// as a nameserver does not receive its own NS group in DNS config. +func TestComponents_PeerIsNameserverExcludedFromNSGroup(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + + // peer-0 has IP 100.64.0.0 — make it a nameserver + nsIP := account.Peers["peer-0"].IP + account.NameServerGroups["ns-self"] = &nbdns.NameServerGroup{ + ID: "ns-self", Name: "Self NS", Enabled: true, Groups: []string{"group-all"}, + NameServers: []nbdns.NameServer{{IP: netip.AddrFrom4([4]byte{nsIP[0], nsIP[1], nsIP[2], nsIP[3]}), NSType: nbdns.UDPNameServerType, Port: 53}}, + } + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + hasSelfNS := false + for _, ns := range nm.DNSConfig.NameServerGroups { + if ns.ID == "ns-self" { + hasSelfNS = true + } + } + assert.False(t, hasSelfNS, "peer serving as nameserver should NOT receive its own NS group") + + // peer-10 is NOT the nameserver, should receive the NS group + nm10 := componentsNetworkMap(account, "peer-10", validatedPeers) + require.NotNil(t, nm10) + hasNSForPeer10 := false + for _, ns := range nm10.DNSConfig.NameServerGroups { + if ns.ID == "ns-self" { + hasNSForPeer10 = true + } + } + assert.True(t, hasNSForPeer10, "non-nameserver peer should receive the NS group") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 17. DOMAIN NETWORK RESOURCES +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_DomainNetworkResource verifies that domain-based network resources +// produce routes with the correct domain configuration. +func TestComponents_DomainNetworkResource(t *testing.T) { + account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(20, 2) + + netID := "net-domain" + resID := "res-domain" + account.Networks = append(account.Networks, &networkTypes.Network{ID: netID, Name: "Domain Network", AccountID: "test-account"}) + account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{ + ID: resID, NetworkID: netID, AccountID: "test-account", Enabled: true, + Address: "api.example.com", Type: "domain", + }) + account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{ + ID: "router-domain", NetworkID: netID, Peer: "peer-5", Enabled: true, AccountID: "test-account", + }) + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-domain-res", Name: "Domain Resource Policy", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{{ + ID: "rule-domain-res", Name: "Allow Domain", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{"group-0"}, DestinationResource: types.Resource{ID: resID}, + }}, + }) + + // peer-0 is source, should get route to the domain resource via peer-5 + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + peerIDs := make(map[string]bool, len(nm.Peers)) + for _, p := range nm.Peers { + peerIDs[p.ID] = true + } + assert.True(t, peerIDs["peer-5"], "source peer should see domain resource router peer") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 18. DISABLED RULE WITHIN ENABLED POLICY +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_DisabledRuleInEnabledPolicy verifies that a disabled rule within +// an enabled policy does not generate firewall rules. +func TestComponents_DisabledRuleInEnabledPolicy(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-mixed-rules", Name: "Mixed Rules", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{ + { + ID: "rule-enabled", Name: "Enabled Rule", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, Ports: []string{"3000"}, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + }, + { + ID: "rule-disabled", Name: "Disabled Rule", Enabled: false, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, Ports: []string{"3001"}, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + }, + }, + }) + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + has3000, has3001 := false, false + for _, rule := range nm.FirewallRules { + if rule.Port == "3000" { + has3000 = true + } + if rule.Port == "3001" { + has3001 = true + } + } + assert.True(t, has3000, "enabled rule should generate firewall rule for port 3000") + assert.False(t, has3001, "disabled rule should NOT generate firewall rule for port 3001") +} diff --git a/management/server/types/networkmap_components_test.go b/management/server/types/networkmap_components_test.go new file mode 100644 index 000000000..dde639ccb --- /dev/null +++ b/management/server/types/networkmap_components_test.go @@ -0,0 +1,787 @@ +package types_test + +import ( + "context" + "net" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" +) + +func networkMapFromComponents(t *testing.T, account *types.Account, peerID string, validatedPeers map[string]struct{}) *types.NetworkMap { + t.Helper() + return account.GetPeerNetworkMapFromComponents( + context.Background(), + peerID, + account.GetPeersCustomZone(context.Background(), "netbird.io"), + nil, + validatedPeers, + account.GetResourcePoliciesMap(), + account.GetResourceRoutersMap(), + nil, + account.GetActiveGroupUsers(), + ) +} + +func allPeersValidated(account *types.Account, excludePeerIDs ...string) map[string]struct{} { + excludeSet := make(map[string]struct{}, len(excludePeerIDs)) + for _, id := range excludePeerIDs { + excludeSet[id] = struct{}{} + } + validated := make(map[string]struct{}, len(account.Peers)) + for id := range account.Peers { + if _, excluded := excludeSet[id]; !excluded { + validated[id] = struct{}{} + } + } + return validated +} + +func peerIDs(peers []*nbpeer.Peer) []string { + ids := make([]string, len(peers)) + for i, p := range peers { + ids[i] = p.ID + } + return ids +} + +func TestNetworkMapComponents_RegularPeerConnectivity(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + assert.NotNil(t, nm) + assert.Contains(t, peerIDs(nm.Peers), "peer-dst-1", "should see peer from destination group via bidirectional policy") + assert.Contains(t, peerIDs(nm.Peers), "peer-router-1", "should see router peer via resource policy") + assert.NotContains(t, peerIDs(nm.Peers), "peer-src-1", "should not see itself") + assert.Empty(t, nm.OfflinePeers, "no expired peers expected") +} + +func TestNetworkMapComponents_IntraGroupConnectivity(t *testing.T) { + account := createComponentTestAccount() + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-intra-src", Name: "Intra-source connectivity", Enabled: true, AccountID: account.Id, + Rules: []*types.PolicyRule{{ + ID: "rule-intra-src", Name: "src <-> src", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Bidirectional: true, + Sources: []string{"group-src"}, Destinations: []string{"group-src"}, + }}, + }) + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + assert.Contains(t, peerIDs(nm.Peers), "peer-src-2", "should see peer from same group with intra-group policy") +} + +func TestNetworkMapComponents_FirewallRules(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + require.NotEmpty(t, nm.FirewallRules, "firewall rules should be generated") + + var hasAcceptAll bool + for _, rule := range nm.FirewallRules { + if rule.Protocol == string(types.PolicyRuleProtocolALL) && rule.Action == string(types.PolicyTrafficActionAccept) { + hasAcceptAll = true + } + } + assert.True(t, hasAcceptAll, "should have an accept-all firewall rule from the base policy") +} + +func TestNetworkMapComponents_LoginExpiration(t *testing.T) { + account := createComponentTestAccount() + account.Settings.PeerLoginExpirationEnabled = true + account.Settings.PeerLoginExpiration = 1 * time.Hour + + expiredTime := time.Now().Add(-2 * time.Hour) + account.Peers["peer-dst-1"].LoginExpirationEnabled = true + account.Peers["peer-dst-1"].LastLogin = &expiredTime + + validated := allPeersValidated(account) + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + assert.Contains(t, peerIDs(nm.OfflinePeers), "peer-dst-1", "expired peer should be in OfflinePeers") + assert.NotContains(t, peerIDs(nm.Peers), "peer-dst-1", "expired peer should NOT be in active Peers") +} + +func TestNetworkMapComponents_InvalidatedPeerExcluded(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account, "peer-dst-1") + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + assert.NotContains(t, peerIDs(nm.Peers), "peer-dst-1", "non-validated peer should be excluded") + assert.NotContains(t, peerIDs(nm.OfflinePeers), "peer-dst-1", "non-validated peer should not be in offline peers either") +} + +func TestNetworkMapComponents_NonValidatedTargetPeer(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account, "peer-src-1") + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + assert.Empty(t, nm.Peers, "non-validated target peer should get empty network map") + assert.Empty(t, nm.FirewallRules) +} + +func TestNetworkMapComponents_NetworkResourceRoutes_SourcePeer(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var hasResourceRoute bool + for _, r := range nm.Routes { + if r.Network.String() == "10.200.0.1/32" { + hasResourceRoute = true + break + } + } + assert.True(t, hasResourceRoute, "source peer should receive route to network resource via router") + assert.Contains(t, peerIDs(nm.Peers), "peer-router-1", "source peer should see the routing peer") +} + +func TestNetworkMapComponents_NetworkResourceRoutes_RouterPeer(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-router-1", validated) + + var hasResourceRoute bool + for _, r := range nm.Routes { + if r.Network.String() == "10.200.0.1/32" { + hasResourceRoute = true + break + } + } + assert.True(t, hasResourceRoute, "router peer should receive network resource route") + assert.NotEmpty(t, nm.RoutesFirewallRules, "router peer should have route firewall rules for the resource") +} + +func TestNetworkMapComponents_NetworkResourceRoutes_UnrelatedPeer(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-dst-1", validated) + + for _, r := range nm.Routes { + assert.NotEqual(t, "10.200.0.1/32", r.Network.String(), "unrelated peer should not receive network resource route") + } +} + +func TestNetworkMapComponents_NetworkResource_WithPostureCheck(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.PostureChecks = []*posture.Checks{ + {ID: "pc-version", Name: "Version check", Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"}, + }}, + } + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-posture-resource", Name: "Posture resource access", Enabled: true, AccountID: account.Id, + SourcePostureChecks: []string{"pc-version"}, + Rules: []*types.PolicyRule{{ + ID: "rule-posture-resource", Name: "Posture -> Resource", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Sources: []string{"group-src"}, + DestinationResource: types.Resource{ID: "resource-guarded"}, + }}, + }) + + account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{ + ID: "resource-guarded", NetworkID: "net-guarded", AccountID: account.Id, Enabled: true, + Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.1.1/32"), Address: "10.200.1.1/32", + }) + account.Networks = append(account.Networks, &networkTypes.Network{ + ID: "net-guarded", Name: "Guarded Net", AccountID: account.Id, + }) + account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{ + ID: "router-guarded", NetworkID: "net-guarded", Peer: "peer-router-1", Enabled: true, AccountID: account.Id, + }) + + t.Run("peer passes posture check", func(t *testing.T) { + account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0" + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var hasGuardedRoute bool + for _, r := range nm.Routes { + if r.Network.String() == "10.200.1.1/32" { + hasGuardedRoute = true + } + } + assert.True(t, hasGuardedRoute, "peer passing posture check should get guarded resource route") + }) + + t.Run("peer fails posture check", func(t *testing.T) { + account.Peers["peer-src-1"].Meta.WtVersion = "0.20.0" + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + for _, r := range nm.Routes { + assert.NotEqual(t, "10.200.1.1/32", r.Network.String(), "peer failing posture check should NOT get guarded resource route") + } + }) +} + +func TestNetworkMapComponents_NetworkResource_MultiplePostureChecks(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.PostureChecks = []*posture.Checks{ + {ID: "pc-version", Name: "Version", Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"}, + }}, + {ID: "pc-os", Name: "OS check", Checks: posture.ChecksDefinition{ + OSVersionCheck: &posture.OSVersionCheck{Linux: &posture.MinKernelVersionCheck{MinKernelVersion: "5.0"}}, + }}, + } + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-multi-posture", Name: "Multi posture", Enabled: true, AccountID: account.Id, + SourcePostureChecks: []string{"pc-version", "pc-os"}, + Rules: []*types.PolicyRule{{ + ID: "rule-multi-posture", Name: "Multi posture rule", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Sources: []string{"group-src"}, + DestinationResource: types.Resource{ID: "resource-strict"}, + }}, + }) + + account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{ + ID: "resource-strict", NetworkID: "net-strict", AccountID: account.Id, Enabled: true, + Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.2.1/32"), Address: "10.200.2.1/32", + }) + account.Networks = append(account.Networks, &networkTypes.Network{ + ID: "net-strict", Name: "Strict Net", AccountID: account.Id, + }) + account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{ + ID: "router-strict", NetworkID: "net-strict", Peer: "peer-router-1", Enabled: true, AccountID: account.Id, + }) + + t.Run("passes both posture checks", func(t *testing.T) { + account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0" + account.Peers["peer-src-1"].Meta.GoOS = "linux" + account.Peers["peer-src-1"].Meta.KernelVersion = "6.1.0" + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var found bool + for _, r := range nm.Routes { + if r.Network.String() == "10.200.2.1/32" { + found = true + } + } + assert.True(t, found, "peer passing both checks should get resource route") + }) + + t.Run("fails version posture check", func(t *testing.T) { + account.Peers["peer-src-1"].Meta.WtVersion = "0.20.0" + account.Peers["peer-src-1"].Meta.KernelVersion = "6.1.0" + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + for _, r := range nm.Routes { + assert.NotEqual(t, "10.200.2.1/32", r.Network.String(), "peer failing version check should NOT get resource route") + } + }) + + t.Run("fails OS posture check", func(t *testing.T) { + account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0" + account.Peers["peer-src-1"].Meta.KernelVersion = "4.0.0" + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + for _, r := range nm.Routes { + assert.NotEqual(t, "10.200.2.1/32", r.Network.String(), "peer failing OS check should NOT get resource route") + } + }) +} + +func TestNetworkMapComponents_RouterPeerFirewallRules(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-router-1", validated) + + var resourceFWRules []*types.RouteFirewallRule + for _, rule := range nm.RoutesFirewallRules { + if rule.Destination == "10.200.0.1/32" { + resourceFWRules = append(resourceFWRules, rule) + } + } + assert.NotEmpty(t, resourceFWRules, "router should have firewall rules for the network resource") + + var hasSourcePeerIP bool + for _, rule := range resourceFWRules { + for _, sr := range rule.SourceRanges { + if sr == account.Peers["peer-src-1"].IP.String()+"/32" || sr == account.Peers["peer-src-2"].IP.String()+"/32" { + hasSourcePeerIP = true + } + } + } + assert.True(t, hasSourcePeerIP, "resource firewall rules should include source peer IPs") +} + +func TestNetworkMapComponents_DNSManagement(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + t.Run("peer in DNS-enabled group", func(t *testing.T) { + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + assert.True(t, nm.DNSConfig.ServiceEnable, "peer in non-disabled group should have DNS enabled") + }) + + t.Run("peer in DNS-disabled group", func(t *testing.T) { + nm := networkMapFromComponents(t, account, "peer-dst-1", validated) + assert.False(t, nm.DNSConfig.ServiceEnable, "peer in DNS-disabled group should have DNS disabled") + }) +} + +func TestNetworkMapComponents_NameServerGroups(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + assert.True(t, nm.DNSConfig.ServiceEnable) + + var hasNSGroup bool + for _, ns := range nm.DNSConfig.NameServerGroups { + if ns.ID == "ns-main" { + hasNSGroup = true + } + } + assert.True(t, hasNSGroup, "peer in NS group should receive nameserver configuration") +} + +func TestNetworkMapComponents_RoutesWithHADeduplication(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.Routes["route-ha-1"] = &route.Route{ + ID: "route-ha-1", Network: netip.MustParsePrefix("172.16.0.0/16"), + Peer: account.Peers["peer-dst-1"].Key, PeerID: "peer-dst-1", + Enabled: true, Metric: 100, AccountID: account.Id, + Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-dst"}, + } + account.Routes["route-ha-2"] = &route.Route{ + ID: "route-ha-2", Network: netip.MustParsePrefix("172.16.0.0/16"), + Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1", + Enabled: true, Metric: 200, AccountID: account.Id, + Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-src"}, + } + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + haCount := 0 + for _, r := range nm.Routes { + if r.Network.String() == "172.16.0.0/16" { + haCount++ + } + } + assert.Equal(t, 1, haCount, "peer should only receive one route from HA group (not both, since it's a member of one)") +} + +func TestNetworkMapComponents_RoutesFirewallRulesForAccessControl(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.Routes["route-acl"] = &route.Route{ + ID: "route-acl", Network: netip.MustParsePrefix("192.168.100.0/24"), + Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1", + Enabled: true, Metric: 100, AccountID: account.Id, + Groups: []string{"group-dst"}, + PeerGroups: []string{"group-src"}, + AccessControlGroups: []string{"group-dst"}, + } + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var hasFWRule bool + for _, rule := range nm.RoutesFirewallRules { + if rule.Destination == "192.168.100.0/24" { + hasFWRule = true + } + } + assert.True(t, hasFWRule, "routing peer should have firewall rules for route with access control groups") +} + +func TestNetworkMapComponents_RoutesDefaultPermit(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.Routes["route-open"] = &route.Route{ + ID: "route-open", Network: netip.MustParsePrefix("10.99.0.0/16"), + Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1", + Enabled: true, Metric: 100, AccountID: account.Id, + Groups: []string{"group-src"}, + PeerGroups: []string{"group-src"}, + } + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var hasFWRule bool + for _, rule := range nm.RoutesFirewallRules { + if rule.Destination == "10.99.0.0/16" { + hasFWRule = true + } + } + assert.True(t, hasFWRule, "route without access control groups should have default permit firewall rules") +} + +func TestNetworkMapComponents_SSHAuthorizedUsers(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.Peers["peer-dst-1"].SSHEnabled = true + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-ssh", Name: "SSH Access", Enabled: true, AccountID: account.Id, + Rules: []*types.PolicyRule{{ + ID: "rule-ssh", Name: "SSH to dst", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Bidirectional: true, + Sources: []string{"group-src"}, Destinations: []string{"group-dst"}, + }}, + }) + + nm := networkMapFromComponents(t, account, "peer-dst-1", validated) + assert.True(t, nm.EnableSSH, "SSH-enabled peer with matching policy should have EnableSSH") +} + +func TestNetworkMapComponents_DisabledPolicyIgnored(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + for _, p := range account.Policies { + p.Enabled = false + } + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + assert.Empty(t, nm.Peers, "with all policies disabled, peer should see no other peers") + assert.Empty(t, nm.FirewallRules) +} + +func TestNetworkMapComponents_DisabledRouteIgnored(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + for _, r := range account.Routes { + r.Enabled = false + } + for _, r := range account.NetworkResources { + r.Enabled = false + } + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + assert.Empty(t, nm.Routes, "disabled routes should not appear in network map") +} + +func TestNetworkMapComponents_DisabledNetworkResourceIgnored(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + for _, r := range account.NetworkResources { + r.Enabled = false + } + + nm := networkMapFromComponents(t, account, "peer-router-1", validated) + + for _, r := range nm.Routes { + assert.NotEqual(t, "10.200.0.1/32", r.Network.String(), "disabled resource should not generate routes") + } +} + +func TestNetworkMapComponents_BidirectionalPolicy(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nmSrc := networkMapFromComponents(t, account, "peer-src-1", validated) + nmDst := networkMapFromComponents(t, account, "peer-dst-1", validated) + + assert.Contains(t, peerIDs(nmSrc.Peers), "peer-dst-1", "src should see dst via bidirectional policy") + assert.Contains(t, peerIDs(nmDst.Peers), "peer-src-1", "dst should see src via bidirectional policy") +} + +func TestNetworkMapComponents_DropPolicy(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-drop", Name: "Drop traffic", Enabled: true, AccountID: account.Id, + Rules: []*types.PolicyRule{{ + ID: "rule-drop", Name: "Drop src->dst", Enabled: true, + Action: types.PolicyTrafficActionDrop, Protocol: types.PolicyRuleProtocolTCP, + Ports: []string{"5432"}, + Sources: []string{"group-src"}, Destinations: []string{"group-dst"}, + }}, + }) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var hasDropRule bool + for _, rule := range nm.FirewallRules { + if rule.Action == string(types.PolicyTrafficActionDrop) && rule.Port == "5432" { + hasDropRule = true + } + } + assert.True(t, hasDropRule, "drop policy should generate drop firewall rule") +} + +func TestNetworkMapComponents_PortRangePolicy(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.Peers["peer-src-1"].Meta.WtVersion = "0.50.0" + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-range", Name: "Port range", Enabled: true, AccountID: account.Id, + Rules: []*types.PolicyRule{{ + ID: "rule-range", Name: "Range rule", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP, + PortRanges: []types.RulePortRange{{Start: 8080, End: 8090}}, + Sources: []string{"group-src"}, Destinations: []string{"group-dst"}, + }}, + }) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var hasRangeRule bool + for _, rule := range nm.FirewallRules { + if rule.PortRange.Start == 8080 && rule.PortRange.End == 8090 { + hasRangeRule = true + } + } + assert.True(t, hasRangeRule, "port range policy should generate corresponding firewall rule") +} + +func TestNetworkMapComponents_MultipleNetworkResources(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{ + ID: "resource-2", NetworkID: "net-1", AccountID: account.Id, Enabled: true, + Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.0.2/32"), Address: "10.200.0.2/32", + }) + account.Groups["group-res2"] = &types.Group{ID: "group-res2", Name: "Resource 2 Group", Peers: []string{"peer-src-1", "peer-src-2"}, + Resources: []types.Resource{{ID: "resource-2"}}, + } + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-res2", Name: "Resource 2 Policy", Enabled: true, AccountID: account.Id, + Rules: []*types.PolicyRule{{ + ID: "rule-res2", Name: "Access Resource 2", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Sources: []string{"group-src"}, + DestinationResource: types.Resource{ID: "resource-2"}, + }}, + }) + + nm := networkMapFromComponents(t, account, "peer-router-1", validated) + + resourceRouteCount := 0 + for _, r := range nm.Routes { + if r.Network.String() == "10.200.0.1/32" || r.Network.String() == "10.200.0.2/32" { + resourceRouteCount++ + } + } + assert.Equal(t, 2, resourceRouteCount, "router should have routes for both network resources") +} + +func TestNetworkMapComponents_DomainNetworkResource(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{ + ID: "resource-domain", NetworkID: "net-1", AccountID: account.Id, Enabled: true, + Type: resourceTypes.Domain, Domain: "api.example.com", Address: "api.example.com", + }) + account.Groups["group-res-domain"] = &types.Group{ + ID: "group-res-domain", Name: "Domain Resource Group", + Resources: []types.Resource{{ID: "resource-domain"}}, + } + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-domain", Name: "Domain resource policy", Enabled: true, AccountID: account.Id, + Rules: []*types.PolicyRule{{ + ID: "rule-domain", Name: "Access domain resource", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Sources: []string{"group-src"}, + DestinationResource: types.Resource{ID: "resource-domain"}, + }}, + }) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var hasDomainRoute bool + for _, r := range nm.Routes { + if r.NetworkType == route.DomainNetwork && len(r.Domains) > 0 && r.Domains[0].SafeString() == "api.example.com" { + hasDomainRoute = true + } + } + assert.True(t, hasDomainRoute, "source peer should receive domain route for domain network resource") +} + +func TestNetworkMapComponents_NetworkEmpty(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "nonexistent-peer", validated) + + assert.NotNil(t, nm) + assert.Empty(t, nm.Peers) + assert.Empty(t, nm.FirewallRules) + assert.NotNil(t, nm.Network) +} + +func TestNetworkMapComponents_RouterExcludesOtherNetworkRoutes(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{ + ID: "resource-other", NetworkID: "net-other", AccountID: account.Id, Enabled: true, + Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.99.1/32"), Address: "10.200.99.1/32", + }) + account.Networks = append(account.Networks, &networkTypes.Network{ + ID: "net-other", Name: "Other Net", AccountID: account.Id, + }) + account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{ + ID: "router-other", NetworkID: "net-other", Peer: "peer-dst-1", Enabled: true, AccountID: account.Id, + }) + account.Groups["group-res-other"] = &types.Group{ID: "group-res-other", Name: "Other resource group", + Resources: []types.Resource{{ID: "resource-other"}}, + } + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-other-resource", Name: "Other resource policy", Enabled: true, AccountID: account.Id, + Rules: []*types.PolicyRule{{ + ID: "rule-other", Name: "Other resource access", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Sources: []string{"group-src"}, + DestinationResource: types.Resource{ID: "resource-other"}, + }}, + }) + + nm := networkMapFromComponents(t, account, "peer-router-1", validated) + + for _, r := range nm.Routes { + assert.NotEqual(t, "10.200.99.1/32", r.Network.String(), "router-1 should NOT get routes for other network's resources") + } +} + +func createComponentTestAccount() *types.Account { + peers := map[string]*nbpeer.Peer{ + "peer-src-1": { + ID: "peer-src-1", IP: net.IP{100, 64, 0, 1}, Key: "key-src-1", DNSLabel: "src1", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"}, + }, + "peer-src-2": { + ID: "peer-src-2", IP: net.IP{100, 64, 0, 2}, Key: "key-src-2", DNSLabel: "src2", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"}, + }, + "peer-dst-1": { + ID: "peer-dst-1", IP: net.IP{100, 64, 0, 3}, Key: "key-dst-1", DNSLabel: "dst1", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-2", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"}, + }, + "peer-router-1": { + ID: "peer-router-1", IP: net.IP{100, 64, 0, 10}, Key: "key-router-1", DNSLabel: "router1", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"}, + }, + } + + groups := map[string]*types.Group{ + "group-src": {ID: "group-src", Name: "Sources", Peers: []string{"peer-src-1", "peer-src-2"}}, + "group-dst": {ID: "group-dst", Name: "Destinations", Peers: []string{"peer-dst-1"}}, + "group-all": {ID: "group-all", Name: "All", Peers: []string{"peer-src-1", "peer-src-2", "peer-dst-1", "peer-router-1"}}, + "group-res": { + ID: "group-res", Name: "Resource Group", + Resources: []types.Resource{{ID: "resource-1"}}, + }, + } + + policies := []*types.Policy{ + { + ID: "policy-base", Name: "Base connectivity", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: "rule-base", Name: "Allow src <-> dst", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Bidirectional: true, + Sources: []string{"group-src"}, Destinations: []string{"group-dst"}, + }}, + }, + { + ID: "policy-resource", Name: "Network resource access", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: "rule-resource", Name: "Source -> Resource", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Sources: []string{"group-src"}, + DestinationResource: types.Resource{ID: "resource-1"}, + }}, + }, + } + + routes := map[route.ID]*route.Route{ + "route-main": { + ID: "route-main", Network: netip.MustParsePrefix("192.168.10.0/24"), + Peer: peers["peer-dst-1"].Key, PeerID: "peer-dst-1", + Enabled: true, Metric: 100, + Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-dst"}, + }, + } + + users := map[string]*types.User{ + "user-1": {Id: "user-1", Role: types.UserRoleAdmin, IsServiceUser: false, AutoGroups: []string{"group-all"}}, + "user-2": {Id: "user-2", Role: types.UserRoleUser, IsServiceUser: false, AutoGroups: []string{"group-all"}}, + } + + account := &types.Account{ + Id: "account-components-test", Peers: peers, Groups: groups, Policies: policies, Routes: routes, + Users: users, + Network: &types.Network{ + Identifier: "net-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1, + }, + DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{"group-dst"}}, + NameServerGroups: map[string]*nbdns.NameServerGroup{ + "ns-main": { + ID: "ns-main", Name: "Main NS", Enabled: true, Groups: []string{"group-src"}, + NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}}, + }, + }, + PostureChecks: []*posture.Checks{}, + NetworkResources: []*resourceTypes.NetworkResource{ + { + ID: "resource-1", NetworkID: "net-1", AccountID: "account-components-test", Enabled: true, + Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.0.1/32"), Address: "10.200.0.1/32", + }, + }, + Networks: []*networkTypes.Network{ + {ID: "net-1", Name: "Resource Net", AccountID: "account-components-test"}, + }, + NetworkRouters: []*routerTypes.NetworkRouter{ + {ID: "router-1", NetworkID: "net-1", Peer: "peer-router-1", Enabled: true, AccountID: "account-components-test"}, + }, + Settings: &types.Settings{PeerLoginExpirationEnabled: false, PeerLoginExpiration: 24 * time.Hour}, + } + + for _, p := range account.Policies { + p.AccountID = account.Id + } + for _, r := range account.Routes { + r.AccountID = account.Id + } + + return account +} diff --git a/management/server/types/networkmap_golden_test.go b/management/server/types/networkmap_golden_test.go deleted file mode 100644 index 53261f22d..000000000 --- a/management/server/types/networkmap_golden_test.go +++ /dev/null @@ -1,967 +0,0 @@ -package types_test - -import ( - "context" - "encoding/json" - "fmt" - "net" - "net/netip" - "os" - "path/filepath" - "slices" - "sort" - "testing" - "time" - - "github.com/stretchr/testify/require" - - "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/modules/zones" - resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" - routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" - networkTypes "github.com/netbirdio/netbird/management/server/networks/types" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/route" -) - -const ( - numPeers = 100 - devGroupID = "group-dev" - opsGroupID = "group-ops" - allGroupID = "group-all" - sshUsersGroupID = "group-ssh-users" - routeID = route.ID("route-main") - routeHA1ID = route.ID("route-ha-1") - routeHA2ID = route.ID("route-ha-2") - policyIDDevOps = "policy-dev-ops" - policyIDAll = "policy-all" - policyIDPosture = "policy-posture" - policyIDDrop = "policy-drop" - policyIDSSH = "policy-ssh" - postureCheckID = "posture-check-ver" - networkResourceID = "res-database" - networkID = "net-database" - networkRouterID = "router-database" - nameserverGroupID = "ns-group-main" - testingPeerID = "peer-60" // A peer from the "dev" group, should receive the most detailed map. - expiredPeerID = "peer-98" // This peer will be online but with an expired session. - offlinePeerID = "peer-99" // This peer will be completely offline. - routingPeerID = "peer-95" // This peer is used for routing, it has a route to the network. - testAccountID = "account-golden-test" - userAdminID = "user-admin" - userDevID = "user-dev" - userOpsID = "user-ops" -) - -func TestGetPeerNetworkMap_Golden(t *testing.T) { - account := createTestAccountWithEntities() - - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - if peerID == offlinePeerID { - continue - } - validatedPeersMap[peerID] = struct{}{} - } - - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - - legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) - normalizeAndSortNetworkMap(legacyNetworkMap) - legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ") - require.NoError(t, err, "error marshaling legacy network map to JSON") - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - normalizeAndSortNetworkMap(newNetworkMap) - newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ") - require.NoError(t, err, "error marshaling new network map to JSON") - - if string(legacyJSON) != string(newJSON) { - legacyFilePath := filepath.Join("testdata", "networkmap_golden.json") - newFilePath := filepath.Join("testdata", "networkmap_golden_new.json") - - err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755) - require.NoError(t, err) - - err = os.WriteFile(legacyFilePath, legacyJSON, 0644) - require.NoError(t, err) - t.Logf("Saved legacy network map to %s", legacyFilePath) - - err = os.WriteFile(newFilePath, newJSON, 0644) - require.NoError(t, err) - t.Logf("Saved new network map to %s", newFilePath) - - require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps from legacy and new builder do not match") - } -} - -func BenchmarkGetPeerNetworkMap(b *testing.B) { - account := createTestAccountWithEntities() - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - var peerIDs []string - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - validatedPeersMap[peerID] = struct{}{} - peerIDs = append(peerIDs, peerID) - } - - b.ResetTimer() - b.Run("old builder", func(b *testing.B) { - for range b.N { - for _, peerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) - } - } - }) - b.ResetTimer() - b.Run("new builder", func(b *testing.B) { - for range b.N { - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - for _, peerID := range peerIDs { - _ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - } - } - }) -} - -func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) { - account := createTestAccountWithEntities() - - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - if peerID == offlinePeerID { - continue - } - validatedPeersMap[peerID] = struct{}{} - } - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - - newPeerID := "peer-new-101" - newPeerIP := net.IP{100, 64, 1, 1} - newPeer := &nbpeer.Peer{ - ID: newPeerID, - IP: newPeerIP, - Key: fmt.Sprintf("key-%s", newPeerID), - DNSLabel: "peernew101", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, - UserID: "user-admin", - Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, - LastLogin: func() *time.Time { t := time.Now(); return &t }(), - } - - account.Peers[newPeerID] = newPeer - - if devGroup, exists := account.Groups[devGroupID]; exists { - devGroup.Peers = append(devGroup.Peers, newPeerID) - } - - if allGroup, exists := account.Groups[allGroupID]; exists { - allGroup.Peers = append(allGroup.Peers, newPeerID) - } - - validatedPeersMap[newPeerID] = struct{}{} - - if account.Network != nil { - account.Network.Serial++ - } - - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - - legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) - normalizeAndSortNetworkMap(legacyNetworkMap) - legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ") - require.NoError(t, err, "error marshaling legacy network map to JSON") - - err = builder.OnPeerAddedIncremental(account, newPeerID) - require.NoError(t, err, "error adding peer to cache") - - newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - normalizeAndSortNetworkMap(newNetworkMap) - newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ") - require.NoError(t, err, "error marshaling new network map to JSON") - - if string(legacyJSON) != string(newJSON) { - legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_new_peer.json") - newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded.json") - - err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755) - require.NoError(t, err) - - err = os.WriteFile(legacyFilePath, legacyJSON, 0644) - require.NoError(t, err) - t.Logf("Saved legacy network map to %s", legacyFilePath) - - err = os.WriteFile(newFilePath, newJSON, 0644) - require.NoError(t, err) - t.Logf("Saved new network map to %s", newFilePath) - - require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with new peer from legacy and new builder do not match") - } -} - -func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) { - account := createTestAccountWithEntities() - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - var peerIDs []string - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - validatedPeersMap[peerID] = struct{}{} - peerIDs = append(peerIDs, peerID) - } - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - newPeerID := "peer-new-101" - newPeer := &nbpeer.Peer{ - ID: newPeerID, - IP: net.IP{100, 64, 1, 1}, - Key: fmt.Sprintf("key-%s", newPeerID), - DNSLabel: "peernew101", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, - UserID: "user-admin", - Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, - } - - account.Peers[newPeerID] = newPeer - account.Groups[devGroupID].Peers = append(account.Groups[devGroupID].Peers, newPeerID) - account.Groups[allGroupID].Peers = append(account.Groups[allGroupID].Peers, newPeerID) - validatedPeersMap[newPeerID] = struct{}{} - - b.ResetTimer() - b.Run("old builder after add", func(b *testing.B) { - for i := 0; i < b.N; i++ { - for _, testingPeerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) - } - } - }) - - b.ResetTimer() - b.Run("new builder after add", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = builder.OnPeerAddedIncremental(account, newPeerID) - for _, testingPeerID := range peerIDs { - _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - } - } - }) -} - -func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) { - account := createTestAccountWithEntities() - - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - if peerID == offlinePeerID { - continue - } - validatedPeersMap[peerID] = struct{}{} - } - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - - newRouterID := "peer-new-router-102" - newRouterIP := net.IP{100, 64, 1, 2} - newRouter := &nbpeer.Peer{ - ID: newRouterID, - IP: newRouterIP, - Key: fmt.Sprintf("key-%s", newRouterID), - DNSLabel: "newrouter102", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, - UserID: "user-admin", - Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, - LastLogin: func() *time.Time { t := time.Now(); return &t }(), - } - - account.Peers[newRouterID] = newRouter - - if opsGroup, exists := account.Groups[opsGroupID]; exists { - opsGroup.Peers = append(opsGroup.Peers, newRouterID) - } - - if allGroup, exists := account.Groups[allGroupID]; exists { - allGroup.Peers = append(allGroup.Peers, newRouterID) - } - - newRoute := &route.Route{ - ID: route.ID("route-new-router"), - Network: netip.MustParsePrefix("172.16.0.0/24"), - Peer: newRouter.Key, - PeerID: newRouterID, - Description: "Route from new router", - Enabled: true, - PeerGroups: []string{opsGroupID}, - Groups: []string{devGroupID, opsGroupID}, - AccessControlGroups: []string{devGroupID}, - AccountID: account.Id, - } - account.Routes[newRoute.ID] = newRoute - - validatedPeersMap[newRouterID] = struct{}{} - - if account.Network != nil { - account.Network.Serial++ - } - - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - - legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) - normalizeAndSortNetworkMap(legacyNetworkMap) - legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ") - require.NoError(t, err, "error marshaling legacy network map to JSON") - - err = builder.OnPeerAddedIncremental(account, newRouterID) - require.NoError(t, err, "error adding router to cache") - - newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - normalizeAndSortNetworkMap(newNetworkMap) - newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ") - require.NoError(t, err, "error marshaling new network map to JSON") - - if string(legacyJSON) != string(newJSON) { - legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_new_router.json") - newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json") - - err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755) - require.NoError(t, err) - - err = os.WriteFile(legacyFilePath, legacyJSON, 0644) - require.NoError(t, err) - t.Logf("Saved legacy network map to %s", legacyFilePath) - - err = os.WriteFile(newFilePath, newJSON, 0644) - require.NoError(t, err) - t.Logf("Saved new network map to %s", newFilePath) - - require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with new router from legacy and new builder do not match") - } -} - -func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) { - account := createTestAccountWithEntities() - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - var peerIDs []string - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - validatedPeersMap[peerID] = struct{}{} - peerIDs = append(peerIDs, peerID) - } - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - newRouterID := "peer-new-router-102" - newRouterIP := net.IP{100, 64, 1, 2} - newRouter := &nbpeer.Peer{ - ID: newRouterID, - IP: newRouterIP, - Key: fmt.Sprintf("key-%s", newRouterID), - DNSLabel: "newrouter102", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, - UserID: "user-admin", - Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, - LastLogin: func() *time.Time { t := time.Now(); return &t }(), - } - - account.Peers[newRouterID] = newRouter - - if opsGroup, exists := account.Groups[opsGroupID]; exists { - opsGroup.Peers = append(opsGroup.Peers, newRouterID) - } - if allGroup, exists := account.Groups[allGroupID]; exists { - allGroup.Peers = append(allGroup.Peers, newRouterID) - } - - newRoute := &route.Route{ - ID: route.ID("route-new-router"), - Network: netip.MustParsePrefix("172.16.0.0/24"), - Peer: newRouter.Key, - PeerID: newRouterID, - Description: "Route from new router", - Enabled: true, - PeerGroups: []string{opsGroupID}, - Groups: []string{devGroupID, opsGroupID}, - AccessControlGroups: []string{devGroupID}, - AccountID: account.Id, - } - account.Routes[newRoute.ID] = newRoute - - validatedPeersMap[newRouterID] = struct{}{} - - b.ResetTimer() - b.Run("old builder after add", func(b *testing.B) { - for i := 0; i < b.N; i++ { - for _, testingPeerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) - } - } - }) - - b.ResetTimer() - b.Run("new builder after add", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = builder.OnPeerAddedIncremental(account, newRouterID) - for _, testingPeerID := range peerIDs { - _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - } - } - }) -} - -func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) { - account := createTestAccountWithEntities() - - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - if peerID == offlinePeerID { - continue - } - validatedPeersMap[peerID] = struct{}{} - } - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - - deletedPeerID := "peer-25" - - delete(account.Peers, deletedPeerID) - - if devGroup, exists := account.Groups[devGroupID]; exists { - devGroup.Peers = slices.DeleteFunc(devGroup.Peers, func(id string) bool { - return id == deletedPeerID - }) - } - - if allGroup, exists := account.Groups[allGroupID]; exists { - allGroup.Peers = slices.DeleteFunc(allGroup.Peers, func(id string) bool { - return id == deletedPeerID - }) - } - - delete(validatedPeersMap, deletedPeerID) - - if account.Network != nil { - account.Network.Serial++ - } - - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - - legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) - normalizeAndSortNetworkMap(legacyNetworkMap) - legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ") - require.NoError(t, err, "error marshaling legacy network map to JSON") - - err = builder.OnPeerDeleted(account, deletedPeerID) - require.NoError(t, err, "error deleting peer from cache") - - newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - normalizeAndSortNetworkMap(newNetworkMap) - newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ") - require.NoError(t, err, "error marshaling new network map to JSON") - - if string(legacyJSON) != string(newJSON) { - legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_peer.json") - newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeerdeleted.json") - - err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755) - require.NoError(t, err) - - err = os.WriteFile(legacyFilePath, legacyJSON, 0644) - require.NoError(t, err) - t.Logf("Saved legacy network map to %s", legacyFilePath) - - err = os.WriteFile(newFilePath, newJSON, 0644) - require.NoError(t, err) - t.Logf("Saved new network map to %s", newFilePath) - - require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with deleted peer from legacy and new builder do not match") - } -} - -func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) { - account := createTestAccountWithEntities() - - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - if peerID == offlinePeerID { - continue - } - validatedPeersMap[peerID] = struct{}{} - } - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - - deletedRouterID := "peer-75" - - var affectedRoute *route.Route - for _, r := range account.Routes { - if r.PeerID == deletedRouterID { - affectedRoute = r - break - } - } - require.NotNil(t, affectedRoute, "Router peer should have a route") - - for _, group := range account.Groups { - group.Peers = slices.DeleteFunc(group.Peers, func(id string) bool { - return id == deletedRouterID - }) - } - - for routeID, r := range account.Routes { - if r.Peer == account.Peers[deletedRouterID].Key || r.PeerID == deletedRouterID { - delete(account.Routes, routeID) - } - } - delete(account.Peers, deletedRouterID) - delete(validatedPeersMap, deletedRouterID) - - if account.Network != nil { - account.Network.Serial++ - } - - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - - legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) - normalizeAndSortNetworkMap(legacyNetworkMap) - legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ") - require.NoError(t, err, "error marshaling legacy network map to JSON") - - err = builder.OnPeerDeleted(account, deletedRouterID) - require.NoError(t, err, "error deleting routing peer from cache") - - newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - normalizeAndSortNetworkMap(newNetworkMap) - newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ") - require.NoError(t, err, "error marshaling new network map to JSON") - - if string(legacyJSON) != string(newJSON) { - legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_router_peer.json") - newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_deleted_router.json") - - err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755) - require.NoError(t, err) - - err = os.WriteFile(legacyFilePath, legacyJSON, 0644) - require.NoError(t, err) - t.Logf("Saved legacy network map to %s", legacyFilePath) - - err = os.WriteFile(newFilePath, newJSON, 0644) - require.NoError(t, err) - t.Logf("Saved new network map to %s", newFilePath) - - require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with deleted router from legacy and new builder do not match") - } -} - -func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) { - account := createTestAccountWithEntities() - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - var peerIDs []string - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - validatedPeersMap[peerID] = struct{}{} - peerIDs = append(peerIDs, peerID) - } - - deletedPeerID := "peer-25" - - delete(account.Peers, deletedPeerID) - account.Groups[devGroupID].Peers = slices.DeleteFunc(account.Groups[devGroupID].Peers, func(id string) bool { - return id == deletedPeerID - }) - account.Groups[allGroupID].Peers = slices.DeleteFunc(account.Groups[allGroupID].Peers, func(id string) bool { - return id == deletedPeerID - }) - delete(validatedPeersMap, deletedPeerID) - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - - b.ResetTimer() - b.Run("old builder after delete", func(b *testing.B) { - for i := 0; i < b.N; i++ { - for _, testingPeerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) - } - } - }) - - b.ResetTimer() - b.Run("new builder after delete", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = builder.OnPeerDeleted(account, deletedPeerID) - for _, testingPeerID := range peerIDs { - _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - } - } - }) -} - -func normalizeAndSortNetworkMap(networkMap *types.NetworkMap) { - for _, peer := range networkMap.Peers { - if peer.Status != nil { - peer.Status.LastSeen = time.Time{} - } - peer.LastLogin = &time.Time{} - } - for _, peer := range networkMap.OfflinePeers { - if peer.Status != nil { - peer.Status.LastSeen = time.Time{} - } - peer.LastLogin = &time.Time{} - } - - sort.Slice(networkMap.Peers, func(i, j int) bool { return networkMap.Peers[i].ID < networkMap.Peers[j].ID }) - sort.Slice(networkMap.OfflinePeers, func(i, j int) bool { return networkMap.OfflinePeers[i].ID < networkMap.OfflinePeers[j].ID }) - sort.Slice(networkMap.Routes, func(i, j int) bool { return networkMap.Routes[i].ID < networkMap.Routes[j].ID }) - - sort.Slice(networkMap.FirewallRules, func(i, j int) bool { - r1, r2 := networkMap.FirewallRules[i], networkMap.FirewallRules[j] - if r1.PeerIP != r2.PeerIP { - return r1.PeerIP < r2.PeerIP - } - if r1.Protocol != r2.Protocol { - return r1.Protocol < r2.Protocol - } - if r1.Direction != r2.Direction { - return r1.Direction < r2.Direction - } - if r1.Action != r2.Action { - return r1.Action < r2.Action - } - return r1.Port < r2.Port - }) - - sort.Slice(networkMap.RoutesFirewallRules, func(i, j int) bool { - r1, r2 := networkMap.RoutesFirewallRules[i], networkMap.RoutesFirewallRules[j] - if r1.RouteID != r2.RouteID { - return r1.RouteID < r2.RouteID - } - if r1.Action != r2.Action { - return r1.Action < r2.Action - } - if r1.Destination != r2.Destination { - return r1.Destination < r2.Destination - } - if len(r1.SourceRanges) > 0 && len(r2.SourceRanges) > 0 { - if r1.SourceRanges[0] != r2.SourceRanges[0] { - return r1.SourceRanges[0] < r2.SourceRanges[0] - } - } - return r1.Port < r2.Port - }) - - for _, ranges := range networkMap.RoutesFirewallRules { - sort.Slice(ranges.SourceRanges, func(i, j int) bool { - return ranges.SourceRanges[i] < ranges.SourceRanges[j] - }) - } -} - -type networkMapJSON struct { - Peers []*nbpeer.Peer `json:"Peers"` - Network *types.Network `json:"Network"` - Routes []*route.Route `json:"Routes"` - DNSConfig dns.Config `json:"DNSConfig"` - OfflinePeers []*nbpeer.Peer `json:"OfflinePeers"` - FirewallRules []*types.FirewallRule `json:"FirewallRules"` - RoutesFirewallRules []*types.RouteFirewallRule `json:"RoutesFirewallRules"` - ForwardingRules []*types.ForwardingRule `json:"ForwardingRules"` - AuthorizedUsers map[string][]string `json:"AuthorizedUsers,omitempty"` - EnableSSH bool `json:"EnableSSH"` -} - -func toNetworkMapJSON(nm *types.NetworkMap) *networkMapJSON { - result := &networkMapJSON{ - Peers: nm.Peers, - Network: nm.Network, - Routes: nm.Routes, - DNSConfig: nm.DNSConfig, - OfflinePeers: nm.OfflinePeers, - FirewallRules: nm.FirewallRules, - RoutesFirewallRules: nm.RoutesFirewallRules, - ForwardingRules: nm.ForwardingRules, - EnableSSH: nm.EnableSSH, - } - - if len(nm.AuthorizedUsers) > 0 { - result.AuthorizedUsers = make(map[string][]string) - localUsers := make([]string, 0, len(nm.AuthorizedUsers)) - for localUser := range nm.AuthorizedUsers { - localUsers = append(localUsers, localUser) - } - sort.Strings(localUsers) - - for _, localUser := range localUsers { - userIDs := nm.AuthorizedUsers[localUser] - sortedUserIDs := make([]string, 0, len(userIDs)) - for userID := range userIDs { - sortedUserIDs = append(sortedUserIDs, userID) - } - sort.Strings(sortedUserIDs) - result.AuthorizedUsers[localUser] = sortedUserIDs - } - } - - return result -} - -func createTestAccountWithEntities() *types.Account { - peers := make(map[string]*nbpeer.Peer) - devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{} - - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - ip := net.IP{100, 64, 0, byte(i + 1)} - wtVersion := "0.25.0" - if i%2 == 0 { - wtVersion = "0.40.0" - } - - p := &nbpeer.Peer{ - ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1), - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, - UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"}, - } - - if peerID == expiredPeerID { - p.LoginExpirationEnabled = true - pastTimestamp := time.Now().Add(-2 * time.Hour) - p.LastLogin = &pastTimestamp - } - - peers[peerID] = p - allGroupPeers = append(allGroupPeers, peerID) - if i < numPeers/2 { - devGroupPeers = append(devGroupPeers, peerID) - } else { - opsGroupPeers = append(opsGroupPeers, peerID) - } - - } - - groups := map[string]*types.Group{ - allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers}, - devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers}, - opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers}, - sshUsersGroupID: {ID: sshUsersGroupID, Name: "SSH Users", Peers: []string{}}, - } - - policies := []*types.Policy{ - { - ID: policyIDAll, Name: "Default-Allow", Enabled: true, - Rules: []*types.PolicyRule{{ - ID: policyIDAll, Name: "Allow All", Enabled: true, Action: types.PolicyTrafficActionAccept, - Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, - Sources: []string{allGroupID}, Destinations: []string{allGroupID}, - }}, - }, - { - ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true, - Rules: []*types.PolicyRule{{ - ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: types.PolicyTrafficActionAccept, - Protocol: types.PolicyRuleProtocolTCP, Bidirectional: false, - PortRanges: []types.RulePortRange{{Start: 8080, End: 8090}}, - Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, - }}, - }, - { - ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true, - Rules: []*types.PolicyRule{{ - ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: types.PolicyTrafficActionDrop, - Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true, - Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, - }}, - }, - { - ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true, - SourcePostureChecks: []string{postureCheckID}, - Rules: []*types.PolicyRule{{ - ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: types.PolicyTrafficActionAccept, - Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, - Sources: []string{opsGroupID}, DestinationResource: types.Resource{ID: networkResourceID}, - }}, - }, - { - ID: policyIDSSH, Name: "SSH Access Policy", Enabled: true, - Rules: []*types.PolicyRule{{ - ID: policyIDSSH, Name: "Allow SSH to Ops", Enabled: true, Action: types.PolicyTrafficActionAccept, - Protocol: types.PolicyRuleProtocolNetbirdSSH, Bidirectional: false, - Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, - AuthorizedGroups: map[string][]string{sshUsersGroupID: {"root", "admin"}}, - }}, - }, - } - - routes := map[route.ID]*route.Route{ - routeID: { - ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"), - Peer: peers["peer-75"].Key, - PeerID: "peer-75", - Description: "Route to internal resource", Enabled: true, - PeerGroups: []string{devGroupID, opsGroupID}, - Groups: []string{devGroupID, opsGroupID}, - AccessControlGroups: []string{devGroupID}, - }, - routeHA1ID: { - ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"), - Peer: peers["peer-80"].Key, - PeerID: "peer-80", - Description: "HA Route 1", Enabled: true, Metric: 1000, - PeerGroups: []string{allGroupID}, - Groups: []string{allGroupID}, - AccessControlGroups: []string{allGroupID}, - }, - routeHA2ID: { - ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"), - Peer: peers["peer-90"].Key, - PeerID: "peer-90", - Description: "HA Route 2", Enabled: true, Metric: 900, - PeerGroups: []string{devGroupID, opsGroupID}, - Groups: []string{devGroupID, opsGroupID}, - AccessControlGroups: []string{allGroupID}, - }, - } - - users := map[string]*types.User{ - userAdminID: {Id: userAdminID, Role: types.UserRoleAdmin, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{allGroupID}}, - userDevID: {Id: userDevID, Role: types.UserRoleUser, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{sshUsersGroupID, devGroupID}}, - userOpsID: {Id: userOpsID, Role: types.UserRoleUser, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{sshUsersGroupID, opsGroupID}}, - } - - account := &types.Account{ - Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes, - Users: users, - Network: &types.Network{ - Identifier: "net-golden-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1, - }, - DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{opsGroupID}}, - NameServerGroups: map[string]*dns.NameServerGroup{ - nameserverGroupID: { - ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID}, - NameServers: []dns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: dns.UDPNameServerType, Port: 53}}, - }, - }, - PostureChecks: []*posture.Checks{ - {ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"}, - }}, - }, - NetworkResources: []*resourceTypes.NetworkResource{ - {ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"}, - }, - Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}}, - NetworkRouters: []*routerTypes.NetworkRouter{ - {ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID}, - }, - Settings: &types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour}, - } - - for _, p := range account.Policies { - p.AccountID = account.Id - } - for _, r := range account.Routes { - r.AccountID = account.Id - } - - return account -} - -func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter_Batched(t *testing.T) { - account := createTestAccountWithEntities() - - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - if peerID == offlinePeerID { - continue - } - validatedPeersMap[peerID] = struct{}{} - } - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - - newRouterID := "peer-new-router-102" - newRouterIP := net.IP{100, 64, 1, 2} - newRouter := &nbpeer.Peer{ - ID: newRouterID, - IP: newRouterIP, - Key: fmt.Sprintf("key-%s", newRouterID), - DNSLabel: "newrouter102", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, - UserID: "user-admin", - Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, - LastLogin: func() *time.Time { t := time.Now(); return &t }(), - } - - account.Peers[newRouterID] = newRouter - - if opsGroup, exists := account.Groups[opsGroupID]; exists { - opsGroup.Peers = append(opsGroup.Peers, newRouterID) - } - if allGroup, exists := account.Groups[allGroupID]; exists { - allGroup.Peers = append(allGroup.Peers, newRouterID) - } - - newRoute := &route.Route{ - ID: route.ID("route-new-router"), - Network: netip.MustParsePrefix("172.16.0.0/24"), - Peer: newRouter.Key, - PeerID: newRouterID, - Description: "Route from new router", - Enabled: true, - PeerGroups: []string{opsGroupID}, - Groups: []string{devGroupID, opsGroupID}, - AccessControlGroups: []string{devGroupID}, - AccountID: account.Id, - } - account.Routes[newRoute.ID] = newRoute - - validatedPeersMap[newRouterID] = struct{}{} - - if account.Network != nil { - account.Network.Serial++ - } - - builder.EnqueuePeersForIncrementalAdd(account, newRouterID) - - time.Sleep(100 * time.Millisecond) - - networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - - normalizeAndSortNetworkMap(networkMap) - - jsonData, err := json.MarshalIndent(networkMap, "", " ") - require.NoError(t, err, "error marshaling network map to JSON") - - goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json") - - t.Log("Update golden file with OnPeerAdded router...") - err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) - require.NoError(t, err) - err = os.WriteFile(goldenFilePath, jsonData, 0644) - require.NoError(t, err) - - expectedJSON, err := os.ReadFile(goldenFilePath) - require.NoError(t, err, "error reading golden file") - - require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded router does not match golden file") -} diff --git a/management/server/types/networkmapbuilder.go b/management/server/types/networkmapbuilder.go deleted file mode 100644 index 6448b8403..000000000 --- a/management/server/types/networkmapbuilder.go +++ /dev/null @@ -1,2317 +0,0 @@ -package types - -import ( - "context" - "fmt" - "slices" - "strconv" - "strings" - "sync" - "time" - - log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" - - "github.com/netbirdio/netbird/client/ssh/auth" - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/modules/zones" - resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" - routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/route" -) - -const ( - allPeers = "0.0.0.0" - allWildcard = "0.0.0.0/0" - v6AllWildcard = "::/0" - fw = "fw:" - rfw = "route-fw:" - - szAddPeerBatch = 10 - maxPeerAddRetries = 20 -) - -type NetworkMapCache struct { - globalRoutes map[route.ID]*route.Route - globalRules map[string]*FirewallRule //ruleId - globalRouteRules map[string]*RouteFirewallRule //ruleId - globalPeers map[string]*nbpeer.Peer - - groupToPeers map[string][]string - peerToGroups map[string][]string - policyToRules map[string][]*PolicyRule //policyId - groupToPolicies map[string][]*Policy - groupToRoutes map[string][]*route.Route - peerToRoutes map[string][]*route.Route - - peerACLs map[string]*PeerACLView - peerRoutes map[string]*PeerRoutesView - peerDNS map[string]*nbdns.Config - peerSSH map[string]*PeerSSHView - - groupIDToUserIDs map[string][]string - allowedUserIDs map[string]struct{} - - resourceRouters map[string]map[string]*routerTypes.NetworkRouter - resourcePolicies map[string][]*Policy - - globalResources map[string]*resourceTypes.NetworkResource // resourceId - - acgToRoutes map[string]map[route.ID]*RouteOwnerInfo // routeID -> owner info - noACGRoutes map[route.ID]*RouteOwnerInfo - - mu sync.RWMutex -} - -type RouteOwnerInfo struct { - PeerID string - RouteID route.ID -} - -type PeerACLView struct { - ConnectedPeerIDs []string - FirewallRuleIDs []string -} - -type PeerRoutesView struct { - OwnRouteIDs []route.ID - NetworkResourceIDs []route.ID - InheritedRouteIDs []route.ID - RouteFirewallRuleIDs []string -} - -type PeerSSHView struct { - EnableSSH bool - AuthorizedUsers map[string]map[string]struct{} -} - -type NetworkMapBuilder struct { - account *Account - cache *NetworkMapCache - validatedPeers map[string]struct{} - - apb addPeerBatch -} - -type addPeerBatch struct { - mu sync.Mutex - sg *sync.Cond - ids []string - la *Account - retryCount map[string]int -} - -func NewNetworkMapBuilder(account *Account, validatedPeers map[string]struct{}) *NetworkMapBuilder { - builder := &NetworkMapBuilder{ - cache: &NetworkMapCache{ - globalRoutes: make(map[route.ID]*route.Route), - globalRules: make(map[string]*FirewallRule), - globalRouteRules: make(map[string]*RouteFirewallRule), - globalPeers: make(map[string]*nbpeer.Peer), - groupToPeers: make(map[string][]string), - peerToGroups: make(map[string][]string), - policyToRules: make(map[string][]*PolicyRule), - groupToPolicies: make(map[string][]*Policy), - groupToRoutes: make(map[string][]*route.Route), - peerToRoutes: make(map[string][]*route.Route), - peerACLs: make(map[string]*PeerACLView), - peerRoutes: make(map[string]*PeerRoutesView), - peerDNS: make(map[string]*nbdns.Config), - peerSSH: make(map[string]*PeerSSHView), - groupIDToUserIDs: make(map[string][]string), - allowedUserIDs: make(map[string]struct{}), - globalResources: make(map[string]*resourceTypes.NetworkResource), - acgToRoutes: make(map[string]map[route.ID]*RouteOwnerInfo), - noACGRoutes: make(map[route.ID]*RouteOwnerInfo), - }, - validatedPeers: make(map[string]struct{}), - } - builder.apb.sg = sync.NewCond(&builder.apb.mu) - builder.apb.ids = make([]string, 0, szAddPeerBatch) - builder.apb.la = account - builder.apb.retryCount = make(map[string]int) - - maps.Copy(builder.validatedPeers, validatedPeers) - - builder.initialBuild(account) - - go builder.incAddPeerLoop() - return builder -} - -func (b *NetworkMapBuilder) initialBuild(account *Account) { - b.cache.mu.Lock() - defer b.cache.mu.Unlock() - - b.account = account - - start := time.Now() - - b.buildGlobalIndexes(account) - - resourceRouters := account.GetResourceRoutersMap() - resourcePolicies := account.GetResourcePoliciesMap() - b.cache.resourceRouters = resourceRouters - b.cache.resourcePolicies = resourcePolicies - - for peerID := range account.Peers { - b.buildPeerACLView(account, peerID) - b.buildPeerRoutesView(account, peerID) - b.buildPeerDNSView(account, peerID) - } - - log.Debugf("NetworkMapBuilder: Initial build completed in %v for account %s", time.Since(start), account.Id) -} - -func (b *NetworkMapBuilder) buildGlobalIndexes(account *Account) { - clear(b.cache.globalPeers) - clear(b.cache.groupToPeers) - clear(b.cache.peerToGroups) - clear(b.cache.policyToRules) - clear(b.cache.groupToPolicies) - clear(b.cache.globalRoutes) - clear(b.cache.globalRules) - clear(b.cache.globalRouteRules) - clear(b.cache.globalResources) - clear(b.cache.groupToRoutes) - clear(b.cache.peerToRoutes) - clear(b.cache.acgToRoutes) - clear(b.cache.noACGRoutes) - clear(b.cache.groupIDToUserIDs) - clear(b.cache.allowedUserIDs) - clear(b.cache.peerSSH) - - maps.Copy(b.cache.globalPeers, account.Peers) - - b.cache.groupIDToUserIDs = account.GetActiveGroupUsers() - b.cache.allowedUserIDs = b.buildAllowedUserIDs(account) - - for groupID, group := range account.Groups { - peersCopy := make([]string, len(group.Peers)) - copy(peersCopy, group.Peers) - b.cache.groupToPeers[groupID] = peersCopy - - for _, peerID := range group.Peers { - b.cache.peerToGroups[peerID] = append(b.cache.peerToGroups[peerID], groupID) - } - } - - for _, policy := range account.Policies { - if !policy.Enabled { - continue - } - - b.cache.policyToRules[policy.ID] = policy.Rules - - affectedGroups := make(map[string]struct{}) - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } - - for _, groupID := range rule.Sources { - affectedGroups[groupID] = struct{}{} - } - for _, groupID := range rule.Destinations { - affectedGroups[groupID] = struct{}{} - } - if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { - groupId := rule.SourceResource.ID - affectedGroups[groupId] = struct{}{} - b.cache.peerToGroups[rule.SourceResource.ID] = append(b.cache.peerToGroups[rule.SourceResource.ID], groupId) - } - if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { - groupId := rule.DestinationResource.ID - affectedGroups[groupId] = struct{}{} - b.cache.peerToGroups[rule.DestinationResource.ID] = append(b.cache.peerToGroups[rule.DestinationResource.ID], groupId) - } - } - - for groupID := range affectedGroups { - b.cache.groupToPolicies[groupID] = append(b.cache.groupToPolicies[groupID], policy) - } - } - - for _, resource := range account.NetworkResources { - if !resource.Enabled { - continue - } - b.cache.globalResources[resource.ID] = resource - } - - for _, r := range account.Routes { - if !r.Enabled { - continue - } - for _, groupID := range r.PeerGroups { - b.cache.groupToRoutes[groupID] = append(b.cache.groupToRoutes[groupID], r) - } - if r.Peer != "" { - if peer, ok := b.cache.globalPeers[r.Peer]; ok { - b.cache.peerToRoutes[peer.ID] = append(b.cache.peerToRoutes[peer.ID], r) - } - } - } -} - -func (b *NetworkMapBuilder) buildPeerACLView(account *Account, peerID string) { - peer := account.GetPeer(peerID) - if peer == nil { - return - } - - allPotentialPeers, firewallRules, authorizedUsers, sshEnabled := b.getPeerConnectionResources(account, peer, b.validatedPeers) - - isRouter, networkResourcesRoutes, sourcePeers := b.getNetworkResourcesForPeer(account, peer) - - var emptyExpiredPeers []*nbpeer.Peer - finalAllPeers := b.addNetworksRoutingPeers( - networkResourcesRoutes, - peer, - allPotentialPeers, - emptyExpiredPeers, - isRouter, - sourcePeers, - ) - - view := &PeerACLView{ - ConnectedPeerIDs: make([]string, 0, len(finalAllPeers)), - FirewallRuleIDs: make([]string, 0, len(firewallRules)), - } - - for _, p := range finalAllPeers { - view.ConnectedPeerIDs = append(view.ConnectedPeerIDs, p.ID) - } - - for _, rule := range firewallRules { - ruleID := b.generateFirewallRuleID(rule) - view.FirewallRuleIDs = append(view.FirewallRuleIDs, ruleID) - b.cache.globalRules[ruleID] = rule - } - - b.cache.peerACLs[peerID] = view - b.cache.peerSSH[peerID] = &PeerSSHView{ - EnableSSH: sshEnabled, - AuthorizedUsers: authorizedUsers, - } -} - -func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *nbpeer.Peer, - validatedPeersMap map[string]struct{}, -) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) { - peerID := peer.ID - ctx := context.Background() - - peerGroups := b.cache.peerToGroups[peerID] - peerGroupsMap := make(map[string]struct{}, len(peerGroups)) - for _, groupID := range peerGroups { - peerGroupsMap[groupID] = struct{}{} - } - - rulesExists := make(map[string]struct{}) - peersExists := make(map[string]struct{}) - fwRules := make([]*FirewallRule, 0) - peers := make([]*nbpeer.Peer, 0) - - authorizedUsers := make(map[string]map[string]struct{}) - sshEnabled := false - - for _, group := range peerGroups { - policies := b.cache.groupToPolicies[group] - for _, policy := range policies { - if isValid := account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID); !isValid { - continue - } - rules := b.cache.policyToRules[policy.ID] - for _, rule := range rules { - var sourcePeers, destinationPeers []*nbpeer.Peer - var peerInSources, peerInDestinations bool - - if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { - peerInSources = rule.SourceResource.ID == peerID - } else { - peerInSources = b.isPeerInGroupscached(rule.Sources, peerGroupsMap) - } - - if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { - peerInDestinations = rule.DestinationResource.ID == peerID - } else { - peerInDestinations = b.isPeerInGroupscached(rule.Destinations, peerGroupsMap) - } - - if !peerInSources && !peerInDestinations { - continue - } - - if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { - peer := account.GetPeer(rule.SourceResource.ID) - if peer != nil { - sourcePeers = []*nbpeer.Peer{peer} - } - } else { - sourcePeers = b.getPeersFromGroupscached(account, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) - } - - if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { - peer := account.GetPeer(rule.DestinationResource.ID) - if peer != nil { - destinationPeers = []*nbpeer.Peer{peer} - } - } else { - destinationPeers = b.getPeersFromGroupscached(account, rule.Destinations, peerID, nil, validatedPeersMap) - } - - if rule.Bidirectional { - if peerInSources { - b.generateResourcescached( - rule, destinationPeers, FirewallRuleDirectionIN, - peer, &peers, &fwRules, peersExists, rulesExists, - ) - } - if peerInDestinations { - b.generateResourcescached( - rule, sourcePeers, FirewallRuleDirectionOUT, - peer, &peers, &fwRules, peersExists, rulesExists, - ) - } - } - - if peerInSources { - b.generateResourcescached( - rule, destinationPeers, FirewallRuleDirectionOUT, - peer, &peers, &fwRules, peersExists, rulesExists, - ) - } - - if peerInDestinations { - b.generateResourcescached( - rule, sourcePeers, FirewallRuleDirectionIN, - peer, &peers, &fwRules, peersExists, rulesExists, - ) - - if rule.Protocol == PolicyRuleProtocolNetbirdSSH { - sshEnabled = true - switch { - case len(rule.AuthorizedGroups) > 0: - for groupID, localUsers := range rule.AuthorizedGroups { - userIDs, ok := b.cache.groupIDToUserIDs[groupID] - if !ok { - continue - } - - if len(localUsers) == 0 { - localUsers = []string{auth.Wildcard} - } - - for _, localUser := range localUsers { - if authorizedUsers[localUser] == nil { - authorizedUsers[localUser] = make(map[string]struct{}) - } - for _, userID := range userIDs { - authorizedUsers[localUser][userID] = struct{}{} - } - } - } - case rule.AuthorizedUser != "": - if authorizedUsers[auth.Wildcard] == nil { - authorizedUsers[auth.Wildcard] = make(map[string]struct{}) - } - authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{} - default: - authorizedUsers[auth.Wildcard] = maps.Clone(b.cache.allowedUserIDs) - } - } else if policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled { - sshEnabled = true - authorizedUsers[auth.Wildcard] = maps.Clone(b.cache.allowedUserIDs) - } - } - } - } - } - - return peers, fwRules, authorizedUsers, sshEnabled -} - -func (b *NetworkMapBuilder) isPeerInGroupscached(groupIDs []string, peerGroupsMap map[string]struct{}) bool { - for _, groupID := range groupIDs { - if _, exists := peerGroupsMap[groupID]; exists { - return true - } - } - return false -} - -func (b *NetworkMapBuilder) getPeersFromGroupscached(account *Account, groupIDs []string, - excludePeerID string, postureChecksIDs []string, validatedPeersMap map[string]struct{}, -) []*nbpeer.Peer { - ctx := context.Background() - uniquePeers := make(map[string]*nbpeer.Peer) - - for _, groupID := range groupIDs { - peerIDs := b.cache.groupToPeers[groupID] - for _, peerID := range peerIDs { - if peerID == excludePeerID { - continue - } - - if _, ok := validatedPeersMap[peerID]; !ok { - continue - } - - peer := b.cache.globalPeers[peerID] - if peer == nil { - continue - } - - if len(postureChecksIDs) > 0 { - if !account.validatePostureChecksOnPeer(ctx, postureChecksIDs, peerID) { - continue - } - } - - uniquePeers[peerID] = peer - } - } - - result := make([]*nbpeer.Peer, 0, len(uniquePeers)) - for _, peer := range uniquePeers { - result = append(result, peer) - } - - return result -} - -func (b *NetworkMapBuilder) generateResourcescached( - rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int, targetPeer *nbpeer.Peer, - peers *[]*nbpeer.Peer, rules *[]*FirewallRule, peersExists map[string]struct{}, rulesExists map[string]struct{}, -) { - for _, peer := range groupPeers { - if peer == nil { - continue - } - if _, ok := peersExists[peer.ID]; !ok { - *peers = append(*peers, peer) - peersExists[peer.ID] = struct{}{} - } - - fr := FirewallRule{ - PolicyID: rule.ID, - PeerIP: peer.IP.String(), - Direction: direction, - Action: string(rule.Action), - Protocol: firewallRuleProtocol(rule.Protocol), - } - - var s strings.Builder - s.WriteString(rule.ID) - s.WriteString(fr.PeerIP) - s.WriteString(strconv.Itoa(direction)) - s.WriteString(fr.Protocol) - s.WriteString(fr.Action) - s.WriteString(strings.Join(rule.Ports, ",")) - - ruleID := s.String() - - if _, ok := rulesExists[ruleID]; ok { - continue - } - rulesExists[ruleID] = struct{}{} - - if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 { - *rules = append(*rules, &fr) - continue - } - - *rules = append(*rules, expandPortsAndRanges(fr, rule, targetPeer)...) - } -} - -func (b *NetworkMapBuilder) getNetworkResourcesForPeer(account *Account, peer *nbpeer.Peer) (bool, []*route.Route, map[string]struct{}) { - ctx := context.Background() - peerID := peer.ID - - var isRoutingPeer bool - var routes []*route.Route - allSourcePeers := make(map[string]struct{}) - - peerGroups := b.cache.peerToGroups[peerID] - peerGroupsMap := make(map[string]struct{}, len(peerGroups)) - for _, groupID := range peerGroups { - peerGroupsMap[groupID] = struct{}{} - } - - for _, resource := range b.cache.globalResources { - - networkRoutingPeers := b.cache.resourceRouters[resource.NetworkID] - resourcePolicies := b.cache.resourcePolicies[resource.ID] - if len(resourcePolicies) == 0 { - continue - } - - isRouterForThisResource := false - - if networkRoutingPeers != nil { - if router, ok := networkRoutingPeers[peerID]; ok && router.Enabled { - isRoutingPeer = true - isRouterForThisResource = true - if rt := b.createNetworkResourceRoutes(resource, peerID, router, resourcePolicies); rt != nil { - routes = append(routes, rt) - } - } - } - - hasAccessAsClient := false - if !isRouterForThisResource { - for _, policy := range resourcePolicies { - if b.isPeerInGroupscached(policy.SourceGroups(), peerGroupsMap) { - if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { - hasAccessAsClient = true - break - } - } - } - } - - if hasAccessAsClient && networkRoutingPeers != nil { - for routerPeerID, router := range networkRoutingPeers { - if router.Enabled { - if rt := b.createNetworkResourceRoutes(resource, routerPeerID, router, resourcePolicies); rt != nil { - routes = append(routes, rt) - } - } - } - } - - if isRouterForThisResource { - for _, policy := range resourcePolicies { - var peersWithAccess []*nbpeer.Peer - if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { - peersWithAccess = []*nbpeer.Peer{peer} - } else { - peersWithAccess = b.getPeersFromGroupscached(account, policy.SourceGroups(), "", policy.SourcePostureChecks, b.validatedPeers) - } - for _, p := range peersWithAccess { - allSourcePeers[p.ID] = struct{}{} - } - } - } - } - - return isRoutingPeer, routes, allSourcePeers -} - -func (b *NetworkMapBuilder) createNetworkResourceRoutes( - resource *resourceTypes.NetworkResource, routerPeerID string, - router *routerTypes.NetworkRouter, resourcePolicies []*Policy, -) *route.Route { - if len(resourcePolicies) > 0 { - peer := b.cache.globalPeers[routerPeerID] - if peer != nil { - return resource.ToRoute(peer, router) - } - } - return nil -} - -func (b *NetworkMapBuilder) addNetworksRoutingPeers( - networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, peersToConnect []*nbpeer.Peer, - expiredPeers []*nbpeer.Peer, isRouter bool, sourcePeers map[string]struct{}, -) []*nbpeer.Peer { - - networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes)) - for _, r := range networkResourcesRoutes { - networkRoutesPeers[r.PeerID] = struct{}{} - } - - delete(sourcePeers, peer.ID) - delete(networkRoutesPeers, peer.ID) - - for _, existingPeer := range peersToConnect { - delete(sourcePeers, existingPeer.ID) - delete(networkRoutesPeers, existingPeer.ID) - } - for _, expPeer := range expiredPeers { - delete(sourcePeers, expPeer.ID) - delete(networkRoutesPeers, expPeer.ID) - } - - missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers)) - if isRouter { - for p := range sourcePeers { - missingPeers[p] = struct{}{} - } - } - for p := range networkRoutesPeers { - missingPeers[p] = struct{}{} - } - - for p := range missingPeers { - if missingPeer := b.cache.globalPeers[p]; missingPeer != nil { - peersToConnect = append(peersToConnect, missingPeer) - } - } - - return peersToConnect -} - -func (b *NetworkMapBuilder) buildPeerRoutesView(account *Account, peerID string) { - ctx := context.Background() - peer := account.GetPeer(peerID) - if peer == nil { - return - } - resourcePolicies := b.cache.resourcePolicies - - view := &PeerRoutesView{ - OwnRouteIDs: make([]route.ID, 0), - NetworkResourceIDs: make([]route.ID, 0), - RouteFirewallRuleIDs: make([]string, 0), - } - - enabledRoutes, disabledRoutes := b.getRoutingPeerRoutes(peerID) - for _, rt := range enabledRoutes { - if rt.PeerID != "" && rt.PeerID != peerID { - if b.cache.globalPeers[rt.PeerID] == nil { - continue - } - } - - view.OwnRouteIDs = append(view.OwnRouteIDs, rt.ID) - b.cache.globalRoutes[rt.ID] = rt - } - - aclView := b.cache.peerACLs[peerID] - if aclView != nil { - peerRoutesMembership := make(LookupMap) - for _, r := range append(enabledRoutes, disabledRoutes...) { - peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} - } - - peerGroups := b.cache.peerToGroups[peerID] - peerGroupsMap := make(LookupMap) - for _, groupID := range peerGroups { - peerGroupsMap[groupID] = struct{}{} - } - - for _, aclPeerID := range aclView.ConnectedPeerIDs { - if aclPeerID == peerID { - continue - } - activeRoutes, _ := b.getRoutingPeerRoutes(aclPeerID) - groupFilteredRoutes := account.filterRoutesByGroups(activeRoutes, peerGroupsMap) - haFilteredRoutes := account.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) - - for _, inheritedRoute := range haFilteredRoutes { - view.InheritedRouteIDs = append(view.InheritedRouteIDs, inheritedRoute.ID) - b.cache.globalRoutes[inheritedRoute.ID] = inheritedRoute - } - } - } - - _, networkResourcesRoutes, _ := b.getNetworkResourcesForPeer(account, peer) - - for _, rt := range networkResourcesRoutes { - view.NetworkResourceIDs = append(view.NetworkResourceIDs, rt.ID) - b.cache.globalRoutes[rt.ID] = rt - } - - allRoutes := slices.Concat(enabledRoutes, networkResourcesRoutes) - b.updateACGIndexForPeer(peerID, allRoutes) - - routeFirewallRules := b.getPeerRoutesFirewallRules(account, peerID, b.validatedPeers) - for _, rule := range routeFirewallRules { - ruleID := b.generateRouteFirewallRuleID(rule) - view.RouteFirewallRuleIDs = append(view.RouteFirewallRuleIDs, ruleID) - b.cache.globalRouteRules[ruleID] = rule - } - - if len(networkResourcesRoutes) > 0 { - networkResourceFirewallRules := account.GetPeerNetworkResourceFirewallRules(ctx, peer, b.validatedPeers, networkResourcesRoutes, resourcePolicies) - for _, rule := range networkResourceFirewallRules { - ruleID := b.generateRouteFirewallRuleID(rule) - view.RouteFirewallRuleIDs = append(view.RouteFirewallRuleIDs, ruleID) - b.cache.globalRouteRules[ruleID] = rule - } - } - - b.cache.peerRoutes[peerID] = view -} - -func (b *NetworkMapBuilder) updateACGIndexForPeer(peerID string, routes []*route.Route) { - for acg, routeMap := range b.cache.acgToRoutes { - for routeID, info := range routeMap { - if info.PeerID == peerID { - delete(routeMap, routeID) - } - } - if len(routeMap) == 0 { - delete(b.cache.acgToRoutes, acg) - } - } - - for routeID, info := range b.cache.noACGRoutes { - if info.PeerID == peerID { - delete(b.cache.noACGRoutes, routeID) - } - } - - for _, rt := range routes { - if !rt.Enabled { - continue - } - - if len(rt.AccessControlGroups) == 0 { - b.cache.noACGRoutes[rt.ID] = &RouteOwnerInfo{ - PeerID: peerID, - RouteID: rt.ID, - } - } else { - for _, acg := range rt.AccessControlGroups { - if b.cache.acgToRoutes[acg] == nil { - b.cache.acgToRoutes[acg] = make(map[route.ID]*RouteOwnerInfo) - } - - b.cache.acgToRoutes[acg][rt.ID] = &RouteOwnerInfo{ - PeerID: peerID, - RouteID: rt.ID, - } - } - } - } -} - -func (b *NetworkMapBuilder) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { - peer := b.cache.globalPeers[peerID] - if peer == nil { - return enabledRoutes, disabledRoutes - } - - seenRoute := make(map[route.ID]struct{}) - - takeRoute := func(r *route.Route, id string) { - if _, ok := seenRoute[r.ID]; ok { - return - } - seenRoute[r.ID] = struct{}{} - - if r.Enabled { - // maybe here is some mess - here we store peer key (see comment below) - r.Peer = peer.Key - enabledRoutes = append(enabledRoutes, r) - return - } - disabledRoutes = append(disabledRoutes, r) - } - - peerGroups := b.cache.peerToGroups[peerID] - for _, groupID := range peerGroups { - groupRoutes := b.cache.groupToRoutes[groupID] - for _, r := range groupRoutes { - newPeerRoute := r.Copy() - // and here we store peer ID - this logic is taken from original account.getRoutingPeerRoutes - newPeerRoute.Peer = peerID - newPeerRoute.PeerGroups = nil - newPeerRoute.ID = route.ID(string(r.ID) + ":" + peerID) - takeRoute(newPeerRoute, peerID) - } - } - for _, r := range b.cache.peerToRoutes[peerID] { - takeRoute(r.Copy(), peerID) - } - return enabledRoutes, disabledRoutes -} - -func (b *NetworkMapBuilder) getPeerRoutesFirewallRules(account *Account, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule { - routesFirewallRules := make([]*RouteFirewallRule, 0) - - enabledRoutes, _ := b.getRoutingPeerRoutes(peerID) - for _, route := range enabledRoutes { - if len(route.AccessControlGroups) == 0 { - defaultPermit := getDefaultPermit(route) - routesFirewallRules = append(routesFirewallRules, defaultPermit...) - continue - } - - distributionPeers := b.getDistributionGroupsPeers(route) - - for _, accessGroup := range route.AccessControlGroups { - policies := b.getAllRoutePoliciesFromGroups([]string{accessGroup}) - - rules := b.getRouteFirewallRules(peerID, policies, route, validatedPeersMap, distributionPeers, account) - routesFirewallRules = append(routesFirewallRules, rules...) - } - } - - return routesFirewallRules -} - -func (b *NetworkMapBuilder) getDistributionGroupsPeers(route *route.Route) map[string]struct{} { - distPeers := make(map[string]struct{}) - for _, id := range route.Groups { - groupPeers := b.cache.groupToPeers[id] - if groupPeers == nil { - continue - } - - for _, pID := range groupPeers { - distPeers[pID] = struct{}{} - } - } - return distPeers -} - -func (b *NetworkMapBuilder) getAllRoutePoliciesFromGroups(accessControlGroups []string) []*Policy { - routePolicies := make(map[string]*Policy) - - for _, groupID := range accessControlGroups { - candidatePolicies := b.cache.groupToPolicies[groupID] - - for _, policy := range candidatePolicies { - if _, found := routePolicies[policy.ID]; found { - continue - } - policyRules := b.cache.policyToRules[policy.ID] - for _, rule := range policyRules { - if slices.Contains(rule.Destinations, groupID) { - routePolicies[policy.ID] = policy - break - } - } - } - } - - return maps.Values(routePolicies) -} - -func (b *NetworkMapBuilder) getRouteFirewallRules( - peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, - distributionPeers map[string]struct{}, account *Account, -) []*RouteFirewallRule { - ctx := context.Background() - var fwRules []*RouteFirewallRule - for _, policy := range policies { - if !policy.Enabled { - continue - } - - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } - - rulePeers := b.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap, account) - - rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) - fwRules = append(fwRules, rules...) - } - } - return fwRules -} - -func (b *NetworkMapBuilder) getRulePeers( - rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, - validatedPeersMap map[string]struct{}, account *Account, -) []*nbpeer.Peer { - distPeersWithPolicy := make(map[string]struct{}) - - for _, id := range rule.Sources { - groupPeers := b.cache.groupToPeers[id] - if groupPeers == nil { - continue - } - - for _, pID := range groupPeers { - if pID == peerID { - continue - } - _, distPeer := distributionPeers[pID] - _, valid := validatedPeersMap[pID] - - if distPeer && valid && account.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) { - distPeersWithPolicy[pID] = struct{}{} - } - } - } - - if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { - _, distPeer := distributionPeers[rule.SourceResource.ID] - _, valid := validatedPeersMap[rule.SourceResource.ID] - if distPeer && valid && account.validatePostureChecksOnPeer(context.Background(), postureChecks, rule.SourceResource.ID) { - distPeersWithPolicy[rule.SourceResource.ID] = struct{}{} - } - } - - distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) - for pID := range distPeersWithPolicy { - peer := b.cache.globalPeers[pID] - if peer == nil { - continue - } - distributionGroupPeers = append(distributionGroupPeers, peer) - } - return distributionGroupPeers -} - -func (b *NetworkMapBuilder) buildPeerDNSView(account *Account, peerID string) { - peerGroups := b.cache.peerToGroups[peerID] - checkGroups := make(map[string]struct{}, len(peerGroups)) - for _, groupID := range peerGroups { - checkGroups[groupID] = struct{}{} - } - - dnsManagementStatus := b.getPeerDNSManagementStatus(account, checkGroups) - dnsConfig := &nbdns.Config{ - ServiceEnable: dnsManagementStatus, - } - - if dnsManagementStatus { - dnsConfig.NameServerGroups = b.getPeerNSGroups(account, peerID, checkGroups) - } - - b.cache.peerDNS[peerID] = dnsConfig -} - -func (b *NetworkMapBuilder) getPeerDNSManagementStatus(account *Account, checkGroups map[string]struct{}) bool { - - enabled := true - for _, groupID := range account.DNSSettings.DisabledManagementGroups { - _, found := checkGroups[groupID] - if found { - enabled = false - break - } - } - return enabled -} - -func (b *NetworkMapBuilder) getPeerNSGroups(account *Account, peerID string, checkGroups map[string]struct{}) []*nbdns.NameServerGroup { - var peerNSGroups []*nbdns.NameServerGroup - - for _, nsGroup := range account.NameServerGroups { - if !nsGroup.Enabled { - continue - } - for _, gID := range nsGroup.Groups { - _, found := checkGroups[gID] - if found { - peer := b.cache.globalPeers[peerID] - if !peerIsNameserver(peer, nsGroup) { - peerNSGroups = append(peerNSGroups, nsGroup.Copy()) - break - } - } - } - } - - return peerNSGroups -} - -func (b *NetworkMapBuilder) buildAllowedUserIDs(account *Account) map[string]struct{} { - users := make(map[string]struct{}) - for _, nbUser := range account.Users { - if !nbUser.IsBlocked() && !nbUser.IsServiceUser { - users[nbUser.Id] = struct{}{} - } - } - return users -} - -func firewallRuleProtocol(protocol PolicyRuleProtocolType) string { - if protocol == PolicyRuleProtocolNetbirdSSH { - return string(PolicyRuleProtocolTCP) - } - return string(protocol) -} - -// lock should be held -func (b *NetworkMapBuilder) updateAccountLocked(account *Account) *Account { - if account.Network.CurrentSerial() > b.account.Network.CurrentSerial() { - b.account = account - } - return b.account -} - -func (b *NetworkMapBuilder) GetPeerNetworkMap( - ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone, accountZones []*zones.Zone, - validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics, -) *NetworkMap { - start := time.Now() - - b.cache.mu.RLock() - defer b.cache.mu.RUnlock() - - account := b.account - - peer := account.GetPeer(peerID) - if peer == nil { - return &NetworkMap{Network: account.Network.Copy()} - } - - aclView := b.cache.peerACLs[peerID] - routesView := b.cache.peerRoutes[peerID] - dnsConfig := b.cache.peerDNS[peerID] - sshView := b.cache.peerSSH[peerID] - - if aclView == nil || routesView == nil || dnsConfig == nil { - return &NetworkMap{Network: account.Network.Copy()} - } - - nm := b.assembleNetworkMap(ctx, account, peer, aclView, routesView, dnsConfig, sshView, peersCustomZone, accountZones, validatedPeers) - - if metrics != nil { - objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules)) - metrics.CountNetworkMapObjects(objectCount) - metrics.CountGetPeerNetworkMapDuration(time.Since(start)) - - if objectCount > 5000 { - log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects from cache", - account.Id, objectCount) - } - } - - return nm -} - -func (b *NetworkMapBuilder) assembleNetworkMap( - ctx context.Context, account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView, - dnsConfig *nbdns.Config, sshView *PeerSSHView, peersCustomZone nbdns.CustomZone, accountZones []*zones.Zone, validatedPeers map[string]struct{}, -) *NetworkMap { - - var peersToConnect []*nbpeer.Peer - var expiredPeers []*nbpeer.Peer - - for _, peerID := range aclView.ConnectedPeerIDs { - if _, ok := validatedPeers[peerID]; !ok { - continue - } - - peer := b.cache.globalPeers[peerID] - if peer == nil { - continue - } - - expired, _ := peer.LoginExpired(account.Settings.PeerLoginExpiration) - if account.Settings.PeerLoginExpirationEnabled && expired { - expiredPeers = append(expiredPeers, peer) - } else { - peersToConnect = append(peersToConnect, peer) - } - } - - var routes []*route.Route - allRouteIDs := slices.Concat(routesView.OwnRouteIDs, routesView.NetworkResourceIDs, routesView.InheritedRouteIDs) - - for _, routeID := range allRouteIDs { - if route := b.cache.globalRoutes[routeID]; route != nil { - routes = append(routes, route) - } - } - - var firewallRules []*FirewallRule - for _, ruleID := range aclView.FirewallRuleIDs { - if rule := b.cache.globalRules[ruleID]; rule != nil { - firewallRules = append(firewallRules, rule) - } else { - log.Debugf("NetworkMapBuilder: peer %s assembling network map has no fwrule %s in globalRules", peer.ID, ruleID) - } - } - - var routesFirewallRules []*RouteFirewallRule - for _, ruleID := range routesView.RouteFirewallRuleIDs { - if rule := b.cache.globalRouteRules[ruleID]; rule != nil { - routesFirewallRules = append(routesFirewallRules, rule) - } - } - - finalDNSConfig := *dnsConfig - if finalDNSConfig.ServiceEnable { - var zones []nbdns.CustomZone - - peerGroupsSlice := b.cache.peerToGroups[peer.ID] - peerGroups := make(LookupMap, len(peerGroupsSlice)) - for _, groupID := range peerGroupsSlice { - peerGroups[groupID] = struct{}{} - } - - if peersCustomZone.Domain != "" { - records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect, expiredPeers) - zones = append(zones, nbdns.CustomZone{ - Domain: peersCustomZone.Domain, - Records: records, - }) - } - - filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups) - zones = append(zones, filteredAccountZones...) - - finalDNSConfig.CustomZones = zones - } - - nm := &NetworkMap{ - Peers: peersToConnect, - Network: account.Network.Copy(), - Routes: routes, - DNSConfig: finalDNSConfig, - OfflinePeers: expiredPeers, - FirewallRules: firewallRules, - RoutesFirewallRules: routesFirewallRules, - } - - if sshView != nil { - nm.EnableSSH = sshView.EnableSSH - nm.AuthorizedUsers = sshView.AuthorizedUsers - } - - return nm -} - -func (b *NetworkMapBuilder) generateFirewallRuleID(rule *FirewallRule) string { - var s strings.Builder - s.WriteString(fw) - s.WriteString(rule.PolicyID) - s.WriteRune(':') - s.WriteString(rule.PeerIP) - s.WriteRune(':') - s.WriteString(strconv.Itoa(rule.Direction)) - s.WriteRune(':') - s.WriteString(rule.Protocol) - s.WriteRune(':') - s.WriteString(rule.Action) - s.WriteRune(':') - s.WriteString(rule.Port) - s.WriteRune(':') - s.WriteString(strconv.Itoa(int(rule.PortRange.Start))) - s.WriteRune('-') - s.WriteString(strconv.Itoa(int(rule.PortRange.End))) - return s.String() -} - -func (b *NetworkMapBuilder) generateRouteFirewallRuleID(rule *RouteFirewallRule) string { - var s strings.Builder - s.WriteString(rfw) - s.WriteString(string(rule.RouteID)) - s.WriteRune(':') - s.WriteString(rule.Destination) - s.WriteRune(':') - s.WriteString(rule.Action) - s.WriteRune(':') - s.WriteString(strings.Join(rule.SourceRanges, ",")) - s.WriteRune(':') - s.WriteString(rule.Protocol) - s.WriteRune(':') - s.WriteString(strconv.Itoa(int(rule.Port))) - return s.String() -} - -func (b *NetworkMapBuilder) isPeerInGroups(groupIDs []string, peerGroups []string) bool { - for _, groupID := range groupIDs { - if slices.Contains(peerGroups, groupID) { - return true - } - } - return false -} - -func (b *NetworkMapBuilder) isPeerRouter(account *Account, peerID string) bool { - for _, r := range account.Routes { - if !r.Enabled { - continue - } - - if r.PeerID == peerID { - return true - } - - if peer := b.cache.globalPeers[peerID]; peer != nil { - if r.Peer == peer.Key && r.PeerID == "" { - return true - } - } - } - - routers := account.GetResourceRoutersMap() - for _, networkRouters := range routers { - if router, exists := networkRouters[peerID]; exists && router.Enabled { - return true - } - } - - return false -} - -func (b *NetworkMapBuilder) incAddPeerLoop() { - for { - b.apb.mu.Lock() - if len(b.apb.ids) == 0 { - b.apb.sg.Wait() - } - b.addPeersIncrementally() - b.apb.mu.Unlock() - } -} - -// lock on b.apb level should be held -func (b *NetworkMapBuilder) addPeersIncrementally() { - peers := slices.Clone(b.apb.ids) - clear(b.apb.ids) - b.apb.ids = b.apb.ids[:0] - latestAcc := b.apb.la - b.apb.mu.Unlock() - - tt := time.Now() - b.cache.mu.Lock() - defer b.cache.mu.Unlock() - - account := b.updateAccountLocked(latestAcc) - - log.Debugf("NetworkMapBuilder: Starting incremental add of %d peers", len(peers)) - - allUpdates := make(map[string]*PeerUpdateDelta) - - for _, peerID := range peers { - peer := account.GetPeer(peerID) - if peer == nil { - b.apb.mu.Lock() - retries := b.apb.retryCount[peerID] - b.apb.mu.Unlock() - - if retries >= maxPeerAddRetries { - log.Errorf("NetworkMapBuilder: peer %s not found in account %s after %d retries, giving up", peerID, account.Id, retries) - b.apb.mu.Lock() - delete(b.apb.retryCount, peerID) - b.apb.mu.Unlock() - continue - } - - log.Warnf("NetworkMapBuilder: peer %s not found in account %s, retry %d/%d", peerID, account.Id, retries+1, maxPeerAddRetries) - b.apb.mu.Lock() - b.apb.retryCount[peerID] = retries + 1 - b.apb.mu.Unlock() - b.enqueuePeersForIncrementalAdd(latestAcc, peerID) - continue - } - - b.apb.mu.Lock() - delete(b.apb.retryCount, peerID) - b.apb.mu.Unlock() - - b.validatedPeers[peerID] = struct{}{} - b.cache.globalPeers[peerID] = peer - - peerGroups := b.updateIndexesForNewPeer(account, peerID) - b.buildPeerACLView(account, peerID) - b.buildPeerRoutesView(account, peerID) - b.buildPeerDNSView(account, peerID) - - peerDeltas := b.collectDeltasForNewPeer(account, peerID, peerGroups) - for affectedPeerID, delta := range peerDeltas { - if existing, ok := allUpdates[affectedPeerID]; ok { - existing.mergeFrom(delta) - continue - } - allUpdates[affectedPeerID] = delta - } - } - - for affectedPeerID, delta := range allUpdates { - b.applyDeltaToPeer(account, affectedPeerID, delta) - } - - log.Debugf("NetworkMapBuilder: Added %d peers to cache, affected %d peers, took %s", len(peers), len(allUpdates), time.Since(tt)) - - b.apb.mu.Lock() - if len(b.apb.ids) > 0 { - b.apb.sg.Signal() - } -} - -func (b *NetworkMapBuilder) enqueuePeersForIncrementalAdd(acc *Account, peerIDs ...string) { - b.apb.mu.Lock() - b.apb.ids = append(b.apb.ids, peerIDs...) - if b.apb.la != nil && acc.Network.CurrentSerial() > b.apb.la.Network.CurrentSerial() { - b.apb.la = acc - } - b.apb.sg.Signal() - b.apb.mu.Unlock() -} - -func (b *NetworkMapBuilder) EnqueuePeersForIncrementalAdd(acc *Account, peerIDs ...string) { - b.enqueuePeersForIncrementalAdd(acc, peerIDs...) -} - -type ViewDelta struct { - AddedPeerIDs []string - RemovedPeerIDs []string - AddedRuleIDs []string - RemovedRuleIDs []string -} - -func (b *NetworkMapBuilder) OnPeerAddedIncremental(acc *Account, peerID string) error { - tt := time.Now() - peer := acc.GetPeer(peerID) - if peer == nil { - return fmt.Errorf("NetworkMapBuilder: peer %s not found in account", peerID) - } - - b.cache.mu.Lock() - defer b.cache.mu.Unlock() - - account := b.updateAccountLocked(acc) - - log.Debugf("NetworkMapBuilder: Adding peer %s (IP: %s) to cache", peerID, peer.IP.String()) - - b.validatedPeers[peerID] = struct{}{} - - b.cache.globalPeers[peerID] = peer - - peerGroups := b.updateIndexesForNewPeer(account, peerID) - - b.buildPeerACLView(account, peerID) - b.buildPeerRoutesView(account, peerID) - b.buildPeerDNSView(account, peerID) - - log.Debugf("NetworkMapBuilder: Adding peer %s to cache, views took %s", peerID, time.Since(tt)) - - b.incrementalUpdateAffectedPeers(account, peerID, peerGroups) - - log.Debugf("NetworkMapBuilder: Added peer %s to cache, took %s", peerID, time.Since(tt)) - - return nil -} - -func (b *NetworkMapBuilder) updateIndexesForNewPeer(account *Account, peerID string) []string { - peerGroups := make([]string, 0) - - for groupID, group := range account.Groups { - if slices.Contains(group.Peers, peerID) { - if !slices.Contains(b.cache.groupToPeers[groupID], peerID) { - b.cache.groupToPeers[groupID] = append(b.cache.groupToPeers[groupID], peerID) - } - peerGroups = append(peerGroups, groupID) - } - } - - b.cache.peerToGroups[peerID] = peerGroups - - for _, r := range account.Routes { - if !r.Enabled || b.cache.globalRoutes[r.ID] != nil { - continue - } - for _, groupID := range r.PeerGroups { - if !slices.Contains(b.cache.groupToRoutes[groupID], r) { - b.cache.groupToRoutes[groupID] = append(b.cache.groupToRoutes[groupID], r) - } - } - if r.Peer != "" { - if peer, ok := b.cache.globalPeers[r.Peer]; ok { - if !slices.Contains(b.cache.peerToRoutes[peer.ID], r) { - b.cache.peerToRoutes[peer.ID] = append(b.cache.peerToRoutes[peer.ID], r) - } - } - } - b.cache.globalRoutes[r.ID] = r - } - - return peerGroups -} - -func (b *NetworkMapBuilder) incrementalUpdateAffectedPeers(account *Account, newPeerID string, peerGroups []string) { - updates := b.collectDeltasForNewPeer(account, newPeerID, peerGroups) - for affectedPeerID, delta := range updates { - b.applyDeltaToPeer(account, affectedPeerID, delta) - } -} - -func (b *NetworkMapBuilder) collectDeltasForNewPeer(account *Account, newPeerID string, peerGroups []string) map[string]*PeerUpdateDelta { - updates := b.calculateIncrementalUpdates(account, newPeerID, peerGroups) - - if b.isPeerRouter(account, newPeerID) { - affectedByRoutes := b.findPeersAffectedByNewRouter(account, newPeerID, peerGroups) - for affectedPeerID := range affectedByRoutes { - if affectedPeerID == newPeerID { - continue - } - if _, exists := updates[affectedPeerID]; !exists { - updates[affectedPeerID] = &PeerUpdateDelta{ - PeerID: affectedPeerID, - RebuildRoutesView: true, - } - } else { - updates[affectedPeerID].RebuildRoutesView = true - } - } - } - - return updates -} - -func (b *NetworkMapBuilder) findPeersAffectedByNewRouter(account *Account, newRouterID string, routerGroups []string) map[string]struct{} { - affected := make(map[string]struct{}) - enabledRoutes, _ := b.getRoutingPeerRoutes(newRouterID) - - for _, route := range enabledRoutes { - for _, distGroupID := range route.Groups { - if peers := b.cache.groupToPeers[distGroupID]; peers != nil { - for _, peerID := range peers { - if peerID != newRouterID { - affected[peerID] = struct{}{} - } - } - } - } - - for _, peerGroupID := range route.PeerGroups { - if peers := b.cache.groupToPeers[peerGroupID]; peers != nil { - for _, peerID := range peers { - if peerID != newRouterID { - affected[peerID] = struct{}{} - } - } - } - } - } - - for _, route := range account.Routes { - if !route.Enabled { - continue - } - - routerInPeerGroups := false - for _, peerGroupID := range route.PeerGroups { - if slices.Contains(routerGroups, peerGroupID) { - routerInPeerGroups = true - break - } - } - - if routerInPeerGroups { - for _, distGroupID := range route.Groups { - if peers := b.cache.groupToPeers[distGroupID]; peers != nil { - for _, peerID := range peers { - affected[peerID] = struct{}{} - } - } - } - } - } - - return affected -} - -func (b *NetworkMapBuilder) calculateIncrementalUpdates(account *Account, newPeerID string, peerGroups []string) map[string]*PeerUpdateDelta { - updates := make(map[string]*PeerUpdateDelta) - ctx := context.Background() - - groupAllLn := 0 - if allGroup, err := account.GetGroupAll(); err == nil { - groupAllLn = len(allGroup.Peers) - 1 - } - - newPeer := b.cache.globalPeers[newPeerID] - if newPeer == nil { - return updates - } - - for _, policy := range account.Policies { - if !policy.Enabled { - continue - } - - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } - var peerInSources, peerInDestinations bool - - if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID == newPeerID { - peerInSources = true - } else { - peerInSources = b.isPeerInGroups(rule.Sources, peerGroups) - } - - if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID == newPeerID { - peerInDestinations = true - } else { - peerInDestinations = b.isPeerInGroups(rule.Destinations, peerGroups) - } - - if peerInSources { - if len(rule.Destinations) > 0 { - b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) - } - if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { - b.addUpdateForDirectPeerResource(updates, rule.DestinationResource.ID, newPeerID, rule, FirewallRuleDirectionIN) - } - } - - if peerInDestinations { - if len(rule.Sources) > 0 { - b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) - } - if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { - b.addUpdateForDirectPeerResource(updates, rule.SourceResource.ID, newPeerID, rule, FirewallRuleDirectionOUT) - } - } - - if rule.Bidirectional { - if peerInSources { - if len(rule.Destinations) > 0 { - b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) - } - if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { - b.addUpdateForDirectPeerResource(updates, rule.DestinationResource.ID, newPeerID, rule, FirewallRuleDirectionOUT) - } - } - if peerInDestinations { - if len(rule.Sources) > 0 { - b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) - } - if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { - b.addUpdateForDirectPeerResource(updates, rule.SourceResource.ID, newPeerID, rule, FirewallRuleDirectionIN) - } - } - } - } - } - - b.calculateRouteFirewallUpdates(newPeerID, newPeer, peerGroups, updates) - - b.calculateNetworkResourceFirewallUpdates(ctx, account, newPeerID, newPeer, peerGroups, updates) - - b.calculateNewRouterNetworkResourceUpdates(ctx, account, newPeerID, updates) - - return updates -} - -func (b *NetworkMapBuilder) calculateNewRouterNetworkResourceUpdates( - ctx context.Context, account *Account, newPeerID string, - updates map[string]*PeerUpdateDelta, -) { - resourceRouters := b.cache.resourceRouters - - for networkID, routers := range resourceRouters { - router, isRouter := routers[newPeerID] - if !isRouter || !router.Enabled { - continue - } - - for _, resource := range b.cache.globalResources { - if resource.NetworkID != networkID { - continue - } - - policies := b.cache.resourcePolicies[resource.ID] - if len(policies) == 0 { - continue - } - - peersWithAccess := make(map[string]struct{}) - - for _, policy := range policies { - if !policy.Enabled { - continue - } - - sourceGroups := policy.SourceGroups() - for _, sourceGroup := range sourceGroups { - groupPeers := b.cache.groupToPeers[sourceGroup] - for _, peerID := range groupPeers { - if peerID == newPeerID { - continue - } - - if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { - peersWithAccess[peerID] = struct{}{} - } - } - } - } - - for peerID := range peersWithAccess { - delta := updates[peerID] - if delta == nil { - delta = &PeerUpdateDelta{ - PeerID: peerID, - } - updates[peerID] = delta - } - - if !slices.Contains(delta.AddConnectedPeers, newPeerID) { - delta.AddConnectedPeers = append(delta.AddConnectedPeers, newPeerID) - } - - delta.RebuildRoutesView = true - } - } - } -} - -func (b *NetworkMapBuilder) calculateRouteFirewallUpdates( - newPeerID string, newPeer *nbpeer.Peer, - peerGroups []string, updates map[string]*PeerUpdateDelta, -) { - processedPeerRoutes := make(map[string]map[route.ID]struct{}) - - for routeID, info := range b.cache.noACGRoutes { - if info.PeerID == newPeerID { - continue - } - - b.addRouteFirewallUpdate(updates, info.PeerID, string(routeID), newPeer.IP.String()) - - if processedPeerRoutes[info.PeerID] == nil { - processedPeerRoutes[info.PeerID] = make(map[route.ID]struct{}) - } - processedPeerRoutes[info.PeerID][routeID] = struct{}{} - } - - for _, acg := range peerGroups { - routeInfos := b.cache.acgToRoutes[acg] - if routeInfos == nil { - continue - } - - for routeID, info := range routeInfos { - if info.PeerID == newPeerID { - continue - } - - if processedRoutes, exists := processedPeerRoutes[info.PeerID]; exists { - if _, processed := processedRoutes[routeID]; processed { - continue - } - } - - b.addRouteFirewallUpdate(updates, info.PeerID, string(routeID), newPeer.IP.String()) - - if processedPeerRoutes[info.PeerID] == nil { - processedPeerRoutes[info.PeerID] = make(map[route.ID]struct{}) - } - processedPeerRoutes[info.PeerID][routeID] = struct{}{} - } - } -} - -func (b *NetworkMapBuilder) addRouteFirewallUpdate( - updates map[string]*PeerUpdateDelta, peerID string, - routeID string, sourceIP string, -) { - delta := updates[peerID] - if delta == nil { - delta = &PeerUpdateDelta{ - PeerID: peerID, - UpdateRouteFirewallRules: make([]*RouteFirewallRuleUpdate, 0), - } - updates[peerID] = delta - } - - for _, existing := range delta.UpdateRouteFirewallRules { - if existing.RuleID == routeID && existing.AddSourceIP == sourceIP { - return - } - } - - delta.UpdateRouteFirewallRules = append(delta.UpdateRouteFirewallRules, &RouteFirewallRuleUpdate{ - RuleID: routeID, - AddSourceIP: sourceIP, - }) -} - -func (b *NetworkMapBuilder) calculateNetworkResourceFirewallUpdates( - ctx context.Context, account *Account, newPeerID string, - newPeer *nbpeer.Peer, peerGroups []string, updates map[string]*PeerUpdateDelta, -) { - for _, resource := range b.cache.globalResources { - resourcePolicies := b.cache.resourcePolicies - resourceRouters := b.cache.resourceRouters - - policies := resourcePolicies[resource.ID] - peerHasAccess := false - - for _, policy := range policies { - if !policy.Enabled { - continue - } - - sourceGroups := policy.SourceGroups() - for _, sourceGroup := range sourceGroups { - if slices.Contains(peerGroups, sourceGroup) { - if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, newPeerID) { - peerHasAccess = true - break - } - } - } - - if peerHasAccess { - break - } - } - - if !peerHasAccess { - continue - } - - networkRouters := resourceRouters[resource.NetworkID] - for routerPeerID, router := range networkRouters { - if !router.Enabled || routerPeerID == newPeerID { - continue - } - - delta := updates[routerPeerID] - if delta == nil { - delta = &PeerUpdateDelta{ - PeerID: routerPeerID, - } - updates[routerPeerID] = delta - } - - if !slices.Contains(delta.AddConnectedPeers, newPeerID) { - delta.AddConnectedPeers = append(delta.AddConnectedPeers, newPeerID) - } - - delta.RebuildRoutesView = true - } - } -} - -type PeerUpdateDelta struct { - PeerID string - AddConnectedPeers []string - AddFirewallRules []*FirewallRuleDelta - AddRoutes []route.ID - UpdateRouteFirewallRules []*RouteFirewallRuleUpdate - UpdateDNS bool - RebuildRoutesView bool -} - -func (d *PeerUpdateDelta) mergeFrom(other *PeerUpdateDelta) { - for _, peerID := range other.AddConnectedPeers { - if !slices.Contains(d.AddConnectedPeers, peerID) { - d.AddConnectedPeers = append(d.AddConnectedPeers, peerID) - } - } - - existingRuleIDs := make(map[string]struct{}, len(d.AddFirewallRules)) - for _, rule := range d.AddFirewallRules { - existingRuleIDs[rule.RuleID] = struct{}{} - } - for _, rule := range other.AddFirewallRules { - if _, exists := existingRuleIDs[rule.RuleID]; !exists { - d.AddFirewallRules = append(d.AddFirewallRules, rule) - existingRuleIDs[rule.RuleID] = struct{}{} - } - } - - for _, routeID := range other.AddRoutes { - if !slices.Contains(d.AddRoutes, routeID) { - d.AddRoutes = append(d.AddRoutes, routeID) - } - } - - existingRouteUpdates := make(map[string]map[string]struct{}) - for _, update := range d.UpdateRouteFirewallRules { - if existingRouteUpdates[update.RuleID] == nil { - existingRouteUpdates[update.RuleID] = make(map[string]struct{}) - } - existingRouteUpdates[update.RuleID][update.AddSourceIP] = struct{}{} - } - for _, update := range other.UpdateRouteFirewallRules { - if existingRouteUpdates[update.RuleID] == nil { - existingRouteUpdates[update.RuleID] = make(map[string]struct{}) - } - if _, exists := existingRouteUpdates[update.RuleID][update.AddSourceIP]; !exists { - d.UpdateRouteFirewallRules = append(d.UpdateRouteFirewallRules, update) - existingRouteUpdates[update.RuleID][update.AddSourceIP] = struct{}{} - } - } - - if other.UpdateDNS { - d.UpdateDNS = true - } - if other.RebuildRoutesView { - d.RebuildRoutesView = true - } -} - -type FirewallRuleDelta struct { - Rule *FirewallRule - RuleID string - Direction int -} - -type RouteFirewallRuleUpdate struct { - RuleID string - AddSourceIP string -} - -func (b *NetworkMapBuilder) addUpdateForPeersInGroups( - updates map[string]*PeerUpdateDelta, groupIDs []string, newPeerID string, - rule *PolicyRule, direction int, allGroupLn int, -) { - for _, groupID := range groupIDs { - peers := b.cache.groupToPeers[groupID] - cnt := 0 - for _, peerID := range peers { - if peerID == newPeerID { - continue - } - if _, ok := b.validatedPeers[peerID]; !ok { - continue - } - cnt++ - } - all := false - if allGroupLn > 0 && cnt == allGroupLn { - all = true - } - newPeer := b.cache.globalPeers[newPeerID] - fr := &FirewallRule{ - PolicyID: rule.ID, - PeerIP: newPeer.IP.String(), - Direction: direction, - Action: string(rule.Action), - Protocol: firewallRuleProtocol(rule.Protocol), - } - for _, peerID := range peers { - if peerID == newPeerID { - continue - } - if _, ok := b.validatedPeers[peerID]; !ok { - continue - } - targetPeer := b.cache.globalPeers[peerID] - if targetPeer == nil { - continue - } - - peerIPForRule := fr.PeerIP - if all { - peerIPForRule = allPeers - } - - b.addOrUpdateFirewallRuleInDelta(updates, peerID, newPeerID, rule, direction, fr, peerIPForRule, targetPeer) - } - } -} - -func (b *NetworkMapBuilder) addUpdateForDirectPeerResource( - updates map[string]*PeerUpdateDelta, targetPeerID string, newPeerID string, - rule *PolicyRule, direction int, -) { - if targetPeerID == newPeerID { - return - } - - if _, ok := b.validatedPeers[targetPeerID]; !ok { - return - } - - newPeer := b.cache.globalPeers[newPeerID] - if newPeer == nil { - return - } - - targetPeer := b.cache.globalPeers[targetPeerID] - if targetPeer == nil { - return - } - - fr := &FirewallRule{ - PolicyID: rule.ID, - PeerIP: newPeer.IP.String(), - Direction: direction, - Action: string(rule.Action), - Protocol: firewallRuleProtocol(rule.Protocol), - } - - b.addOrUpdateFirewallRuleInDelta(updates, targetPeerID, newPeerID, rule, direction, fr, fr.PeerIP, targetPeer) -} - -func (b *NetworkMapBuilder) addOrUpdateFirewallRuleInDelta( - updates map[string]*PeerUpdateDelta, targetPeerID string, newPeerID string, - rule *PolicyRule, direction int, baseRule *FirewallRule, peerIP string, targetPeer *nbpeer.Peer, -) { - delta := updates[targetPeerID] - if delta == nil { - delta = &PeerUpdateDelta{ - PeerID: targetPeerID, - AddConnectedPeers: []string{newPeerID}, - AddFirewallRules: make([]*FirewallRuleDelta, 0), - } - updates[targetPeerID] = delta - } else if !slices.Contains(delta.AddConnectedPeers, newPeerID) { - delta.AddConnectedPeers = append(delta.AddConnectedPeers, newPeerID) - } - - baseRule.PeerIP = peerIP - - if len(rule.Ports) > 0 || len(rule.PortRanges) > 0 { - expandedRules := expandPortsAndRanges(*baseRule, rule, targetPeer) - for _, expandedRule := range expandedRules { - ruleID := b.generateFirewallRuleID(expandedRule) - delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ - Rule: expandedRule, - RuleID: ruleID, - Direction: direction, - }) - } - } else { - ruleID := b.generateFirewallRuleID(baseRule) - delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ - Rule: baseRule, - RuleID: ruleID, - Direction: direction, - }) - } -} - -func (b *NetworkMapBuilder) applyDeltaToPeer(account *Account, peerID string, delta *PeerUpdateDelta) { - if len(delta.AddConnectedPeers) > 0 || len(delta.AddFirewallRules) > 0 { - if aclView := b.cache.peerACLs[peerID]; aclView != nil { - for _, connectedPeerID := range delta.AddConnectedPeers { - if !slices.Contains(aclView.ConnectedPeerIDs, connectedPeerID) { - aclView.ConnectedPeerIDs = append(aclView.ConnectedPeerIDs, connectedPeerID) - } - } - - for _, ruleDelta := range delta.AddFirewallRules { - b.cache.globalRules[ruleDelta.RuleID] = ruleDelta.Rule - - if !slices.Contains(aclView.FirewallRuleIDs, ruleDelta.RuleID) { - aclView.FirewallRuleIDs = append(aclView.FirewallRuleIDs, ruleDelta.RuleID) - } - } - } - } - - if delta.RebuildRoutesView { - b.buildPeerRoutesView(account, peerID) - } else if len(delta.UpdateRouteFirewallRules) > 0 { - if routesView := b.cache.peerRoutes[peerID]; routesView != nil { - b.updateRouteFirewallRules(routesView, delta.UpdateRouteFirewallRules) - } - } - - if delta.UpdateDNS { - b.buildPeerDNSView(account, peerID) - } -} - -func (b *NetworkMapBuilder) updateRouteFirewallRules(routesView *PeerRoutesView, updates []*RouteFirewallRuleUpdate) { - for _, update := range updates { - for _, ruleID := range routesView.RouteFirewallRuleIDs { - rule := b.cache.globalRouteRules[ruleID] - if rule == nil { - continue - } - - if string(rule.RouteID) == update.RuleID { - if hasWildcard := slices.Contains(rule.SourceRanges, allWildcard) || slices.Contains(rule.SourceRanges, v6AllWildcard); hasWildcard { - break - } - - sourceIP := update.AddSourceIP - - if strings.Contains(sourceIP, ":") { - sourceIP += "/128" // IPv6 - } else { - sourceIP += "/32" // IPv4 - } - - if !slices.Contains(rule.SourceRanges, sourceIP) { - rule.SourceRanges = append(rule.SourceRanges, sourceIP) - } - break - } - } - } -} - -func (b *NetworkMapBuilder) OnPeerDeleted(acc *Account, peerID string) error { - b.cache.mu.Lock() - defer b.cache.mu.Unlock() - - account := b.updateAccountLocked(acc) - - deletedPeer := b.cache.globalPeers[peerID] - if deletedPeer == nil { - return fmt.Errorf("peer %s not found in cache", peerID) - } - - deletedPeerKey := deletedPeer.Key - peerGroups := b.cache.peerToGroups[peerID] - peerIP := deletedPeer.IP.String() - - log.Debugf("NetworkMapBuilder: Deleting peer %s (IP: %s) from cache", peerID, peerIP) - - delete(b.validatedPeers, peerID) - - routesToDelete := []route.ID{} - - for routeID, r := range account.Routes { - if r.Peer != deletedPeerKey && r.PeerID != peerID { - continue - } - if len(r.PeerGroups) == 0 { - routesToDelete = append(routesToDelete, routeID) - continue - } - newPeerAssigned := false - for _, groupID := range r.PeerGroups { - candidatePeerIDs := b.cache.groupToPeers[groupID] - for _, candidatePeerID := range candidatePeerIDs { - if candidatePeerID == peerID { - continue - } - if candidatePeer := b.cache.globalPeers[candidatePeerID]; candidatePeer != nil { - r.Peer = candidatePeer.Key - r.PeerID = candidatePeerID - newPeerAssigned = true - break - } - } - if newPeerAssigned { - break - } - } - - if !newPeerAssigned { - routesToDelete = append(routesToDelete, routeID) - } - } - - for _, routeID := range routesToDelete { - delete(account.Routes, routeID) - } - - delete(b.cache.peerACLs, peerID) - delete(b.cache.peerRoutes, peerID) - delete(b.cache.peerDNS, peerID) - delete(b.cache.peerSSH, peerID) - - delete(b.cache.globalPeers, peerID) - - for acg, routeMap := range b.cache.acgToRoutes { - for routeID, info := range routeMap { - if info.PeerID == peerID { - delete(routeMap, routeID) - } - } - if len(routeMap) == 0 { - delete(b.cache.acgToRoutes, acg) - } - } - - for _, groupID := range peerGroups { - if peers := b.cache.groupToPeers[groupID]; peers != nil { - b.cache.groupToPeers[groupID] = slices.DeleteFunc(peers, func(id string) bool { - return id == peerID - }) - } - } - delete(b.cache.peerToGroups, peerID) - - affectedPeers := make(map[string]struct{}) - - for _, r := range account.Routes { - for _, groupID := range r.Groups { - if peers := b.cache.groupToPeers[groupID]; peers != nil { - for _, p := range peers { - affectedPeers[p] = struct{}{} - } - } - } - - for _, groupID := range r.PeerGroups { - if peers := b.cache.groupToPeers[groupID]; peers != nil { - for _, p := range peers { - affectedPeers[p] = struct{}{} - } - } - } - } - - for affectedPeerID := range affectedPeers { - if affectedPeerID == peerID { - continue - } - b.buildPeerRoutesView(account, affectedPeerID) - } - - peersToRebuildACL := make(map[string]struct{}) - peerDeletionUpdates := b.findPeersAffectedByDeletedPeerACL(peerID, peerIP, peerGroups, peersToRebuildACL) - for affectedPeerID, updates := range peerDeletionUpdates { - b.applyDeletionUpdates(affectedPeerID, updates) - } - - for affectedPeerID := range peersToRebuildACL { - b.buildPeerACLView(account, affectedPeerID) - } - - b.cleanupUnusedRules() - - log.Debugf("NetworkMapBuilder: Deleted peer %s, affected %d other peers", peerID, len(affectedPeers)) - - return nil -} - -func (b *NetworkMapBuilder) findPeersAffectedByDeletedPeerACL( - deletedPeerID string, - peerIP string, - peerGroups []string, - peersToRebuildACL map[string]struct{}, -) map[string]*PeerDeletionUpdate { - - affected := make(map[string]*PeerDeletionUpdate) - - for peerID, aclView := range b.cache.peerACLs { - if peerID == deletedPeerID { - continue - } - - if slices.Contains(aclView.ConnectedPeerIDs, deletedPeerID) { - peersToRebuildACL[peerID] = struct{}{} - if affected[peerID] == nil { - affected[peerID] = &PeerDeletionUpdate{ - RemovePeerID: deletedPeerID, - PeerIP: peerIP, - } - } - } - } - - affectedRouteOwners := make(map[string]struct{}) - - for _, groupID := range peerGroups { - if routeMap, ok := b.cache.acgToRoutes[groupID]; ok { - for _, info := range routeMap { - if info.PeerID != deletedPeerID { - affectedRouteOwners[info.PeerID] = struct{}{} - } - } - } - } - - for _, info := range b.cache.noACGRoutes { - if info.PeerID != deletedPeerID { - affectedRouteOwners[info.PeerID] = struct{}{} - } - } - - for ownerPeerID := range affectedRouteOwners { - if affected[ownerPeerID] == nil { - affected[ownerPeerID] = &PeerDeletionUpdate{ - RemovePeerID: deletedPeerID, - PeerIP: peerIP, - RemoveFromSourceRanges: true, - } - } else { - affected[ownerPeerID].RemoveFromSourceRanges = true - } - } - - return affected -} - -type PeerDeletionUpdate struct { - RemovePeerID string - RemoveFirewallRuleIDs []string - RemoveRouteIDs []route.ID - RemoveFromSourceRanges bool - PeerIP string -} - -func (b *NetworkMapBuilder) applyDeletionUpdates(peerID string, updates *PeerDeletionUpdate) { - if routesView := b.cache.peerRoutes[peerID]; routesView != nil { - if len(updates.RemoveRouteIDs) > 0 { - routesView.NetworkResourceIDs = slices.DeleteFunc(routesView.NetworkResourceIDs, func(routeID route.ID) bool { - return slices.Contains(updates.RemoveRouteIDs, routeID) - }) - } - - if updates.RemoveFromSourceRanges { - b.removeIPFromRouteFirewallRules(routesView, updates.PeerIP) - } - } -} - -func (b *NetworkMapBuilder) removeIPFromRouteFirewallRules(routesView *PeerRoutesView, peerIP string) { - sourceIPv4 := peerIP + "/32" - sourceIPv6 := peerIP + "/128" - - rulesToRemove := []string{} - - for _, ruleID := range routesView.RouteFirewallRuleIDs { - if rule := b.cache.globalRouteRules[ruleID]; rule != nil { - rule.SourceRanges = slices.DeleteFunc(rule.SourceRanges, func(source string) bool { - return source == sourceIPv4 || source == sourceIPv6 || source == peerIP - }) - - if len(rule.SourceRanges) == 0 { - rulesToRemove = append(rulesToRemove, ruleID) - } - } - } - - if len(rulesToRemove) > 0 { - routesView.RouteFirewallRuleIDs = slices.DeleteFunc(routesView.RouteFirewallRuleIDs, func(ruleID string) bool { - return slices.Contains(rulesToRemove, ruleID) - }) - } -} - -func (b *NetworkMapBuilder) cleanupUnusedRules() { - usedFirewallRules := make(map[string]struct{}) - usedRouteRules := make(map[string]struct{}) - usedRoutes := make(map[route.ID]struct{}) - - for _, aclView := range b.cache.peerACLs { - for _, ruleID := range aclView.FirewallRuleIDs { - usedFirewallRules[ruleID] = struct{}{} - } - } - - for _, routesView := range b.cache.peerRoutes { - for _, ruleID := range routesView.RouteFirewallRuleIDs { - usedRouteRules[ruleID] = struct{}{} - } - - for _, routeID := range routesView.OwnRouteIDs { - usedRoutes[routeID] = struct{}{} - } - for _, routeID := range routesView.NetworkResourceIDs { - usedRoutes[routeID] = struct{}{} - } - } - - for ruleID := range b.cache.globalRules { - if _, used := usedFirewallRules[ruleID]; !used { - delete(b.cache.globalRules, ruleID) - } - } - - for ruleID := range b.cache.globalRouteRules { - if _, used := usedRouteRules[ruleID]; !used { - delete(b.cache.globalRouteRules, ruleID) - } - } - - for routeID := range b.cache.globalRoutes { - if _, used := usedRoutes[routeID]; !used { - delete(b.cache.globalRoutes, routeID) - } - } -} - -func (b *NetworkMapBuilder) UpdatePeer(peer *nbpeer.Peer) { - b.cache.mu.Lock() - defer b.cache.mu.Unlock() - peerStored, ok := b.cache.globalPeers[peer.ID] - if !ok { - return - } - *peerStored = *peer -} diff --git a/management/server/types/policy.go b/management/server/types/policy.go index d4e1a8816..d410aec8d 100644 --- a/management/server/types/policy.go +++ b/management/server/types/policy.go @@ -93,6 +93,44 @@ func (p *Policy) Copy() *Policy { return c } +func (p *Policy) Equal(other *Policy) bool { + if p == nil || other == nil { + return p == other + } + + if p.ID != other.ID || + p.AccountID != other.AccountID || + p.Name != other.Name || + p.Description != other.Description || + p.Enabled != other.Enabled { + return false + } + + if !stringSlicesEqualUnordered(p.SourcePostureChecks, other.SourcePostureChecks) { + return false + } + + if len(p.Rules) != len(other.Rules) { + return false + } + + otherRules := make(map[string]*PolicyRule, len(other.Rules)) + for _, r := range other.Rules { + otherRules[r.ID] = r + } + for _, r := range p.Rules { + otherRule, ok := otherRules[r.ID] + if !ok { + return false + } + if !r.Equal(otherRule) { + return false + } + } + + return true +} + // EventMeta returns activity event meta related to this policy func (p *Policy) EventMeta() map[string]any { return map[string]any{"name": p.Name} diff --git a/management/server/types/policy_test.go b/management/server/types/policy_test.go new file mode 100644 index 000000000..b1d7aabc2 --- /dev/null +++ b/management/server/types/policy_test.go @@ -0,0 +1,193 @@ +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPolicyEqual_SameRulesDifferentOrder(t *testing.T) { + a := &Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "test", + Enabled: true, + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1", Ports: []string{"80"}}, + {ID: "r2", PolicyID: "pol1", Ports: []string{"443"}}, + }, + } + b := &Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "test", + Enabled: true, + Rules: []*PolicyRule{ + {ID: "r2", PolicyID: "pol1", Ports: []string{"443"}}, + {ID: "r1", PolicyID: "pol1", Ports: []string{"80"}}, + }, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyEqual_DifferentRules(t *testing.T) { + a := &Policy{ + ID: "pol1", + Enabled: true, + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1", Ports: []string{"80"}}, + }, + } + b := &Policy{ + ID: "pol1", + Enabled: true, + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1", Ports: []string{"443"}}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyEqual_DifferentRuleCount(t *testing.T) { + a := &Policy{ + ID: "pol1", + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1"}, + }, + } + b := &Policy{ + ID: "pol1", + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1"}, + {ID: "r2", PolicyID: "pol1"}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyEqual_PostureChecksDifferentOrder(t *testing.T) { + a := &Policy{ + ID: "pol1", + SourcePostureChecks: []string{"pc3", "pc1", "pc2"}, + } + b := &Policy{ + ID: "pol1", + SourcePostureChecks: []string{"pc1", "pc2", "pc3"}, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyEqual_DifferentPostureChecks(t *testing.T) { + a := &Policy{ + ID: "pol1", + SourcePostureChecks: []string{"pc1", "pc2"}, + } + b := &Policy{ + ID: "pol1", + SourcePostureChecks: []string{"pc1", "pc3"}, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyEqual_DifferentScalarFields(t *testing.T) { + base := Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "test", + Description: "desc", + Enabled: true, + } + + other := base + other.Name = "changed" + assert.False(t, base.Equal(&other)) + + other = base + other.Enabled = false + assert.False(t, base.Equal(&other)) + + other = base + other.Description = "changed" + assert.False(t, base.Equal(&other)) +} + +func TestPolicyEqual_NilCases(t *testing.T) { + var a *Policy + var b *Policy + assert.True(t, a.Equal(b)) + + a = &Policy{ID: "pol1"} + assert.False(t, a.Equal(nil)) +} + +func TestPolicyEqual_RulesMismatchByID(t *testing.T) { + a := &Policy{ + ID: "pol1", + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1"}, + }, + } + b := &Policy{ + ID: "pol1", + Rules: []*PolicyRule{ + {ID: "r2", PolicyID: "pol1"}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyEqual_FullScenario(t *testing.T) { + a := &Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "Web Access", + Description: "Allow web access", + Enabled: true, + SourcePostureChecks: []string{"pc2", "pc1"}, + Rules: []*PolicyRule{ + { + ID: "r1", + PolicyID: "pol1", + Name: "HTTP", + Enabled: true, + Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolTCP, + Bidirectional: true, + Sources: []string{"g2", "g1"}, + Destinations: []string{"g4", "g3"}, + Ports: []string{"443", "80", "8080"}, + PortRanges: []RulePortRange{ + {Start: 8000, End: 9000}, + {Start: 80, End: 80}, + }, + }, + }, + } + b := &Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "Web Access", + Description: "Allow web access", + Enabled: true, + SourcePostureChecks: []string{"pc1", "pc2"}, + Rules: []*PolicyRule{ + { + ID: "r1", + PolicyID: "pol1", + Name: "HTTP", + Enabled: true, + Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolTCP, + Bidirectional: true, + Sources: []string{"g1", "g2"}, + Destinations: []string{"g3", "g4"}, + Ports: []string{"80", "8080", "443"}, + PortRanges: []RulePortRange{ + {Start: 80, End: 80}, + {Start: 8000, End: 9000}, + }, + }, + }, + } + assert.True(t, a.Equal(b)) +} diff --git a/management/server/types/policyrule.go b/management/server/types/policyrule.go index bb75dd555..52c494a6a 100644 --- a/management/server/types/policyrule.go +++ b/management/server/types/policyrule.go @@ -1,6 +1,8 @@ package types import ( + "slices" + "github.com/netbirdio/netbird/shared/management/proto" ) @@ -118,3 +120,106 @@ func (pm *PolicyRule) Copy() *PolicyRule { } return rule } + +func (pm *PolicyRule) Equal(other *PolicyRule) bool { + if pm == nil || other == nil { + return pm == other + } + + if pm.ID != other.ID || + pm.PolicyID != other.PolicyID || + pm.Name != other.Name || + pm.Description != other.Description || + pm.Enabled != other.Enabled || + pm.Action != other.Action || + pm.Bidirectional != other.Bidirectional || + pm.Protocol != other.Protocol || + pm.SourceResource != other.SourceResource || + pm.DestinationResource != other.DestinationResource || + pm.AuthorizedUser != other.AuthorizedUser { + return false + } + + if !stringSlicesEqualUnordered(pm.Sources, other.Sources) { + return false + } + if !stringSlicesEqualUnordered(pm.Destinations, other.Destinations) { + return false + } + if !stringSlicesEqualUnordered(pm.Ports, other.Ports) { + return false + } + if !portRangeSlicesEqualUnordered(pm.PortRanges, other.PortRanges) { + return false + } + if !authorizedGroupsEqual(pm.AuthorizedGroups, other.AuthorizedGroups) { + return false + } + + return true +} + +func stringSlicesEqualUnordered(a, b []string) bool { + if len(a) != len(b) { + return false + } + if len(a) == 0 { + return true + } + sorted1 := make([]string, len(a)) + sorted2 := make([]string, len(b)) + copy(sorted1, a) + copy(sorted2, b) + slices.Sort(sorted1) + slices.Sort(sorted2) + return slices.Equal(sorted1, sorted2) +} + +func portRangeSlicesEqualUnordered(a, b []RulePortRange) bool { + if len(a) != len(b) { + return false + } + if len(a) == 0 { + return true + } + cmp := func(x, y RulePortRange) int { + if x.Start != y.Start { + if x.Start < y.Start { + return -1 + } + return 1 + } + if x.End != y.End { + if x.End < y.End { + return -1 + } + return 1 + } + return 0 + } + sorted1 := make([]RulePortRange, len(a)) + sorted2 := make([]RulePortRange, len(b)) + copy(sorted1, a) + copy(sorted2, b) + slices.SortFunc(sorted1, cmp) + slices.SortFunc(sorted2, cmp) + return slices.EqualFunc(sorted1, sorted2, func(x, y RulePortRange) bool { + return x.Start == y.Start && x.End == y.End + }) +} + +func authorizedGroupsEqual(a, b map[string][]string) bool { + if len(a) != len(b) { + return false + } + for k, va := range a { + vb, ok := b[k] + if !ok { + return false + } + if !stringSlicesEqualUnordered(va, vb) { + return false + } + } + return true +} diff --git a/management/server/types/policyrule_test.go b/management/server/types/policyrule_test.go new file mode 100644 index 000000000..816e72abb --- /dev/null +++ b/management/server/types/policyrule_test.go @@ -0,0 +1,194 @@ +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPolicyRuleEqual_SamePortsDifferentOrder(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{"443", "80", "22"}, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{"22", "443", "80"}, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentPorts(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{"443", "80"}, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{"443", "22"}, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_SourcesDestinationsDifferentOrder(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Sources: []string{"g1", "g2", "g3"}, + Destinations: []string{"g4", "g5"}, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Sources: []string{"g3", "g1", "g2"}, + Destinations: []string{"g5", "g4"}, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentSources(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Sources: []string{"g1", "g2"}, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Sources: []string{"g1", "g3"}, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_PortRangesDifferentOrder(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + PortRanges: []RulePortRange{ + {Start: 8000, End: 9000}, + {Start: 80, End: 80}, + }, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + PortRanges: []RulePortRange{ + {Start: 80, End: 80}, + {Start: 8000, End: 9000}, + }, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentPortRanges(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + PortRanges: []RulePortRange{ + {Start: 80, End: 80}, + }, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + PortRanges: []RulePortRange{ + {Start: 80, End: 443}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_AuthorizedGroupsDifferentValueOrder(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + AuthorizedGroups: map[string][]string{ + "g1": {"u1", "u2", "u3"}, + }, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + AuthorizedGroups: map[string][]string{ + "g1": {"u3", "u1", "u2"}, + }, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentAuthorizedGroups(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + AuthorizedGroups: map[string][]string{ + "g1": {"u1"}, + }, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + AuthorizedGroups: map[string][]string{ + "g2": {"u1"}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentScalarFields(t *testing.T) { + base := PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Name: "test", + Description: "desc", + Enabled: true, + Action: PolicyTrafficActionAccept, + Bidirectional: true, + Protocol: PolicyRuleProtocolTCP, + } + + other := base + other.Name = "changed" + assert.False(t, base.Equal(&other)) + + other = base + other.Enabled = false + assert.False(t, base.Equal(&other)) + + other = base + other.Action = PolicyTrafficActionDrop + assert.False(t, base.Equal(&other)) + + other = base + other.Protocol = PolicyRuleProtocolUDP + assert.False(t, base.Equal(&other)) +} + +func TestPolicyRuleEqual_NilCases(t *testing.T) { + var a *PolicyRule + var b *PolicyRule + assert.True(t, a.Equal(b)) + + a = &PolicyRule{ID: "rule1"} + assert.False(t, a.Equal(nil)) +} + +func TestPolicyRuleEqual_EmptySlices(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{}, + Sources: nil, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: nil, + Sources: []string{}, + } + assert.True(t, a.Equal(b)) +} + diff --git a/management/server/types/update_reason.go b/management/server/types/update_reason.go new file mode 100644 index 000000000..9d752da9a --- /dev/null +++ b/management/server/types/update_reason.go @@ -0,0 +1,37 @@ +package types + +// UpdateReason describes why an account peers update was triggered. +type UpdateReason struct { + Resource UpdateResource + Operation UpdateOperation +} + +// UpdateResource represents the kind of resource that triggered an account peers update. +type UpdateResource string + +const ( + UpdateResourceAccountSettings UpdateResource = "account_settings" + UpdateResourceDNSSettings UpdateResource = "dns_settings" + UpdateResourceGroup UpdateResource = "group" + UpdateResourceNameServerGroup UpdateResource = "nameserver_group" + UpdateResourcePolicy UpdateResource = "policy" + UpdateResourcePostureCheck UpdateResource = "posture_check" + UpdateResourceRoute UpdateResource = "route" + UpdateResourceUser UpdateResource = "user" + UpdateResourcePeer UpdateResource = "peer" + UpdateResourceNetwork UpdateResource = "network" + UpdateResourceNetworkResource UpdateResource = "network_resource" + UpdateResourceNetworkRouter UpdateResource = "network_router" + UpdateResourceService UpdateResource = "service" + UpdateResourceZone UpdateResource = "zone" + UpdateResourceZoneRecord UpdateResource = "zone_record" +) + +// UpdateOperation represents the kind of change that triggered the update. +type UpdateOperation string + +const ( + UpdateOperationCreate UpdateOperation = "create" + UpdateOperationUpdate UpdateOperation = "update" + UpdateOperationDelete UpdateOperation = "delete" +) diff --git a/management/server/user.go b/management/server/user.go index 327aec2d0..43e0a9821 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -15,6 +15,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/idp/dex" + "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -395,8 +396,8 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string return nil, status.Errorf(status.InvalidArgument, "token name can't be empty") } - if expiresIn < 1 || expiresIn > 365 { - return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365") + if expiresIn < account.PATMinExpireDays || expiresIn > account.PATMaxExpireDays { + return nil, status.Errorf(status.InvalidArgument, "expiration has to be between %d and %d", account.PATMinExpireDays, account.PATMaxExpireDays) } allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Pats, operations.Create) @@ -417,6 +418,10 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string return nil, err } + if targetUser.AccountID != accountID { + return nil, status.NewPermissionDeniedError() + } + // @note this is essential to prevent non admin users with Pats create permission frpm creating one for a service user if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) { return nil, status.NewAdminPermissionError() @@ -457,6 +462,10 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string return err } + if targetUser.AccountID != accountID { + return status.NewPermissionDeniedError() + } + if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) { return status.NewAdminPermissionError() } @@ -496,6 +505,10 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i return nil, err } + if targetUser.AccountID != accountID { + return nil, status.NewPermissionDeniedError() + } + if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) { return nil, status.NewAdminPermissionError() } @@ -523,6 +536,10 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin return nil, err } + if targetUser.AccountID != accountID { + return nil, status.NewPermissionDeniedError() + } + if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) { return nil, status.NewAdminPermissionError() } @@ -659,7 +676,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, if err = am.Store.IncrementNetworkSerial(ctx, accountID); err != nil { return nil, fmt.Errorf("failed to increment network serial: %w", err) } - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceUser, Operation: types.UpdateOperationUpdate}) } return updatedUsersInfo, globalErr @@ -764,9 +781,15 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact updatedUser.Role = update.Role updatedUser.Blocked = update.Blocked updatedUser.AutoGroups = update.AutoGroups - // these two fields can't be set via API, only via direct call to the method + // these fields can't be set via API, only via direct call to the method updatedUser.Issued = update.Issued updatedUser.IntegrationReference = update.IntegrationReference + if update.Name != "" { + updatedUser.Name = update.Name + } + if update.Email != "" { + updatedUser.Email = update.Email + } var transferredOwnerRole bool result, err := handleOwnerRoleTransfer(ctx, transaction, initiatorUser, update) diff --git a/management/server/user_test.go b/management/server/user_test.go index 800d2406c..c77ea53d1 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -336,6 +336,104 @@ func TestUser_GetAllPATs(t *testing.T) { assert.Equal(t, 2, len(pats)) } +func TestUser_PAT_CrossAccountProtection(t *testing.T) { + const ( + accountAID = "accountA" + accountBID = "accountB" + userAID = "userA" + adminBID = "adminB" + serviceUserBID = "serviceUserB" + regularUserBID = "regularUserB" + tokenBID = "tokenB1" + hashedTokenB = "SoMeHaShEdToKeNB" + ) + + setupStore := func(t *testing.T) (*DefaultAccountManager, func()) { + t.Helper() + + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + require.NoError(t, err, "creating store") + + accountA := newAccountWithId(context.Background(), accountAID, userAID, "", "", "", false) + require.NoError(t, s.SaveAccount(context.Background(), accountA)) + + accountB := newAccountWithId(context.Background(), accountBID, adminBID, "", "", "", false) + accountB.Users[serviceUserBID] = &types.User{ + Id: serviceUserBID, + AccountID: accountBID, + IsServiceUser: true, + ServiceUserName: "svcB", + Role: types.UserRoleAdmin, + PATs: map[string]*types.PersonalAccessToken{ + tokenBID: { + ID: tokenBID, + HashedToken: hashedTokenB, + }, + }, + } + accountB.Users[regularUserBID] = &types.User{ + Id: regularUserBID, + AccountID: accountBID, + Role: types.UserRoleUser, + } + require.NoError(t, s.SaveAccount(context.Background(), accountB)) + + pm := permissions.NewManager(s) + am := &DefaultAccountManager{ + Store: s, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: pm, + } + return am, cleanup + } + + t.Run("CreatePAT for user in different account is denied", func(t *testing.T) { + am, cleanup := setupStore(t) + t.Cleanup(cleanup) + + _, err := am.CreatePAT(context.Background(), accountAID, userAID, serviceUserBID, "xss-token", 7) + require.Error(t, err, "cross-account CreatePAT must fail") + + _, err = am.CreatePAT(context.Background(), accountAID, userAID, regularUserBID, "xss-token", 7) + require.Error(t, err, "cross-account CreatePAT for regular user must fail") + + _, err = am.CreatePAT(context.Background(), accountBID, adminBID, serviceUserBID, "legit-token", 7) + require.NoError(t, err, "same-account CreatePAT should succeed") + }) + + t.Run("DeletePAT for user in different account is denied", func(t *testing.T) { + am, cleanup := setupStore(t) + t.Cleanup(cleanup) + + err := am.DeletePAT(context.Background(), accountAID, userAID, serviceUserBID, tokenBID) + require.Error(t, err, "cross-account DeletePAT must fail") + }) + + t.Run("GetPAT for user in different account is denied", func(t *testing.T) { + am, cleanup := setupStore(t) + t.Cleanup(cleanup) + + _, err := am.GetPAT(context.Background(), accountAID, userAID, serviceUserBID, tokenBID) + require.Error(t, err, "cross-account GetPAT must fail") + }) + + t.Run("GetAllPATs for user in different account is denied", func(t *testing.T) { + am, cleanup := setupStore(t) + t.Cleanup(cleanup) + + _, err := am.GetAllPATs(context.Background(), accountAID, userAID, serviceUserBID) + require.Error(t, err, "cross-account GetAllPATs must fail") + }) + + t.Run("CreatePAT with forged accountID targeting foreign user is denied", func(t *testing.T) { + am, cleanup := setupStore(t) + t.Cleanup(cleanup) + + _, err := am.CreatePAT(context.Background(), accountAID, userAID, adminBID, "forged", 7) + require.Error(t, err, "forged accountID CreatePAT must fail") + }) +} + func TestUser_Copy(t *testing.T) { // this is an imaginary case which will never be in DB this way user := types.User{ @@ -1488,7 +1586,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1511,7 +1609,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) diff --git a/proxy/cmd/proxy/cmd/debug.go b/proxy/cmd/proxy/cmd/debug.go index 59f7a6b65..1b1664490 100644 --- a/proxy/cmd/proxy/cmd/debug.go +++ b/proxy/cmd/proxy/cmd/debug.go @@ -2,7 +2,12 @@ package cmd import ( "fmt" + "os" + "os/signal" + "path/filepath" "strconv" + "strings" + "syscall" "github.com/spf13/cobra" @@ -99,6 +104,27 @@ var debugStopCmd = &cobra.Command{ SilenceUsage: true, } +var debugCaptureCmd = &cobra.Command{ + Use: "capture [filter expression]", + Short: "Capture packets on a client's WireGuard interface", + Long: `Captures decrypted packets flowing through a client's WireGuard interface. + +Default output is human-readable text. Use --pcap or --output for pcap binary. +Filter arguments after the account ID use BPF-like syntax. + +Examples: + netbird-proxy debug capture + netbird-proxy debug capture --duration 1m host 10.0.0.1 + netbird-proxy debug capture host 10.0.0.1 and tcp port 443 + netbird-proxy debug capture not port 22 + netbird-proxy debug capture -o capture.pcap + netbird-proxy debug capture --pcap | tcpdump -r - -n + netbird-proxy debug capture --pcap | tshark -r -`, + Args: cobra.MinimumNArgs(1), + RunE: runDebugCapture, + SilenceUsage: true, +} + func init() { debugCmd.PersistentFlags().StringVar(&debugAddr, "addr", envStringOrDefault("NB_PROXY_DEBUG_ADDRESS", "localhost:8444"), "Debug endpoint address") debugCmd.PersistentFlags().BoolVar(&jsonOutput, "json", false, "Output JSON instead of pretty format") @@ -110,6 +136,12 @@ func init() { debugPingCmd.Flags().StringVar(&pingTimeout, "timeout", "", "Ping timeout (e.g., 10s)") + debugCaptureCmd.Flags().DurationP("duration", "d", 0, "Capture duration (0 = server default)") + debugCaptureCmd.Flags().Bool("pcap", false, "Force pcap binary output (default when --output is set)") + debugCaptureCmd.Flags().BoolP("verbose", "v", false, "Show seq/ack, TTL, window, total length (text mode)") + debugCaptureCmd.Flags().Bool("ascii", false, "Print payload as ASCII after each packet (text mode)") + debugCaptureCmd.Flags().StringP("output", "o", "", "Write pcap to file instead of stdout") + debugCmd.AddCommand(debugHealthCmd) debugCmd.AddCommand(debugClientsCmd) debugCmd.AddCommand(debugStatusCmd) @@ -119,6 +151,7 @@ func init() { debugCmd.AddCommand(debugLogCmd) debugCmd.AddCommand(debugStartCmd) debugCmd.AddCommand(debugStopCmd) + debugCmd.AddCommand(debugCaptureCmd) rootCmd.AddCommand(debugCmd) } @@ -171,3 +204,84 @@ func runDebugStart(cmd *cobra.Command, args []string) error { func runDebugStop(cmd *cobra.Command, args []string) error { return getDebugClient(cmd).StopClient(cmd.Context(), args[0]) } + +func runDebugCapture(cmd *cobra.Command, args []string) error { + duration, _ := cmd.Flags().GetDuration("duration") + forcePcap, _ := cmd.Flags().GetBool("pcap") + verbose, _ := cmd.Flags().GetBool("verbose") + ascii, _ := cmd.Flags().GetBool("ascii") + outPath, _ := cmd.Flags().GetString("output") + + // Default to text. Use pcap when --pcap is set or --output is given. + wantText := !forcePcap && outPath == "" + + var filterExpr string + if len(args) > 1 { + filterExpr = strings.Join(args[1:], " ") + } + + ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + out, cleanup, err := captureOutputWriter(cmd, outPath) + if err != nil { + return err + } + defer cleanup() + + if wantText { + cmd.PrintErrln("Capturing packets... Press Ctrl+C to stop.") + } else { + cmd.PrintErrln("Capturing packets (pcap)... Press Ctrl+C to stop.") + } + + var durationStr string + if duration > 0 { + durationStr = duration.String() + } + + err = getDebugClient(cmd).Capture(ctx, debug.CaptureOptions{ + AccountID: args[0], + Duration: durationStr, + FilterExpr: filterExpr, + Text: wantText, + Verbose: verbose, + ASCII: ascii, + Output: out, + }) + if err != nil { + return err + } + + cmd.PrintErrln("\nCapture finished.") + return nil +} + +// captureOutputWriter returns the writer and cleanup function for capture output. +func captureOutputWriter(cmd *cobra.Command, outPath string) (out *os.File, cleanup func(), err error) { + if outPath != "" { + f, err := os.CreateTemp(filepath.Dir(outPath), filepath.Base(outPath)+".*.tmp") + if err != nil { + return nil, nil, fmt.Errorf("create output file: %w", err) + } + tmpPath := f.Name() + return f, func() { + if err := f.Close(); err != nil { + cmd.PrintErrf("close output file: %v\n", err) + } + if fi, err := os.Stat(tmpPath); err == nil && fi.Size() > 0 { + if err := os.Rename(tmpPath, outPath); err != nil { + cmd.PrintErrf("rename output file: %v\n", err) + } else { + cmd.PrintErrf("Wrote %s\n", outPath) + } + } else { + os.Remove(tmpPath) + } + }, nil + } + + return os.Stdout, func() { + // no cleanup needed for stdout + }, nil +} diff --git a/proxy/cmd/proxy/cmd/root.go b/proxy/cmd/proxy/cmd/root.go index 1c36ee334..ec8980ad9 100644 --- a/proxy/cmd/proxy/cmd/root.go +++ b/proxy/cmd/proxy/cmd/root.go @@ -35,7 +35,7 @@ var ( ) var ( - logLevel string + logLevel string debugLogs bool mgmtAddr string addr string @@ -64,6 +64,8 @@ var ( supportsCustomPorts bool requireSubdomain bool geoDataDir string + crowdsecAPIURL string + crowdsecAPIKey string ) var rootCmd = &cobra.Command{ @@ -106,6 +108,8 @@ func init() { rootCmd.Flags().DurationVar(&maxDialTimeout, "max-dial-timeout", envDurationOrDefault("NB_PROXY_MAX_DIAL_TIMEOUT", 0), "Cap per-service backend dial timeout (0 = no cap)") rootCmd.Flags().DurationVar(&maxSessionIdleTimeout, "max-session-idle-timeout", envDurationOrDefault("NB_PROXY_MAX_SESSION_IDLE_TIMEOUT", 0), "Cap per-service session idle timeout (0 = no cap)") rootCmd.Flags().StringVar(&geoDataDir, "geo-data-dir", envStringOrDefault("NB_PROXY_GEO_DATA_DIR", "/var/lib/netbird/geolocation"), "Directory for the GeoLite2 MMDB file (auto-downloaded if missing)") + rootCmd.Flags().StringVar(&crowdsecAPIURL, "crowdsec-api-url", envStringOrDefault("NB_PROXY_CROWDSEC_API_URL", ""), "CrowdSec LAPI URL for IP reputation checks") + rootCmd.Flags().StringVar(&crowdsecAPIKey, "crowdsec-api-key", envStringOrDefault("NB_PROXY_CROWDSEC_API_KEY", ""), "CrowdSec bouncer API key") } // Execute runs the root command. @@ -187,6 +191,8 @@ func runServer(cmd *cobra.Command, args []string) error { MaxDialTimeout: maxDialTimeout, MaxSessionIdleTimeout: maxSessionIdleTimeout, GeoDataDir: geoDataDir, + CrowdSecAPIURL: crowdsecAPIURL, + CrowdSecAPIKey: crowdsecAPIKey, } ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) diff --git a/proxy/cmd/proxy/main.go b/proxy/cmd/proxy/main.go index 14e540a2e..16e7e8ac2 100644 --- a/proxy/cmd/proxy/main.go +++ b/proxy/cmd/proxy/main.go @@ -1,8 +1,13 @@ package main import ( + "net/http" + // nolint:gosec + _ "net/http/pprof" "runtime" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/proxy/cmd/proxy/cmd" ) @@ -21,6 +26,9 @@ var ( ) func main() { + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() cmd.SetVersionInfo(Version, Commit, BuildDate, GoVersion) cmd.Execute() } diff --git a/proxy/internal/accesslog/logger.go b/proxy/internal/accesslog/logger.go index 3ed3275b5..3283f61db 100644 --- a/proxy/internal/accesslog/logger.go +++ b/proxy/internal/accesslog/logger.go @@ -2,6 +2,7 @@ package accesslog import ( "context" + "maps" "net/netip" "sync" "sync/atomic" @@ -126,6 +127,7 @@ type logEntry struct { BytesUpload int64 BytesDownload int64 Protocol Protocol + Metadata map[string]string } // Protocol identifies the transport protocol of an access log entry. @@ -150,8 +152,10 @@ type L4Entry struct { BytesDownload int64 // DenyReason, when non-empty, indicates the connection was denied. // Values match the HTTP auth mechanism strings: "ip_restricted", - // "country_restricted", "geo_unavailable". + // "country_restricted", "geo_unavailable", "crowdsec_ban", etc. DenyReason string + // Metadata carries extra context about the connection (e.g. CrowdSec verdict). + Metadata map[string]string } // LogL4 sends an access log entry for a layer-4 connection (TCP or UDP). @@ -167,6 +171,7 @@ func (l *Logger) LogL4(entry L4Entry) { DurationMs: entry.DurationMs, BytesUpload: entry.BytesUpload, BytesDownload: entry.BytesDownload, + Metadata: maps.Clone(entry.Metadata), } if entry.DenyReason != "" { if !l.allowDenyLog(entry.ServiceID, entry.DenyReason) { @@ -258,6 +263,7 @@ func (l *Logger) log(entry logEntry) { BytesUpload: entry.BytesUpload, BytesDownload: entry.BytesDownload, Protocol: string(entry.Protocol), + Metadata: entry.Metadata, }, }); err != nil { l.logger.WithFields(log.Fields{ diff --git a/proxy/internal/accesslog/middleware.go b/proxy/internal/accesslog/middleware.go index 81c790b17..5a0684c19 100644 --- a/proxy/internal/accesslog/middleware.go +++ b/proxy/internal/accesslog/middleware.go @@ -82,6 +82,7 @@ func (l *Logger) Middleware(next http.Handler) http.Handler { BytesUpload: bytesUpload, BytesDownload: bytesDownload, Protocol: ProtocolHTTP, + Metadata: capturedData.GetMetadata(), } l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s", requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceID(), capturedData.GetAccountID()) diff --git a/proxy/internal/auth/middleware.go b/proxy/internal/auth/middleware.go index 670cafb68..3b383f8b4 100644 --- a/proxy/internal/auth/middleware.go +++ b/proxy/internal/auth/middleware.go @@ -167,6 +167,20 @@ func (mw *Middleware) checkIPRestrictions(w http.ResponseWriter, r *http.Request return true } + if verdict.IsCrowdSec() { + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + cd.SetMetadata("crowdsec_verdict", verdict.String()) + if config.IPRestrictions.IsObserveOnly(verdict) { + cd.SetMetadata("crowdsec_mode", "observe") + } + } + } + + if config.IPRestrictions.IsObserveOnly(verdict) { + mw.logger.Debugf("CrowdSec observe: would block %s for %s (%s)", clientIP, r.Host, verdict) + return true + } + reason := verdict.String() mw.blockIPRestriction(r, reason) http.Error(w, "Forbidden", http.StatusForbidden) @@ -358,6 +372,12 @@ func (mw *Middleware) authenticateWithSchemes(w http.ResponseWriter, r *http.Req cd.SetAuthMethod(attemptedMethod) } } + + if oidcURL, ok := methods[auth.MethodOIDC.String()]; ok && len(methods) == 1 && oidcURL != "" { + http.Redirect(w, r, oidcURL, http.StatusFound) + return + } + web.ServeHTTP(w, r, map[string]any{"methods": methods}, http.StatusUnauthorized) } @@ -413,6 +433,7 @@ func setSessionCookie(w http.ResponseWriter, token string, expiration time.Durat http.SetCookie(w, &http.Cookie{ Name: auth.SessionCookieName, Value: token, + Path: "/", HttpOnly: true, Secure: true, SameSite: http.SameSiteLaxMode, diff --git a/proxy/internal/auth/middleware_test.go b/proxy/internal/auth/middleware_test.go index 6063f070e..2c93d7912 100644 --- a/proxy/internal/auth/middleware_test.go +++ b/proxy/internal/auth/middleware_test.go @@ -391,6 +391,15 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) { assert.Equal(t, http.SameSiteLaxMode, sessionCookie.SameSite) } +func TestSetSessionCookieHasRootPath(t *testing.T) { + w := httptest.NewRecorder() + setSessionCookie(w, "test-token", time.Hour) + + cookies := w.Result().Cookies() + require.Len(t, cookies, 1) + assert.Equal(t, "/", cookies[0].Path, "session cookie must be scoped to root so it applies to all paths") +} + func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) { mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) @@ -669,7 +678,7 @@ func TestCheckIPRestrictions_UnparseableAddress(t *testing.T) { mw := NewMiddleware(log.StandardLogger(), nil, nil) err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1", - restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)) + restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}})) require.NoError(t, err) handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -705,7 +714,7 @@ func TestCheckIPRestrictions_UsesCapturedDataClientIP(t *testing.T) { mw := NewMiddleware(log.StandardLogger(), nil, nil) err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1", - restrict.ParseFilter([]string{"203.0.113.0/24"}, nil, nil, nil)) + restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"203.0.113.0/24"}})) require.NoError(t, err) handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -746,7 +755,7 @@ func TestCheckIPRestrictions_NilGeoWithCountryRules(t *testing.T) { mw := NewMiddleware(log.StandardLogger(), nil, nil) err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1", - restrict.ParseFilter(nil, nil, []string{"US"}, nil)) + restrict.ParseFilter(restrict.FilterConfig{AllowedCountries: []string{"US"}})) require.NoError(t, err) handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -761,6 +770,56 @@ func TestCheckIPRestrictions_NilGeoWithCountryRules(t *testing.T) { assert.Equal(t, http.StatusForbidden, rr.Code, "country restrictions with nil geo must deny") } +func TestProtect_OIDCOnlyRedirectsDirectly(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + kp := generateTestKeyPair(t) + + oidcURL := "https://idp.example.com/authorize?client_id=abc" + scheme := &stubScheme{ + method: auth.MethodOIDC, + authFn: func(_ *http.Request) (string, string, error) { + return "", oidcURL, nil + }, + } + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) + + handler := mw.Protect(newPassthroughHandler()) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusFound, rec.Code, "should redirect directly to IdP") + assert.Equal(t, oidcURL, rec.Header().Get("Location")) +} + +func TestProtect_OIDCWithOtherMethodShowsLoginPage(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + kp := generateTestKeyPair(t) + + oidcScheme := &stubScheme{ + method: auth.MethodOIDC, + authFn: func(_ *http.Request) (string, string, error) { + return "", "https://idp.example.com/authorize", nil + }, + } + pinScheme := &stubScheme{ + method: auth.MethodPIN, + authFn: func(_ *http.Request) (string, string, error) { + return "", "pin", nil + }, + } + require.NoError(t, mw.AddDomain("example.com", []Scheme{oidcScheme, pinScheme}, kp.PublicKey, time.Hour, "", "", nil)) + + handler := mw.Protect(newPassthroughHandler()) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusUnauthorized, rec.Code, "should show login page when multiple methods exist") +} + // mockAuthenticator is a minimal mock for the authenticator gRPC interface // used by the Header scheme. type mockAuthenticator struct { diff --git a/proxy/internal/crowdsec/bouncer.go b/proxy/internal/crowdsec/bouncer.go new file mode 100644 index 000000000..06a452520 --- /dev/null +++ b/proxy/internal/crowdsec/bouncer.go @@ -0,0 +1,251 @@ +// Package crowdsec provides a CrowdSec stream bouncer that maintains a local +// decision cache for IP reputation checks. +package crowdsec + +import ( + "context" + "errors" + "net/netip" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/crowdsecurity/crowdsec/pkg/models" + csbouncer "github.com/crowdsecurity/go-cs-bouncer" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/proxy/internal/restrict" +) + +// Bouncer wraps a CrowdSec StreamBouncer, maintaining a local cache of +// active decisions for fast IP lookups. It implements restrict.CrowdSecChecker. +type Bouncer struct { + mu sync.RWMutex + ips map[netip.Addr]*restrict.CrowdSecDecision + prefixes map[netip.Prefix]*restrict.CrowdSecDecision + ready atomic.Bool + + apiURL string + apiKey string + tickerInterval time.Duration + logger *log.Entry + + // lifeMu protects cancel and done from concurrent Start/Stop calls. + lifeMu sync.Mutex + cancel context.CancelFunc + done chan struct{} +} + +// compile-time check +var _ restrict.CrowdSecChecker = (*Bouncer)(nil) + +// NewBouncer creates a bouncer but does not start the stream. +func NewBouncer(apiURL, apiKey string, logger *log.Entry) *Bouncer { + return &Bouncer{ + apiURL: apiURL, + apiKey: apiKey, + logger: logger, + ips: make(map[netip.Addr]*restrict.CrowdSecDecision), + prefixes: make(map[netip.Prefix]*restrict.CrowdSecDecision), + } +} + +// Start launches the background goroutine that streams decisions from the +// CrowdSec LAPI. The stream runs until Stop is called or ctx is cancelled. +func (b *Bouncer) Start(ctx context.Context) error { + interval := b.tickerInterval + if interval == 0 { + interval = 10 * time.Second + } + stream := &csbouncer.StreamBouncer{ + APIKey: b.apiKey, + APIUrl: b.apiURL, + TickerInterval: interval.String(), + UserAgent: "netbird-proxy/1.0", + Scopes: []string{"ip", "range"}, + RetryInitialConnect: true, + } + + b.logger.Infof("connecting to CrowdSec LAPI at %s", b.apiURL) + + if err := stream.Init(); err != nil { + return err + } + + // Reset state from any previous run. + b.mu.Lock() + b.ips = make(map[netip.Addr]*restrict.CrowdSecDecision) + b.prefixes = make(map[netip.Prefix]*restrict.CrowdSecDecision) + b.mu.Unlock() + b.ready.Store(false) + + ctx, cancel := context.WithCancel(ctx) + done := make(chan struct{}) + + b.lifeMu.Lock() + if b.cancel != nil { + b.lifeMu.Unlock() + cancel() + return errors.New("bouncer already started") + } + b.cancel = cancel + b.done = done + b.lifeMu.Unlock() + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + if err := stream.Run(ctx); err != nil && ctx.Err() == nil { + b.logger.Errorf("CrowdSec stream ended: %v", err) + } + }() + + go func() { + defer wg.Done() + b.consumeStream(ctx, stream) + }() + + go func() { + wg.Wait() + close(done) + }() + + return nil +} + +// Stop cancels the stream and waits for all goroutines to finish. +func (b *Bouncer) Stop() { + b.lifeMu.Lock() + cancel := b.cancel + done := b.done + b.cancel = nil + b.lifeMu.Unlock() + + if cancel != nil { + cancel() + <-done + } +} + +// Ready returns true after the first batch of decisions has been processed. +func (b *Bouncer) Ready() bool { + return b.ready.Load() +} + +// CheckIP looks up addr in the local decision cache. Returns nil if no +// active decision exists for the address. +// +// Prefix lookups are O(1): instead of scanning all stored prefixes, we +// probe the map for every possible containing prefix of the address +// (at most 33 for IPv4, 129 for IPv6). +func (b *Bouncer) CheckIP(addr netip.Addr) *restrict.CrowdSecDecision { + addr = addr.Unmap() + + b.mu.RLock() + defer b.mu.RUnlock() + + if d, ok := b.ips[addr]; ok { + return d + } + + maxBits := 32 + if addr.Is6() { + maxBits = 128 + } + // Walk from most-specific to least-specific prefix so the narrowest + // matching decision wins when ranges overlap. + for bits := maxBits; bits >= 0; bits-- { + prefix := netip.PrefixFrom(addr, bits).Masked() + if d, ok := b.prefixes[prefix]; ok { + return d + } + } + + return nil +} + +func (b *Bouncer) consumeStream(ctx context.Context, stream *csbouncer.StreamBouncer) { + first := true + for { + select { + case <-ctx.Done(): + return + case resp, ok := <-stream.Stream: + if !ok { + return + } + b.mu.Lock() + b.applyDeleted(resp.Deleted) + b.applyNew(resp.New) + b.mu.Unlock() + + if first { + b.ready.Store(true) + b.logger.Info("CrowdSec bouncer synced initial decisions") + first = false + } + } + } +} + +func (b *Bouncer) applyDeleted(decisions []*models.Decision) { + for _, d := range decisions { + if d.Value == nil || d.Scope == nil { + continue + } + value := *d.Value + + if strings.ToLower(*d.Scope) == "range" || strings.Contains(value, "/") { + prefix, err := netip.ParsePrefix(value) + if err != nil { + b.logger.Debugf("skip unparsable CrowdSec range deletion %q: %v", value, err) + continue + } + prefix = normalizePrefix(prefix) + delete(b.prefixes, prefix) + } else { + addr, err := netip.ParseAddr(value) + if err != nil { + b.logger.Debugf("skip unparsable CrowdSec IP deletion %q: %v", value, err) + continue + } + delete(b.ips, addr.Unmap()) + } + } +} + +func (b *Bouncer) applyNew(decisions []*models.Decision) { + for _, d := range decisions { + if d.Value == nil || d.Type == nil || d.Scope == nil { + continue + } + dec := &restrict.CrowdSecDecision{Type: restrict.DecisionType(*d.Type)} + value := *d.Value + + if strings.ToLower(*d.Scope) == "range" || strings.Contains(value, "/") { + prefix, err := netip.ParsePrefix(value) + if err != nil { + b.logger.Debugf("skip unparsable CrowdSec range %q: %v", value, err) + continue + } + prefix = normalizePrefix(prefix) + b.prefixes[prefix] = dec + } else { + addr, err := netip.ParseAddr(value) + if err != nil { + b.logger.Debugf("skip unparsable CrowdSec IP %q: %v", value, err) + continue + } + b.ips[addr.Unmap()] = dec + } + } +} + +// normalizePrefix unmaps v4-mapped-v6 addresses and zeros host bits so +// the prefix is a valid map key that matches CheckIP's probe logic. +func normalizePrefix(p netip.Prefix) netip.Prefix { + return netip.PrefixFrom(p.Addr().Unmap(), p.Bits()).Masked() +} diff --git a/proxy/internal/crowdsec/bouncer_test.go b/proxy/internal/crowdsec/bouncer_test.go new file mode 100644 index 000000000..3bd8aa068 --- /dev/null +++ b/proxy/internal/crowdsec/bouncer_test.go @@ -0,0 +1,337 @@ +package crowdsec + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/netip" + "sync" + "testing" + "time" + + "github.com/crowdsecurity/crowdsec/pkg/models" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/restrict" +) + +func TestBouncer_CheckIP_Empty(t *testing.T) { + b := newTestBouncer() + b.ready.Store(true) + + assert.Nil(t, b.CheckIP(netip.MustParseAddr("1.2.3.4"))) +} + +func TestBouncer_CheckIP_ExactMatch(t *testing.T) { + b := newTestBouncer() + b.ready.Store(true) + b.ips[netip.MustParseAddr("10.0.0.1")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan} + + d := b.CheckIP(netip.MustParseAddr("10.0.0.1")) + require.NotNil(t, d) + assert.Equal(t, restrict.DecisionBan, d.Type) + + assert.Nil(t, b.CheckIP(netip.MustParseAddr("10.0.0.2"))) +} + +func TestBouncer_CheckIP_PrefixMatch(t *testing.T) { + b := newTestBouncer() + b.ready.Store(true) + b.prefixes[netip.MustParsePrefix("192.168.1.0/24")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan} + + d := b.CheckIP(netip.MustParseAddr("192.168.1.100")) + require.NotNil(t, d) + assert.Equal(t, restrict.DecisionBan, d.Type) + + assert.Nil(t, b.CheckIP(netip.MustParseAddr("192.168.2.1"))) +} + +func TestBouncer_CheckIP_UnmapsV4InV6(t *testing.T) { + b := newTestBouncer() + b.ready.Store(true) + b.ips[netip.MustParseAddr("10.0.0.1")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan} + + d := b.CheckIP(netip.MustParseAddr("::ffff:10.0.0.1")) + require.NotNil(t, d) + assert.Equal(t, restrict.DecisionBan, d.Type) +} + +func TestBouncer_Ready(t *testing.T) { + b := newTestBouncer() + assert.False(t, b.Ready()) + + b.ready.Store(true) + assert.True(t, b.Ready()) +} + +func TestBouncer_CheckIP_ExactBeforePrefix(t *testing.T) { + b := newTestBouncer() + b.ready.Store(true) + b.ips[netip.MustParseAddr("10.0.0.1")] = &restrict.CrowdSecDecision{Type: restrict.DecisionCaptcha} + b.prefixes[netip.MustParsePrefix("10.0.0.0/8")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan} + + d := b.CheckIP(netip.MustParseAddr("10.0.0.1")) + require.NotNil(t, d) + assert.Equal(t, restrict.DecisionCaptcha, d.Type) + + d2 := b.CheckIP(netip.MustParseAddr("10.0.0.2")) + require.NotNil(t, d2) + assert.Equal(t, restrict.DecisionBan, d2.Type) +} + +func TestBouncer_ApplyNew_IP(t *testing.T) { + b := newTestBouncer() + + b.applyNew(makeDecisions( + decision{scope: "ip", value: "1.2.3.4", dtype: "ban", scenario: "test/brute"}, + decision{scope: "ip", value: "5.6.7.8", dtype: "captcha", scenario: "test/crawl"}, + )) + + require.Len(t, b.ips, 2) + assert.Equal(t, restrict.DecisionBan, b.ips[netip.MustParseAddr("1.2.3.4")].Type) + assert.Equal(t, restrict.DecisionCaptcha, b.ips[netip.MustParseAddr("5.6.7.8")].Type) +} + +func TestBouncer_ApplyNew_Range(t *testing.T) { + b := newTestBouncer() + + b.applyNew(makeDecisions( + decision{scope: "range", value: "10.0.0.0/8", dtype: "ban"}, + )) + + require.Len(t, b.prefixes, 1) + assert.NotNil(t, b.prefixes[netip.MustParsePrefix("10.0.0.0/8")]) +} + +func TestBouncer_ApplyDeleted_IP(t *testing.T) { + b := newTestBouncer() + b.ips[netip.MustParseAddr("1.2.3.4")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan} + b.ips[netip.MustParseAddr("5.6.7.8")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan} + + b.applyDeleted(makeDecisions( + decision{scope: "ip", value: "1.2.3.4", dtype: "ban"}, + )) + + assert.Len(t, b.ips, 1) + assert.Nil(t, b.ips[netip.MustParseAddr("1.2.3.4")]) + assert.NotNil(t, b.ips[netip.MustParseAddr("5.6.7.8")]) +} + +func TestBouncer_ApplyDeleted_Range(t *testing.T) { + b := newTestBouncer() + b.prefixes[netip.MustParsePrefix("10.0.0.0/8")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan} + b.prefixes[netip.MustParsePrefix("192.168.0.0/16")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan} + + b.applyDeleted(makeDecisions( + decision{scope: "range", value: "10.0.0.0/8", dtype: "ban"}, + )) + + require.Len(t, b.prefixes, 1) + assert.NotNil(t, b.prefixes[netip.MustParsePrefix("192.168.0.0/16")]) +} + +func TestBouncer_ApplyNew_OverwritesExisting(t *testing.T) { + b := newTestBouncer() + b.ips[netip.MustParseAddr("1.2.3.4")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan} + + b.applyNew(makeDecisions( + decision{scope: "ip", value: "1.2.3.4", dtype: "captcha"}, + )) + + assert.Equal(t, restrict.DecisionCaptcha, b.ips[netip.MustParseAddr("1.2.3.4")].Type) +} + +func TestBouncer_ApplyNew_SkipsInvalid(t *testing.T) { + b := newTestBouncer() + + b.applyNew(makeDecisions( + decision{scope: "ip", value: "not-an-ip", dtype: "ban"}, + decision{scope: "range", value: "also-not-valid", dtype: "ban"}, + )) + + assert.Empty(t, b.ips) + assert.Empty(t, b.prefixes) +} + +// TestBouncer_StreamIntegration tests the full flow: fake LAPI → StreamBouncer → Bouncer cache → CheckIP. +func TestBouncer_StreamIntegration(t *testing.T) { + lapi := newFakeLAPI() + ts := httptest.NewServer(lapi) + defer ts.Close() + + // Seed the LAPI with initial decisions. + lapi.setDecisions( + decision{scope: "ip", value: "1.2.3.4", dtype: "ban", scenario: "crowdsecurity/ssh-bf"}, + decision{scope: "range", value: "10.0.0.0/8", dtype: "ban", scenario: "crowdsecurity/http-probing"}, + decision{scope: "ip", value: "5.5.5.5", dtype: "captcha", scenario: "crowdsecurity/http-crawl"}, + ) + + b := NewBouncer(ts.URL, "test-key", log.NewEntry(log.StandardLogger())) + b.tickerInterval = 200 * time.Millisecond + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + require.NoError(t, b.Start(ctx)) + defer b.Stop() + + // Wait for initial sync. + require.Eventually(t, b.Ready, 5*time.Second, 50*time.Millisecond, "bouncer should become ready") + + // Verify decisions are cached. + d := b.CheckIP(netip.MustParseAddr("1.2.3.4")) + require.NotNil(t, d, "1.2.3.4 should be banned") + assert.Equal(t, restrict.DecisionBan, d.Type) + + d2 := b.CheckIP(netip.MustParseAddr("10.1.2.3")) + require.NotNil(t, d2, "10.1.2.3 should match range ban") + assert.Equal(t, restrict.DecisionBan, d2.Type) + + d3 := b.CheckIP(netip.MustParseAddr("5.5.5.5")) + require.NotNil(t, d3, "5.5.5.5 should have captcha") + assert.Equal(t, restrict.DecisionCaptcha, d3.Type) + + assert.Nil(t, b.CheckIP(netip.MustParseAddr("9.9.9.9")), "unknown IP should be nil") + + // Simulate a delta update: delete one IP, add a new one. + lapi.setDelta( + []decision{{scope: "ip", value: "1.2.3.4", dtype: "ban"}}, + []decision{{scope: "ip", value: "2.3.4.5", dtype: "throttle", scenario: "crowdsecurity/http-flood"}}, + ) + + // Wait for the delta to be picked up. + require.Eventually(t, func() bool { + return b.CheckIP(netip.MustParseAddr("2.3.4.5")) != nil + }, 5*time.Second, 50*time.Millisecond, "new decision should appear") + + assert.Nil(t, b.CheckIP(netip.MustParseAddr("1.2.3.4")), "deleted decision should be gone") + + d4 := b.CheckIP(netip.MustParseAddr("2.3.4.5")) + require.NotNil(t, d4) + assert.Equal(t, restrict.DecisionThrottle, d4.Type) + + // Range ban should still be active. + assert.NotNil(t, b.CheckIP(netip.MustParseAddr("10.99.99.99"))) +} + +// Helpers + +func newTestBouncer() *Bouncer { + return &Bouncer{ + ips: make(map[netip.Addr]*restrict.CrowdSecDecision), + prefixes: make(map[netip.Prefix]*restrict.CrowdSecDecision), + logger: log.NewEntry(log.StandardLogger()), + } +} + +type decision struct { + scope string + value string + dtype string + scenario string +} + +func makeDecisions(decs ...decision) []*models.Decision { + out := make([]*models.Decision, len(decs)) + for i, d := range decs { + out[i] = &models.Decision{ + Scope: strPtr(d.scope), + Value: strPtr(d.value), + Type: strPtr(d.dtype), + Scenario: strPtr(d.scenario), + Duration: strPtr("1h"), + Origin: strPtr("cscli"), + } + } + return out +} + +func strPtr(s string) *string { return &s } + +// fakeLAPI is a minimal fake CrowdSec LAPI that serves /v1/decisions/stream. +type fakeLAPI struct { + mu sync.Mutex + initial []decision + newDelta []decision + delDelta []decision + served bool // true after the initial snapshot has been served +} + +func newFakeLAPI() *fakeLAPI { + return &fakeLAPI{} +} + +func (f *fakeLAPI) setDecisions(decs ...decision) { + f.mu.Lock() + defer f.mu.Unlock() + f.initial = decs + f.served = false +} + +func (f *fakeLAPI) setDelta(deleted, added []decision) { + f.mu.Lock() + defer f.mu.Unlock() + f.delDelta = deleted + f.newDelta = added +} + +func (f *fakeLAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/decisions/stream" { + http.NotFound(w, r) + return + } + + f.mu.Lock() + defer f.mu.Unlock() + + resp := streamResponse{} + + if !f.served { + for _, d := range f.initial { + resp.New = append(resp.New, toLAPIDecision(d)) + } + f.served = true + } else { + for _, d := range f.delDelta { + resp.Deleted = append(resp.Deleted, toLAPIDecision(d)) + } + for _, d := range f.newDelta { + resp.New = append(resp.New, toLAPIDecision(d)) + } + // Clear delta after serving once. + f.delDelta = nil + f.newDelta = nil + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) //nolint:errcheck +} + +// streamResponse mirrors the CrowdSec LAPI /v1/decisions/stream JSON structure. +type streamResponse struct { + New []*lapiDecision `json:"new"` + Deleted []*lapiDecision `json:"deleted"` +} + +type lapiDecision struct { + Duration *string `json:"duration"` + Origin *string `json:"origin"` + Scenario *string `json:"scenario"` + Scope *string `json:"scope"` + Type *string `json:"type"` + Value *string `json:"value"` +} + +func toLAPIDecision(d decision) *lapiDecision { + return &lapiDecision{ + Duration: strPtr("1h"), + Origin: strPtr("cscli"), + Scenario: strPtr(d.scenario), + Scope: strPtr(d.scope), + Type: strPtr(d.dtype), + Value: strPtr(d.value), + } +} diff --git a/proxy/internal/crowdsec/registry.go b/proxy/internal/crowdsec/registry.go new file mode 100644 index 000000000..652fb6f9f --- /dev/null +++ b/proxy/internal/crowdsec/registry.go @@ -0,0 +1,103 @@ +package crowdsec + +import ( + "context" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/proxy/internal/types" +) + +// Registry manages a single shared Bouncer instance with reference counting. +// The bouncer starts when the first service acquires it and stops when the +// last service releases it. +type Registry struct { + mu sync.Mutex + bouncer *Bouncer + refs map[types.ServiceID]struct{} + apiURL string + apiKey string + logger *log.Entry + cancel context.CancelFunc +} + +// NewRegistry creates a registry. The bouncer is not started until Acquire is called. +func NewRegistry(apiURL, apiKey string, logger *log.Entry) *Registry { + return &Registry{ + apiURL: apiURL, + apiKey: apiKey, + logger: logger, + refs: make(map[types.ServiceID]struct{}), + } +} + +// Available returns true when the LAPI URL and API key are configured. +func (r *Registry) Available() bool { + return r.apiURL != "" && r.apiKey != "" +} + +// Acquire registers svcID as a consumer and starts the bouncer if this is the +// first consumer. Returns the shared Bouncer (which implements the restrict +// package's CrowdSecChecker interface). Returns nil if not Available. +func (r *Registry) Acquire(svcID types.ServiceID) *Bouncer { + r.mu.Lock() + defer r.mu.Unlock() + + if !r.Available() { + return nil + } + + if _, exists := r.refs[svcID]; exists { + return r.bouncer + } + + if r.bouncer == nil { + r.startLocked() + } + + // startLocked may fail, leaving r.bouncer nil. + if r.bouncer == nil { + return nil + } + + r.refs[svcID] = struct{}{} + return r.bouncer +} + +// Release removes svcID as a consumer. Stops the bouncer when the last +// consumer releases. +func (r *Registry) Release(svcID types.ServiceID) { + r.mu.Lock() + defer r.mu.Unlock() + + delete(r.refs, svcID) + + if len(r.refs) == 0 && r.bouncer != nil { + r.stopLocked() + } +} + +func (r *Registry) startLocked() { + b := NewBouncer(r.apiURL, r.apiKey, r.logger) + + ctx, cancel := context.WithCancel(context.Background()) + r.cancel = cancel + + if err := b.Start(ctx); err != nil { + r.logger.Errorf("failed to start CrowdSec bouncer: %v", err) + cancel() + return + } + + r.bouncer = b + r.logger.Info("CrowdSec bouncer started") +} + +func (r *Registry) stopLocked() { + r.bouncer.Stop() + r.cancel() + r.bouncer = nil + r.cancel = nil + r.logger.Info("CrowdSec bouncer stopped") +} diff --git a/proxy/internal/crowdsec/registry_test.go b/proxy/internal/crowdsec/registry_test.go new file mode 100644 index 000000000..f1567b186 --- /dev/null +++ b/proxy/internal/crowdsec/registry_test.go @@ -0,0 +1,66 @@ +package crowdsec + +import ( + "testing" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/proxy/internal/types" +) + +func TestRegistry_Available(t *testing.T) { + r := NewRegistry("http://localhost:8080/", "test-key", log.NewEntry(log.StandardLogger())) + assert.True(t, r.Available()) + + r2 := NewRegistry("", "", log.NewEntry(log.StandardLogger())) + assert.False(t, r2.Available()) + + r3 := NewRegistry("http://localhost:8080/", "", log.NewEntry(log.StandardLogger())) + assert.False(t, r3.Available()) +} + +func TestRegistry_Acquire_NotAvailable(t *testing.T) { + r := NewRegistry("", "", log.NewEntry(log.StandardLogger())) + b := r.Acquire("svc-1") + assert.Nil(t, b) +} + +func TestRegistry_Acquire_Idempotent(t *testing.T) { + r := newTestRegistry() + + b1 := r.Acquire("svc-1") + // Can't start without a real LAPI, but we can verify the ref tracking. + // The bouncer will be nil because Start fails, but the ref is tracked. + _ = b1 + + assert.Len(t, r.refs, 1) + + // Second acquire of same service should not add another ref. + r.Acquire("svc-1") + assert.Len(t, r.refs, 1) +} + +func TestRegistry_Release_Removes(t *testing.T) { + r := newTestRegistry() + r.refs[types.ServiceID("svc-1")] = struct{}{} + + r.Release("svc-1") + assert.Empty(t, r.refs) +} + +func TestRegistry_Release_Noop(t *testing.T) { + r := newTestRegistry() + // Releasing a service that was never acquired should not panic. + r.Release("nonexistent") + assert.Empty(t, r.refs) +} + +func newTestRegistry() *Registry { + return &Registry{ + apiURL: "http://localhost:8080/", + apiKey: "test-key", + logger: log.NewEntry(log.StandardLogger()), + refs: make(map[types.ServiceID]struct{}), + } +} diff --git a/proxy/internal/debug/client.go b/proxy/internal/debug/client.go index 01b0bc8e6..e01149522 100644 --- a/proxy/internal/debug/client.go +++ b/proxy/internal/debug/client.go @@ -310,6 +310,76 @@ func (c *Client) printError(data map[string]any) { } } +// CaptureOptions configures a capture request. +type CaptureOptions struct { + AccountID string + Duration string + FilterExpr string + Text bool + Verbose bool + ASCII bool + Output io.Writer +} + +// Capture streams a packet capture from the debug endpoint. The response body +// (pcap or text) is written directly to opts.Output until the server closes the +// connection or the context is cancelled. +func (c *Client) Capture(ctx context.Context, opts CaptureOptions) error { + if opts.AccountID == "" { + return fmt.Errorf("account ID is required") + } + if opts.Output == nil { + return fmt.Errorf("output writer is required") + } + + params := url.Values{} + if opts.Duration != "" { + params.Set("duration", opts.Duration) + } + if opts.FilterExpr != "" { + params.Set("filter", opts.FilterExpr) + } + if opts.Text { + params.Set("format", "text") + } + if opts.Verbose { + params.Set("verbose", "true") + } + if opts.ASCII { + params.Set("ascii", "true") + } + + path := fmt.Sprintf("/debug/clients/%s/capture", url.PathEscape(opts.AccountID)) + if len(params) > 0 { + path += "?" + params.Encode() + } + + fullURL := c.baseURL + path + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) + if err != nil { + return fmt.Errorf("create request: %w", err) + } + + // Use a separate client without timeout since captures stream for their full duration. + httpClient := &http.Client{} + resp, err := httpClient.Do(req) + if err != nil { + return fmt.Errorf("request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("server error (%d): %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + _, err = io.Copy(opts.Output, resp.Body) + if err != nil && ctx.Err() != nil { + return nil + } + return err +} + func (c *Client) fetchAndPrint(ctx context.Context, path string, printer func(map[string]any)) error { data, raw, err := c.fetch(ctx, path) if err != nil { diff --git a/proxy/internal/debug/handler.go b/proxy/internal/debug/handler.go index c507cfad9..6cd124554 100644 --- a/proxy/internal/debug/handler.go +++ b/proxy/internal/debug/handler.go @@ -174,6 +174,8 @@ func (h *Handler) handleClientRoutes(w http.ResponseWriter, r *http.Request, pat h.handleClientStart(w, r, accountID) case "stop": h.handleClientStop(w, r, accountID) + case "capture": + h.handleCapture(w, r, accountID) default: return false } @@ -632,6 +634,81 @@ func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accou }) } +const maxCaptureDuration = 30 * time.Minute + +// handleCapture streams a pcap or text packet capture for the given client. +// +// Query params: +// +// duration: capture duration (0 or absent = max, capped at 30m) +// format: "text" for human-readable output (default: pcap) +// filter: BPF-like filter expression (e.g. "host 10.0.0.1 and tcp port 443") +func (h *Handler) handleCapture(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { + client, ok := h.provider.GetClient(accountID) + if !ok { + http.Error(w, "client not found", http.StatusNotFound) + return + } + + duration := maxCaptureDuration + if durationStr := r.URL.Query().Get("duration"); durationStr != "" { + d, err := time.ParseDuration(durationStr) + if err != nil { + http.Error(w, "invalid duration: "+err.Error(), http.StatusBadRequest) + return + } + if d < 0 { + http.Error(w, "duration must not be negative", http.StatusBadRequest) + return + } + if d > 0 { + duration = min(d, maxCaptureDuration) + } + } + + filter := r.URL.Query().Get("filter") + wantText := r.URL.Query().Get("format") == "text" + verbose := r.URL.Query().Get("verbose") == "true" + ascii := r.URL.Query().Get("ascii") == "true" + + opts := nbembed.CaptureOptions{Filter: filter, Verbose: verbose, ASCII: ascii} + if wantText { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + opts.TextOutput = w + } else { + w.Header().Set("Content-Type", "application/vnd.tcpdump.pcap") + w.Header().Set("Content-Disposition", + fmt.Sprintf("attachment; filename=capture-%s.pcap", accountID)) + opts.Output = w + } + + cs, err := client.StartCapture(opts) + if err != nil { + http.Error(w, "start capture: "+err.Error(), http.StatusServiceUnavailable) + return + } + defer cs.Stop() + + // Flush headers after setup succeeds so errors above can still set status codes. + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + + timer := time.NewTimer(duration) + defer timer.Stop() + + select { + case <-r.Context().Done(): + case <-timer.C: + } + + cs.Stop() + + stats := cs.Stats() + h.logger.Infof("capture for %s finished: %d packets, %d bytes, %d dropped", + accountID, stats.Packets, stats.Bytes, stats.Dropped) +} + func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON bool) { if !wantJSON { http.Redirect(w, r, "/debug", http.StatusSeeOther) diff --git a/proxy/internal/proxy/context.go b/proxy/internal/proxy/context.go index d3f67dc57..a888ad9ed 100644 --- a/proxy/internal/proxy/context.go +++ b/proxy/internal/proxy/context.go @@ -2,6 +2,7 @@ package proxy import ( "context" + "maps" "net/netip" "sync" @@ -52,6 +53,7 @@ type CapturedData struct { clientIP netip.Addr userID string authMethod string + metadata map[string]string } // NewCapturedData creates a CapturedData with the given request ID. @@ -150,6 +152,23 @@ func (c *CapturedData) GetAuthMethod() string { return c.authMethod } +// SetMetadata sets a key-value pair in the metadata map. +func (c *CapturedData) SetMetadata(key, value string) { + c.mu.Lock() + defer c.mu.Unlock() + if c.metadata == nil { + c.metadata = make(map[string]string) + } + c.metadata[key] = value +} + +// GetMetadata returns a copy of the metadata map. +func (c *CapturedData) GetMetadata() map[string]string { + c.mu.RLock() + defer c.mu.RUnlock() + return maps.Clone(c.metadata) +} + // WithCapturedData adds a CapturedData struct to the context. func WithCapturedData(ctx context.Context, data *CapturedData) context.Context { return context.WithValue(ctx, capturedDataKey, data) diff --git a/proxy/internal/restrict/restrict.go b/proxy/internal/restrict/restrict.go index a0d99ce93..f3e0fa695 100644 --- a/proxy/internal/restrict/restrict.go +++ b/proxy/internal/restrict/restrict.go @@ -12,12 +12,44 @@ import ( "github.com/netbirdio/netbird/proxy/internal/geolocation" ) +// defaultLogger is used when no logger is provided to ParseFilter. +var defaultLogger = log.NewEntry(log.StandardLogger()) + // GeoResolver resolves an IP address to geographic information. type GeoResolver interface { LookupAddr(addr netip.Addr) geolocation.Result Available() bool } +// DecisionType is the type of CrowdSec remediation action. +type DecisionType string + +const ( + DecisionBan DecisionType = "ban" + DecisionCaptcha DecisionType = "captcha" + DecisionThrottle DecisionType = "throttle" +) + +// CrowdSecDecision holds the type of a CrowdSec decision. +type CrowdSecDecision struct { + Type DecisionType +} + +// CrowdSecChecker queries CrowdSec decisions for an IP address. +type CrowdSecChecker interface { + CheckIP(addr netip.Addr) *CrowdSecDecision + Ready() bool +} + +// CrowdSecMode is the per-service enforcement mode. +type CrowdSecMode string + +const ( + CrowdSecOff CrowdSecMode = "" + CrowdSecEnforce CrowdSecMode = "enforce" + CrowdSecObserve CrowdSecMode = "observe" +) + // Filter evaluates IP restrictions. CIDR checks are performed first // (cheap), followed by country lookups (more expensive) only when needed. type Filter struct { @@ -25,32 +57,55 @@ type Filter struct { BlockedCIDRs []netip.Prefix AllowedCountries []string BlockedCountries []string + CrowdSec CrowdSecChecker + CrowdSecMode CrowdSecMode } -// ParseFilter builds a Filter from the raw string slices. Returns nil -// if all slices are empty. -func ParseFilter(allowedCIDRs, blockedCIDRs, allowedCountries, blockedCountries []string) *Filter { - if len(allowedCIDRs) == 0 && len(blockedCIDRs) == 0 && - len(allowedCountries) == 0 && len(blockedCountries) == 0 { +// FilterConfig holds the raw configuration for building a Filter. +type FilterConfig struct { + AllowedCIDRs []string + BlockedCIDRs []string + AllowedCountries []string + BlockedCountries []string + CrowdSec CrowdSecChecker + CrowdSecMode CrowdSecMode + Logger *log.Entry +} + +// ParseFilter builds a Filter from the config. Returns nil if no restrictions +// are configured. +func ParseFilter(cfg FilterConfig) *Filter { + hasCS := cfg.CrowdSecMode == CrowdSecEnforce || cfg.CrowdSecMode == CrowdSecObserve + if len(cfg.AllowedCIDRs) == 0 && len(cfg.BlockedCIDRs) == 0 && + len(cfg.AllowedCountries) == 0 && len(cfg.BlockedCountries) == 0 && !hasCS { return nil } - f := &Filter{ - AllowedCountries: normalizeCountryCodes(allowedCountries), - BlockedCountries: normalizeCountryCodes(blockedCountries), + logger := cfg.Logger + if logger == nil { + logger = defaultLogger } - for _, cidr := range allowedCIDRs { + + f := &Filter{ + AllowedCountries: normalizeCountryCodes(cfg.AllowedCountries), + BlockedCountries: normalizeCountryCodes(cfg.BlockedCountries), + } + if hasCS { + f.CrowdSec = cfg.CrowdSec + f.CrowdSecMode = cfg.CrowdSecMode + } + for _, cidr := range cfg.AllowedCIDRs { prefix, err := netip.ParsePrefix(cidr) if err != nil { - log.Warnf("skip invalid allowed CIDR %q: %v", cidr, err) + logger.Warnf("skip invalid allowed CIDR %q: %v", cidr, err) continue } f.AllowedCIDRs = append(f.AllowedCIDRs, prefix.Masked()) } - for _, cidr := range blockedCIDRs { + for _, cidr := range cfg.BlockedCIDRs { prefix, err := netip.ParsePrefix(cidr) if err != nil { - log.Warnf("skip invalid blocked CIDR %q: %v", cidr, err) + logger.Warnf("skip invalid blocked CIDR %q: %v", cidr, err) continue } f.BlockedCIDRs = append(f.BlockedCIDRs, prefix.Masked()) @@ -82,6 +137,15 @@ const ( // DenyGeoUnavailable indicates that country restrictions are configured // but the geo lookup is unavailable. DenyGeoUnavailable + // DenyCrowdSecBan indicates a CrowdSec "ban" decision. + DenyCrowdSecBan + // DenyCrowdSecCaptcha indicates a CrowdSec "captcha" decision. + DenyCrowdSecCaptcha + // DenyCrowdSecThrottle indicates a CrowdSec "throttle" decision. + DenyCrowdSecThrottle + // DenyCrowdSecUnavailable indicates enforce mode but the bouncer has not + // completed its initial sync. + DenyCrowdSecUnavailable ) // String returns the deny reason string matching the HTTP auth mechanism names. @@ -95,14 +159,42 @@ func (v Verdict) String() string { return "country_restricted" case DenyGeoUnavailable: return "geo_unavailable" + case DenyCrowdSecBan: + return "crowdsec_ban" + case DenyCrowdSecCaptcha: + return "crowdsec_captcha" + case DenyCrowdSecThrottle: + return "crowdsec_throttle" + case DenyCrowdSecUnavailable: + return "crowdsec_unavailable" default: return "unknown" } } +// IsCrowdSec returns true when the verdict originates from a CrowdSec check. +func (v Verdict) IsCrowdSec() bool { + switch v { + case DenyCrowdSecBan, DenyCrowdSecCaptcha, DenyCrowdSecThrottle, DenyCrowdSecUnavailable: + return true + default: + return false + } +} + +// IsObserveOnly returns true when v is a CrowdSec verdict and the filter is in +// observe mode. Callers should log the verdict but not block the request. +func (f *Filter) IsObserveOnly(v Verdict) bool { + if f == nil { + return false + } + return v.IsCrowdSec() && f.CrowdSecMode == CrowdSecObserve +} + // Check evaluates whether addr is permitted. CIDR rules are evaluated // first because they are O(n) prefix comparisons. Country rules run -// only when CIDR checks pass and require a geo lookup. +// only when CIDR checks pass and require a geo lookup. CrowdSec checks +// run last. func (f *Filter) Check(addr netip.Addr, geo GeoResolver) Verdict { if f == nil { return Allow @@ -115,7 +207,10 @@ func (f *Filter) Check(addr netip.Addr, geo GeoResolver) Verdict { if v := f.checkCIDR(addr); v != Allow { return v } - return f.checkCountry(addr, geo) + if v := f.checkCountry(addr, geo); v != Allow { + return v + } + return f.checkCrowdSec(addr) } func (f *Filter) checkCIDR(addr netip.Addr) Verdict { @@ -173,11 +268,48 @@ func (f *Filter) checkCountry(addr netip.Addr, geo GeoResolver) Verdict { return Allow } +func (f *Filter) checkCrowdSec(addr netip.Addr) Verdict { + if f.CrowdSecMode == CrowdSecOff { + return Allow + } + + // Checker nil with enforce means CrowdSec was requested but the proxy + // has no LAPI configured. Fail-closed. + if f.CrowdSec == nil { + if f.CrowdSecMode == CrowdSecEnforce { + return DenyCrowdSecUnavailable + } + return Allow + } + + if !f.CrowdSec.Ready() { + if f.CrowdSecMode == CrowdSecEnforce { + return DenyCrowdSecUnavailable + } + return Allow + } + + d := f.CrowdSec.CheckIP(addr) + if d == nil { + return Allow + } + + switch d.Type { + case DecisionCaptcha: + return DenyCrowdSecCaptcha + case DecisionThrottle: + return DenyCrowdSecThrottle + default: + return DenyCrowdSecBan + } +} + // HasRestrictions returns true if any restriction rules are configured. func (f *Filter) HasRestrictions() bool { if f == nil { return false } return len(f.AllowedCIDRs) > 0 || len(f.BlockedCIDRs) > 0 || - len(f.AllowedCountries) > 0 || len(f.BlockedCountries) > 0 + len(f.AllowedCountries) > 0 || len(f.BlockedCountries) > 0 || + f.CrowdSecMode == CrowdSecEnforce || f.CrowdSecMode == CrowdSecObserve } diff --git a/proxy/internal/restrict/restrict_test.go b/proxy/internal/restrict/restrict_test.go index 17a5848d8..abaa1afdc 100644 --- a/proxy/internal/restrict/restrict_test.go +++ b/proxy/internal/restrict/restrict_test.go @@ -29,21 +29,21 @@ func TestFilter_Check_NilFilter(t *testing.T) { } func TestFilter_Check_AllowedCIDR(t *testing.T) { - f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil) + f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}) assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil)) assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil)) } func TestFilter_Check_BlockedCIDR(t *testing.T) { - f := ParseFilter(nil, []string{"10.0.0.0/8"}, nil, nil) + f := ParseFilter(FilterConfig{BlockedCIDRs: []string{"10.0.0.0/8"}}) assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("10.1.2.3"), nil)) assert.Equal(t, Allow, f.Check(netip.MustParseAddr("192.168.1.1"), nil)) } func TestFilter_Check_AllowedAndBlockedCIDR(t *testing.T) { - f := ParseFilter([]string{"10.0.0.0/8"}, []string{"10.1.0.0/16"}, nil, nil) + f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}, BlockedCIDRs: []string{"10.1.0.0/16"}}) assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.2.3.4"), nil), "allowed by allowlist, not in blocklist") assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("10.1.2.3"), nil), "allowed by allowlist but in blocklist") @@ -56,7 +56,7 @@ func TestFilter_Check_AllowedCountry(t *testing.T) { "2.2.2.2": "DE", "3.3.3.3": "CN", }) - f := ParseFilter(nil, nil, []string{"US", "DE"}, nil) + f := ParseFilter(FilterConfig{AllowedCountries: []string{"US", "DE"}}) assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "US in allowlist") assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "DE in allowlist") @@ -69,7 +69,7 @@ func TestFilter_Check_BlockedCountry(t *testing.T) { "2.2.2.2": "RU", "3.3.3.3": "US", }) - f := ParseFilter(nil, nil, nil, []string{"CN", "RU"}) + f := ParseFilter(FilterConfig{BlockedCountries: []string{"CN", "RU"}}) assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "CN in blocklist") assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "RU in blocklist") @@ -83,7 +83,7 @@ func TestFilter_Check_AllowedAndBlockedCountry(t *testing.T) { "3.3.3.3": "CN", }) // Allow US and DE, but block DE explicitly. - f := ParseFilter(nil, nil, []string{"US", "DE"}, []string{"DE"}) + f := ParseFilter(FilterConfig{AllowedCountries: []string{"US", "DE"}, BlockedCountries: []string{"DE"}}) assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "US allowed and not blocked") assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "DE allowed but also blocked, block wins") @@ -94,7 +94,7 @@ func TestFilter_Check_UnknownCountryWithAllowlist(t *testing.T) { geo := newMockGeo(map[string]string{ "1.1.1.1": "US", }) - f := ParseFilter(nil, nil, []string{"US"}, nil) + f := ParseFilter(FilterConfig{AllowedCountries: []string{"US"}}) assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "known US in allowlist") assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("9.9.9.9"), geo), "unknown country denied when allowlist is active") @@ -104,34 +104,34 @@ func TestFilter_Check_UnknownCountryWithBlocklistOnly(t *testing.T) { geo := newMockGeo(map[string]string{ "1.1.1.1": "CN", }) - f := ParseFilter(nil, nil, nil, []string{"CN"}) + f := ParseFilter(FilterConfig{BlockedCountries: []string{"CN"}}) assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "known CN in blocklist") assert.Equal(t, Allow, f.Check(netip.MustParseAddr("9.9.9.9"), geo), "unknown country allowed when only blocklist is active") } func TestFilter_Check_CountryWithoutGeo(t *testing.T) { - f := ParseFilter(nil, nil, []string{"US"}, nil) + f := ParseFilter(FilterConfig{AllowedCountries: []string{"US"}}) assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), nil), "nil geo with country allowlist") } func TestFilter_Check_CountryBlocklistWithoutGeo(t *testing.T) { - f := ParseFilter(nil, nil, nil, []string{"CN"}) + f := ParseFilter(FilterConfig{BlockedCountries: []string{"CN"}}) assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), nil), "nil geo with country blocklist") } func TestFilter_Check_GeoUnavailable(t *testing.T) { geo := &unavailableGeo{} - f := ParseFilter(nil, nil, []string{"US"}, nil) + f := ParseFilter(FilterConfig{AllowedCountries: []string{"US"}}) assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), geo), "unavailable geo with country allowlist") - f2 := ParseFilter(nil, nil, nil, []string{"CN"}) + f2 := ParseFilter(FilterConfig{BlockedCountries: []string{"CN"}}) assert.Equal(t, DenyGeoUnavailable, f2.Check(netip.MustParseAddr("1.2.3.4"), geo), "unavailable geo with country blocklist") } func TestFilter_Check_CIDROnlySkipsGeo(t *testing.T) { - f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil) + f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}) // CIDR-only filter should never touch geo, so nil geo is fine. assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil)) @@ -143,7 +143,7 @@ func TestFilter_Check_CIDRAllowThenCountryBlock(t *testing.T) { "10.1.2.3": "CN", "10.2.3.4": "US", }) - f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, []string{"CN"}) + f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}, BlockedCountries: []string{"CN"}}) assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("10.1.2.3"), geo), "CIDR allowed but country blocked") assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.2.3.4"), geo), "CIDR allowed and country not blocked") @@ -151,12 +151,12 @@ func TestFilter_Check_CIDRAllowThenCountryBlock(t *testing.T) { } func TestParseFilter_Empty(t *testing.T) { - f := ParseFilter(nil, nil, nil, nil) + f := ParseFilter(FilterConfig{}) assert.Nil(t, f) } func TestParseFilter_InvalidCIDR(t *testing.T) { - f := ParseFilter([]string{"invalid", "10.0.0.0/8"}, nil, nil, nil) + f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"invalid", "10.0.0.0/8"}}) assert.NotNil(t, f) assert.Len(t, f.AllowedCIDRs, 1, "invalid CIDR should be skipped") @@ -166,12 +166,12 @@ func TestParseFilter_InvalidCIDR(t *testing.T) { func TestFilter_HasRestrictions(t *testing.T) { assert.False(t, (*Filter)(nil).HasRestrictions()) assert.False(t, (&Filter{}).HasRestrictions()) - assert.True(t, ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil).HasRestrictions()) - assert.True(t, ParseFilter(nil, nil, []string{"US"}, nil).HasRestrictions()) + assert.True(t, ParseFilter(FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}).HasRestrictions()) + assert.True(t, ParseFilter(FilterConfig{AllowedCountries: []string{"US"}}).HasRestrictions()) } func TestFilter_Check_IPv6CIDR(t *testing.T) { - f := ParseFilter([]string{"2001:db8::/32"}, nil, nil, nil) + f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"2001:db8::/32"}}) assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2001:db8::1"), nil), "v6 addr in v6 allowlist") assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("2001:db9::1"), nil), "v6 addr not in v6 allowlist") @@ -179,7 +179,7 @@ func TestFilter_Check_IPv6CIDR(t *testing.T) { } func TestFilter_Check_IPv4MappedIPv6(t *testing.T) { - f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil) + f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}) // A v4-mapped-v6 address like ::ffff:10.1.2.3 must match a v4 CIDR. v4mapped := netip.MustParseAddr("::ffff:10.1.2.3") @@ -191,7 +191,7 @@ func TestFilter_Check_IPv4MappedIPv6(t *testing.T) { } func TestFilter_Check_MixedV4V6CIDRs(t *testing.T) { - f := ParseFilter([]string{"10.0.0.0/8", "2001:db8::/32"}, nil, nil, nil) + f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8", "2001:db8::/32"}}) assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil), "v4 in v4 CIDR") assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2001:db8::1"), nil), "v6 in v6 CIDR") @@ -202,7 +202,7 @@ func TestFilter_Check_MixedV4V6CIDRs(t *testing.T) { func TestParseFilter_CanonicalizesNonMaskedCIDR(t *testing.T) { // 1.1.1.1/24 has host bits set; ParseFilter should canonicalize to 1.1.1.0/24. - f := ParseFilter([]string{"1.1.1.1/24"}, nil, nil, nil) + f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"1.1.1.1/24"}}) assert.Equal(t, netip.MustParsePrefix("1.1.1.0/24"), f.AllowedCIDRs[0]) // Verify it still matches correctly. @@ -264,7 +264,7 @@ func TestFilter_Check_CountryCodeCaseInsensitive(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - f := ParseFilter(nil, nil, tc.allowedCountries, tc.blockedCountries) + f := ParseFilter(FilterConfig{AllowedCountries: tc.allowedCountries, BlockedCountries: tc.blockedCountries}) got := f.Check(netip.MustParseAddr(tc.addr), geo) assert.Equal(t, tc.want, got) }) @@ -275,4 +275,252 @@ func TestFilter_Check_CountryCodeCaseInsensitive(t *testing.T) { type unavailableGeo struct{} func (u *unavailableGeo) LookupAddr(_ netip.Addr) geolocation.Result { return geolocation.Result{} } -func (u *unavailableGeo) Available() bool { return false } +func (u *unavailableGeo) Available() bool { return false } + +// mockCrowdSec is a test implementation of CrowdSecChecker. +type mockCrowdSec struct { + decisions map[string]*CrowdSecDecision + ready bool +} + +func (m *mockCrowdSec) CheckIP(addr netip.Addr) *CrowdSecDecision { + return m.decisions[addr.Unmap().String()] +} + +func (m *mockCrowdSec) Ready() bool { return m.ready } + +func TestFilter_CrowdSec_Enforce_Ban(t *testing.T) { + cs := &mockCrowdSec{ + decisions: map[string]*CrowdSecDecision{"1.2.3.4": {Type: DecisionBan}}, + ready: true, + } + f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}) + + assert.Equal(t, DenyCrowdSecBan, f.Check(netip.MustParseAddr("1.2.3.4"), nil)) + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("5.6.7.8"), nil)) +} + +func TestFilter_CrowdSec_Enforce_Captcha(t *testing.T) { + cs := &mockCrowdSec{ + decisions: map[string]*CrowdSecDecision{"1.2.3.4": {Type: DecisionCaptcha}}, + ready: true, + } + f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}) + + assert.Equal(t, DenyCrowdSecCaptcha, f.Check(netip.MustParseAddr("1.2.3.4"), nil)) +} + +func TestFilter_CrowdSec_Enforce_Throttle(t *testing.T) { + cs := &mockCrowdSec{ + decisions: map[string]*CrowdSecDecision{"1.2.3.4": {Type: DecisionThrottle}}, + ready: true, + } + f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}) + + assert.Equal(t, DenyCrowdSecThrottle, f.Check(netip.MustParseAddr("1.2.3.4"), nil)) +} + +func TestFilter_CrowdSec_Observe_DoesNotBlock(t *testing.T) { + cs := &mockCrowdSec{ + decisions: map[string]*CrowdSecDecision{"1.2.3.4": {Type: DecisionBan}}, + ready: true, + } + f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecObserve}) + + verdict := f.Check(netip.MustParseAddr("1.2.3.4"), nil) + assert.Equal(t, DenyCrowdSecBan, verdict, "verdict should be ban") + assert.True(t, f.IsObserveOnly(verdict), "should be observe-only") +} + +func TestFilter_CrowdSec_Enforce_NotReady(t *testing.T) { + cs := &mockCrowdSec{ready: false} + f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}) + + assert.Equal(t, DenyCrowdSecUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), nil)) +} + +func TestFilter_CrowdSec_Observe_NotReady_Allows(t *testing.T) { + cs := &mockCrowdSec{ready: false} + f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecObserve}) + + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.2.3.4"), nil)) +} + +func TestFilter_CrowdSec_Off(t *testing.T) { + cs := &mockCrowdSec{ + decisions: map[string]*CrowdSecDecision{"1.2.3.4": {Type: DecisionBan}}, + ready: true, + } + f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecOff}) + + // CrowdSecOff means the filter is nil (no restrictions). + assert.Nil(t, f) +} + +func TestFilter_IsObserveOnly(t *testing.T) { + f := &Filter{CrowdSecMode: CrowdSecObserve} + assert.True(t, f.IsObserveOnly(DenyCrowdSecBan)) + assert.True(t, f.IsObserveOnly(DenyCrowdSecCaptcha)) + assert.True(t, f.IsObserveOnly(DenyCrowdSecThrottle)) + assert.True(t, f.IsObserveOnly(DenyCrowdSecUnavailable)) + assert.False(t, f.IsObserveOnly(DenyCIDR)) + assert.False(t, f.IsObserveOnly(Allow)) + + f2 := &Filter{CrowdSecMode: CrowdSecEnforce} + assert.False(t, f2.IsObserveOnly(DenyCrowdSecBan)) +} + +// TestFilter_LayerInteraction exercises the evaluation order across all three +// restriction layers: CIDR -> Country -> CrowdSec. Each layer can only further +// restrict; no layer can relax a denial from an earlier layer. +// +// Layer order | Behavior +// ---------------|------------------------------------------------------- +// 1. CIDR | Allowlist narrows to specific ranges, blocklist removes +// | specific ranges. Deny here → stop, CrowdSec never runs. +// 2. Country | Allowlist/blocklist by geo. Deny here → stop. +// 3. CrowdSec | IP reputation. Can block IPs that passed layers 1-2. +// | Observe mode: verdict returned but caller doesn't block. +func TestFilter_LayerInteraction(t *testing.T) { + bannedIP := "10.1.2.3" + cleanIP := "10.2.3.4" + outsideIP := "192.168.1.1" + + cs := &mockCrowdSec{ + decisions: map[string]*CrowdSecDecision{bannedIP: {Type: DecisionBan}}, + ready: true, + } + geo := newMockGeo(map[string]string{ + bannedIP: "US", + cleanIP: "US", + outsideIP: "CN", + }) + + tests := []struct { + name string + config FilterConfig + addr string + want Verdict + }{ + // CIDR allowlist + CrowdSec enforce: CrowdSec blocks inside allowed range + { + name: "allowed CIDR + CrowdSec banned", + config: FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}, CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}, + addr: bannedIP, + want: DenyCrowdSecBan, + }, + { + name: "allowed CIDR + CrowdSec clean", + config: FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}, CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}, + addr: cleanIP, + want: Allow, + }, + { + name: "CIDR deny stops before CrowdSec", + config: FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}, CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}, + addr: outsideIP, + want: DenyCIDR, + }, + + // CIDR blocklist + CrowdSec enforce: blocklist blocks first, CrowdSec blocks remaining + { + name: "blocked CIDR stops before CrowdSec", + config: FilterConfig{BlockedCIDRs: []string{"10.1.0.0/16"}, CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}, + addr: bannedIP, + want: DenyCIDR, + }, + { + name: "not in blocklist + CrowdSec clean", + config: FilterConfig{BlockedCIDRs: []string{"10.1.0.0/16"}, CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}, + addr: cleanIP, + want: Allow, + }, + + // Country allowlist + CrowdSec enforce + { + name: "allowed country + CrowdSec banned", + config: FilterConfig{AllowedCountries: []string{"US"}, CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}, + addr: bannedIP, + want: DenyCrowdSecBan, + }, + { + name: "country deny stops before CrowdSec", + config: FilterConfig{AllowedCountries: []string{"US"}, CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}, + addr: outsideIP, + want: DenyCountry, + }, + + // All three layers: CIDR allowlist + country blocklist + CrowdSec + { + name: "all layers: CIDR allow + country allow + CrowdSec ban", + config: FilterConfig{ + AllowedCIDRs: []string{"10.0.0.0/8"}, + BlockedCountries: []string{"CN"}, + CrowdSec: cs, + CrowdSecMode: CrowdSecEnforce, + }, + addr: bannedIP, // 10.x (CIDR ok), US (country ok), banned (CrowdSec deny) + want: DenyCrowdSecBan, + }, + { + name: "all layers: CIDR deny short-circuits everything", + config: FilterConfig{ + AllowedCIDRs: []string{"10.0.0.0/8"}, + BlockedCountries: []string{"CN"}, + CrowdSec: cs, + CrowdSecMode: CrowdSecEnforce, + }, + addr: outsideIP, // 192.x (CIDR deny) + want: DenyCIDR, + }, + + // Observe mode: verdict returned but IsObserveOnly is true + { + name: "observe mode: CrowdSec banned inside allowed CIDR", + config: FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}, CrowdSec: cs, CrowdSecMode: CrowdSecObserve}, + addr: bannedIP, + want: DenyCrowdSecBan, // verdict is ban, caller checks IsObserveOnly + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + f := ParseFilter(tc.config) + got := f.Check(netip.MustParseAddr(tc.addr), geo) + assert.Equal(t, tc.want, got) + + // Verify observe mode flag when applicable. + if tc.config.CrowdSecMode == CrowdSecObserve && got.IsCrowdSec() { + assert.True(t, f.IsObserveOnly(got), "observe mode verdict should be observe-only") + } + if tc.config.CrowdSecMode == CrowdSecEnforce && got.IsCrowdSec() { + assert.False(t, f.IsObserveOnly(got), "enforce mode verdict should not be observe-only") + } + }) + } +} + +func TestFilter_CrowdSec_Enforce_NilChecker(t *testing.T) { + // LAPI not configured: checker is nil but mode is enforce. Must fail closed. + f := ParseFilter(FilterConfig{CrowdSec: nil, CrowdSecMode: CrowdSecEnforce}) + + assert.Equal(t, DenyCrowdSecUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), nil)) +} + +func TestFilter_CrowdSec_Observe_NilChecker(t *testing.T) { + // LAPI not configured: checker is nil but mode is observe. Must allow. + f := ParseFilter(FilterConfig{CrowdSec: nil, CrowdSecMode: CrowdSecObserve}) + + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.2.3.4"), nil)) +} + +func TestFilter_HasRestrictions_CrowdSec(t *testing.T) { + cs := &mockCrowdSec{ready: true} + f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}) + assert.True(t, f.HasRestrictions()) + + // Enforce mode without checker (LAPI not configured): still has restrictions + // because Check() will fail-closed with DenyCrowdSecUnavailable. + f2 := ParseFilter(FilterConfig{CrowdSec: nil, CrowdSecMode: CrowdSecEnforce}) + assert.True(t, f2.HasRestrictions()) +} diff --git a/proxy/internal/tcp/router.go b/proxy/internal/tcp/router.go index 8255c36d3..9f8660aeb 100644 --- a/proxy/internal/tcp/router.go +++ b/proxy/internal/tcp/router.go @@ -479,9 +479,14 @@ func (r *Router) checkRestrictions(conn net.Conn, route Route) restrict.Verdict // On success (nil error), both conn and backend are closed by the relay. func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route Route) error { if verdict := r.checkRestrictions(conn, route); verdict != restrict.Allow { - r.logger.Debugf("connection from %s rejected by access restrictions: %s", conn.RemoteAddr(), verdict) - r.logL4Deny(route, conn, verdict) - return errAccessRestricted + if route.Filter != nil && route.Filter.IsObserveOnly(verdict) { + r.logger.Debugf("CrowdSec observe: would block %s for %s (%s)", conn.RemoteAddr(), sni, verdict) + r.logL4Deny(route, conn, verdict, true) + } else { + r.logger.Debugf("connection from %s rejected by access restrictions: %s", conn.RemoteAddr(), verdict) + r.logL4Deny(route, conn, verdict, false) + return errAccessRestricted + } } svcCtx, err := r.acquireRelay(ctx, route) @@ -610,7 +615,7 @@ func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration, } // logL4Deny sends an access log entry for a denied connection. -func (r *Router) logL4Deny(route Route, conn net.Conn, verdict restrict.Verdict) { +func (r *Router) logL4Deny(route Route, conn net.Conn, verdict restrict.Verdict, observeOnly bool) { r.mu.RLock() al := r.accessLog r.mu.RUnlock() @@ -621,14 +626,22 @@ func (r *Router) logL4Deny(route Route, conn net.Conn, verdict restrict.Verdict) sourceIP, _ := addrFromConn(conn) - al.LogL4(accesslog.L4Entry{ + entry := accesslog.L4Entry{ AccountID: route.AccountID, ServiceID: route.ServiceID, Protocol: route.Protocol, Host: route.Domain, SourceIP: sourceIP, DenyReason: verdict.String(), - }) + } + if verdict.IsCrowdSec() { + entry.Metadata = map[string]string{"crowdsec_verdict": verdict.String()} + if observeOnly { + entry.Metadata["crowdsec_mode"] = "observe" + entry.DenyReason = "" + } + } + al.LogL4(entry) } // getOrCreateServiceCtxLocked returns the context for a service, creating one diff --git a/proxy/internal/tcp/router_test.go b/proxy/internal/tcp/router_test.go index 189cdc622..93b6560f4 100644 --- a/proxy/internal/tcp/router_test.go +++ b/proxy/internal/tcp/router_test.go @@ -1686,7 +1686,7 @@ func (f *fakeConn) RemoteAddr() net.Addr { return f.remote } func TestCheckRestrictions_UnparseableAddress(t *testing.T) { router := NewPortRouter(log.StandardLogger(), nil) - filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil) + filter := restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}) route := Route{Filter: filter} conn := &fakeConn{remote: fakeAddr("not-an-ip")} @@ -1695,7 +1695,7 @@ func TestCheckRestrictions_UnparseableAddress(t *testing.T) { func TestCheckRestrictions_NilRemoteAddr(t *testing.T) { router := NewPortRouter(log.StandardLogger(), nil) - filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil) + filter := restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}) route := Route{Filter: filter} conn := &fakeConn{remote: nil} @@ -1704,7 +1704,7 @@ func TestCheckRestrictions_NilRemoteAddr(t *testing.T) { func TestCheckRestrictions_AllowedAndDenied(t *testing.T) { router := NewPortRouter(log.StandardLogger(), nil) - filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil) + filter := restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}) route := Route{Filter: filter} allowed := &fakeConn{remote: &net.TCPAddr{IP: net.IPv4(10, 1, 2, 3), Port: 1234}} @@ -1724,7 +1724,7 @@ func TestCheckRestrictions_NilFilter(t *testing.T) { func TestCheckRestrictions_IPv4MappedIPv6(t *testing.T) { router := NewPortRouter(log.StandardLogger(), nil) - filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil) + filter := restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}) route := Route{Filter: filter} // net.IPv4() returns a 16-byte v4-in-v6 representation internally. diff --git a/proxy/internal/udp/relay.go b/proxy/internal/udp/relay.go index d20ecf48b..8293bfe81 100644 --- a/proxy/internal/udp/relay.go +++ b/proxy/internal/udp/relay.go @@ -336,8 +336,13 @@ func (r *Relay) checkAccessRestrictions(addr net.Addr) error { return fmt.Errorf("parse client address %s for restriction check: %w", addr, err) } if v := r.filter.Check(clientIP, r.geo); v != restrict.Allow { - r.logDeny(clientIP, v) - return fmt.Errorf("access restricted for %s", addr) + if r.filter.IsObserveOnly(v) { + r.logger.Debugf("CrowdSec observe: would block %s (%s)", clientIP, v) + r.logDeny(clientIP, v, true) + } else { + r.logDeny(clientIP, v, false) + return fmt.Errorf("access restricted for %s", addr) + } } return nil } @@ -498,19 +503,27 @@ func (r *Relay) logSessionEnd(sess *session) { } // logDeny sends an access log entry for a denied UDP packet. -func (r *Relay) logDeny(clientIP netip.Addr, verdict restrict.Verdict) { +func (r *Relay) logDeny(clientIP netip.Addr, verdict restrict.Verdict, observeOnly bool) { if r.accessLog == nil { return } - r.accessLog.LogL4(accesslog.L4Entry{ + entry := accesslog.L4Entry{ AccountID: r.accountID, ServiceID: r.serviceID, Protocol: accesslog.ProtocolUDP, Host: r.domain, SourceIP: clientIP, DenyReason: verdict.String(), - }) + } + if verdict.IsCrowdSec() { + entry.Metadata = map[string]string{"crowdsec_verdict": verdict.String()} + if observeOnly { + entry.Metadata["crowdsec_mode"] = "observe" + entry.DenyReason = "" + } + } + r.accessLog.LogL4(entry) } // Close stops the relay, waits for all session goroutines to exit, diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index c30234b5a..4b1ecf922 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -22,6 +22,7 @@ import ( nbproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" @@ -113,11 +114,11 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup { } // Create real token store - tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100) + cacheStore, err := nbcache.NewStore(ctx, 30*time.Minute, 10*time.Minute, 100) require.NoError(t, err) - pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore) + pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore) // Create real users manager usersManager := users.NewManager(testStore) @@ -200,7 +201,7 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string, // testProxyManager is a mock implementation of proxy.Manager for testing. type testProxyManager struct{} -func (m *testProxyManager) Connect(_ context.Context, _, _, _ string) error { +func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *nbproxy.Capabilities) error { return nil } @@ -220,6 +221,18 @@ func (m *testProxyManager) GetActiveClusters(_ context.Context) ([]nbproxy.Clust return nil, nil } +func (m *testProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool { + return nil +} + +func (m *testProxyManager) ClusterRequireSubdomain(_ context.Context, _ string) *bool { + return nil +} + +func (m *testProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) *bool { + return nil +} + func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) error { return nil } @@ -247,14 +260,6 @@ func (c *testProxyController) GetProxiesForCluster(_ string) []string { return nil } -func (c *testProxyController) ClusterSupportsCustomPorts(_ string) *bool { - return nil -} - -func (c *testProxyController) ClusterRequireSubdomain(_ string) *bool { - return nil -} - // storeBackedServiceManager reads directly from the real store. type storeBackedServiceManager struct { store store.Store diff --git a/proxy/server.go b/proxy/server.go index acfe3c12d..fbd0d058e 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -42,6 +42,7 @@ import ( "github.com/netbirdio/netbird/proxy/internal/auth" "github.com/netbirdio/netbird/proxy/internal/certwatch" "github.com/netbirdio/netbird/proxy/internal/conntrack" + "github.com/netbirdio/netbird/proxy/internal/crowdsec" "github.com/netbirdio/netbird/proxy/internal/debug" "github.com/netbirdio/netbird/proxy/internal/geolocation" proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc" @@ -100,6 +101,13 @@ type Server struct { geo restrict.GeoResolver geoRaw *geolocation.Lookup + // crowdsecRegistry manages the shared CrowdSec bouncer lifecycle. + crowdsecRegistry *crowdsec.Registry + // crowdsecServices tracks which services have CrowdSec enabled for + // proper acquire/release lifecycle management. + crowdsecMu sync.Mutex + crowdsecServices map[types.ServiceID]bool + // routerReady is closed once mainRouter is fully initialized. // The mapping worker waits on this before processing updates. routerReady chan struct{} @@ -175,6 +183,10 @@ type Server struct { // GeoDataDir is the directory containing GeoLite2 MMDB files for // country-based access restrictions. Empty disables geo lookups. GeoDataDir string + // CrowdSecAPIURL is the CrowdSec LAPI URL. Empty disables CrowdSec. + CrowdSecAPIURL string + // CrowdSecAPIKey is the CrowdSec bouncer API key. Empty disables CrowdSec. + CrowdSecAPIKey string // MaxSessionIdleTimeout caps the per-service session idle timeout. // Zero means no cap (the proxy honors whatever management sends). // Set via NB_PROXY_MAX_SESSION_IDLE_TIMEOUT for shared deployments. @@ -275,6 +287,9 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { // management connectivity from the first stream connection. s.healthChecker = health.NewChecker(s.Logger, s.netbird) + s.crowdsecRegistry = crowdsec.NewRegistry(s.CrowdSecAPIURL, s.CrowdSecAPIKey, log.NewEntry(s.Logger)) + s.crowdsecServices = make(map[types.ServiceID]bool) + go s.newManagementMappingWorker(runCtx, s.mgmtClient) tlsConfig, err := s.configureTLS(ctx) @@ -763,6 +778,22 @@ func (s *Server) shutdownServices() { s.Logger.Debugf("close geolocation: %v", err) } } + + s.shutdownCrowdSec() +} + +func (s *Server) shutdownCrowdSec() { + if s.crowdsecRegistry == nil { + return + } + s.crowdsecMu.Lock() + services := maps.Clone(s.crowdsecServices) + maps.Clear(s.crowdsecServices) + s.crowdsecMu.Unlock() + + for svcID := range services { + s.crowdsecRegistry.Release(svcID) + } } // resolveDialFunc returns a DialContextFunc that dials through the @@ -916,6 +947,7 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr s.healthChecker.SetManagementConnected(false) } + supportsCrowdSec := s.crowdsecRegistry.Available() mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ ProxyId: s.ID, Version: s.Version, @@ -924,6 +956,7 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr Capabilities: &proto.ProxyCapabilities{ SupportsCustomPorts: &s.SupportsCustomPorts, RequireSubdomain: &s.RequireSubdomain, + SupportsCrowdsec: &supportsCrowdSec, }, }) if err != nil { @@ -1159,7 +1192,7 @@ func (s *Server) setupTCPMapping(ctx context.Context, mapping *proto.ProxyMappin ProxyProtocol: s.l4ProxyProtocol(mapping), DialTimeout: s.l4DialTimeout(mapping), SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)), - Filter: parseRestrictions(mapping), + Filter: s.parseRestrictions(mapping), }) s.portMu.Lock() @@ -1234,7 +1267,7 @@ func (s *Server) setupTLSMapping(ctx context.Context, mapping *proto.ProxyMappin ProxyProtocol: s.l4ProxyProtocol(mapping), DialTimeout: s.l4DialTimeout(mapping), SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)), - Filter: parseRestrictions(mapping), + Filter: s.parseRestrictions(mapping), }) if tlsPort != s.mainPort { @@ -1268,12 +1301,51 @@ func (s *Server) serviceKeyForMapping(mapping *proto.ProxyMapping) roundtrip.Ser // parseRestrictions converts a proto mapping's access restrictions into // a restrict.Filter. Returns nil if the mapping has no restrictions. -func parseRestrictions(mapping *proto.ProxyMapping) *restrict.Filter { +func (s *Server) parseRestrictions(mapping *proto.ProxyMapping) *restrict.Filter { r := mapping.GetAccessRestrictions() if r == nil { return nil } - return restrict.ParseFilter(r.GetAllowedCidrs(), r.GetBlockedCidrs(), r.GetAllowedCountries(), r.GetBlockedCountries()) + + svcID := types.ServiceID(mapping.GetId()) + csMode := restrict.CrowdSecMode(r.GetCrowdsecMode()) + + var checker restrict.CrowdSecChecker + if csMode == restrict.CrowdSecEnforce || csMode == restrict.CrowdSecObserve { + if b := s.crowdsecRegistry.Acquire(svcID); b != nil { + checker = b + s.crowdsecMu.Lock() + s.crowdsecServices[svcID] = true + s.crowdsecMu.Unlock() + } else { + s.Logger.Warnf("service %s requests CrowdSec mode %q but proxy has no CrowdSec configured", svcID, csMode) + // Keep the mode: restrict.Filter will fail-closed for enforce (DenyCrowdSecUnavailable) + // and allow for observe. + } + } + + return restrict.ParseFilter(restrict.FilterConfig{ + AllowedCIDRs: r.GetAllowedCidrs(), + BlockedCIDRs: r.GetBlockedCidrs(), + AllowedCountries: r.GetAllowedCountries(), + BlockedCountries: r.GetBlockedCountries(), + CrowdSec: checker, + CrowdSecMode: csMode, + Logger: log.NewEntry(s.Logger), + }) +} + +// releaseCrowdSec releases the CrowdSec bouncer reference for the given +// service if it had one. +func (s *Server) releaseCrowdSec(svcID types.ServiceID) { + s.crowdsecMu.Lock() + had := s.crowdsecServices[svcID] + delete(s.crowdsecServices, svcID) + s.crowdsecMu.Unlock() + + if had { + s.crowdsecRegistry.Release(svcID) + } } // warnIfGeoUnavailable logs a warning if the mapping has country restrictions @@ -1388,7 +1460,7 @@ func (s *Server) addUDPRelay(ctx context.Context, mapping *proto.ProxyMapping, t DialTimeout: s.l4DialTimeout(mapping), SessionTTL: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)), AccessLog: s.accessLog, - Filter: parseRestrictions(mapping), + Filter: s.parseRestrictions(mapping), Geo: s.geo, }) relay.SetObserver(s.meter) @@ -1425,7 +1497,7 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping) schemes = append(schemes, auth.NewHeader(s.mgmtClient, svcID, accountID, ha.GetHeader())) } - ipRestrictions := parseRestrictions(mapping) + ipRestrictions := s.parseRestrictions(mapping) s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions()) maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second @@ -1507,6 +1579,9 @@ func (s *Server) cleanupMappingRoutes(mapping *proto.ProxyMapping) { // UDP relay cleanup (idempotent). s.removeUDPRelay(svcID) + // Release CrowdSec after all routes are removed so the shared bouncer + // isn't stopped while stale filters can still be reached by in-flight requests. + s.releaseCrowdSec(svcID) } // removeUDPRelay stops and removes a UDP relay by service ID. diff --git a/proxy/web/package-lock.json b/proxy/web/package-lock.json index d16196d77..1611323a7 100644 --- a/proxy/web/package-lock.json +++ b/proxy/web/package-lock.json @@ -15,7 +15,7 @@ "tailwind-merge": "^2.6.0" }, "devDependencies": { - "@eslint/js": "^9.39.1", + "@eslint/js": "9.39.2", "@tailwindcss/vite": "^4.1.18", "@types/node": "^24.10.1", "@types/react": "^19.2.5", @@ -29,7 +29,7 @@ "tsx": "^4.21.0", "typescript": "~5.9.3", "typescript-eslint": "^8.46.4", - "vite": "^7.2.4" + "vite": "7.3.2" } }, "node_modules/@babel/code-frame": { @@ -1024,9 +1024,9 @@ "license": "MIT" }, "node_modules/@rollup/rollup-android-arm-eabi": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.57.1.tgz", - "integrity": "sha512-A6ehUVSiSaaliTxai040ZpZ2zTevHYbvu/lDoeAteHI8QnaosIzm4qwtezfRg1jOYaUmnzLX1AOD6Z+UJjtifg==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.60.0.tgz", + "integrity": "sha512-WOhNW9K8bR3kf4zLxbfg6Pxu2ybOUbB2AjMDHSQx86LIF4rH4Ft7vmMwNt0loO0eonglSNy4cpD3MKXXKQu0/A==", "cpu": [ "arm" ], @@ -1038,9 +1038,9 @@ ] }, "node_modules/@rollup/rollup-android-arm64": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.57.1.tgz", - "integrity": "sha512-dQaAddCY9YgkFHZcFNS/606Exo8vcLHwArFZ7vxXq4rigo2bb494/xKMMwRRQW6ug7Js6yXmBZhSBRuBvCCQ3w==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.60.0.tgz", + "integrity": "sha512-u6JHLll5QKRvjciE78bQXDmqRqNs5M/3GVqZeMwvmjaNODJih/WIrJlFVEihvV0MiYFmd+ZyPr9wxOVbPAG2Iw==", "cpu": [ "arm64" ], @@ -1052,9 +1052,9 @@ ] }, "node_modules/@rollup/rollup-darwin-arm64": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.57.1.tgz", - "integrity": "sha512-crNPrwJOrRxagUYeMn/DZwqN88SDmwaJ8Cvi/TN1HnWBU7GwknckyosC2gd0IqYRsHDEnXf328o9/HC6OkPgOg==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.60.0.tgz", + "integrity": "sha512-qEF7CsKKzSRc20Ciu2Zw1wRrBz4g56F7r/vRwY430UPp/nt1x21Q/fpJ9N5l47WWvJlkNCPJz3QRVw008fi7yA==", "cpu": [ "arm64" ], @@ -1066,9 +1066,9 @@ ] }, "node_modules/@rollup/rollup-darwin-x64": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.57.1.tgz", - "integrity": "sha512-Ji8g8ChVbKrhFtig5QBV7iMaJrGtpHelkB3lsaKzadFBe58gmjfGXAOfI5FV0lYMH8wiqsxKQ1C9B0YTRXVy4w==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.60.0.tgz", + "integrity": "sha512-WADYozJ4QCnXCH4wPB+3FuGmDPoFseVCUrANmA5LWwGmC6FL14BWC7pcq+FstOZv3baGX65tZ378uT6WG8ynTw==", "cpu": [ "x64" ], @@ -1080,9 +1080,9 @@ ] }, "node_modules/@rollup/rollup-freebsd-arm64": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.57.1.tgz", - "integrity": "sha512-R+/WwhsjmwodAcz65guCGFRkMb4gKWTcIeLy60JJQbXrJ97BOXHxnkPFrP+YwFlaS0m+uWJTstrUA9o+UchFug==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.60.0.tgz", + "integrity": "sha512-6b8wGHJlDrGeSE3aH5mGNHBjA0TTkxdoNHik5EkvPHCt351XnigA4pS7Wsj/Eo9Y8RBU6f35cjN9SYmCFBtzxw==", "cpu": [ "arm64" ], @@ -1094,9 +1094,9 @@ ] }, "node_modules/@rollup/rollup-freebsd-x64": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.57.1.tgz", - "integrity": "sha512-IEQTCHeiTOnAUC3IDQdzRAGj3jOAYNr9kBguI7MQAAZK3caezRrg0GxAb6Hchg4lxdZEI5Oq3iov/w/hnFWY9Q==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.60.0.tgz", + "integrity": "sha512-h25Ga0t4jaylMB8M/JKAyrvvfxGRjnPQIR8lnCayyzEjEOx2EJIlIiMbhpWxDRKGKF8jbNH01NnN663dH638mA==", "cpu": [ "x64" ], @@ -1108,9 +1108,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm-gnueabihf": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.57.1.tgz", - "integrity": "sha512-F8sWbhZ7tyuEfsmOxwc2giKDQzN3+kuBLPwwZGyVkLlKGdV1nvnNwYD0fKQ8+XS6hp9nY7B+ZeK01EBUE7aHaw==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.60.0.tgz", + "integrity": "sha512-RzeBwv0B3qtVBWtcuABtSuCzToo2IEAIQrcyB/b2zMvBWVbjo8bZDjACUpnaafaxhTw2W+imQbP2BD1usasK4g==", "cpu": [ "arm" ], @@ -1122,9 +1122,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm-musleabihf": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.57.1.tgz", - "integrity": "sha512-rGfNUfn0GIeXtBP1wL5MnzSj98+PZe/AXaGBCRmT0ts80lU5CATYGxXukeTX39XBKsxzFpEeK+Mrp9faXOlmrw==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.60.0.tgz", + "integrity": "sha512-Sf7zusNI2CIU1HLzuu9Tc5YGAHEZs5Lu7N1ssJG4Tkw6e0MEsN7NdjUDDfGNHy2IU+ENyWT+L2obgWiguWibWQ==", "cpu": [ "arm" ], @@ -1136,9 +1136,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm64-gnu": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.57.1.tgz", - "integrity": "sha512-MMtej3YHWeg/0klK2Qodf3yrNzz6CGjo2UntLvk2RSPlhzgLvYEB3frRvbEF2wRKh1Z2fDIg9KRPe1fawv7C+g==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.60.0.tgz", + "integrity": "sha512-DX2x7CMcrJzsE91q7/O02IJQ5/aLkVtYFryqCjduJhUfGKG6yJV8hxaw8pZa93lLEpPTP/ohdN4wFz7yp/ry9A==", "cpu": [ "arm64" ], @@ -1150,9 +1150,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm64-musl": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.57.1.tgz", - "integrity": "sha512-1a/qhaaOXhqXGpMFMET9VqwZakkljWHLmZOX48R0I/YLbhdxr1m4gtG1Hq7++VhVUmf+L3sTAf9op4JlhQ5u1Q==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.60.0.tgz", + "integrity": "sha512-09EL+yFVbJZlhcQfShpswwRZ0Rg+z/CsSELFCnPt3iK+iqwGsI4zht3secj5vLEs957QvFFXnzAT0FFPIxSrkQ==", "cpu": [ "arm64" ], @@ -1164,9 +1164,9 @@ ] }, "node_modules/@rollup/rollup-linux-loong64-gnu": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-gnu/-/rollup-linux-loong64-gnu-4.57.1.tgz", - "integrity": "sha512-QWO6RQTZ/cqYtJMtxhkRkidoNGXc7ERPbZN7dVW5SdURuLeVU7lwKMpo18XdcmpWYd0qsP1bwKPf7DNSUinhvA==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-gnu/-/rollup-linux-loong64-gnu-4.60.0.tgz", + "integrity": "sha512-i9IcCMPr3EXm8EQg5jnja0Zyc1iFxJjZWlb4wr7U2Wx/GrddOuEafxRdMPRYVaXjgbhvqalp6np07hN1w9kAKw==", "cpu": [ "loong64" ], @@ -1178,9 +1178,9 @@ ] }, "node_modules/@rollup/rollup-linux-loong64-musl": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-musl/-/rollup-linux-loong64-musl-4.57.1.tgz", - "integrity": "sha512-xpObYIf+8gprgWaPP32xiN5RVTi/s5FCR+XMXSKmhfoJjrpRAjCuuqQXyxUa/eJTdAE6eJ+KDKaoEqjZQxh3Gw==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-musl/-/rollup-linux-loong64-musl-4.60.0.tgz", + "integrity": "sha512-DGzdJK9kyJ+B78MCkWeGnpXJ91tK/iKA6HwHxF4TAlPIY7GXEvMe8hBFRgdrR9Ly4qebR/7gfUs9y2IoaVEyog==", "cpu": [ "loong64" ], @@ -1192,9 +1192,9 @@ ] }, "node_modules/@rollup/rollup-linux-ppc64-gnu": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-gnu/-/rollup-linux-ppc64-gnu-4.57.1.tgz", - "integrity": "sha512-4BrCgrpZo4hvzMDKRqEaW1zeecScDCR+2nZ86ATLhAoJ5FQ+lbHVD3ttKe74/c7tNT9c6F2viwB3ufwp01Oh2w==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-gnu/-/rollup-linux-ppc64-gnu-4.60.0.tgz", + "integrity": "sha512-RwpnLsqC8qbS8z1H1AxBA1H6qknR4YpPR9w2XX0vo2Sz10miu57PkNcnHVaZkbqyw/kUWfKMI73jhmfi9BRMUQ==", "cpu": [ "ppc64" ], @@ -1206,9 +1206,9 @@ ] }, "node_modules/@rollup/rollup-linux-ppc64-musl": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-musl/-/rollup-linux-ppc64-musl-4.57.1.tgz", - "integrity": "sha512-NOlUuzesGauESAyEYFSe3QTUguL+lvrN1HtwEEsU2rOwdUDeTMJdO5dUYl/2hKf9jWydJrO9OL/XSSf65R5+Xw==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-musl/-/rollup-linux-ppc64-musl-4.60.0.tgz", + "integrity": "sha512-Z8pPf54Ly3aqtdWC3G4rFigZgNvd+qJlOE52fmko3KST9SoGfAdSRCwyoyG05q1HrrAblLbk1/PSIV+80/pxLg==", "cpu": [ "ppc64" ], @@ -1220,9 +1220,9 @@ ] }, "node_modules/@rollup/rollup-linux-riscv64-gnu": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.57.1.tgz", - "integrity": "sha512-ptA88htVp0AwUUqhVghwDIKlvJMD/fmL/wrQj99PRHFRAG6Z5nbWoWG4o81Nt9FT+IuqUQi+L31ZKAFeJ5Is+A==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.60.0.tgz", + "integrity": "sha512-3a3qQustp3COCGvnP4SvrMHnPQ9d1vzCakQVRTliaz8cIp/wULGjiGpbcqrkv0WrHTEp8bQD/B3HBjzujVWLOA==", "cpu": [ "riscv64" ], @@ -1234,9 +1234,9 @@ ] }, "node_modules/@rollup/rollup-linux-riscv64-musl": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-musl/-/rollup-linux-riscv64-musl-4.57.1.tgz", - "integrity": "sha512-S51t7aMMTNdmAMPpBg7OOsTdn4tySRQvklmL3RpDRyknk87+Sp3xaumlatU+ppQ+5raY7sSTcC2beGgvhENfuw==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-musl/-/rollup-linux-riscv64-musl-4.60.0.tgz", + "integrity": "sha512-pjZDsVH/1VsghMJ2/kAaxt6dL0psT6ZexQVrijczOf+PeP2BUqTHYejk3l6TlPRydggINOeNRhvpLa0AYpCWSQ==", "cpu": [ "riscv64" ], @@ -1248,9 +1248,9 @@ ] }, "node_modules/@rollup/rollup-linux-s390x-gnu": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.57.1.tgz", - "integrity": "sha512-Bl00OFnVFkL82FHbEqy3k5CUCKH6OEJL54KCyx2oqsmZnFTR8IoNqBF+mjQVcRCT5sB6yOvK8A37LNm/kPJiZg==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.60.0.tgz", + "integrity": "sha512-3ObQs0BhvPgiUVZrN7gqCSvmFuMWvWvsjG5ayJ3Lraqv+2KhOsp+pUbigqbeWqueGIsnn+09HBw27rJ+gYK4VQ==", "cpu": [ "s390x" ], @@ -1262,9 +1262,9 @@ ] }, "node_modules/@rollup/rollup-linux-x64-gnu": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.57.1.tgz", - "integrity": "sha512-ABca4ceT4N+Tv/GtotnWAeXZUZuM/9AQyCyKYyKnpk4yoA7QIAuBt6Hkgpw8kActYlew2mvckXkvx0FfoInnLg==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.60.0.tgz", + "integrity": "sha512-EtylprDtQPdS5rXvAayrNDYoJhIz1/vzN2fEubo3yLE7tfAw+948dO0g4M0vkTVFhKojnF+n6C8bDNe+gDRdTg==", "cpu": [ "x64" ], @@ -1276,9 +1276,9 @@ ] }, "node_modules/@rollup/rollup-linux-x64-musl": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.57.1.tgz", - "integrity": "sha512-HFps0JeGtuOR2convgRRkHCekD7j+gdAuXM+/i6kGzQtFhlCtQkpwtNzkNj6QhCDp7DRJ7+qC/1Vg2jt5iSOFw==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.60.0.tgz", + "integrity": "sha512-k09oiRCi/bHU9UVFqD17r3eJR9bn03TyKraCrlz5ULFJGdJGi7VOmm9jl44vOJvRJ6P7WuBi/s2A97LxxHGIdw==", "cpu": [ "x64" ], @@ -1290,9 +1290,9 @@ ] }, "node_modules/@rollup/rollup-openbsd-x64": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-openbsd-x64/-/rollup-openbsd-x64-4.57.1.tgz", - "integrity": "sha512-H+hXEv9gdVQuDTgnqD+SQffoWoc0Of59AStSzTEj/feWTBAnSfSD3+Dql1ZruJQxmykT/JVY0dE8Ka7z0DH1hw==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-openbsd-x64/-/rollup-openbsd-x64-4.60.0.tgz", + "integrity": "sha512-1o/0/pIhozoSaDJoDcec+IVLbnRtQmHwPV730+AOD29lHEEo4F5BEUB24H0OBdhbBBDwIOSuf7vgg0Ywxdfiiw==", "cpu": [ "x64" ], @@ -1304,9 +1304,9 @@ ] }, "node_modules/@rollup/rollup-openharmony-arm64": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-openharmony-arm64/-/rollup-openharmony-arm64-4.57.1.tgz", - "integrity": "sha512-4wYoDpNg6o/oPximyc/NG+mYUejZrCU2q+2w6YZqrAs2UcNUChIZXjtafAiiZSUc7On8v5NyNj34Kzj/Ltk6dQ==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-openharmony-arm64/-/rollup-openharmony-arm64-4.60.0.tgz", + "integrity": "sha512-pESDkos/PDzYwtyzB5p/UoNU/8fJo68vcXM9ZW2V0kjYayj1KaaUfi1NmTUTUpMn4UhU4gTuK8gIaFO4UGuMbA==", "cpu": [ "arm64" ], @@ -1318,9 +1318,9 @@ ] }, "node_modules/@rollup/rollup-win32-arm64-msvc": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.57.1.tgz", - "integrity": "sha512-O54mtsV/6LW3P8qdTcamQmuC990HDfR71lo44oZMZlXU4tzLrbvTii87Ni9opq60ds0YzuAlEr/GNwuNluZyMQ==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.60.0.tgz", + "integrity": "sha512-hj1wFStD7B1YBeYmvY+lWXZ7ey73YGPcViMShYikqKT1GtstIKQAtfUI6yrzPjAy/O7pO0VLXGmUVWXQMaYgTQ==", "cpu": [ "arm64" ], @@ -1332,9 +1332,9 @@ ] }, "node_modules/@rollup/rollup-win32-ia32-msvc": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.57.1.tgz", - "integrity": "sha512-P3dLS+IerxCT/7D2q2FYcRdWRl22dNbrbBEtxdWhXrfIMPP9lQhb5h4Du04mdl5Woq05jVCDPCMF7Ub0NAjIew==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.60.0.tgz", + "integrity": "sha512-SyaIPFoxmUPlNDq5EHkTbiKzmSEmq/gOYFI/3HHJ8iS/v1mbugVa7dXUzcJGQfoytp9DJFLhHH4U3/eTy2Bq4w==", "cpu": [ "ia32" ], @@ -1346,9 +1346,9 @@ ] }, "node_modules/@rollup/rollup-win32-x64-gnu": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-gnu/-/rollup-win32-x64-gnu-4.57.1.tgz", - "integrity": "sha512-VMBH2eOOaKGtIJYleXsi2B8CPVADrh+TyNxJ4mWPnKfLB/DBUmzW+5m1xUrcwWoMfSLagIRpjUFeW5CO5hyciQ==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-gnu/-/rollup-win32-x64-gnu-4.60.0.tgz", + "integrity": "sha512-RdcryEfzZr+lAr5kRm2ucN9aVlCCa2QNq4hXelZxb8GG0NJSazq44Z3PCCc8wISRuCVnGs0lQJVX5Vp6fKA+IA==", "cpu": [ "x64" ], @@ -1360,9 +1360,9 @@ ] }, "node_modules/@rollup/rollup-win32-x64-msvc": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.57.1.tgz", - "integrity": "sha512-mxRFDdHIWRxg3UfIIAwCm6NzvxG0jDX/wBN6KsQFTvKFqqg9vTrWUE68qEjHt19A5wwx5X5aUi2zuZT7YR0jrA==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.60.0.tgz", + "integrity": "sha512-PrsWNQ8BuE00O3Xsx3ALh2Df8fAj9+cvvX9AIA6o4KpATR98c9mud4XtDWVvsEuyia5U4tVSTKygawyJkjm60w==", "cpu": [ "x64" ], @@ -1926,9 +1926,9 @@ } }, "node_modules/@typescript-eslint/typescript-estree/node_modules/brace-expansion": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", - "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.3.tgz", + "integrity": "sha512-MCV/fYJEbqx68aE58kv2cA/kiky1G8vux3OR6/jbS+jIMe/6fJWa0DTzJU7dqijOWYwHi1t29FlfYI9uytqlpA==", "dev": true, "license": "MIT", "dependencies": { @@ -1936,13 +1936,13 @@ } }, "node_modules/@typescript-eslint/typescript-estree/node_modules/minimatch": { - "version": "9.0.5", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", - "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "version": "9.0.9", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.9.tgz", + "integrity": "sha512-OBwBN9AL4dqmETlpS2zasx+vTeWclWzkblfZk7KTA5j3jeOONz/tRCnZomUyvNg83wL5Zv9Ss6HMJXAgL8R2Yg==", "dev": true, "license": "ISC", "dependencies": { - "brace-expansion": "^2.0.1" + "brace-expansion": "^2.0.2" }, "engines": { "node": ">=16 || 14 >=14.17" @@ -2052,9 +2052,9 @@ } }, "node_modules/ajv": { - "version": "6.12.6", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", - "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", + "version": "6.14.0", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.14.0.tgz", + "integrity": "sha512-IWrosm/yrn43eiKqkfkHis7QioDleaXQHdDVPKg0FSwwd/DuvyX79TZnFOnYpB7dcsFAMmtFztZuXPDvSePkFw==", "dev": true, "license": "MIT", "dependencies": { @@ -2109,9 +2109,9 @@ } }, "node_modules/brace-expansion": { - "version": "1.1.12", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", - "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", + "version": "1.1.13", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.13.tgz", + "integrity": "sha512-9ZLprWS6EENmhEOpjCYW2c8VkmOvckIJZfkr7rBW6dObmfgJ/L1GpSYW5Hpo9lDz4D1+n0Ckz8rU7FwHDQiG/w==", "dev": true, "license": "MIT", "dependencies": { @@ -2657,9 +2657,9 @@ } }, "node_modules/flatted": { - "version": "3.3.3", - "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.3.tgz", - "integrity": "sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==", + "version": "3.4.2", + "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.4.2.tgz", + "integrity": "sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==", "dev": true, "license": "ISC" }, @@ -3243,9 +3243,9 @@ } }, "node_modules/minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "version": "3.1.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz", + "integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==", "dev": true, "license": "ISC", "dependencies": { @@ -3386,9 +3386,9 @@ "license": "ISC" }, "node_modules/picomatch": { - "version": "4.0.3", - "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz", - "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz", + "integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==", "dev": true, "license": "MIT", "peer": true, @@ -3501,9 +3501,9 @@ } }, "node_modules/rollup": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.57.1.tgz", - "integrity": "sha512-oQL6lgK3e2QZeQ7gcgIkS2YZPg5slw37hYufJ3edKlfQSGGm8ICoxswK15ntSzF/a8+h7ekRy7k7oWc3BQ7y8A==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.60.0.tgz", + "integrity": "sha512-yqjxruMGBQJ2gG4HtjZtAfXArHomazDHoFwFFmZZl0r7Pdo7qCIXKqKHZc8yeoMgzJJ+pO6pEEHa+V7uzWlrAQ==", "dev": true, "license": "MIT", "dependencies": { @@ -3517,31 +3517,31 @@ "npm": ">=8.0.0" }, "optionalDependencies": { - "@rollup/rollup-android-arm-eabi": "4.57.1", - "@rollup/rollup-android-arm64": "4.57.1", - "@rollup/rollup-darwin-arm64": "4.57.1", - "@rollup/rollup-darwin-x64": "4.57.1", - "@rollup/rollup-freebsd-arm64": "4.57.1", - "@rollup/rollup-freebsd-x64": "4.57.1", - "@rollup/rollup-linux-arm-gnueabihf": "4.57.1", - "@rollup/rollup-linux-arm-musleabihf": "4.57.1", - "@rollup/rollup-linux-arm64-gnu": "4.57.1", - "@rollup/rollup-linux-arm64-musl": "4.57.1", - "@rollup/rollup-linux-loong64-gnu": "4.57.1", - "@rollup/rollup-linux-loong64-musl": "4.57.1", - "@rollup/rollup-linux-ppc64-gnu": "4.57.1", - "@rollup/rollup-linux-ppc64-musl": "4.57.1", - "@rollup/rollup-linux-riscv64-gnu": "4.57.1", - "@rollup/rollup-linux-riscv64-musl": "4.57.1", - "@rollup/rollup-linux-s390x-gnu": "4.57.1", - "@rollup/rollup-linux-x64-gnu": "4.57.1", - "@rollup/rollup-linux-x64-musl": "4.57.1", - "@rollup/rollup-openbsd-x64": "4.57.1", - "@rollup/rollup-openharmony-arm64": "4.57.1", - "@rollup/rollup-win32-arm64-msvc": "4.57.1", - "@rollup/rollup-win32-ia32-msvc": "4.57.1", - "@rollup/rollup-win32-x64-gnu": "4.57.1", - "@rollup/rollup-win32-x64-msvc": "4.57.1", + "@rollup/rollup-android-arm-eabi": "4.60.0", + "@rollup/rollup-android-arm64": "4.60.0", + "@rollup/rollup-darwin-arm64": "4.60.0", + "@rollup/rollup-darwin-x64": "4.60.0", + "@rollup/rollup-freebsd-arm64": "4.60.0", + "@rollup/rollup-freebsd-x64": "4.60.0", + "@rollup/rollup-linux-arm-gnueabihf": "4.60.0", + "@rollup/rollup-linux-arm-musleabihf": "4.60.0", + "@rollup/rollup-linux-arm64-gnu": "4.60.0", + "@rollup/rollup-linux-arm64-musl": "4.60.0", + "@rollup/rollup-linux-loong64-gnu": "4.60.0", + "@rollup/rollup-linux-loong64-musl": "4.60.0", + "@rollup/rollup-linux-ppc64-gnu": "4.60.0", + "@rollup/rollup-linux-ppc64-musl": "4.60.0", + "@rollup/rollup-linux-riscv64-gnu": "4.60.0", + "@rollup/rollup-linux-riscv64-musl": "4.60.0", + "@rollup/rollup-linux-s390x-gnu": "4.60.0", + "@rollup/rollup-linux-x64-gnu": "4.60.0", + "@rollup/rollup-linux-x64-musl": "4.60.0", + "@rollup/rollup-openbsd-x64": "4.60.0", + "@rollup/rollup-openharmony-arm64": "4.60.0", + "@rollup/rollup-win32-arm64-msvc": "4.60.0", + "@rollup/rollup-win32-ia32-msvc": "4.60.0", + "@rollup/rollup-win32-x64-gnu": "4.60.0", + "@rollup/rollup-win32-x64-msvc": "4.60.0", "fsevents": "~2.3.2" } }, @@ -3803,9 +3803,9 @@ } }, "node_modules/vite": { - "version": "7.3.1", - "resolved": "https://registry.npmjs.org/vite/-/vite-7.3.1.tgz", - "integrity": "sha512-w+N7Hifpc3gRjZ63vYBXA56dvvRlNWRczTdmCBBa+CotUzAPf5b7YMdMR/8CQoeYE5LX3W4wj6RYTgonm1b9DA==", + "version": "7.3.2", + "resolved": "https://registry.npmjs.org/vite/-/vite-7.3.2.tgz", + "integrity": "sha512-Bby3NOsna2jsjfLVOHKes8sGwgl4TT0E6vvpYgnAYDIF/tie7MRaFthmKuHx1NSXjiTueXH3do80FMQgvEktRg==", "dev": true, "license": "MIT", "peer": true, diff --git a/proxy/web/package.json b/proxy/web/package.json index 97ec1ec0d..9a7c84ed4 100644 --- a/proxy/web/package.json +++ b/proxy/web/package.json @@ -17,7 +17,7 @@ "tailwind-merge": "^2.6.0" }, "devDependencies": { - "@eslint/js": "^9.39.1", + "@eslint/js": "9.39.2", "@tailwindcss/vite": "^4.1.18", "@types/node": "^24.10.1", "@types/react": "^19.2.5", @@ -31,6 +31,6 @@ "tsx": "^4.21.0", "typescript": "~5.9.3", "typescript-eslint": "^8.46.4", - "vite": "^7.2.4" + "vite": "7.3.2" } } diff --git a/relay/server/handshake.go b/relay/server/handshake.go index 8c3ee1899..067888406 100644 --- a/relay/server/handshake.go +++ b/relay/server/handshake.go @@ -1,11 +1,13 @@ package server import ( + "context" "fmt" - "net" + "time" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/shared/relay/messages" //nolint:staticcheck "github.com/netbirdio/netbird/shared/relay/messages/address" @@ -13,6 +15,12 @@ import ( authmsg "github.com/netbirdio/netbird/shared/relay/messages/auth" ) +const ( + // handshakeTimeout bounds how long a connection may remain in the + // pre-authentication handshake phase before being closed. + handshakeTimeout = 10 * time.Second +) + type Validator interface { Validate(any) error // Deprecated: Use Validate instead. @@ -58,7 +66,7 @@ func marshalResponseHelloMsg(instanceURL string) ([]byte, error) { } type handshake struct { - conn net.Conn + conn listener.Conn validator Validator preparedMsg *preparedMsg @@ -66,9 +74,9 @@ type handshake struct { peerID *messages.PeerID } -func (h *handshake) handshakeReceive() (*messages.PeerID, error) { +func (h *handshake) handshakeReceive(ctx context.Context) (*messages.PeerID, error) { buf := make([]byte, messages.MaxHandshakeSize) - n, err := h.conn.Read(buf) + n, err := h.conn.Read(ctx, buf) if err != nil { return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err) } @@ -103,7 +111,7 @@ func (h *handshake) handshakeReceive() (*messages.PeerID, error) { return peerID, nil } -func (h *handshake) handshakeResponse() error { +func (h *handshake) handshakeResponse(ctx context.Context) error { var responseMsg []byte if h.handshakeMethodAuth { responseMsg = h.preparedMsg.responseAuthMsg @@ -111,7 +119,7 @@ func (h *handshake) handshakeResponse() error { responseMsg = h.preparedMsg.responseHelloMsg } - if _, err := h.conn.Write(responseMsg); err != nil { + if _, err := h.conn.Write(ctx, responseMsg); err != nil { return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err) } diff --git a/relay/server/listener/conn.go b/relay/server/listener/conn.go new file mode 100644 index 000000000..ef0869594 --- /dev/null +++ b/relay/server/listener/conn.go @@ -0,0 +1,14 @@ +package listener + +import ( + "context" + "net" +) + +// Conn is the relay connection contract implemented by WS and QUIC transports. +type Conn interface { + Read(ctx context.Context, b []byte) (n int, err error) + Write(ctx context.Context, b []byte) (n int, err error) + RemoteAddr() net.Addr + Close() error +} diff --git a/relay/server/listener/listener.go b/relay/server/listener/listener.go deleted file mode 100644 index 0a79182f4..000000000 --- a/relay/server/listener/listener.go +++ /dev/null @@ -1,14 +0,0 @@ -package listener - -import ( - "context" - "net" - - "github.com/netbirdio/netbird/relay/protocol" -) - -type Listener interface { - Listen(func(conn net.Conn)) error - Shutdown(ctx context.Context) error - Protocol() protocol.Protocol -} diff --git a/relay/server/listener/quic/conn.go b/relay/server/listener/quic/conn.go index 6e2201bf7..d8dafcd1f 100644 --- a/relay/server/listener/quic/conn.go +++ b/relay/server/listener/quic/conn.go @@ -3,33 +3,26 @@ package quic import ( "context" "errors" - "fmt" "net" "sync" - "time" "github.com/quic-go/quic-go" ) type Conn struct { - session *quic.Conn - closed bool - closedMu sync.Mutex - ctx context.Context - ctxCancel context.CancelFunc + session *quic.Conn + closed bool + closedMu sync.Mutex } func NewConn(session *quic.Conn) *Conn { - ctx, cancel := context.WithCancel(context.Background()) return &Conn{ - session: session, - ctx: ctx, - ctxCancel: cancel, + session: session, } } -func (c *Conn) Read(b []byte) (n int, err error) { - dgram, err := c.session.ReceiveDatagram(c.ctx) +func (c *Conn) Read(ctx context.Context, b []byte) (n int, err error) { + dgram, err := c.session.ReceiveDatagram(ctx) if err != nil { return 0, c.remoteCloseErrHandling(err) } @@ -38,33 +31,17 @@ func (c *Conn) Read(b []byte) (n int, err error) { return n, nil } -func (c *Conn) Write(b []byte) (int, error) { +func (c *Conn) Write(_ context.Context, b []byte) (int, error) { if err := c.session.SendDatagram(b); err != nil { return 0, c.remoteCloseErrHandling(err) } return len(b), nil } -func (c *Conn) LocalAddr() net.Addr { - return c.session.LocalAddr() -} - func (c *Conn) RemoteAddr() net.Addr { return c.session.RemoteAddr() } -func (c *Conn) SetReadDeadline(t time.Time) error { - return nil -} - -func (c *Conn) SetWriteDeadline(t time.Time) error { - return fmt.Errorf("SetWriteDeadline is not implemented") -} - -func (c *Conn) SetDeadline(t time.Time) error { - return fmt.Errorf("SetDeadline is not implemented") -} - func (c *Conn) Close() error { c.closedMu.Lock() if c.closed { @@ -74,8 +51,6 @@ func (c *Conn) Close() error { c.closed = true c.closedMu.Unlock() - c.ctxCancel() // Cancel the context - sessionErr := c.session.CloseWithError(0, "normal closure") return sessionErr } diff --git a/relay/server/listener/quic/listener.go b/relay/server/listener/quic/listener.go index 797223e74..68f0e03c0 100644 --- a/relay/server/listener/quic/listener.go +++ b/relay/server/listener/quic/listener.go @@ -5,12 +5,12 @@ import ( "crypto/tls" "errors" "fmt" - "net" "github.com/quic-go/quic-go" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/relay/protocol" + relaylistener "github.com/netbirdio/netbird/relay/server/listener" nbRelay "github.com/netbirdio/netbird/shared/relay" ) @@ -25,7 +25,7 @@ type Listener struct { listener *quic.Listener } -func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { +func (l *Listener) Listen(acceptFn func(conn relaylistener.Conn)) error { quicCfg := &quic.Config{ EnableDatagrams: true, InitialPacketSize: nbRelay.QUICInitialPacketSize, diff --git a/relay/server/listener/ws/conn.go b/relay/server/listener/ws/conn.go index d5bce56f7..c22b5719d 100644 --- a/relay/server/listener/ws/conn.go +++ b/relay/server/listener/ws/conn.go @@ -18,25 +18,21 @@ const ( type Conn struct { *websocket.Conn - lAddr *net.TCPAddr rAddr *net.TCPAddr closed bool closedMu sync.Mutex - ctx context.Context } -func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn { +func NewConn(wsConn *websocket.Conn, rAddr *net.TCPAddr) *Conn { return &Conn{ Conn: wsConn, - lAddr: lAddr, rAddr: rAddr, - ctx: context.Background(), } } -func (c *Conn) Read(b []byte) (n int, err error) { - t, r, err := c.Reader(c.ctx) +func (c *Conn) Read(ctx context.Context, b []byte) (n int, err error) { + t, r, err := c.Reader(ctx) if err != nil { return 0, c.ioErrHandling(err) } @@ -56,34 +52,18 @@ func (c *Conn) Read(b []byte) (n int, err error) { // Write writes a binary message with the given payload. // It does not block until fill the internal buffer. // If the buffer filled up, wait until the buffer is drained or timeout. -func (c *Conn) Write(b []byte) (int, error) { - ctx, ctxCancel := context.WithTimeout(c.ctx, writeTimeout) +func (c *Conn) Write(ctx context.Context, b []byte) (int, error) { + ctx, ctxCancel := context.WithTimeout(ctx, writeTimeout) defer ctxCancel() err := c.Conn.Write(ctx, websocket.MessageBinary, b) return len(b), err } -func (c *Conn) LocalAddr() net.Addr { - return c.lAddr -} - func (c *Conn) RemoteAddr() net.Addr { return c.rAddr } -func (c *Conn) SetReadDeadline(t time.Time) error { - return fmt.Errorf("SetReadDeadline is not implemented") -} - -func (c *Conn) SetWriteDeadline(t time.Time) error { - return fmt.Errorf("SetWriteDeadline is not implemented") -} - -func (c *Conn) SetDeadline(t time.Time) error { - return fmt.Errorf("SetDeadline is not implemented") -} - func (c *Conn) Close() error { c.closedMu.Lock() c.closed = true diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go index 12219e29b..ba175f901 100644 --- a/relay/server/listener/ws/listener.go +++ b/relay/server/listener/ws/listener.go @@ -7,11 +7,13 @@ import ( "fmt" "net" "net/http" + "time" "github.com/coder/websocket" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/relay/protocol" + relaylistener "github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/shared/relay" ) @@ -27,18 +29,19 @@ type Listener struct { TLSConfig *tls.Config server *http.Server - acceptFn func(conn net.Conn) + acceptFn func(conn relaylistener.Conn) } -func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { +func (l *Listener) Listen(acceptFn func(conn relaylistener.Conn)) error { l.acceptFn = acceptFn mux := http.NewServeMux() mux.HandleFunc(URLPath, l.onAccept) l.server = &http.Server{ - Addr: l.Address, - Handler: mux, - TLSConfig: l.TLSConfig, + Addr: l.Address, + Handler: mux, + TLSConfig: l.TLSConfig, + ReadHeaderTimeout: 5 * time.Second, } log.Infof("WS server listening address: %s", l.Address) @@ -93,18 +96,9 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) { return } - lAddr, err := net.ResolveTCPAddr("tcp", l.server.Addr) - if err != nil { - err = wsConn.Close(websocket.StatusInternalError, "internal error") - if err != nil { - log.Errorf("failed to close ws connection: %s", err) - } - return - } - log.Infof("WS client connected from: %s", rAddr) - conn := NewConn(wsConn, lAddr, rAddr) + conn := NewConn(wsConn, rAddr) l.acceptFn(conn) } diff --git a/relay/server/peer.go b/relay/server/peer.go index c5ff41857..8376cdfa7 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -10,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/relay/metrics" + "github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/relay/server/store" "github.com/netbirdio/netbird/shared/relay/healthcheck" "github.com/netbirdio/netbird/shared/relay/messages" @@ -26,11 +27,14 @@ type Peer struct { metrics *metrics.Metrics log *log.Entry id messages.PeerID - conn net.Conn + conn listener.Conn connMu sync.RWMutex store *store.Store notifier *store.PeerNotifier + ctx context.Context + ctxCancel context.CancelFunc + peersListener *store.Listener // between the online peer collection step and the notification sending should not be sent offline notifications from another thread @@ -38,14 +42,17 @@ type Peer struct { } // NewPeer creates a new Peer instance and prepare custom logging -func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer { +func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn listener.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer { + ctx, cancel := context.WithCancel(context.Background()) p := &Peer{ - metrics: metrics, - log: log.WithField("peer_id", id.String()), - id: id, - conn: conn, - store: store, - notifier: notifier, + metrics: metrics, + log: log.WithField("peer_id", id.String()), + id: id, + conn: conn, + store: store, + notifier: notifier, + ctx: ctx, + ctxCancel: cancel, } return p @@ -57,6 +64,7 @@ func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store func (p *Peer) Work() { p.peersListener = p.notifier.NewListener(p.sendPeersOnline, p.sendPeersWentOffline) defer func() { + p.ctxCancel() p.notifier.RemoveListener(p.peersListener) if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { @@ -64,8 +72,7 @@ func (p *Peer) Work() { } }() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := p.ctx hc := healthcheck.NewSender(p.log) go hc.StartHealthCheck(ctx) @@ -73,7 +80,7 @@ func (p *Peer) Work() { buf := make([]byte, bufferSize) for { - n, err := p.conn.Read(buf) + n, err := p.conn.Read(ctx, buf) if err != nil { if !errors.Is(err, net.ErrClosed) { p.log.Errorf("failed to read message: %s", err) @@ -131,10 +138,10 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc * } // Write writes data to the connection -func (p *Peer) Write(b []byte) (int, error) { +func (p *Peer) Write(ctx context.Context, b []byte) (int, error) { p.connMu.RLock() defer p.connMu.RUnlock() - return p.conn.Write(b) + return p.conn.Write(ctx, b) } // CloseGracefully closes the connection with the peer gracefully. Send a close message to the client and close the @@ -147,6 +154,7 @@ func (p *Peer) CloseGracefully(ctx context.Context) { p.log.Errorf("failed to send close message to peer: %s", p.String()) } + p.ctxCancel() if err := p.conn.Close(); err != nil { p.log.Errorf(errCloseConn, err) } @@ -156,6 +164,7 @@ func (p *Peer) Close() { p.connMu.Lock() defer p.connMu.Unlock() + p.ctxCancel() if err := p.conn.Close(); err != nil { p.log.Errorf(errCloseConn, err) } @@ -170,26 +179,15 @@ func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error { ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() - writeDone := make(chan struct{}) - var err error - go func() { - _, err = p.conn.Write(buf) - close(writeDone) - }() - - select { - case <-ctx.Done(): - return ctx.Err() - case <-writeDone: - return err - } + _, err := p.conn.Write(ctx, buf) + return err } func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Sender) { for { select { case <-hc.HealthCheck: - _, err := p.Write(messages.MarshalHealthcheck()) + _, err := p.Write(ctx, messages.MarshalHealthcheck()) if err != nil { p.log.Errorf("failed to send healthcheck message: %s", err) return @@ -228,12 +226,12 @@ func (p *Peer) handleTransportMsg(msg []byte) { return } - n, err := dp.Write(msg) + n, err := dp.Write(dp.ctx, msg) if err != nil { p.log.Errorf("failed to write transport message to: %s", dp.String()) return } - p.metrics.TransferBytesSent.Add(context.Background(), int64(n)) + p.metrics.TransferBytesSent.Add(p.ctx, int64(n)) } func (p *Peer) handleSubscribePeerState(msg []byte) { @@ -276,7 +274,7 @@ func (p *Peer) sendPeersOnline(peers []messages.PeerID) { } for n, msg := range msgs { - if _, err := p.Write(msg); err != nil { + if _, err := p.Write(p.ctx, msg); err != nil { p.log.Errorf("failed to write %d. peers offline message: %s", n, err) } } @@ -293,7 +291,7 @@ func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) { } for n, msg := range msgs { - if _, err := p.Write(msg); err != nil { + if _, err := p.Write(p.ctx, msg); err != nil { p.log.Errorf("failed to write %d. peers offline message: %s", n, err) } } diff --git a/relay/server/relay.go b/relay/server/relay.go index bb355f58f..56add8bea 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -3,7 +3,6 @@ package server import ( "context" "fmt" - "net" "net/url" "sync" "time" @@ -13,11 +12,20 @@ import ( "go.opentelemetry.io/otel/metric" "github.com/netbirdio/netbird/relay/healthcheck/peerid" + "github.com/netbirdio/netbird/relay/protocol" + "github.com/netbirdio/netbird/relay/server/listener" + //nolint:staticcheck "github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/server/store" ) +type Listener interface { + Listen(func(conn listener.Conn)) error + Shutdown(ctx context.Context) error + Protocol() protocol.Protocol +} + type Config struct { Meter metric.Meter ExposedAddress string @@ -109,7 +117,7 @@ func NewRelay(config Config) (*Relay, error) { } // Accept start to handle a new peer connection -func (r *Relay) Accept(conn net.Conn) { +func (r *Relay) Accept(conn listener.Conn) { acceptTime := time.Now() r.closeMu.RLock() defer r.closeMu.RUnlock() @@ -117,12 +125,15 @@ func (r *Relay) Accept(conn net.Conn) { return } + hsCtx, hsCancel := context.WithTimeout(context.Background(), handshakeTimeout) + defer hsCancel() + h := handshake{ conn: conn, validator: r.validator, preparedMsg: r.preparedMsg, } - peerID, err := h.handshakeReceive() + peerID, err := h.handshakeReceive(hsCtx) if err != nil { if peerid.IsHealthCheck(peerID) { log.Debugf("health check connection from %s", conn.RemoteAddr()) @@ -154,7 +165,7 @@ func (r *Relay) Accept(conn net.Conn) { r.metrics.PeerDisconnected(peer.String()) }() - if err := h.handshakeResponse(); err != nil { + if err := h.handshakeResponse(hsCtx); err != nil { log.Errorf("failed to send handshake response, close peer: %s", err) peer.Close() } diff --git a/relay/server/server.go b/relay/server/server.go index a0f7eb73c..340da55b8 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -3,7 +3,6 @@ package server import ( "context" "crypto/tls" - "net" "net/url" "sync" @@ -31,7 +30,7 @@ type ListenerConfig struct { // In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method. type Server struct { relay *Relay - listeners []listener.Listener + listeners []Listener listenerMux sync.Mutex } @@ -56,7 +55,7 @@ func NewServer(config Config) (*Server, error) { } return &Server{ relay: relay, - listeners: make([]listener.Listener, 0, 2), + listeners: make([]Listener, 0, 2), }, nil } @@ -86,7 +85,7 @@ func (r *Server) Listen(cfg ListenerConfig) error { wg := sync.WaitGroup{} for _, l := range r.listeners { wg.Add(1) - go func(listener listener.Listener) { + go func(listener Listener) { defer wg.Done() errChan <- listener.Listen(r.relay.Accept) }(l) @@ -139,6 +138,6 @@ func (r *Server) InstanceURL() url.URL { // RelayAccept returns the relay's Accept function for handling incoming connections. // This allows external HTTP handlers to route connections to the relay without // starting the relay's own listeners. -func (r *Server) RelayAccept() func(conn net.Conn) { +func (r *Server) RelayAccept() func(conn listener.Conn) { return r.relay.Accept } diff --git a/shared/auth/jwt/extractor.go b/shared/auth/jwt/extractor.go index 5806d1f4d..d113f53d5 100644 --- a/shared/auth/jwt/extractor.go +++ b/shared/auth/jwt/extractor.go @@ -146,7 +146,11 @@ func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string userJWTGroups := make([]string, 0) if claim, ok := claims[claimName]; ok { - if claimGroups, ok := claim.([]interface{}); ok { + switch claimGroups := claim.(type) { + case string: + // Some IdPs emit a single group claim as a string instead of an array. + userJWTGroups = append(userJWTGroups, claimGroups) + case []any: for _, g := range claimGroups { if group, ok := g.(string); ok { userJWTGroups = append(userJWTGroups, group) @@ -154,9 +158,11 @@ func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string log.Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g) } } + default: + log.Debugf("JWT claim %q is not a string or string array (type: %T): %v", claimName, claim, claim) } } else { - log.Debugf("JWT claim %q is not a string array", claimName) + log.Debugf("JWT claim %q is missing", claimName) } return userJWTGroups diff --git a/shared/auth/jwt/extractor_test.go b/shared/auth/jwt/extractor_test.go index 45529770d..4f8fe0007 100644 --- a/shared/auth/jwt/extractor_test.go +++ b/shared/auth/jwt/extractor_test.go @@ -249,6 +249,15 @@ func TestClaimsExtractor_ToGroups(t *testing.T) { groupClaimName: "groups", expectedGroups: []string{}, }, + { + name: "extracts single group string from claim", + claims: jwt.MapClaims{ + "sub": "user-123", + "groups": "admin", + }, + groupClaimName: "groups", + expectedGroups: []string{"admin"}, + }, { name: "handles custom claim name", claims: jwt.MapClaims{ diff --git a/shared/auth/jwt/validator.go b/shared/auth/jwt/validator.go index aeaa5842c..cf18b2cf6 100644 --- a/shared/auth/jwt/validator.go +++ b/shared/auth/jwt/validator.go @@ -25,7 +25,7 @@ import ( // Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation type Jwks struct { Keys []JSONWebKey `json:"keys"` - expiresInTime time.Time + ExpiresInTime time.Time `json:"-"` } // The supported elliptic curves types @@ -53,12 +53,17 @@ type JSONWebKey struct { X5c []string `json:"x5c"` } +// KeyFetcher is a function that retrieves JWKS keys directly (e.g., from Dex storage) +// bypassing HTTP. When set on a Validator, it is used instead of the HTTP-based getPemKeys. +type KeyFetcher func(ctx context.Context) (*Jwks, error) + type Validator struct { lock sync.Mutex issuer string audienceList []string keysLocation string idpSignkeyRefreshEnabled bool + keyFetcher KeyFetcher keys *Jwks lastForcedRefresh time.Time } @@ -85,10 +90,39 @@ func NewValidator(issuer string, audienceList []string, keysLocation string, idp } } +// NewValidatorWithKeyFetcher creates a Validator that fetches keys directly using the +// provided KeyFetcher (e.g., from Dex storage) instead of via HTTP. +func NewValidatorWithKeyFetcher(issuer string, audienceList []string, keyFetcher KeyFetcher) *Validator { + ctx := context.Background() + keys, err := keyFetcher(ctx) + if err != nil { + log.Warnf("could not get keys from key fetcher: %s, it will try again on the next http request", err) + } + if keys == nil { + keys = &Jwks{} + } + + return &Validator{ + keys: keys, + issuer: issuer, + audienceList: audienceList, + idpSignkeyRefreshEnabled: true, + keyFetcher: keyFetcher, + } +} + // forcedRefreshCooldown is the minimum time between forced key refreshes // to prevent abuse from invalid tokens with fake kid values const forcedRefreshCooldown = 30 * time.Second +// fetchKeys retrieves keys using the keyFetcher if available, otherwise falls back to HTTP. +func (v *Validator) fetchKeys(ctx context.Context) (*Jwks, error) { + if v.keyFetcher != nil { + return v.keyFetcher(ctx) + } + return getPemKeys(v.keysLocation) +} + func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc { return func(token *jwt.Token) (interface{}, error) { // If keys are rotated, verify the keys prior to token validation @@ -131,13 +165,13 @@ func (v *Validator) refreshKeys(ctx context.Context) { v.lock.Lock() defer v.lock.Unlock() - refreshedKeys, err := getPemKeys(v.keysLocation) + refreshedKeys, err := v.fetchKeys(ctx) if err != nil { log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err) return } - log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC()) + log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.ExpiresInTime.UTC()) v.keys = refreshedKeys } @@ -155,13 +189,13 @@ func (v *Validator) forceRefreshKeys(ctx context.Context) bool { log.WithContext(ctx).Debugf("key not found in cache, forcing JWKS refresh") - refreshedKeys, err := getPemKeys(v.keysLocation) + refreshedKeys, err := v.fetchKeys(ctx) if err != nil { log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err) return false } - log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC()) + log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.ExpiresInTime.UTC()) v.keys = refreshedKeys v.lastForcedRefresh = time.Now() return true @@ -203,7 +237,7 @@ func (v *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.To // stillValid returns true if the JSONWebKey still valid and have enough time to be used func (jwks *Jwks) stillValid() bool { - return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime) + return !jwks.ExpiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.ExpiresInTime) } func getPemKeys(keysLocation string) (*Jwks, error) { @@ -227,7 +261,7 @@ func getPemKeys(keysLocation string) (*Jwks, error) { cacheControlHeader := resp.Header.Get("Cache-Control") expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader) - jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second) + jwks.ExpiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second) return jwks, nil } diff --git a/shared/management/client/client.go b/shared/management/client/client.go index a15301223..18efba87b 100644 --- a/shared/management/client/client.go +++ b/shared/management/client/client.go @@ -4,8 +4,6 @@ import ( "context" "io" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" @@ -16,14 +14,18 @@ type Client interface { io.Closer Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error Job(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error - GetServerPublicKey() (*wgtypes.Key, error) - Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) - Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) - GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) - GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) + Register(setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) + Login(sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) + GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) + GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) GetServerURL() string + // IsHealthy returns the current connection status without blocking. + // Used by the engine to monitor connectivity in the background. IsHealthy() bool + // HealthCheck actively probes the management server and returns an error if unreachable. + // Used to validate connectivity before committing configuration changes. + HealthCheck() error SyncMeta(sysInfo *system.Info) error Logout() error CreateExpose(ctx context.Context, req ExposeRequest) (*ExposeResponse, error) diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index a11f863a7..a8e8172dc 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -31,6 +31,7 @@ import ( "github.com/netbirdio/netbird/management/internals/server/config" mgmt "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/mock_server" @@ -95,9 +96,16 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { settingsManagerMock := settings.NewMockManager(ctrl) jobManager := job.NewJobManager(nil, store, peersManger) - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManger, settingsManagerMock, eventStore) + ctx := context.Background() - metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + t.Fatal(err) + } + + ia, _ := integrations.NewIntegratedValidator(ctx, peersManger, settingsManagerMock, eventStore, cacheStore) + + metrics, err := telemetry.NewDefaultAppMetrics(ctx) require.NoError(t, err) settingsMockManager := settings.NewMockManager(ctrl) @@ -116,11 +124,10 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { Return(&types.ExtraSettings{}, nil). AnyTimes() - ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManger), config) - accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore) if err != nil { t.Fatal(err) } @@ -131,7 +138,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { if err != nil { t.Fatal(err) } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController, nil) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController, nil, nil) if err != nil { t.Fatal(err) } @@ -189,7 +196,7 @@ func closeManagementSilently(s *grpc.Server, listener net.Listener) { } } -func TestClient_GetServerPublicKey(t *testing.T) { +func TestClient_HealthCheck(t *testing.T) { testKey, err := wgtypes.GenerateKey() if err != nil { t.Fatal(err) @@ -203,12 +210,8 @@ func TestClient_GetServerPublicKey(t *testing.T) { t.Fatal(err) } - key, err := client.GetServerPublicKey() - if err != nil { - t.Error("couldn't retrieve management public key") - } - if key == nil { - t.Error("got an empty management public key") + if err := client.HealthCheck(); err != nil { + t.Errorf("health check failed: %v", err) } } @@ -225,12 +228,8 @@ func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) { if err != nil { t.Fatal(err) } - key, err := client.GetServerPublicKey() - if err != nil { - t.Fatal(err) - } sysInfo := system.GetInfo(context.TODO()) - _, err = client.Login(*key, sysInfo, nil, nil) + _, err = client.Login(sysInfo, nil, nil) if err == nil { t.Error("expecting err on unregistered login, got nil") } @@ -253,12 +252,8 @@ func TestClient_LoginRegistered(t *testing.T) { t.Fatal(err) } - key, err := client.GetServerPublicKey() - if err != nil { - t.Error(err) - } info := system.GetInfo(context.TODO()) - resp, err := client.Register(*key, ValidKey, "", info, nil, nil) + resp, err := client.Register(ValidKey, "", info, nil, nil) if err != nil { t.Error(err) } @@ -282,13 +277,8 @@ func TestClient_Sync(t *testing.T) { t.Fatal(err) } - serverKey, err := client.GetServerPublicKey() - if err != nil { - t.Error(err) - } - info := system.GetInfo(context.TODO()) - _, err = client.Register(*serverKey, ValidKey, "", info, nil, nil) + _, err = client.Register(ValidKey, "", info, nil, nil) if err != nil { t.Error(err) } @@ -304,7 +294,7 @@ func TestClient_Sync(t *testing.T) { } info = system.GetInfo(context.TODO()) - _, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil, nil) + _, err = remoteClient.Register(ValidKey, "", info, nil, nil) if err != nil { t.Fatal(err) } @@ -364,11 +354,6 @@ func Test_SystemMetaDataFromClient(t *testing.T) { t.Fatalf("error while creating testClient: %v", err) } - key, err := testClient.GetServerPublicKey() - if err != nil { - t.Fatalf("error while getting server public key from testclient, %v", err) - } - var actualMeta *mgmtProto.PeerSystemMeta var actualValidKey string var wg sync.WaitGroup @@ -405,7 +390,7 @@ func Test_SystemMetaDataFromClient(t *testing.T) { } info := system.GetInfo(context.TODO()) - _, err = testClient.Register(*key, ValidKey, "", info, nil, nil) + _, err = testClient.Register(ValidKey, "", info, nil, nil) if err != nil { t.Errorf("error while trying to register client: %v", err) } @@ -505,7 +490,7 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) { } mgmtMockServer.GetDeviceAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) { - encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo) + encryptedResp, err := encryption.EncryptMessage(client.key.PublicKey(), serverKey, expectedFlowInfo) if err != nil { return nil, err } @@ -517,7 +502,7 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) { }, nil } - flowInfo, err := client.GetDeviceAuthorizationFlow(serverKey) + flowInfo, err := client.GetDeviceAuthorizationFlow() if err != nil { t.Error("error while retrieving device auth flow information") } @@ -551,7 +536,7 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) { } mgmtMockServer.GetPKCEAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) { - encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo) + encryptedResp, err := encryption.EncryptMessage(client.key.PublicKey(), serverKey, expectedFlowInfo) if err != nil { return nil, err } @@ -563,11 +548,11 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) { }, nil } - flowInfo, err := client.GetPKCEAuthorizationFlow(serverKey) + flowInfo, err := client.GetPKCEAuthorizationFlow() if err != nil { t.Error("error while retrieving pkce auth flow information") } assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match") - assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientSecret, flowInfo.ProviderConfig.ClientSecret, "provider configured client secret should match") + assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientSecret, flowInfo.ProviderConfig.ClientSecret, "provider configured client secret should match") //nolint:staticcheck } diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 252199498..80625fe06 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -30,6 +30,8 @@ import ( const ConnectTimeout = 10 * time.Second +const healthCheckTimeout = 5 * time.Second + const ( // EnvMaxRecvMsgSize overrides the default gRPC max receive message size (4 MB) // for the management client connection. Value is in bytes. @@ -202,7 +204,7 @@ func (c *GrpcClient) withMgmtStream( return fmt.Errorf("connection to management is not ready and in %s state", connState) } - serverPubKey, err := c.GetServerPublicKey() + serverPubKey, err := c.getServerPublicKey() if err != nil { log.Debugf(errMsgMgmtPublicKey, err) return err @@ -244,29 +246,23 @@ func (c *GrpcClient) handleJobStream( for { jobReq, err := c.receiveJobRequest(ctx, stream, serverPubKey) if err != nil { + if ctx.Err() != nil { + log.Debugf("job stream context has been canceled, this usually indicates shutdown") + return nil + } if s, ok := gstatus.FromError(err); ok { switch s.Code() { case codes.PermissionDenied: c.notifyDisconnected(err) return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer - case codes.Canceled: - log.Debugf("management connection context has been canceled, this usually indicates shutdown") - return err case codes.Unimplemented: log.Warn("Job feature is not supported by the current management server version. " + "Please update the management service to use this feature.") return nil - default: - c.notifyDisconnected(err) - log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err) - return err } - } else { - // non-gRPC error - c.notifyDisconnected(err) - log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err) - return err } + log.Warnf("job stream disconnected, will retry silently. Reason: %v", err) + return err } if jobReq == nil || len(jobReq.ID) == 0 { @@ -381,22 +377,15 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes. err = c.receiveUpdatesEvents(stream, serverPubKey, msgHandler) if err != nil { c.notifyDisconnected(err) - if s, ok := gstatus.FromError(err); ok { - switch s.Code() { - case codes.PermissionDenied: - return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer - case codes.Canceled: - log.Debugf("management connection context has been canceled, this usually indicates shutdown") - return nil - default: - log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err) - return err - } - } else { - // non-gRPC error - log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err) - return err + if ctx.Err() != nil { + log.Debugf("management connection context has been canceled, this usually indicates shutdown") + return nil } + if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied { + return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer + } + log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err) + return err } return nil @@ -404,7 +393,7 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes. // GetNetworkMap return with the network map func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) { - serverPubKey, err := c.GetServerPublicKey() + serverPubKey, err := c.getServerPublicKey() if err != nil { log.Debugf("failed getting Management Service public key: %s", err) return nil, err @@ -490,18 +479,24 @@ func (c *GrpcClient) receiveUpdatesEvents(stream proto.ManagementService_SyncCli } } -// GetServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server) -func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) { +// HealthCheck actively probes the management server and returns an error if unreachable. +// Used to validate connectivity before committing configuration changes. +func (c *GrpcClient) HealthCheck() error { if !c.ready() { - return nil, errors.New(errMsgNoMgmtConnection) + return errors.New(errMsgNoMgmtConnection) } + _, err := c.getServerPublicKey() + return err +} + +// getServerPublicKey fetches the server's WireGuard public key. +func (c *GrpcClient) getServerPublicKey() (*wgtypes.Key, error) { mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) defer cancel() resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{}) if err != nil { - log.Errorf("failed while getting Management Service public key: %v", err) - return nil, fmt.Errorf("failed while getting Management Service public key") + return nil, fmt.Errorf("failed getting Management Service public key: %w", err) } serverKey, err := wgtypes.ParseKey(resp.Key) @@ -512,7 +507,8 @@ func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) { return &serverKey, nil } -// IsHealthy probes the gRPC connection and returns false on errors +// IsHealthy returns the current connection status without blocking. +// Used by the engine to monitor connectivity in the background. func (c *GrpcClient) IsHealthy() bool { switch c.conn.GetState() { case connectivity.TransientFailure: @@ -525,7 +521,7 @@ func (c *GrpcClient) IsHealthy() bool { case connectivity.Ready: } - ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second) + ctx, cancel := context.WithTimeout(c.ctx, healthCheckTimeout) defer cancel() _, err := c.realClient.GetServerKey(ctx, &proto.Empty{}) @@ -538,12 +534,17 @@ func (c *GrpcClient) IsHealthy() bool { return true } -func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) { +func (c *GrpcClient) login(req *proto.LoginRequest) (*proto.LoginResponse, error) { if !c.ready() { return nil, errors.New(errMsgNoMgmtConnection) } - loginReq, err := encryption.EncryptMessage(serverKey, c.key, req) + serverKey, err := c.getServerPublicKey() + if err != nil { + return nil, err + } + + loginReq, err := encryption.EncryptMessage(*serverKey, c.key, req) if err != nil { log.Errorf("failed to encrypt message: %s", err) return nil, err @@ -577,7 +578,7 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro } loginResp := &proto.LoginResponse{} - err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp) + err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, loginResp) if err != nil { log.Errorf("failed to decrypt login response: %s", err) return nil, err @@ -589,34 +590,40 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro // Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key // Takes care of encrypting and decrypting messages. // This method will also collect system info and send it with the request (e.g. hostname, os, etc) -func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { +func (c *GrpcClient) Register(setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { keys := &proto.PeerKeys{ SshPubKey: pubSSHKey, WgPubKey: []byte(c.key.PublicKey().String()), } - return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()}) + return c.login(&proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()}) } // Login attempts login to Management Server. Takes care of encrypting and decrypting messages. -func (c *GrpcClient) Login(serverKey wgtypes.Key, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { +func (c *GrpcClient) Login(sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { keys := &proto.PeerKeys{ SshPubKey: pubSSHKey, WgPubKey: []byte(c.key.PublicKey().String()), } - return c.login(serverKey, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()}) + return c.login(&proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()}) } // GetDeviceAuthorizationFlow returns a device authorization flow information. // It also takes care of encrypting and decrypting messages. -func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) { +func (c *GrpcClient) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) { if !c.ready() { return nil, fmt.Errorf("no connection to management in order to get device authorization flow") } + + serverKey, err := c.getServerPublicKey() + if err != nil { + return nil, err + } + mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2) defer cancel() message := &proto.DeviceAuthorizationFlowRequest{} - encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message) + encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message) if err != nil { return nil, err } @@ -630,7 +637,7 @@ func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D } flowInfoResp := &proto.DeviceAuthorizationFlow{} - err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp) + err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfoResp) if err != nil { errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err) log.Error(errWithMSG) @@ -642,15 +649,21 @@ func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D // GetPKCEAuthorizationFlow returns a pkce authorization flow information. // It also takes care of encrypting and decrypting messages. -func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) { +func (c *GrpcClient) GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error) { if !c.ready() { return nil, fmt.Errorf("no connection to management in order to get pkce authorization flow") } + + serverKey, err := c.getServerPublicKey() + if err != nil { + return nil, err + } + mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2) defer cancel() message := &proto.PKCEAuthorizationFlowRequest{} - encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message) + encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message) if err != nil { return nil, err } @@ -664,7 +677,7 @@ func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKC } flowInfoResp := &proto.PKCEAuthorizationFlow{} - err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp) + err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfoResp) if err != nil { errWithMSG := fmt.Errorf("failed to decrypt pkce authorization flow message: %s", err) log.Error(errWithMSG) @@ -681,7 +694,7 @@ func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error { return errors.New(errMsgNoMgmtConnection) } - serverPubKey, err := c.GetServerPublicKey() + serverPubKey, err := c.getServerPublicKey() if err != nil { log.Debugf(errMsgMgmtPublicKey, err) return err @@ -724,7 +737,7 @@ func (c *GrpcClient) notifyConnected() { } func (c *GrpcClient) Logout() error { - serverKey, err := c.GetServerPublicKey() + serverKey, err := c.getServerPublicKey() if err != nil { return fmt.Errorf("get server public key: %w", err) } @@ -751,7 +764,7 @@ func (c *GrpcClient) Logout() error { // CreateExpose calls the management server to create a new expose service. func (c *GrpcClient) CreateExpose(ctx context.Context, req ExposeRequest) (*ExposeResponse, error) { - serverPubKey, err := c.GetServerPublicKey() + serverPubKey, err := c.getServerPublicKey() if err != nil { return nil, err } @@ -787,7 +800,7 @@ func (c *GrpcClient) CreateExpose(ctx context.Context, req ExposeRequest) (*Expo // RenewExpose extends the TTL of an active expose session on the management server. func (c *GrpcClient) RenewExpose(ctx context.Context, domain string) error { - serverPubKey, err := c.GetServerPublicKey() + serverPubKey, err := c.getServerPublicKey() if err != nil { return err } @@ -810,7 +823,7 @@ func (c *GrpcClient) RenewExpose(ctx context.Context, domain string) error { // StopExpose terminates an active expose session on the management server. func (c *GrpcClient) StopExpose(ctx context.Context, domain string) error { - serverPubKey, err := c.GetServerPublicKey() + serverPubKey, err := c.getServerPublicKey() if err != nil { return err } diff --git a/shared/management/client/mock.go b/shared/management/client/mock.go index 548e379e8..361e8ffad 100644 --- a/shared/management/client/mock.go +++ b/shared/management/client/mock.go @@ -3,8 +3,6 @@ package client import ( "context" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" @@ -14,12 +12,12 @@ import ( type MockClient struct { CloseFunc func() error SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error - GetServerPublicKeyFunc func() (*wgtypes.Key, error) - RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) - LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) - GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) - GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) + RegisterFunc func(setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) + LoginFunc func(info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) + GetDeviceAuthorizationFlowFunc func() (*proto.DeviceAuthorizationFlow, error) + GetPKCEAuthorizationFlowFunc func() (*proto.PKCEAuthorizationFlow, error) GetServerURLFunc func() string + HealthCheckFunc func() error SyncMetaFunc func(sysInfo *system.Info) error LogoutFunc func() error JobFunc func(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error @@ -53,39 +51,39 @@ func (m *MockClient) Job(ctx context.Context, msgHandler func(msg *proto.JobRequ return m.JobFunc(ctx, msgHandler) } -func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) { - if m.GetServerPublicKeyFunc == nil { - return nil, nil - } - return m.GetServerPublicKeyFunc() -} - -func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { +func (m *MockClient) Register(setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { if m.RegisterFunc == nil { return nil, nil } - return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey, dnsLabels) + return m.RegisterFunc(setupKey, jwtToken, info, sshKey, dnsLabels) } -func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { +func (m *MockClient) Login(info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { if m.LoginFunc == nil { return nil, nil } - return m.LoginFunc(serverKey, info, sshKey, dnsLabels) + return m.LoginFunc(info, sshKey, dnsLabels) } -func (m *MockClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) { +func (m *MockClient) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) { if m.GetDeviceAuthorizationFlowFunc == nil { return nil, nil } - return m.GetDeviceAuthorizationFlowFunc(serverKey) + return m.GetDeviceAuthorizationFlowFunc() } -func (m *MockClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) { +func (m *MockClient) GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error) { if m.GetPKCEAuthorizationFlowFunc == nil { return nil, nil } - return m.GetPKCEAuthorizationFlowFunc(serverKey) + return m.GetPKCEAuthorizationFlowFunc() +} + +func (m *MockClient) HealthCheck() error { + if m.HealthCheckFunc == nil { + return nil + } + return m.HealthCheckFunc() } // GetNetworkMap mock implementation of GetNetworkMap from Client interface. diff --git a/shared/management/client/rest/azure_idp.go b/shared/management/client/rest/azure_idp.go new file mode 100644 index 000000000..40b90bc30 --- /dev/null +++ b/shared/management/client/rest/azure_idp.go @@ -0,0 +1,112 @@ +package rest + +import ( + "bytes" + "context" + "encoding/json" + + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// AzureIDPAPI APIs for Azure AD IDP integrations +type AzureIDPAPI struct { + c *Client +} + +// List retrieves all Azure AD IDP integrations +func (a *AzureIDPAPI) List(ctx context.Context) ([]api.AzureIntegration, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/azure-idp", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.AzureIntegration](resp) + return ret, err +} + +// Get retrieves a specific Azure AD IDP integration by ID +func (a *AzureIDPAPI) Get(ctx context.Context, integrationID string) (*api.AzureIntegration, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/azure-idp/"+integrationID, nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.AzureIntegration](resp) + return &ret, err +} + +// Create creates a new Azure AD IDP integration +func (a *AzureIDPAPI) Create(ctx context.Context, request api.CreateAzureIntegrationRequest) (*api.AzureIntegration, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/azure-idp", bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.AzureIntegration](resp) + return &ret, err +} + +// Update updates an existing Azure AD IDP integration +func (a *AzureIDPAPI) Update(ctx context.Context, integrationID string, request api.UpdateAzureIntegrationRequest) (*api.AzureIntegration, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/azure-idp/"+integrationID, bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.AzureIntegration](resp) + return &ret, err +} + +// Delete deletes an Azure AD IDP integration +func (a *AzureIDPAPI) Delete(ctx context.Context, integrationID string) error { + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/azure-idp/"+integrationID, nil, nil) + if err != nil { + return err + } + if resp.Body != nil { + defer resp.Body.Close() + } + return nil +} + +// Sync triggers a manual sync for an Azure AD IDP integration +func (a *AzureIDPAPI) Sync(ctx context.Context, integrationID string) (*api.SyncResult, error) { + resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/azure-idp/"+integrationID+"/sync", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.SyncResult](resp) + return &ret, err +} + +// GetLogs retrieves synchronization logs for an Azure AD IDP integration +func (a *AzureIDPAPI) GetLogs(ctx context.Context, integrationID string) ([]api.IdpIntegrationSyncLog, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/azure-idp/"+integrationID+"/logs", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.IdpIntegrationSyncLog](resp) + return ret, err +} diff --git a/shared/management/client/rest/azure_idp_test.go b/shared/management/client/rest/azure_idp_test.go new file mode 100644 index 000000000..480d2a313 --- /dev/null +++ b/shared/management/client/rest/azure_idp_test.go @@ -0,0 +1,252 @@ +//go:build integration + +package rest_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/shared/management/client/rest" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" +) + +var testAzureIntegration = api.AzureIntegration{ + Id: 1, + Enabled: true, + ClientId: "12345678-1234-1234-1234-123456789012", + TenantId: "87654321-4321-4321-4321-210987654321", + SyncInterval: 300, + GroupPrefixes: []string{"eng-"}, + UserGroupPrefixes: []string{"dev-"}, + Host: "microsoft.com", + LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), +} + +func TestAzureIDP_List_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal([]api.AzureIntegration{testAzureIntegration}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.List(context.Background()) + require.NoError(t, err) + assert.Len(t, ret, 1) + assert.Equal(t, testAzureIntegration, ret[0]) + }) +} + +func TestAzureIDP_List_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.List(context.Background()) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Empty(t, ret) + }) +} + +func TestAzureIDP_Get_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal(testAzureIntegration) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Get(context.Background(), "int-1") + require.NoError(t, err) + assert.Equal(t, testAzureIntegration, *ret) + }) +} + +func TestAzureIDP_Get_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Get(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestAzureIDP_Create_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.CreateAzureIntegrationRequest + err = json.Unmarshal(reqBytes, &req) + require.NoError(t, err) + assert.Equal(t, "12345678-1234-1234-1234-123456789012", req.ClientId) + retBytes, _ := json.Marshal(testAzureIntegration) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Create(context.Background(), api.CreateAzureIntegrationRequest{ + ClientId: "12345678-1234-1234-1234-123456789012", + ClientSecret: "secret", + TenantId: "87654321-4321-4321-4321-210987654321", + Host: api.CreateAzureIntegrationRequestHostMicrosoftCom, + GroupPrefixes: &[]string{"eng-"}, + }) + require.NoError(t, err) + assert.Equal(t, testAzureIntegration, *ret) + }) +} + +func TestAzureIDP_Create_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Create(context.Background(), api.CreateAzureIntegrationRequest{}) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestAzureIDP_Update_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "PUT", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.UpdateAzureIntegrationRequest + err = json.Unmarshal(reqBytes, &req) + require.NoError(t, err) + assert.Equal(t, true, *req.Enabled) + retBytes, _ := json.Marshal(testAzureIntegration) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Update(context.Background(), "int-1", api.UpdateAzureIntegrationRequest{ + Enabled: ptr(true), + }) + require.NoError(t, err) + assert.Equal(t, testAzureIntegration, *ret) + }) +} + +func TestAzureIDP_Update_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Update(context.Background(), "int-1", api.UpdateAzureIntegrationRequest{}) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestAzureIDP_Delete_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "DELETE", r.Method) + w.WriteHeader(200) + }) + err := c.AzureIDP.Delete(context.Background(), "int-1") + require.NoError(t, err) + }) +} + +func TestAzureIDP_Delete_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + err := c.AzureIDP.Delete(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + }) +} + +func TestAzureIDP_Sync_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1/sync", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + retBytes, _ := json.Marshal(api.SyncResult{Result: ptr("ok")}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Sync(context.Background(), "int-1") + require.NoError(t, err) + assert.Equal(t, "ok", *ret.Result) + }) +} + +func TestAzureIDP_Sync_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1/sync", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Sync(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestAzureIDP_GetLogs_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal([]api.IdpIntegrationSyncLog{testSyncLog}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.GetLogs(context.Background(), "int-1") + require.NoError(t, err) + assert.Len(t, ret, 1) + assert.Equal(t, testSyncLog, ret[0]) + }) +} + +func TestAzureIDP_GetLogs_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.GetLogs(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Empty(t, ret) + }) +} diff --git a/shared/management/client/rest/client.go b/shared/management/client/rest/client.go index f308761fb..f0cb4d2d1 100644 --- a/shared/management/client/rest/client.go +++ b/shared/management/client/rest/client.go @@ -110,6 +110,15 @@ type Client struct { // see more: https://docs.netbird.io/api/resources/scim SCIM *SCIMAPI + // GoogleIDP NetBird Google Workspace IDP integration APIs + GoogleIDP *GoogleIDPAPI + + // AzureIDP NetBird Azure AD IDP integration APIs + AzureIDP *AzureIDPAPI + + // OktaScimIDP NetBird Okta SCIM IDP integration APIs + OktaScimIDP *OktaScimIDPAPI + // EventStreaming NetBird Event Streaming integration APIs // see more: https://docs.netbird.io/api/resources/event-streaming EventStreaming *EventStreamingAPI @@ -185,6 +194,9 @@ func (c *Client) initialize() { c.MSP = &MSPAPI{c} c.EDR = &EDRAPI{c} c.SCIM = &SCIMAPI{c} + c.GoogleIDP = &GoogleIDPAPI{c} + c.AzureIDP = &AzureIDPAPI{c} + c.OktaScimIDP = &OktaScimIDPAPI{c} c.EventStreaming = &EventStreamingAPI{c} c.IdentityProviders = &IdentityProvidersAPI{c} c.Ingress = &IngressAPI{c} diff --git a/shared/management/client/rest/edr.go b/shared/management/client/rest/edr.go index 7dfc891c2..f9b7f2a88 100644 --- a/shared/management/client/rest/edr.go +++ b/shared/management/client/rest/edr.go @@ -265,6 +265,65 @@ func (a *EDRAPI) DeleteHuntressIntegration(ctx context.Context) error { return nil } +// GetFleetDMIntegration retrieves the EDR FleetDM integration. +func (a *EDRAPI) GetFleetDMIntegration(ctx context.Context) (*api.EDRFleetDMResponse, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/edr/fleetdm", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.EDRFleetDMResponse](resp) + return &ret, err +} + +// CreateFleetDMIntegration creates a new EDR FleetDM integration. +func (a *EDRAPI) CreateFleetDMIntegration(ctx context.Context, request api.EDRFleetDMRequest) (*api.EDRFleetDMResponse, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/edr/fleetdm", bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.EDRFleetDMResponse](resp) + return &ret, err +} + +// UpdateFleetDMIntegration updates an existing EDR FleetDM integration. +func (a *EDRAPI) UpdateFleetDMIntegration(ctx context.Context, request api.EDRFleetDMRequest) (*api.EDRFleetDMResponse, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/edr/fleetdm", bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.EDRFleetDMResponse](resp) + return &ret, err +} + +// DeleteFleetDMIntegration deletes the EDR FleetDM integration. +func (a *EDRAPI) DeleteFleetDMIntegration(ctx context.Context) error { + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/edr/fleetdm", nil, nil) + if err != nil { + return err + } + if resp.Body != nil { + defer resp.Body.Close() + } + return nil +} + // BypassPeerCompliance bypasses compliance for a non-compliant peer // See more: https://docs.netbird.io/api/resources/edr#bypass-peer-compliance func (a *EDRAPI) BypassPeerCompliance(ctx context.Context, peerID string) (*api.BypassResponse, error) { diff --git a/shared/management/client/rest/google_idp.go b/shared/management/client/rest/google_idp.go new file mode 100644 index 000000000..b86436503 --- /dev/null +++ b/shared/management/client/rest/google_idp.go @@ -0,0 +1,112 @@ +package rest + +import ( + "bytes" + "context" + "encoding/json" + + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// GoogleIDPAPI APIs for Google Workspace IDP integrations +type GoogleIDPAPI struct { + c *Client +} + +// List retrieves all Google Workspace IDP integrations +func (a *GoogleIDPAPI) List(ctx context.Context) ([]api.GoogleIntegration, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/google-idp", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.GoogleIntegration](resp) + return ret, err +} + +// Get retrieves a specific Google Workspace IDP integration by ID +func (a *GoogleIDPAPI) Get(ctx context.Context, integrationID string) (*api.GoogleIntegration, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/google-idp/"+integrationID, nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.GoogleIntegration](resp) + return &ret, err +} + +// Create creates a new Google Workspace IDP integration +func (a *GoogleIDPAPI) Create(ctx context.Context, request api.CreateGoogleIntegrationRequest) (*api.GoogleIntegration, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/google-idp", bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.GoogleIntegration](resp) + return &ret, err +} + +// Update updates an existing Google Workspace IDP integration +func (a *GoogleIDPAPI) Update(ctx context.Context, integrationID string, request api.UpdateGoogleIntegrationRequest) (*api.GoogleIntegration, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/google-idp/"+integrationID, bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.GoogleIntegration](resp) + return &ret, err +} + +// Delete deletes a Google Workspace IDP integration +func (a *GoogleIDPAPI) Delete(ctx context.Context, integrationID string) error { + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/google-idp/"+integrationID, nil, nil) + if err != nil { + return err + } + if resp.Body != nil { + defer resp.Body.Close() + } + return nil +} + +// Sync triggers a manual sync for a Google Workspace IDP integration +func (a *GoogleIDPAPI) Sync(ctx context.Context, integrationID string) (*api.SyncResult, error) { + resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/google-idp/"+integrationID+"/sync", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.SyncResult](resp) + return &ret, err +} + +// GetLogs retrieves synchronization logs for a Google Workspace IDP integration +func (a *GoogleIDPAPI) GetLogs(ctx context.Context, integrationID string) ([]api.IdpIntegrationSyncLog, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/google-idp/"+integrationID+"/logs", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.IdpIntegrationSyncLog](resp) + return ret, err +} diff --git a/shared/management/client/rest/google_idp_test.go b/shared/management/client/rest/google_idp_test.go new file mode 100644 index 000000000..03a6c161e --- /dev/null +++ b/shared/management/client/rest/google_idp_test.go @@ -0,0 +1,248 @@ +//go:build integration + +package rest_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/shared/management/client/rest" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" +) + +var testGoogleIntegration = api.GoogleIntegration{ + Id: 1, + Enabled: true, + CustomerId: "C01234567", + SyncInterval: 300, + GroupPrefixes: []string{"eng-"}, + UserGroupPrefixes: []string{"dev-"}, + LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), +} + +func TestGoogleIDP_List_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal([]api.GoogleIntegration{testGoogleIntegration}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.List(context.Background()) + require.NoError(t, err) + assert.Len(t, ret, 1) + assert.Equal(t, testGoogleIntegration, ret[0]) + }) +} + +func TestGoogleIDP_List_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.List(context.Background()) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Empty(t, ret) + }) +} + +func TestGoogleIDP_Get_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal(testGoogleIntegration) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Get(context.Background(), "int-1") + require.NoError(t, err) + assert.Equal(t, testGoogleIntegration, *ret) + }) +} + +func TestGoogleIDP_Get_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Get(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestGoogleIDP_Create_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.CreateGoogleIntegrationRequest + err = json.Unmarshal(reqBytes, &req) + require.NoError(t, err) + assert.Equal(t, "C01234567", req.CustomerId) + retBytes, _ := json.Marshal(testGoogleIntegration) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Create(context.Background(), api.CreateGoogleIntegrationRequest{ + CustomerId: "C01234567", + ServiceAccountKey: "key-data", + GroupPrefixes: &[]string{"eng-"}, + }) + require.NoError(t, err) + assert.Equal(t, testGoogleIntegration, *ret) + }) +} + +func TestGoogleIDP_Create_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Create(context.Background(), api.CreateGoogleIntegrationRequest{}) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestGoogleIDP_Update_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "PUT", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.UpdateGoogleIntegrationRequest + err = json.Unmarshal(reqBytes, &req) + require.NoError(t, err) + assert.Equal(t, true, *req.Enabled) + retBytes, _ := json.Marshal(testGoogleIntegration) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Update(context.Background(), "int-1", api.UpdateGoogleIntegrationRequest{ + Enabled: ptr(true), + }) + require.NoError(t, err) + assert.Equal(t, testGoogleIntegration, *ret) + }) +} + +func TestGoogleIDP_Update_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Update(context.Background(), "int-1", api.UpdateGoogleIntegrationRequest{}) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestGoogleIDP_Delete_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "DELETE", r.Method) + w.WriteHeader(200) + }) + err := c.GoogleIDP.Delete(context.Background(), "int-1") + require.NoError(t, err) + }) +} + +func TestGoogleIDP_Delete_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + err := c.GoogleIDP.Delete(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + }) +} + +func TestGoogleIDP_Sync_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1/sync", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + retBytes, _ := json.Marshal(api.SyncResult{Result: ptr("ok")}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Sync(context.Background(), "int-1") + require.NoError(t, err) + assert.Equal(t, "ok", *ret.Result) + }) +} + +func TestGoogleIDP_Sync_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1/sync", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Sync(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestGoogleIDP_GetLogs_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal([]api.IdpIntegrationSyncLog{testSyncLog}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.GetLogs(context.Background(), "int-1") + require.NoError(t, err) + assert.Len(t, ret, 1) + assert.Equal(t, testSyncLog, ret[0]) + }) +} + +func TestGoogleIDP_GetLogs_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.GetLogs(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Empty(t, ret) + }) +} diff --git a/shared/management/client/rest/okta_scim_idp.go b/shared/management/client/rest/okta_scim_idp.go new file mode 100644 index 000000000..eb677dae8 --- /dev/null +++ b/shared/management/client/rest/okta_scim_idp.go @@ -0,0 +1,112 @@ +package rest + +import ( + "bytes" + "context" + "encoding/json" + + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// OktaScimIDPAPI APIs for Okta SCIM IDP integrations +type OktaScimIDPAPI struct { + c *Client +} + +// List retrieves all Okta SCIM IDP integrations +func (a *OktaScimIDPAPI) List(ctx context.Context) ([]api.OktaScimIntegration, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/okta-scim-idp", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.OktaScimIntegration](resp) + return ret, err +} + +// Get retrieves a specific Okta SCIM IDP integration by ID +func (a *OktaScimIDPAPI) Get(ctx context.Context, integrationID string) (*api.OktaScimIntegration, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/okta-scim-idp/"+integrationID, nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.OktaScimIntegration](resp) + return &ret, err +} + +// Create creates a new Okta SCIM IDP integration +func (a *OktaScimIDPAPI) Create(ctx context.Context, request api.CreateOktaScimIntegrationRequest) (*api.OktaScimIntegration, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/okta-scim-idp", bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.OktaScimIntegration](resp) + return &ret, err +} + +// Update updates an existing Okta SCIM IDP integration +func (a *OktaScimIDPAPI) Update(ctx context.Context, integrationID string, request api.UpdateOktaScimIntegrationRequest) (*api.OktaScimIntegration, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/okta-scim-idp/"+integrationID, bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.OktaScimIntegration](resp) + return &ret, err +} + +// Delete deletes an Okta SCIM IDP integration +func (a *OktaScimIDPAPI) Delete(ctx context.Context, integrationID string) error { + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/okta-scim-idp/"+integrationID, nil, nil) + if err != nil { + return err + } + if resp.Body != nil { + defer resp.Body.Close() + } + return nil +} + +// RegenerateToken regenerates the SCIM API token for an Okta SCIM integration +func (a *OktaScimIDPAPI) RegenerateToken(ctx context.Context, integrationID string) (*api.ScimTokenResponse, error) { + resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/okta-scim-idp/"+integrationID+"/token", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.ScimTokenResponse](resp) + return &ret, err +} + +// GetLogs retrieves synchronization logs for an Okta SCIM IDP integration +func (a *OktaScimIDPAPI) GetLogs(ctx context.Context, integrationID string) ([]api.IdpIntegrationSyncLog, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/okta-scim-idp/"+integrationID+"/logs", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.IdpIntegrationSyncLog](resp) + return ret, err +} diff --git a/shared/management/client/rest/okta_scim_idp_test.go b/shared/management/client/rest/okta_scim_idp_test.go new file mode 100644 index 000000000..d8d1f2b51 --- /dev/null +++ b/shared/management/client/rest/okta_scim_idp_test.go @@ -0,0 +1,246 @@ +//go:build integration + +package rest_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/shared/management/client/rest" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" +) + +var testOktaScimIntegration = api.OktaScimIntegration{ + Id: 1, + AuthToken: "****", + Enabled: true, + GroupPrefixes: []string{"eng-"}, + UserGroupPrefixes: []string{"dev-"}, + LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), +} + +func TestOktaScimIDP_List_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal([]api.OktaScimIntegration{testOktaScimIntegration}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.List(context.Background()) + require.NoError(t, err) + assert.Len(t, ret, 1) + assert.Equal(t, testOktaScimIntegration, ret[0]) + }) +} + +func TestOktaScimIDP_List_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.List(context.Background()) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Empty(t, ret) + }) +} + +func TestOktaScimIDP_Get_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal(testOktaScimIntegration) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.Get(context.Background(), "int-1") + require.NoError(t, err) + assert.Equal(t, testOktaScimIntegration, *ret) + }) +} + +func TestOktaScimIDP_Get_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.Get(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestOktaScimIDP_Create_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.CreateOktaScimIntegrationRequest + err = json.Unmarshal(reqBytes, &req) + require.NoError(t, err) + assert.Equal(t, "my-okta-connection", req.ConnectionName) + retBytes, _ := json.Marshal(testOktaScimIntegration) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.Create(context.Background(), api.CreateOktaScimIntegrationRequest{ + ConnectionName: "my-okta-connection", + GroupPrefixes: &[]string{"eng-"}, + }) + require.NoError(t, err) + assert.Equal(t, testOktaScimIntegration, *ret) + }) +} + +func TestOktaScimIDP_Create_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.Create(context.Background(), api.CreateOktaScimIntegrationRequest{}) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestOktaScimIDP_Update_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "PUT", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.UpdateOktaScimIntegrationRequest + err = json.Unmarshal(reqBytes, &req) + require.NoError(t, err) + assert.Equal(t, true, *req.Enabled) + retBytes, _ := json.Marshal(testOktaScimIntegration) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.Update(context.Background(), "int-1", api.UpdateOktaScimIntegrationRequest{ + Enabled: ptr(true), + }) + require.NoError(t, err) + assert.Equal(t, testOktaScimIntegration, *ret) + }) +} + +func TestOktaScimIDP_Update_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.Update(context.Background(), "int-1", api.UpdateOktaScimIntegrationRequest{}) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestOktaScimIDP_Delete_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "DELETE", r.Method) + w.WriteHeader(200) + }) + err := c.OktaScimIDP.Delete(context.Background(), "int-1") + require.NoError(t, err) + }) +} + +func TestOktaScimIDP_Delete_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + err := c.OktaScimIDP.Delete(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + }) +} + +func TestOktaScimIDP_RegenerateToken_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1/token", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + retBytes, _ := json.Marshal(testScimToken) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.RegenerateToken(context.Background(), "int-1") + require.NoError(t, err) + assert.Equal(t, testScimToken, *ret) + }) +} + +func TestOktaScimIDP_RegenerateToken_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1/token", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.RegenerateToken(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestOktaScimIDP_GetLogs_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal([]api.IdpIntegrationSyncLog{testSyncLog}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.GetLogs(context.Background(), "int-1") + require.NoError(t, err) + assert.Len(t, ret, 1) + assert.Equal(t, testSyncLog, ret[0]) + }) +} + +func TestOktaScimIDP_GetLogs_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.GetLogs(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Empty(t, ret) + }) +} diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 40d306bc1..3b5043852 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -68,8 +68,17 @@ tags: - name: MSP description: MSP portal for Tenant management. x-cloud-only: true - - name: IDP - description: Manage identity provider integrations for user and group sync. + - name: IDP SCIM Integrations + description: Manage generic SCIM identity provider integrations for user and group sync. + x-cloud-only: true + - name: IDP Google Integrations + description: Manage Google Workspace identity provider integrations for user and group sync. + x-cloud-only: true + - name: IDP Azure Integrations + description: Manage Azure AD identity provider integrations for user and group sync. + x-cloud-only: true + - name: IDP Okta SCIM Integrations + description: Manage Okta SCIM identity provider integrations for user and group sync. x-cloud-only: true - name: EDR Intune Integrations description: Manage Microsoft Intune EDR integrations. @@ -83,6 +92,9 @@ tags: - name: EDR Huntress Integrations description: Manage Huntress EDR integrations. x-cloud-only: true + - name: EDR FleetDM Integrations + description: Manage FleetDM EDR integrations. + x-cloud-only: true - name: EDR Peers description: Manage EDR compliance bypass for peers. x-cloud-only: true @@ -1675,15 +1687,18 @@ components: - locations - action PeerNetworkRangeCheck: - description: Posture check for allow or deny access based on peer local network addresses + description: | + Posture check for allow or deny access based on the peer's IP addresses. A range matches when it + contains any of the peer's local network interface IPs or its public connection (NAT egress) IP, + so ranges may target private subnets, public CIDRs, or single hosts via a /32 or /128. type: object properties: ranges: - description: List of peer network ranges in CIDR notation + description: List of network ranges in CIDR notation, matched against the peer's local interface IPs and its public connection IP type: array items: type: string - example: [ "192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56" ] + example: [ "192.168.1.0/24", "10.0.0.0/8", "1.0.0.0/24", "2.2.2.2/32", "2001:db8:1234:1a00::/56" ] action: description: Action to take upon policy match type: string @@ -2848,6 +2863,11 @@ components: type: string description: "Protocol type: http, tcp, or udp" example: "http" + metadata: + type: object + additionalProperties: + type: string + description: "Extra context about the request (e.g. crowdsec_verdict)" required: - id - service_id @@ -2900,6 +2920,7 @@ components: - okta - pocketid - microsoft + - adfs example: oidc IdentityProvider: type: object @@ -3246,6 +3267,14 @@ components: pattern: '^[a-zA-Z]{2}$' example: "DE" description: ISO 3166-1 alpha-2 country codes to block. + crowdsec_mode: + type: string + enum: + - "off" + - "enforce" + - "observe" + default: "off" + description: CrowdSec IP reputation mode. Only available when the proxy cluster supports CrowdSec. PasswordAuthConfig: type: object properties: @@ -3349,6 +3378,10 @@ components: type: boolean description: Whether a subdomain label is required in front of this domain. When true, the domain cannot be used bare. example: false + supports_crowdsec: + type: boolean + description: Whether the proxy cluster has CrowdSec configured + example: false required: - id - domain @@ -3396,6 +3429,17 @@ components: description: Display name for the admin user (defaults to email if not provided) type: string example: Admin User + create_pat: + description: If true and the server has setup-time PAT issuance enabled (NB_SETUP_PAT_ENABLED=true), create a Personal Access Token for the new owner user and return it in the response. Ignored when the server feature is disabled. + type: boolean + example: true + pat_expire_in: + description: Expiration of the Personal Access Token in days. Applies only when create_pat is true and the server feature is enabled. Defaults to 1 day when omitted. + type: integer + minimum: 1 + maximum: 365 + default: 1 + example: 30 required: - email - password @@ -3412,6 +3456,12 @@ components: description: Email address of the created user type: string example: admin@example.com + personal_access_token: + description: Plain text Personal Access Token created during setup. Present only when create_pat was requested and the NB_SETUP_PAT_ENABLED feature was enabled on the server. + type: string + format: password + readOnly: true + example: nbp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx required: - user_id - email @@ -4438,75 +4488,129 @@ components: description: Status of agent firewall. Can be one of Disabled, Enabled, Pending Isolation, Isolated, Pending Release. example: "Enabled" - CreateScimIntegrationRequest: + EDRFleetDMRequest: type: object - description: Request payload for creating an SCIM IDP integration - required: - - prefix - - provider + description: Request payload for creating or updating a FleetDM EDR integration properties: - prefix: + api_url: type: string - description: The connection prefix used for the SCIM provider - provider: + description: FleetDM server URL + api_token: type: string - description: Name of the SCIM identity provider - group_prefixes: + description: FleetDM API token + groups: type: array - description: List of start_with string patterns for groups to sync + description: The Groups this integrations applies to items: type: string - example: [ "Engineering", "Sales" ] - user_group_prefixes: - type: array - description: List of start_with string patterns for groups which users to sync - items: - type: string - example: [ "Users" ] - UpdateScimIntegrationRequest: - type: object - description: Request payload for updating an SCIM IDP integration - properties: + last_synced_interval: + type: integer + description: The devices last sync requirement interval in hours. Minimum value is 24 hours + minimum: 24 enabled: type: boolean description: Indicates whether the integration is enabled - example: true - group_prefixes: - type: array - description: List of start_with string patterns for groups to sync - items: - type: string - example: [ "Engineering", "Sales" ] - user_group_prefixes: - type: array - description: List of start_with string patterns for groups which users to sync - items: - type: string - example: [ "Users" ] - ScimIntegration: + default: true + match_attributes: + $ref: '#/components/schemas/FleetDMMatchAttributes' + required: + - api_url + - api_token + - groups + - last_synced_interval + - match_attributes + EDRFleetDMResponse: type: object - description: Represents a SCIM IDP integration + description: Represents a FleetDM EDR integration configuration required: - id - - enabled - - provider - - group_prefixes - - user_group_prefixes - - auth_token + - account_id + - api_url + - created_by - last_synced_at + - created_at + - updated_at + - groups + - last_synced_interval + - match_attributes + - enabled properties: id: type: integer format: int64 - description: The unique identifier for the integration + description: The unique numeric identifier for the integration. example: 123 + account_id: + type: string + description: The identifier of the account this integration belongs to. + example: "ch8i4ug6lnn4g9hqv7l0" + api_url: + type: string + description: FleetDM server URL + last_synced_at: + type: string + format: date-time + description: Timestamp of when the integration was last synced. + example: "2023-05-15T10:30:00Z" + created_by: + type: string + description: The user id that created the integration + created_at: + type: string + format: date-time + description: Timestamp of when the integration was created. + example: "2023-05-15T10:30:00Z" + updated_at: + type: string + format: date-time + description: Timestamp of when the integration was last updated. + example: "2023-05-16T11:45:00Z" + groups: + type: array + description: List of groups + items: + $ref: '#/components/schemas/Group' + last_synced_interval: + type: integer + description: The devices last sync requirement interval in hours. enabled: type: boolean description: Indicates whether the integration is enabled - example: true - provider: - type: string - description: Name of the SCIM identity provider + default: true + match_attributes: + $ref: '#/components/schemas/FleetDMMatchAttributes' + + FleetDMMatchAttributes: + type: object + description: Attribute conditions to match when approving FleetDM hosts. Most attributes work with FleetDM's free/open-source version. Premium-only attributes are marked accordingly + additionalProperties: false + properties: + disk_encryption_enabled: + type: boolean + description: Whether disk encryption (FileVault/BitLocker) must be enabled on the host + failing_policies_count_max: + type: integer + description: Maximum number of allowed failing policies. Use 0 to require all policies to pass + minimum: 0 + example: 0 + vulnerable_software_count_max: + type: integer + description: Maximum number of allowed vulnerable software on the host + minimum: 0 + example: 0 + status_online: + type: boolean + description: Whether the host must be online (recently seen by Fleet) + required_policies: + type: array + description: List of FleetDM policy IDs that must be passing on the host. If any of these policies is failing, the host is non-compliant + items: + type: integer + example: [1, 5, 12] + + IntegrationSyncFilters: + type: object + properties: group_prefixes: type: array description: List of start_with string patterns for groups to sync @@ -4519,15 +4623,77 @@ components: items: type: string example: [ "Users" ] - auth_token: + connector_id: type: string - description: SCIM API token (full on creation, masked otherwise) - example: "nbs_abc***********************************" - last_synced_at: - type: string - format: date-time - description: Timestamp of when the integration was last synced - example: "2023-05-15T10:30:00Z" + description: DEX connector ID for embedded IDP setups + IntegrationEnabled: + type: object + properties: + enabled: + type: boolean + description: Whether the integration is enabled + example: true + CreateScimIntegrationRequest: + allOf: + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for creating an SCIM IDP integration + required: + - prefix + - provider + properties: + prefix: + type: string + description: The connection prefix used for the SCIM provider + provider: + type: string + description: Name of the SCIM identity provider + UpdateScimIntegrationRequest: + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for updating an SCIM IDP integration + properties: + prefix: + type: string + description: The connection prefix used for the SCIM provider + ScimIntegration: + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Represents a SCIM IDP integration + required: + - id + - enabled + - prefix + - provider + - group_prefixes + - user_group_prefixes + - auth_token + - last_synced_at + properties: + id: + type: integer + format: int64 + description: The unique identifier for the integration + example: 123 + prefix: + type: string + description: The connection prefix used for the SCIM provider + provider: + type: string + description: Name of the SCIM identity provider + auth_token: + type: string + description: SCIM API token (full on creation, masked otherwise) + example: "nbs_abc***********************************" + last_synced_at: + type: string + format: date-time + description: Timestamp of when the integration was last synced + example: "2023-05-15T10:30:00Z" IdpIntegrationSyncLog: type: object description: Represents a synchronization log entry for an integration @@ -4565,6 +4731,229 @@ components: type: string description: The newly generated SCIM API token example: "nbs_F3f0d..." + CreateGoogleIntegrationRequest: + allOf: + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for creating a Google Workspace IDP integration + required: + - service_account_key + - customer_id + properties: + service_account_key: + type: string + description: Base64-encoded Google service account key + example: "eyJ0eXBlIjoic2VydmljZV9hY2NvdW50Ii..." + customer_id: + type: string + description: Customer ID from Google Workspace Account Settings + example: "C01234567" + sync_interval: + type: integer + description: Sync interval in seconds (minimum 300). Defaults to 300 if not specified. + minimum: 300 + example: 300 + UpdateGoogleIntegrationRequest: + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for updating a Google Workspace IDP integration. All fields are optional. + properties: + service_account_key: + type: string + description: Base64-encoded Google service account key + customer_id: + type: string + description: Customer ID from Google Workspace Account Settings + sync_interval: + type: integer + description: Sync interval in seconds (minimum 300) + minimum: 300 + GoogleIntegration: + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Represents a Google Workspace IDP integration + required: + - id + - customer_id + - sync_interval + - enabled + - group_prefixes + - user_group_prefixes + - last_synced_at + properties: + id: + type: integer + format: int64 + description: The unique identifier for the integration + example: 1 + customer_id: + type: string + description: Customer ID from Google Workspace + example: "C01234567" + sync_interval: + type: integer + description: Sync interval in seconds + example: 300 + last_synced_at: + type: string + format: date-time + description: Timestamp of the last synchronization + example: "2023-05-15T10:30:00Z" + CreateAzureIntegrationRequest: + allOf: + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for creating an Azure AD IDP integration + required: + - client_secret + - client_id + - tenant_id + - host + properties: + client_secret: + type: string + description: Base64-encoded Azure AD client secret + example: "c2VjcmV0..." + client_id: + type: string + description: Azure AD application (client) ID + example: "12345678-1234-1234-1234-123456789012" + tenant_id: + type: string + description: Azure AD tenant ID + example: "87654321-4321-4321-4321-210987654321" + sync_interval: + type: integer + description: Sync interval in seconds (minimum 300). Defaults to 300 if not specified. + minimum: 300 + example: 300 + host: + type: string + description: Azure host domain for the Graph API + enum: + - microsoft.com + - microsoft.us + example: "microsoft.com" + UpdateAzureIntegrationRequest: + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for updating an Azure AD IDP integration. All fields are optional. + properties: + client_secret: + type: string + description: Base64-encoded Azure AD client secret + client_id: + type: string + description: Azure AD application (client) ID + tenant_id: + type: string + description: Azure AD tenant ID + sync_interval: + type: integer + description: Sync interval in seconds (minimum 300) + minimum: 300 + AzureIntegration: + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Represents an Azure AD IDP integration + required: + - id + - client_id + - tenant_id + - sync_interval + - enabled + - group_prefixes + - user_group_prefixes + - host + - last_synced_at + properties: + id: + type: integer + format: int64 + description: The unique identifier for the integration + example: 1 + client_id: + type: string + description: Azure AD application (client) ID + example: "12345678-1234-1234-1234-123456789012" + tenant_id: + type: string + description: Azure AD tenant ID + example: "87654321-4321-4321-4321-210987654321" + sync_interval: + type: integer + description: Sync interval in seconds + example: 300 + host: + type: string + description: Azure host domain for the Graph API + example: "microsoft.com" + last_synced_at: + type: string + format: date-time + description: Timestamp of the last synchronization + example: "2023-05-15T10:30:00Z" + CreateOktaScimIntegrationRequest: + allOf: + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for creating an Okta SCIM IDP integration + required: + - connection_name + properties: + connection_name: + type: string + description: The Okta enterprise connection name on Auth0 + example: "my-okta-connection" + UpdateOktaScimIntegrationRequest: + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for updating an Okta SCIM IDP integration. All fields are optional. + OktaScimIntegration: + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Represents an Okta SCIM IDP integration + required: + - id + - enabled + - group_prefixes + - user_group_prefixes + - auth_token + - last_synced_at + properties: + id: + type: integer + format: int64 + description: The unique identifier for the integration + example: 1 + auth_token: + type: string + description: SCIM API token (full on creation/regeneration, masked on retrieval) + example: "nbs_abc***********************************" + last_synced_at: + type: string + format: date-time + description: Timestamp of the last synchronization + example: "2023-05-15T10:30:00Z" + SyncResult: + type: object + description: Response for a manual sync trigger + properties: + result: + type: string + example: "ok" NotificationChannelType: type: string description: The type of notification channel. @@ -4782,7 +5171,10 @@ paths: /api/setup: post: summary: Setup Instance - description: Creates the initial admin user for the instance. This endpoint does not require authentication but only works when setup is required (no accounts exist and embedded IDP is enabled). + description: | + Creates the initial admin user for the instance. This endpoint does not require authentication but only works when setup is required (no accounts exist and embedded IDP is enabled). + + When the management server is started with `NB_SETUP_PAT_ENABLED=true` and the request includes `create_pat: true`, the endpoint also provisions the NetBird account for the new owner user and returns the plain text Personal Access Token in `personal_access_token`. The optional `pat_expire_in` value applies only when `create_pat` is true and defaults to 1 day when omitted. If a post-user step fails, setup-created resources are rolled back when safe; if account cleanup fails, the owner user is left in place to avoid leaving an account without its admin user. tags: [ Instance ] security: [ ] requestBody: @@ -4795,6 +5187,12 @@ paths: responses: '200': description: Setup completed successfully + headers: + Cache-Control: + description: Always set to no-store because the response may contain a one-time plain text Personal Access Token. + schema: + type: string + example: no-store content: application/json: schema: @@ -9712,10 +10110,877 @@ paths: schema: $ref: '#/components/schemas/ErrorResponse' + /api/integrations/google-idp: + post: + tags: + - IDP Google Integrations + summary: Create Google IDP Integration + description: Creates a new Google Workspace IDP integration + operationId: createGoogleIntegration + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/CreateGoogleIntegrationRequest' + responses: + '200': + description: Integration created successfully. Returns the created integration. + content: + application/json: + schema: + $ref: '#/components/schemas/GoogleIntegration' + '400': + description: Bad Request (e.g., invalid JSON, missing required fields, validation error). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized (e.g., missing or invalid authentication token). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + get: + tags: + - IDP Google Integrations + summary: Get All Google IDP Integrations + description: Retrieves all Google Workspace IDP integrations for the authenticated account + operationId: getAllGoogleIntegrations + responses: + '200': + description: A list of Google IDP integrations. + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/GoogleIntegration' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/google-idp/{id}: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Google IDP integration. + schema: + type: integer + format: int64 + example: 1 + get: + tags: + - IDP Google Integrations + summary: Get Google IDP Integration + description: Retrieves a Google IDP integration by ID. + operationId: getGoogleIntegration + responses: + '200': + description: Successfully retrieved the integration details. + content: + application/json: + schema: + $ref: '#/components/schemas/GoogleIntegration' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found (e.g., integration with the given ID does not exist). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + put: + tags: + - IDP Google Integrations + summary: Update Google IDP Integration + description: Updates an existing Google Workspace IDP integration. + operationId: updateGoogleIntegration + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateGoogleIntegrationRequest' + responses: + '200': + description: Integration updated successfully. Returns the updated integration. + content: + application/json: + schema: + $ref: '#/components/schemas/GoogleIntegration' + '400': + description: Bad Request (e.g., invalid JSON, validation error, invalid ID). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + delete: + tags: + - IDP Google Integrations + summary: Delete Google IDP Integration + description: Deletes a Google IDP integration by ID. + operationId: deleteGoogleIntegration + responses: + '200': + description: Integration deleted successfully. Returns an empty object. + content: + application/json: + schema: + type: object + example: { } + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/google-idp/{id}/sync: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Google IDP integration. + schema: + type: integer + format: int64 + example: 1 + post: + tags: + - IDP Google Integrations + summary: Sync Google IDP Integration + description: Triggers a manual synchronization for a Google IDP integration. + operationId: syncGoogleIntegration + responses: + '200': + description: Sync triggered successfully. + content: + application/json: + schema: + $ref: '#/components/schemas/SyncResult' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/google-idp/{id}/logs: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Google IDP integration. + schema: + type: integer + format: int64 + example: 1 + get: + tags: + - IDP Google Integrations + summary: Get Google Integration Sync Logs + description: Retrieves synchronization logs for a Google IDP integration. + operationId: getGoogleIntegrationLogs + responses: + '200': + description: Successfully retrieved the integration sync logs. + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/IdpIntegrationSyncLog' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/azure-idp: + post: + tags: + - IDP Azure Integrations + summary: Create Azure IDP Integration + description: Creates a new Azure AD IDP integration + operationId: createAzureIntegration + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/CreateAzureIntegrationRequest' + responses: + '200': + description: Integration created successfully. Returns the created integration. + content: + application/json: + schema: + $ref: '#/components/schemas/AzureIntegration' + '400': + description: Bad Request (e.g., invalid JSON, missing required fields, validation error). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized (e.g., missing or invalid authentication token). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + get: + tags: + - IDP Azure Integrations + summary: Get All Azure IDP Integrations + description: Retrieves all Azure AD IDP integrations for the authenticated account + operationId: getAllAzureIntegrations + responses: + '200': + description: A list of Azure IDP integrations. + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/AzureIntegration' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/azure-idp/{id}: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Azure IDP integration. + schema: + type: integer + format: int64 + example: 1 + get: + tags: + - IDP Azure Integrations + summary: Get Azure IDP Integration + description: Retrieves an Azure IDP integration by ID. + operationId: getAzureIntegration + responses: + '200': + description: Successfully retrieved the integration details. + content: + application/json: + schema: + $ref: '#/components/schemas/AzureIntegration' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found (e.g., integration with the given ID does not exist). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + put: + tags: + - IDP Azure Integrations + summary: Update Azure IDP Integration + description: Updates an existing Azure AD IDP integration. + operationId: updateAzureIntegration + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateAzureIntegrationRequest' + responses: + '200': + description: Integration updated successfully. Returns the updated integration. + content: + application/json: + schema: + $ref: '#/components/schemas/AzureIntegration' + '400': + description: Bad Request (e.g., invalid JSON, validation error, invalid ID). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + delete: + tags: + - IDP Azure Integrations + summary: Delete Azure IDP Integration + description: Deletes an Azure IDP integration by ID. + operationId: deleteAzureIntegration + responses: + '200': + description: Integration deleted successfully. Returns an empty object. + content: + application/json: + schema: + type: object + example: { } + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/azure-idp/{id}/sync: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Azure IDP integration. + schema: + type: integer + format: int64 + example: 1 + post: + tags: + - IDP Azure Integrations + summary: Sync Azure IDP Integration + description: Triggers a manual synchronization for an Azure IDP integration. + operationId: syncAzureIntegration + responses: + '200': + description: Sync triggered successfully. + content: + application/json: + schema: + $ref: '#/components/schemas/SyncResult' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/azure-idp/{id}/logs: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Azure IDP integration. + schema: + type: integer + format: int64 + example: 1 + get: + tags: + - IDP Azure Integrations + summary: Get Azure Integration Sync Logs + description: Retrieves synchronization logs for an Azure IDP integration. + operationId: getAzureIntegrationLogs + responses: + '200': + description: Successfully retrieved the integration sync logs. + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/IdpIntegrationSyncLog' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/okta-scim-idp: + post: + tags: + - IDP Okta SCIM Integrations + summary: Create Okta SCIM IDP Integration + description: Creates a new Okta SCIM IDP integration + operationId: createOktaScimIntegration + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/CreateOktaScimIntegrationRequest' + responses: + '200': + description: Integration created successfully. Returns the created integration. + content: + application/json: + schema: + $ref: '#/components/schemas/OktaScimIntegration' + '400': + description: Bad Request (e.g., invalid JSON, missing required fields, validation error). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized (e.g., missing or invalid authentication token). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + get: + tags: + - IDP Okta SCIM Integrations + summary: Get All Okta SCIM IDP Integrations + description: Retrieves all Okta SCIM IDP integrations for the authenticated account + operationId: getAllOktaScimIntegrations + responses: + '200': + description: A list of Okta SCIM IDP integrations. + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/OktaScimIntegration' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/okta-scim-idp/{id}: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Okta SCIM IDP integration. + schema: + type: integer + format: int64 + example: 1 + get: + tags: + - IDP Okta SCIM Integrations + summary: Get Okta SCIM IDP Integration + description: Retrieves an Okta SCIM IDP integration by ID. + operationId: getOktaScimIntegration + responses: + '200': + description: Successfully retrieved the integration details. + content: + application/json: + schema: + $ref: '#/components/schemas/OktaScimIntegration' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found (e.g., integration with the given ID does not exist). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + put: + tags: + - IDP Okta SCIM Integrations + summary: Update Okta SCIM IDP Integration + description: Updates an existing Okta SCIM IDP integration. + operationId: updateOktaScimIntegration + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateOktaScimIntegrationRequest' + responses: + '200': + description: Integration updated successfully. Returns the updated integration. + content: + application/json: + schema: + $ref: '#/components/schemas/OktaScimIntegration' + '400': + description: Bad Request (e.g., invalid JSON, validation error, invalid ID). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + delete: + tags: + - IDP Okta SCIM Integrations + summary: Delete Okta SCIM IDP Integration + description: Deletes an Okta SCIM IDP integration by ID. + operationId: deleteOktaScimIntegration + responses: + '200': + description: Integration deleted successfully. Returns an empty object. + content: + application/json: + schema: + type: object + example: { } + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/okta-scim-idp/{id}/token: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Okta SCIM IDP integration. + schema: + type: integer + format: int64 + example: 1 + post: + tags: + - IDP Okta SCIM Integrations + summary: Regenerate Okta SCIM Token + description: Regenerates the SCIM API token for an Okta SCIM IDP integration. + operationId: regenerateOktaScimToken + responses: + '200': + description: Token regenerated successfully. Returns the new token. + content: + application/json: + schema: + $ref: '#/components/schemas/ScimTokenResponse' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/okta-scim-idp/{id}/logs: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Okta SCIM IDP integration. + schema: + type: integer + format: int64 + example: 1 + get: + tags: + - IDP Okta SCIM Integrations + summary: Get Okta SCIM Integration Sync Logs + description: Retrieves synchronization logs for an Okta SCIM IDP integration. + operationId: getOktaScimIntegrationLogs + responses: + '200': + description: Successfully retrieved the integration sync logs. + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/IdpIntegrationSyncLog' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/integrations/scim-idp: post: tags: - - IDP + - IDP SCIM Integrations summary: Create SCIM IDP Integration description: Creates a new SCIM integration operationId: createSCIMIntegration @@ -9752,7 +11017,7 @@ paths: $ref: '#/components/schemas/ErrorResponse' get: tags: - - IDP + - IDP SCIM Integrations summary: Get All SCIM IDP Integrations description: Retrieves all SCIM IDP integrations for the authenticated account operationId: getAllSCIMIntegrations @@ -9784,11 +11049,12 @@ paths: required: true description: The unique identifier of the SCIM IDP integration. schema: - type: string - example: "ch8i4ug6lnn4g9hqv7m0" + type: integer + format: int64 + example: 1 get: tags: - - IDP + - IDP SCIM Integrations summary: Get SCIM IDP Integration description: Retrieves an SCIM IDP integration by ID. operationId: getSCIMIntegration @@ -9825,7 +11091,7 @@ paths: $ref: '#/components/schemas/ErrorResponse' put: tags: - - IDP + - IDP SCIM Integrations summary: Update SCIM IDP Integration description: Updates an existing SCIM IDP Integration. operationId: updateSCIMIntegration @@ -9868,7 +11134,7 @@ paths: $ref: '#/components/schemas/ErrorResponse' delete: tags: - - IDP + - IDP SCIM Integrations summary: Delete SCIM IDP Integration description: Deletes an SCIM IDP integration by ID. operationId: deleteSCIMIntegration @@ -9911,11 +11177,12 @@ paths: required: true description: The unique identifier of the SCIM IDP integration. schema: - type: string - example: "ch8i4ug6lnn4g9hqv7m0" + type: integer + format: int64 + example: 1 post: tags: - - IDP + - IDP SCIM Integrations summary: Regenerate SCIM Token description: Regenerates the SCIM API token for an SCIM IDP integration. operationId: regenerateSCIMToken @@ -9957,11 +11224,12 @@ paths: required: true description: The unique identifier of the SCIM IDP integration. schema: - type: string - example: "ch8i4ug6lnn4g9hqv7m0" + type: integer + format: int64 + example: 1 get: tags: - - IDP + - IDP SCIM Integrations summary: Get SCIM Integration Sync Logs description: Retrieves synchronization logs for a SCIM IDP integration. operationId: getSCIMIntegrationLogs @@ -10154,6 +11422,161 @@ paths: schema: $ref: '#/components/schemas/ErrorResponse' + /api/integrations/edr/fleetdm: + post: + tags: + - EDR FleetDM Integrations + summary: Create EDR FleetDM Integration + description: Creates a new EDR FleetDM integration + operationId: createFleetDMEDRIntegration + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/EDRFleetDMRequest' + responses: + '200': + description: Integration created successfully. Returns the created integration. + content: + application/json: + schema: + $ref: '#/components/schemas/EDRFleetDMResponse' + '400': + description: Bad Request (e.g., invalid JSON, missing required fields, validation error). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized (e.g., missing or invalid authentication token). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + get: + tags: + - EDR FleetDM Integrations + summary: Get EDR FleetDM Integration + description: Retrieves a specific EDR FleetDM integration by its ID. + responses: + '200': + description: Successfully retrieved the integration details. + content: + application/json: + schema: + $ref: '#/components/schemas/EDRFleetDMResponse' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found (e.g., integration with the given ID does not exist). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + put: + tags: + - EDR FleetDM Integrations + summary: Update EDR FleetDM Integration + description: Updates an existing EDR FleetDM Integration. + operationId: updateFleetDMEDRIntegration + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/EDRFleetDMRequest' + responses: + '200': + description: Integration updated successfully. Returns the updated integration. + content: + application/json: + schema: + $ref: '#/components/schemas/EDRFleetDMResponse' + '400': + description: Bad Request (e.g., invalid JSON, validation error, invalid ID). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + delete: + tags: + - EDR FleetDM Integrations + summary: Delete EDR FleetDM Integration + description: Deletes an EDR FleetDM Integration by its ID. + responses: + '200': + description: Integration deleted successfully. Returns an empty object. + content: + application/json: + schema: + type: object + example: { } + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/peers/{peer-id}/edr/bypass: parameters: - name: peer-id diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index ff8f8056c..62b5541dc 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -17,6 +17,45 @@ const ( TokenAuthScopes tokenAuthContextKey = "TokenAuth.Scopes" ) +// Defines values for AccessRestrictionsCrowdsecMode. +const ( + AccessRestrictionsCrowdsecModeEnforce AccessRestrictionsCrowdsecMode = "enforce" + AccessRestrictionsCrowdsecModeObserve AccessRestrictionsCrowdsecMode = "observe" + AccessRestrictionsCrowdsecModeOff AccessRestrictionsCrowdsecMode = "off" +) + +// Valid indicates whether the value is a known member of the AccessRestrictionsCrowdsecMode enum. +func (e AccessRestrictionsCrowdsecMode) Valid() bool { + switch e { + case AccessRestrictionsCrowdsecModeEnforce: + return true + case AccessRestrictionsCrowdsecModeObserve: + return true + case AccessRestrictionsCrowdsecModeOff: + return true + default: + return false + } +} + +// Defines values for CreateAzureIntegrationRequestHost. +const ( + CreateAzureIntegrationRequestHostMicrosoftCom CreateAzureIntegrationRequestHost = "microsoft.com" + CreateAzureIntegrationRequestHostMicrosoftUs CreateAzureIntegrationRequestHost = "microsoft.us" +) + +// Valid indicates whether the value is a known member of the CreateAzureIntegrationRequestHost enum. +func (e CreateAzureIntegrationRequestHost) Valid() bool { + switch e { + case CreateAzureIntegrationRequestHostMicrosoftCom: + return true + case CreateAzureIntegrationRequestHostMicrosoftUs: + return true + default: + return false + } +} + // Defines values for CreateIntegrationRequestPlatform. const ( CreateIntegrationRequestPlatformDatadog CreateIntegrationRequestPlatform = "datadog" @@ -472,6 +511,7 @@ func (e GroupMinimumIssued) Valid() bool { // Defines values for IdentityProviderType. const ( + IdentityProviderTypeAdfs IdentityProviderType = "adfs" IdentityProviderTypeEntra IdentityProviderType = "entra" IdentityProviderTypeGoogle IdentityProviderType = "google" IdentityProviderTypeMicrosoft IdentityProviderType = "microsoft" @@ -484,6 +524,8 @@ const ( // Valid indicates whether the value is a known member of the IdentityProviderType enum. func (e IdentityProviderType) Valid() bool { switch e { + case IdentityProviderTypeAdfs: + return true case IdentityProviderTypeEntra: return true case IdentityProviderTypeGoogle: @@ -1371,8 +1413,14 @@ type AccessRestrictions struct { // BlockedCountries ISO 3166-1 alpha-2 country codes to block. BlockedCountries *[]string `json:"blocked_countries,omitempty"` + + // CrowdsecMode CrowdSec IP reputation mode. Only available when the proxy cluster supports CrowdSec. + CrowdsecMode *AccessRestrictionsCrowdsecMode `json:"crowdsec_mode,omitempty"` } +// AccessRestrictionsCrowdsecMode CrowdSec IP reputation mode. Only available when the proxy cluster supports CrowdSec. +type AccessRestrictionsCrowdsecMode string + // AccessiblePeer defines model for AccessiblePeer. type AccessiblePeer struct { // CityName Commonly used English name of the city @@ -1532,6 +1580,39 @@ type AvailablePorts struct { Udp int `json:"udp"` } +// AzureIntegration defines model for AzureIntegration. +type AzureIntegration struct { + // ClientId Azure AD application (client) ID + ClientId string `json:"client_id"` + + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // Enabled Whether the integration is enabled + Enabled bool `json:"enabled"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes []string `json:"group_prefixes"` + + // Host Azure host domain for the Graph API + Host string `json:"host"` + + // Id The unique identifier for the integration + Id int64 `json:"id"` + + // LastSyncedAt Timestamp of the last synchronization + LastSyncedAt time.Time `json:"last_synced_at"` + + // SyncInterval Sync interval in seconds + SyncInterval int `json:"sync_interval"` + + // TenantId Azure AD tenant ID + TenantId string `json:"tenant_id"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes []string `json:"user_group_prefixes"` +} + // BearerAuthConfig defines model for BearerAuthConfig. type BearerAuthConfig struct { // DistributionGroups List of group IDs that can use bearer auth @@ -1608,7 +1689,9 @@ type Checks struct { // OsVersionCheck Posture check for the version of operating system OsVersionCheck *OSVersionCheck `json:"os_version_check,omitempty"` - // PeerNetworkRangeCheck Posture check for allow or deny access based on peer local network addresses + // PeerNetworkRangeCheck Posture check for allow or deny access based on the peer's IP addresses. A range matches when it + // contains any of the peer's local network interface IPs or its public connection (NAT egress) IP, + // so ranges may target private subnets, public CIDRs, or single hosts via a /32 or /128. PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:"peer_network_range_check,omitempty"` // ProcessCheck Posture Check for binaries exist and are running in the peer’s system @@ -1639,6 +1722,57 @@ type Country struct { // CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country type CountryCode = string +// CreateAzureIntegrationRequest defines model for CreateAzureIntegrationRequest. +type CreateAzureIntegrationRequest struct { + // ClientId Azure AD application (client) ID + ClientId string `json:"client_id"` + + // ClientSecret Base64-encoded Azure AD client secret + ClientSecret string `json:"client_secret"` + + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + + // Host Azure host domain for the Graph API + Host CreateAzureIntegrationRequestHost `json:"host"` + + // SyncInterval Sync interval in seconds (minimum 300). Defaults to 300 if not specified. + SyncInterval *int `json:"sync_interval,omitempty"` + + // TenantId Azure AD tenant ID + TenantId string `json:"tenant_id"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` +} + +// CreateAzureIntegrationRequestHost Azure host domain for the Graph API +type CreateAzureIntegrationRequestHost string + +// CreateGoogleIntegrationRequest defines model for CreateGoogleIntegrationRequest. +type CreateGoogleIntegrationRequest struct { + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // CustomerId Customer ID from Google Workspace Account Settings + CustomerId string `json:"customer_id"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + + // ServiceAccountKey Base64-encoded Google service account key + ServiceAccountKey string `json:"service_account_key"` + + // SyncInterval Sync interval in seconds (minimum 300). Defaults to 300 if not specified. + SyncInterval *int `json:"sync_interval,omitempty"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` +} + // CreateIntegrationRequest Request payload for creating a new event streaming integration. Also used as the structure for the PUT request body, but not all fields are applicable for updates (see PUT operation description). type CreateIntegrationRequest struct { // Config Platform-specific configuration as key-value pairs. For creation, all necessary credentials and settings must be provided. For updates, provide the fields to change or the entire new configuration. @@ -1654,6 +1788,21 @@ type CreateIntegrationRequest struct { // CreateIntegrationRequestPlatform The event streaming platform to integrate with (e.g., "datadog", "s3", "firehose"). This field is used for creation. For updates (PUT), this field, if sent, is ignored by the backend. type CreateIntegrationRequestPlatform string +// CreateOktaScimIntegrationRequest defines model for CreateOktaScimIntegrationRequest. +type CreateOktaScimIntegrationRequest struct { + // ConnectionName The Okta enterprise connection name on Auth0 + ConnectionName string `json:"connection_name"` + + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` +} + // CreateResellerMSPRequest defines model for CreateResellerMSPRequest. type CreateResellerMSPRequest struct { // Domain The domain for the MSP @@ -1669,8 +1818,11 @@ type CreateResellerMSPRequest struct { ResellerCustomerId *string `json:"reseller_customer_id,omitempty"` } -// CreateScimIntegrationRequest Request payload for creating an SCIM IDP integration +// CreateScimIntegrationRequest defines model for CreateScimIntegrationRequest. type CreateScimIntegrationRequest struct { + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + // GroupPrefixes List of start_with string patterns for groups to sync GroupPrefixes *[]string `json:"group_prefixes,omitempty"` @@ -1822,6 +1974,63 @@ type EDRFalconResponse struct { ZtaScoreThreshold int `json:"zta_score_threshold"` } +// EDRFleetDMRequest Request payload for creating or updating a FleetDM EDR integration +type EDRFleetDMRequest struct { + // ApiToken FleetDM API token + ApiToken string `json:"api_token"` + + // ApiUrl FleetDM server URL + ApiUrl string `json:"api_url"` + + // Enabled Indicates whether the integration is enabled + Enabled *bool `json:"enabled,omitempty"` + + // Groups The Groups this integrations applies to + Groups []string `json:"groups"` + + // LastSyncedInterval The devices last sync requirement interval in hours. Minimum value is 24 hours + LastSyncedInterval int `json:"last_synced_interval"` + + // MatchAttributes Attribute conditions to match when approving FleetDM hosts. Most attributes work with FleetDM's free/open-source version. Premium-only attributes are marked accordingly + MatchAttributes FleetDMMatchAttributes `json:"match_attributes"` +} + +// EDRFleetDMResponse Represents a FleetDM EDR integration configuration +type EDRFleetDMResponse struct { + // AccountId The identifier of the account this integration belongs to. + AccountId string `json:"account_id"` + + // ApiUrl FleetDM server URL + ApiUrl string `json:"api_url"` + + // CreatedAt Timestamp of when the integration was created. + CreatedAt time.Time `json:"created_at"` + + // CreatedBy The user id that created the integration + CreatedBy string `json:"created_by"` + + // Enabled Indicates whether the integration is enabled + Enabled bool `json:"enabled"` + + // Groups List of groups + Groups []Group `json:"groups"` + + // Id The unique numeric identifier for the integration. + Id int64 `json:"id"` + + // LastSyncedAt Timestamp of when the integration was last synced. + LastSyncedAt time.Time `json:"last_synced_at"` + + // LastSyncedInterval The devices last sync requirement interval in hours. + LastSyncedInterval int `json:"last_synced_interval"` + + // MatchAttributes Attribute conditions to match when approving FleetDM hosts. Most attributes work with FleetDM's free/open-source version. Premium-only attributes are marked accordingly + MatchAttributes FleetDMMatchAttributes `json:"match_attributes"` + + // UpdatedAt Timestamp of when the integration was last updated. + UpdatedAt time.Time `json:"updated_at"` +} + // EDRHuntressRequest Request payload for creating or updating a EDR Huntress integration type EDRHuntressRequest struct { // ApiKey Huntress API key @@ -2035,6 +2244,24 @@ type Event struct { // EventActivityCode The string code of the activity that occurred during the event type EventActivityCode string +// FleetDMMatchAttributes Attribute conditions to match when approving FleetDM hosts. Most attributes work with FleetDM's free/open-source version. Premium-only attributes are marked accordingly +type FleetDMMatchAttributes struct { + // DiskEncryptionEnabled Whether disk encryption (FileVault/BitLocker) must be enabled on the host + DiskEncryptionEnabled *bool `json:"disk_encryption_enabled,omitempty"` + + // FailingPoliciesCountMax Maximum number of allowed failing policies. Use 0 to require all policies to pass + FailingPoliciesCountMax *int `json:"failing_policies_count_max,omitempty"` + + // RequiredPolicies List of FleetDM policy IDs that must be passing on the host. If any of these policies is failing, the host is non-compliant + RequiredPolicies *[]int `json:"required_policies,omitempty"` + + // StatusOnline Whether the host must be online (recently seen by Fleet) + StatusOnline *bool `json:"status_online,omitempty"` + + // VulnerableSoftwareCountMax Maximum number of allowed vulnerable software on the host + VulnerableSoftwareCountMax *int `json:"vulnerable_software_count_max,omitempty"` +} + // GeoLocationCheck Posture check for geo location type GeoLocationCheck struct { // Action Action to take upon policy match @@ -2053,6 +2280,33 @@ type GetResellerMSPsResponse = []ResellerMSPResponse // GetTenantsResponse defines model for GetTenantsResponse. type GetTenantsResponse = []TenantResponse +// GoogleIntegration defines model for GoogleIntegration. +type GoogleIntegration struct { + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // CustomerId Customer ID from Google Workspace + CustomerId string `json:"customer_id"` + + // Enabled Whether the integration is enabled + Enabled bool `json:"enabled"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes []string `json:"group_prefixes"` + + // Id The unique identifier for the integration + Id int64 `json:"id"` + + // LastSyncedAt Timestamp of the last synchronization + LastSyncedAt time.Time `json:"last_synced_at"` + + // SyncInterval Sync interval in seconds + SyncInterval int `json:"sync_interval"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes []string `json:"user_group_prefixes"` +} + // Group defines model for Group. type Group struct { // Id Group ID @@ -2344,6 +2598,12 @@ type InstanceVersionInfo struct { ManagementUpdateAvailable bool `json:"management_update_available"` } +// IntegrationEnabled defines model for IntegrationEnabled. +type IntegrationEnabled struct { + // Enabled Whether the integration is enabled + Enabled *bool `json:"enabled,omitempty"` +} + // IntegrationResponse Represents an event streaming integration. type IntegrationResponse struct { // AccountId The identifier of the account this integration belongs to. @@ -2371,6 +2631,18 @@ type IntegrationResponse struct { // IntegrationResponsePlatform The event streaming platform. type IntegrationResponsePlatform string +// IntegrationSyncFilters defines model for IntegrationSyncFilters. +type IntegrationSyncFilters struct { + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` +} + // InvoicePDFResponse defines model for InvoicePDFResponse. type InvoicePDFResponse struct { // Url URL to redirect the user to invoice. @@ -2896,6 +3168,30 @@ type OSVersionCheck struct { Windows *MinKernelVersionCheck `json:"windows,omitempty"` } +// OktaScimIntegration defines model for OktaScimIntegration. +type OktaScimIntegration struct { + // AuthToken SCIM API token (full on creation/regeneration, masked on retrieval) + AuthToken string `json:"auth_token"` + + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // Enabled Whether the integration is enabled + Enabled bool `json:"enabled"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes []string `json:"group_prefixes"` + + // Id The unique identifier for the integration + Id int64 `json:"id"` + + // LastSyncedAt Timestamp of the last synchronization + LastSyncedAt time.Time `json:"last_synced_at"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes []string `json:"user_group_prefixes"` +} + // PINAuthConfig defines model for PINAuthConfig. type PINAuthConfig struct { // Enabled Whether PIN auth is enabled @@ -3144,12 +3440,14 @@ type PeerMinimum struct { Name string `json:"name"` } -// PeerNetworkRangeCheck Posture check for allow or deny access based on peer local network addresses +// PeerNetworkRangeCheck Posture check for allow or deny access based on the peer's IP addresses. A range matches when it +// contains any of the peer's local network interface IPs or its public connection (NAT egress) IP, +// so ranges may target private subnets, public CIDRs, or single hosts via a /32 or /128. type PeerNetworkRangeCheck struct { // Action Action to take upon policy match Action PeerNetworkRangeCheckAction `json:"action"` - // Ranges List of peer network ranges in CIDR notation + // Ranges List of network ranges in CIDR notation, matched against the peer's local interface IPs and its public connection IP Ranges []string `json:"ranges"` } @@ -3542,6 +3840,9 @@ type ProxyAccessLog struct { // Id Unique identifier for the access log entry Id string `json:"id"` + // Metadata Extra context about the request (e.g. crowdsec_verdict) + Metadata *map[string]string `json:"metadata,omitempty"` + // Method HTTP method of the request Method string `json:"method"` @@ -3675,6 +3976,9 @@ type ReverseProxyDomain struct { // RequireSubdomain Whether a subdomain label is required in front of this domain. When true, the domain cannot be used bare. RequireSubdomain *bool `json:"require_subdomain,omitempty"` + // SupportsCrowdsec Whether the proxy cluster has CrowdSec configured + SupportsCrowdsec *bool `json:"supports_crowdsec,omitempty"` + // SupportsCustomPorts Whether the cluster supports binding arbitrary TCP/UDP ports SupportsCustomPorts *bool `json:"supports_custom_ports,omitempty"` @@ -3799,12 +4103,15 @@ type RulePortRange struct { Start int `json:"start"` } -// ScimIntegration Represents a SCIM IDP integration +// ScimIntegration defines model for ScimIntegration. type ScimIntegration struct { // AuthToken SCIM API token (full on creation, masked otherwise) AuthToken string `json:"auth_token"` - // Enabled Indicates whether the integration is enabled + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // Enabled Whether the integration is enabled Enabled bool `json:"enabled"` // GroupPrefixes List of start_with string patterns for groups to sync @@ -3816,6 +4123,9 @@ type ScimIntegration struct { // LastSyncedAt Timestamp of when the integration was last synced LastSyncedAt time.Time `json:"last_synced_at"` + // Prefix The connection prefix used for the SCIM provider + Prefix string `json:"prefix"` + // Provider Name of the SCIM identity provider Provider string `json:"provider"` @@ -4171,6 +4481,9 @@ type SetupKeyRequest struct { // SetupRequest Request to set up the initial admin user type SetupRequest struct { + // CreatePat If true and the server has setup-time PAT issuance enabled (NB_SETUP_PAT_ENABLED=true), create a Personal Access Token for the new owner user and return it in the response. Ignored when the server feature is disabled. + CreatePat *bool `json:"create_pat,omitempty"` + // Email Email address for the admin user Email string `json:"email"` @@ -4179,6 +4492,9 @@ type SetupRequest struct { // Password Password for the admin user (minimum 8 characters) Password string `json:"password"` + + // PatExpireIn Expiration of the Personal Access Token in days. Applies only when create_pat is true and the server feature is enabled. Defaults to 1 day when omitted. + PatExpireIn *int `json:"pat_expire_in,omitempty"` } // SetupResponse Response after successful instance setup @@ -4186,6 +4502,9 @@ type SetupResponse struct { // Email Email address of the created user Email string `json:"email"` + // PersonalAccessToken Plain text Personal Access Token created during setup. Present only when create_pat was requested and the NB_SETUP_PAT_ENABLED feature was enabled on the server. + PersonalAccessToken *string `json:"personal_access_token,omitempty"` + // UserId The ID of the created user UserId string `json:"user_id"` } @@ -4220,6 +4539,11 @@ type Subscription struct { UpdatedAt time.Time `json:"updated_at"` } +// SyncResult Response for a manual sync trigger +type SyncResult struct { + Result *string `json:"result,omitempty"` +} + // TenantGroupResponse defines model for TenantGroupResponse. type TenantGroupResponse struct { // Id The Group ID @@ -4265,6 +4589,72 @@ type TenantResponse struct { // TenantResponseStatus The status of the tenant type TenantResponseStatus string +// UpdateAzureIntegrationRequest defines model for UpdateAzureIntegrationRequest. +type UpdateAzureIntegrationRequest struct { + // ClientId Azure AD application (client) ID + ClientId *string `json:"client_id,omitempty"` + + // ClientSecret Base64-encoded Azure AD client secret + ClientSecret *string `json:"client_secret,omitempty"` + + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // Enabled Whether the integration is enabled + Enabled *bool `json:"enabled,omitempty"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + + // SyncInterval Sync interval in seconds (minimum 300) + SyncInterval *int `json:"sync_interval,omitempty"` + + // TenantId Azure AD tenant ID + TenantId *string `json:"tenant_id,omitempty"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` +} + +// UpdateGoogleIntegrationRequest defines model for UpdateGoogleIntegrationRequest. +type UpdateGoogleIntegrationRequest struct { + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // CustomerId Customer ID from Google Workspace Account Settings + CustomerId *string `json:"customer_id,omitempty"` + + // Enabled Whether the integration is enabled + Enabled *bool `json:"enabled,omitempty"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + + // ServiceAccountKey Base64-encoded Google service account key + ServiceAccountKey *string `json:"service_account_key,omitempty"` + + // SyncInterval Sync interval in seconds (minimum 300) + SyncInterval *int `json:"sync_interval,omitempty"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` +} + +// UpdateOktaScimIntegrationRequest defines model for UpdateOktaScimIntegrationRequest. +type UpdateOktaScimIntegrationRequest struct { + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // Enabled Whether the integration is enabled + Enabled *bool `json:"enabled,omitempty"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` +} + // UpdateResellerMSPRequest Mutable fields the reseller can update on a managed MSP. type UpdateResellerMSPRequest struct { // Name New display name for the MSP. Whitespace is trimmed; empty/whitespace-only values are rejected. @@ -4274,14 +4664,20 @@ type UpdateResellerMSPRequest struct { ResellerCustomerId *string `json:"reseller_customer_id,omitempty"` } -// UpdateScimIntegrationRequest Request payload for updating an SCIM IDP integration +// UpdateScimIntegrationRequest defines model for UpdateScimIntegrationRequest. type UpdateScimIntegrationRequest struct { - // Enabled Indicates whether the integration is enabled + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // Enabled Whether the integration is enabled Enabled *bool `json:"enabled,omitempty"` // GroupPrefixes List of start_with string patterns for groups to sync GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + // Prefix The connection prefix used for the SCIM provider + Prefix *string `json:"prefix,omitempty"` + // UserGroupPrefixes List of start_with string patterns for groups which users to sync UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` } @@ -4834,6 +5230,12 @@ type PostApiIngressPeersJSONRequestBody = IngressPeerCreateRequest // PutApiIngressPeersIngressPeerIdJSONRequestBody defines body for PutApiIngressPeersIngressPeerId for application/json ContentType. type PutApiIngressPeersIngressPeerIdJSONRequestBody = IngressPeerUpdateRequest +// CreateAzureIntegrationJSONRequestBody defines body for CreateAzureIntegration for application/json ContentType. +type CreateAzureIntegrationJSONRequestBody = CreateAzureIntegrationRequest + +// UpdateAzureIntegrationJSONRequestBody defines body for UpdateAzureIntegration for application/json ContentType. +type UpdateAzureIntegrationJSONRequestBody = UpdateAzureIntegrationRequest + // PostApiIntegrationsBillingAwsMarketplaceActivateJSONRequestBody defines body for PostApiIntegrationsBillingAwsMarketplaceActivate for application/json ContentType. type PostApiIntegrationsBillingAwsMarketplaceActivateJSONRequestBody PostApiIntegrationsBillingAwsMarketplaceActivateJSONBody @@ -4852,6 +5254,12 @@ type CreateFalconEDRIntegrationJSONRequestBody = EDRFalconRequest // UpdateFalconEDRIntegrationJSONRequestBody defines body for UpdateFalconEDRIntegration for application/json ContentType. type UpdateFalconEDRIntegrationJSONRequestBody = EDRFalconRequest +// CreateFleetDMEDRIntegrationJSONRequestBody defines body for CreateFleetDMEDRIntegration for application/json ContentType. +type CreateFleetDMEDRIntegrationJSONRequestBody = EDRFleetDMRequest + +// UpdateFleetDMEDRIntegrationJSONRequestBody defines body for UpdateFleetDMEDRIntegration for application/json ContentType. +type UpdateFleetDMEDRIntegrationJSONRequestBody = EDRFleetDMRequest + // CreateHuntressEDRIntegrationJSONRequestBody defines body for CreateHuntressEDRIntegration for application/json ContentType. type CreateHuntressEDRIntegrationJSONRequestBody = EDRHuntressRequest @@ -4870,6 +5278,12 @@ type CreateSentinelOneEDRIntegrationJSONRequestBody = EDRSentinelOneRequest // UpdateSentinelOneEDRIntegrationJSONRequestBody defines body for UpdateSentinelOneEDRIntegration for application/json ContentType. type UpdateSentinelOneEDRIntegrationJSONRequestBody = EDRSentinelOneRequest +// CreateGoogleIntegrationJSONRequestBody defines body for CreateGoogleIntegration for application/json ContentType. +type CreateGoogleIntegrationJSONRequestBody = CreateGoogleIntegrationRequest + +// UpdateGoogleIntegrationJSONRequestBody defines body for UpdateGoogleIntegration for application/json ContentType. +type UpdateGoogleIntegrationJSONRequestBody = UpdateGoogleIntegrationRequest + // PostApiIntegrationsMspJSONRequestBody defines body for PostApiIntegrationsMsp for application/json ContentType. type PostApiIntegrationsMspJSONRequestBody PostApiIntegrationsMspJSONBody @@ -4909,6 +5323,12 @@ type CreateNotificationChannelJSONRequestBody = NotificationChannelRequest // UpdateNotificationChannelJSONRequestBody defines body for UpdateNotificationChannel for application/json ContentType. type UpdateNotificationChannelJSONRequestBody = NotificationChannelRequest +// CreateOktaScimIntegrationJSONRequestBody defines body for CreateOktaScimIntegration for application/json ContentType. +type CreateOktaScimIntegrationJSONRequestBody = CreateOktaScimIntegrationRequest + +// UpdateOktaScimIntegrationJSONRequestBody defines body for UpdateOktaScimIntegration for application/json ContentType. +type UpdateOktaScimIntegrationJSONRequestBody = UpdateOktaScimIntegrationRequest + // CreateSCIMIntegrationJSONRequestBody defines body for CreateSCIMIntegration for application/json ContentType. type CreateSCIMIntegrationJSONRequestBody = CreateScimIntegrationRequest diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index c5581296c..604f9c793 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v6.33.0 +// protoc v7.34.1 // source: management.proto package proto @@ -2259,8 +2259,8 @@ type AutoUpdateSettings struct { unknownFields protoimpl.UnknownFields Version string `protobuf:"bytes,1,opt,name=version,proto3" json:"version,omitempty"` - // alwaysUpdate = true → Updates happen automatically in the background - // alwaysUpdate = false → Updates only happen when triggered by a peer connection + // alwaysUpdate = true → Updates are installed automatically in the background + // alwaysUpdate = false → Updates require user interaction from the UI AlwaysUpdate bool `protobuf:"varint,2,opt,name=alwaysUpdate,proto3" json:"alwaysUpdate,omitempty"` } @@ -2928,7 +2928,9 @@ type ProviderConfig struct { // An IDP application client id ClientID string `protobuf:"bytes,1,opt,name=ClientID,proto3" json:"ClientID,omitempty"` - // An IDP application client secret + // Deprecated: use embedded IdP for providers that require a client secret (e.g. Google Workspace). + // + // Deprecated: Do not use. ClientSecret string `protobuf:"bytes,2,opt,name=ClientSecret,proto3" json:"ClientSecret,omitempty"` // An IDP API domain // Deprecated. Use a DeviceAuthEndpoint and TokenEndpoint @@ -2992,6 +2994,7 @@ func (x *ProviderConfig) GetClientID() string { return "" } +// Deprecated: Do not use. func (x *ProviderConfig) GetClientSecret() string { if x != nil { return x.ClientSecret @@ -4847,287 +4850,287 @@ var file_management_proto_rawDesc = []byte{ 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x22, 0xb8, 0x03, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, + 0x69, 0x67, 0x22, 0xbc, 0x03, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, - 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, - 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, - 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1a, 0x0a, - 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x65, 0x76, - 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, - 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, - 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, - 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, - 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, - 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, - 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, - 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, - 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x09, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x22, 0x0a, 0x0c, 0x52, - 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, - 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x12, - 0x2e, 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, - 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x44, 0x69, 0x73, - 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, - 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67, 0x18, 0x0c, 0x20, 0x01, - 0x28, 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67, 0x22, 0x93, 0x02, - 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, - 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, - 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, - 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, - 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, - 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, - 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, - 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, - 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, - 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x24, 0x0a, - 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x18, 0x0a, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, - 0x70, 0x6c, 0x79, 0x22, 0xde, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, - 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, - 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, - 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, - 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, - 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, - 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, - 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, - 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x12, 0x28, 0x0a, 0x0d, 0x46, 0x6f, - 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x03, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, - 0x50, 0x6f, 0x72, 0x74, 0x22, 0xb8, 0x01, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, - 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, - 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, - 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x12, - 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x44, - 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, - 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x44, 0x69, 0x73, 0x61, 0x62, - 0x6c, 0x65, 0x64, 0x12, 0x2a, 0x0a, 0x10, 0x4e, 0x6f, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, - 0x69, 0x74, 0x61, 0x74, 0x69, 0x76, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x4e, - 0x6f, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x74, 0x61, 0x74, 0x69, 0x76, 0x65, 0x22, - 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, - 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, - 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, - 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, - 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, - 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, - 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, - 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, - 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, - 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, - 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, - 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, - 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, - 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, - 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, - 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, - 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, - 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, - 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, - 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, - 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, - 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, - 0x74, 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x06, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, - 0x6e, 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, - 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x22, - 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, - 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, - 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, - 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, - 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, - 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, - 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, - 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, - 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, - 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, - 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, - 0x6f, 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, - 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, - 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, - 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, - 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, - 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, - 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, - 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, - 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, - 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, - 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, - 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, - 0x61, 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, - 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, - 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, - 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, - 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, - 0x49, 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, - 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, - 0x0e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, - 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, - 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, - 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, - 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, - 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, - 0x52, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, - 0x65, 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, - 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, - 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, - 0x74, 0x22, 0x8b, 0x02, 0x0a, 0x14, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x53, 0x65, 0x72, 0x76, - 0x69, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x6f, - 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x36, - 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x78, - 0x70, 0x6f, 0x73, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x69, 0x6e, 0x18, 0x03, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, - 0x77, 0x6f, 0x72, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, - 0x77, 0x6f, 0x72, 0x64, 0x12, 0x1f, 0x0a, 0x0b, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x67, 0x72, 0x6f, - 0x75, 0x70, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x75, 0x73, 0x65, 0x72, 0x47, - 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, - 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1f, 0x0a, - 0x0b, 0x6e, 0x61, 0x6d, 0x65, 0x5f, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x0a, 0x6e, 0x61, 0x6d, 0x65, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x12, 0x1f, - 0x0a, 0x0b, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x08, 0x20, - 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x22, - 0xa1, 0x01, 0x0a, 0x15, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, - 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x21, 0x0a, 0x0c, 0x73, 0x65, 0x72, - 0x76, 0x69, 0x63, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0b, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x1f, 0x0a, 0x0b, - 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x75, 0x72, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x55, 0x72, 0x6c, 0x12, 0x16, 0x0a, - 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x2c, 0x0a, 0x12, 0x70, 0x6f, 0x72, 0x74, 0x5f, 0x61, 0x75, - 0x74, 0x6f, 0x5f, 0x61, 0x73, 0x73, 0x69, 0x67, 0x6e, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x10, 0x70, 0x6f, 0x72, 0x74, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x73, 0x73, 0x69, 0x67, - 0x6e, 0x65, 0x64, 0x22, 0x2c, 0x0a, 0x12, 0x52, 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, 0x70, 0x6f, - 0x73, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x22, 0x15, 0x0a, 0x13, 0x52, 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x2b, 0x0a, 0x11, 0x53, 0x74, 0x6f, 0x70, + 0x44, 0x12, 0x26, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, + 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0c, 0x43, 0x6c, 0x69, + 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, + 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, + 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, + 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, + 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, + 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, + 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, + 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74, + 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, + 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, + 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, + 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, + 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, + 0x52, 0x4c, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, + 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67, + 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, + 0x67, 0x22, 0x93, 0x02, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, + 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, + 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, + 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, + 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, + 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, + 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, + 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, + 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, + 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, + 0x6c, 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, + 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xde, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, + 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, + 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, + 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, + 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, + 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, + 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, + 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, + 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x12, 0x28, + 0x0a, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x03, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, + 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xb8, 0x01, 0x0a, 0x0a, 0x43, 0x75, 0x73, + 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, + 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, + 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, + 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, + 0x72, 0x64, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x44, + 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x2a, 0x0a, 0x10, 0x4e, 0x6f, 0x6e, 0x41, 0x75, + 0x74, 0x68, 0x6f, 0x72, 0x69, 0x74, 0x61, 0x74, 0x69, 0x76, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x10, 0x4e, 0x6f, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x74, 0x61, 0x74, + 0x69, 0x76, 0x65, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, + 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, + 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, + 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, + 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, + 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, + 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, + 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, + 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, + 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, + 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, + 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, + 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, + 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, + 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, + 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, + 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, + 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, + 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, + 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, + 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, + 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, + 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, + 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x04, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, + 0x6f, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, + 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, + 0x79, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, + 0x79, 0x49, 0x44, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, + 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, + 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, + 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, + 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, + 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, + 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, + 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, + 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, + 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, + 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, + 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, + 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, + 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, + 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, + 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, + 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, + 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, + 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, + 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, + 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, + 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, + 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x73, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, + 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, + 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, + 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, + 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, + 0x44, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, + 0x22, 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, + 0x75, 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, + 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, + 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, + 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, + 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, + 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, + 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, + 0x64, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x8b, 0x02, 0x0a, 0x14, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, + 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, + 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x70, 0x6f, + 0x72, 0x74, 0x12, 0x36, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, + 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x69, + 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, + 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x12, 0x1f, 0x0a, 0x0b, 0x75, 0x73, 0x65, 0x72, + 0x5f, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x75, + 0x73, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x12, 0x1f, 0x0a, 0x0b, 0x6e, 0x61, 0x6d, 0x65, 0x5f, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, + 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x6e, 0x61, 0x6d, 0x65, 0x50, 0x72, 0x65, 0x66, + 0x69, 0x78, 0x12, 0x1f, 0x0a, 0x0b, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x5f, 0x70, 0x6f, 0x72, + 0x74, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x50, + 0x6f, 0x72, 0x74, 0x22, 0xa1, 0x01, 0x0a, 0x15, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x53, 0x65, + 0x72, 0x76, 0x69, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x21, 0x0a, + 0x0c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x0b, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, + 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x75, 0x72, 0x6c, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x55, 0x72, + 0x6c, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x2c, 0x0a, 0x12, 0x70, 0x6f, 0x72, + 0x74, 0x5f, 0x61, 0x75, 0x74, 0x6f, 0x5f, 0x61, 0x73, 0x73, 0x69, 0x67, 0x6e, 0x65, 0x64, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x70, 0x6f, 0x72, 0x74, 0x41, 0x75, 0x74, 0x6f, 0x41, + 0x73, 0x73, 0x69, 0x67, 0x6e, 0x65, 0x64, 0x22, 0x2c, 0x0a, 0x12, 0x52, 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x22, 0x14, 0x0a, 0x12, 0x53, 0x74, 0x6f, 0x70, 0x45, 0x78, 0x70, - 0x6f, 0x73, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a, 0x3a, 0x0a, 0x09, 0x4a, - 0x6f, 0x62, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x12, 0x0a, 0x0e, 0x75, 0x6e, 0x6b, 0x6e, - 0x6f, 0x77, 0x6e, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x10, 0x00, 0x12, 0x0d, 0x0a, 0x09, - 0x73, 0x75, 0x63, 0x63, 0x65, 0x65, 0x64, 0x65, 0x64, 0x10, 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x66, - 0x61, 0x69, 0x6c, 0x65, 0x64, 0x10, 0x02, 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, - 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, - 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, - 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, - 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, - 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, - 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, - 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, - 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x2a, 0x63, 0x0a, 0x0e, 0x45, - 0x78, 0x70, 0x6f, 0x73, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0f, 0x0a, - 0x0b, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x48, 0x54, 0x54, 0x50, 0x10, 0x00, 0x12, 0x10, - 0x0a, 0x0c, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x48, 0x54, 0x54, 0x50, 0x53, 0x10, 0x01, - 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x54, 0x43, 0x50, 0x10, 0x02, - 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x55, 0x44, 0x50, 0x10, 0x03, - 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x54, 0x4c, 0x53, 0x10, 0x04, - 0x32, 0xfd, 0x06, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, - 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, - 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, - 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, - 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, - 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, - 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x22, 0x15, 0x0a, 0x13, 0x52, 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, + 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x2b, 0x0a, 0x11, + 0x53, 0x74, 0x6f, 0x70, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x22, 0x14, 0x0a, 0x12, 0x53, 0x74, 0x6f, + 0x70, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a, + 0x3a, 0x0a, 0x09, 0x4a, 0x6f, 0x62, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x12, 0x0a, 0x0e, + 0x75, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x10, 0x00, + 0x12, 0x0d, 0x0a, 0x09, 0x73, 0x75, 0x63, 0x63, 0x65, 0x65, 0x64, 0x65, 0x64, 0x10, 0x01, 0x12, + 0x0a, 0x0a, 0x06, 0x66, 0x61, 0x69, 0x6c, 0x65, 0x64, 0x10, 0x02, 0x2a, 0x4c, 0x0a, 0x0c, 0x52, + 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, + 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, + 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, + 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, 0x0a, + 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, 0x6c, + 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, + 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, + 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, + 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x2a, + 0x63, 0x0a, 0x0e, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, + 0x6c, 0x12, 0x0f, 0x0a, 0x0b, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x48, 0x54, 0x54, 0x50, + 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x48, 0x54, 0x54, + 0x50, 0x53, 0x10, 0x01, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x54, + 0x43, 0x50, 0x10, 0x02, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x55, + 0x44, 0x50, 0x10, 0x03, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x54, + 0x4c, 0x53, 0x10, 0x04, 0x32, 0xfd, 0x06, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, - 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, - 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, - 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, - 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, - 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, - 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, - 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, - 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, - 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, 0x74, 0x12, 0x1c, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, - 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, - 0x12, 0x47, 0x0a, 0x03, 0x4a, 0x6f, 0x62, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, - 0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x12, 0x4c, 0x0a, 0x0c, 0x43, 0x72, 0x65, - 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, + 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0b, 0x52, 0x65, 0x6e, 0x65, 0x77, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, + 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, + 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, + 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, + 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, + 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, + 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, + 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, + 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, + 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, + 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, + 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, + 0x74, 0x79, 0x22, 0x00, 0x12, 0x47, 0x0a, 0x03, 0x4a, 0x6f, 0x62, 0x12, 0x1c, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x12, 0x4c, 0x0a, + 0x0c, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x12, 0x1c, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, + 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0b, 0x52, + 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, + 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x4a, 0x0a, 0x0a, 0x53, 0x74, 0x6f, 0x70, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x22, 0x00, 0x12, 0x4a, 0x0a, 0x0a, 0x53, 0x74, 0x6f, 0x70, 0x45, 0x78, 0x70, 0x6f, - 0x73, 0x65, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, - 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, - 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, + 0x67, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index 9acf7e2b3..70a530679 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -464,8 +464,8 @@ message PKCEAuthorizationFlow { message ProviderConfig { // An IDP application client id string ClientID = 1; - // An IDP application client secret - string ClientSecret = 2; + // Deprecated: use embedded IdP for providers that require a client secret (e.g. Google Workspace). + string ClientSecret = 2 [deprecated = true]; // An IDP API domain // Deprecated. Use a DeviceAuthEndpoint and TokenEndpoint string Domain = 3; diff --git a/shared/management/proto/proxy_service.pb.go b/shared/management/proto/proxy_service.pb.go index 93295e857..1095b6411 100644 --- a/shared/management/proto/proxy_service.pb.go +++ b/shared/management/proto/proxy_service.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.6 -// protoc v6.33.3 +// protoc-gen-go v1.26.0 +// protoc v7.34.1 // source: proxy_service.proto package proto @@ -13,7 +13,6 @@ import ( timestamppb "google.golang.org/protobuf/types/known/timestamppb" reflect "reflect" sync "sync" - unsafe "unsafe" ) const ( @@ -178,21 +177,26 @@ func (ProxyStatus) EnumDescriptor() ([]byte, []int) { // ProxyCapabilities describes what a proxy can handle. type ProxyCapabilities struct { - state protoimpl.MessageState `protogen:"open.v1"` + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + // Whether the proxy can bind arbitrary ports for TCP/UDP/TLS services. SupportsCustomPorts *bool `protobuf:"varint,1,opt,name=supports_custom_ports,json=supportsCustomPorts,proto3,oneof" json:"supports_custom_ports,omitempty"` // Whether the proxy requires a subdomain label in front of its cluster domain. - // When true, tenants cannot use the cluster domain bare. + // When true, accounts cannot use the cluster domain bare. RequireSubdomain *bool `protobuf:"varint,2,opt,name=require_subdomain,json=requireSubdomain,proto3,oneof" json:"require_subdomain,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // Whether the proxy has CrowdSec configured and can enforce IP reputation checks. + SupportsCrowdsec *bool `protobuf:"varint,3,opt,name=supports_crowdsec,json=supportsCrowdsec,proto3,oneof" json:"supports_crowdsec,omitempty"` } func (x *ProxyCapabilities) Reset() { *x = ProxyCapabilities{} - mi := &file_proxy_service_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *ProxyCapabilities) String() string { @@ -203,7 +207,7 @@ func (*ProxyCapabilities) ProtoMessage() {} func (x *ProxyCapabilities) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[0] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -232,23 +236,33 @@ func (x *ProxyCapabilities) GetRequireSubdomain() bool { return false } +func (x *ProxyCapabilities) GetSupportsCrowdsec() bool { + if x != nil && x.SupportsCrowdsec != nil { + return *x.SupportsCrowdsec + } + return false +} + // GetMappingUpdateRequest is sent to initialise a mapping stream. type GetMappingUpdateRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - ProxyId string `protobuf:"bytes,1,opt,name=proxy_id,json=proxyId,proto3" json:"proxy_id,omitempty"` - Version string `protobuf:"bytes,2,opt,name=version,proto3" json:"version,omitempty"` - StartedAt *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=started_at,json=startedAt,proto3" json:"started_at,omitempty"` - Address string `protobuf:"bytes,4,opt,name=address,proto3" json:"address,omitempty"` - Capabilities *ProxyCapabilities `protobuf:"bytes,5,opt,name=capabilities,proto3" json:"capabilities,omitempty"` - unknownFields protoimpl.UnknownFields + state protoimpl.MessageState sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ProxyId string `protobuf:"bytes,1,opt,name=proxy_id,json=proxyId,proto3" json:"proxy_id,omitempty"` + Version string `protobuf:"bytes,2,opt,name=version,proto3" json:"version,omitempty"` + StartedAt *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=started_at,json=startedAt,proto3" json:"started_at,omitempty"` + Address string `protobuf:"bytes,4,opt,name=address,proto3" json:"address,omitempty"` + Capabilities *ProxyCapabilities `protobuf:"bytes,5,opt,name=capabilities,proto3" json:"capabilities,omitempty"` } func (x *GetMappingUpdateRequest) Reset() { *x = GetMappingUpdateRequest{} - mi := &file_proxy_service_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *GetMappingUpdateRequest) String() string { @@ -259,7 +273,7 @@ func (*GetMappingUpdateRequest) ProtoMessage() {} func (x *GetMappingUpdateRequest) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[1] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -313,20 +327,23 @@ func (x *GetMappingUpdateRequest) GetCapabilities() *ProxyCapabilities { // No mappings may be sent to test the liveness of the Proxy. // Mappings that are sent should be interpreted by the Proxy appropriately. type GetMappingUpdateResponse struct { - state protoimpl.MessageState `protogen:"open.v1"` - Mapping []*ProxyMapping `protobuf:"bytes,1,rep,name=mapping,proto3" json:"mapping,omitempty"` + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Mapping []*ProxyMapping `protobuf:"bytes,1,rep,name=mapping,proto3" json:"mapping,omitempty"` // initial_sync_complete is set on the last message of the initial snapshot. // The proxy uses this to signal that startup is complete. InitialSyncComplete bool `protobuf:"varint,2,opt,name=initial_sync_complete,json=initialSyncComplete,proto3" json:"initial_sync_complete,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache } func (x *GetMappingUpdateResponse) Reset() { *x = GetMappingUpdateResponse{} - mi := &file_proxy_service_proto_msgTypes[2] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *GetMappingUpdateResponse) String() string { @@ -337,7 +354,7 @@ func (*GetMappingUpdateResponse) ProtoMessage() {} func (x *GetMappingUpdateResponse) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[2] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -367,24 +384,27 @@ func (x *GetMappingUpdateResponse) GetInitialSyncComplete() bool { } type PathTargetOptions struct { - state protoimpl.MessageState `protogen:"open.v1"` - SkipTlsVerify bool `protobuf:"varint,1,opt,name=skip_tls_verify,json=skipTlsVerify,proto3" json:"skip_tls_verify,omitempty"` - RequestTimeout *durationpb.Duration `protobuf:"bytes,2,opt,name=request_timeout,json=requestTimeout,proto3" json:"request_timeout,omitempty"` - PathRewrite PathRewriteMode `protobuf:"varint,3,opt,name=path_rewrite,json=pathRewrite,proto3,enum=management.PathRewriteMode" json:"path_rewrite,omitempty"` - CustomHeaders map[string]string `protobuf:"bytes,4,rep,name=custom_headers,json=customHeaders,proto3" json:"custom_headers,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + SkipTlsVerify bool `protobuf:"varint,1,opt,name=skip_tls_verify,json=skipTlsVerify,proto3" json:"skip_tls_verify,omitempty"` + RequestTimeout *durationpb.Duration `protobuf:"bytes,2,opt,name=request_timeout,json=requestTimeout,proto3" json:"request_timeout,omitempty"` + PathRewrite PathRewriteMode `protobuf:"varint,3,opt,name=path_rewrite,json=pathRewrite,proto3,enum=management.PathRewriteMode" json:"path_rewrite,omitempty"` + CustomHeaders map[string]string `protobuf:"bytes,4,rep,name=custom_headers,json=customHeaders,proto3" json:"custom_headers,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` // Send PROXY protocol v2 header to this backend. ProxyProtocol bool `protobuf:"varint,5,opt,name=proxy_protocol,json=proxyProtocol,proto3" json:"proxy_protocol,omitempty"` // Idle timeout before a UDP session is reaped. SessionIdleTimeout *durationpb.Duration `protobuf:"bytes,6,opt,name=session_idle_timeout,json=sessionIdleTimeout,proto3" json:"session_idle_timeout,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache } func (x *PathTargetOptions) Reset() { *x = PathTargetOptions{} - mi := &file_proxy_service_proto_msgTypes[3] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *PathTargetOptions) String() string { @@ -395,7 +415,7 @@ func (*PathTargetOptions) ProtoMessage() {} func (x *PathTargetOptions) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[3] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -453,19 +473,22 @@ func (x *PathTargetOptions) GetSessionIdleTimeout() *durationpb.Duration { } type PathMapping struct { - state protoimpl.MessageState `protogen:"open.v1"` - Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` - Target string `protobuf:"bytes,2,opt,name=target,proto3" json:"target,omitempty"` - Options *PathTargetOptions `protobuf:"bytes,3,opt,name=options,proto3" json:"options,omitempty"` - unknownFields protoimpl.UnknownFields + state protoimpl.MessageState sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` + Target string `protobuf:"bytes,2,opt,name=target,proto3" json:"target,omitempty"` + Options *PathTargetOptions `protobuf:"bytes,3,opt,name=options,proto3" json:"options,omitempty"` } func (x *PathMapping) Reset() { *x = PathMapping{} - mi := &file_proxy_service_proto_msgTypes[4] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *PathMapping) String() string { @@ -476,7 +499,7 @@ func (*PathMapping) ProtoMessage() {} func (x *PathMapping) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[4] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -513,20 +536,23 @@ func (x *PathMapping) GetOptions() *PathTargetOptions { } type HeaderAuth struct { - state protoimpl.MessageState `protogen:"open.v1"` + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + // Header name to check, e.g. "Authorization", "X-API-Key". Header string `protobuf:"bytes,1,opt,name=header,proto3" json:"header,omitempty"` // argon2id hash of the expected full header value. - HashedValue string `protobuf:"bytes,2,opt,name=hashed_value,json=hashedValue,proto3" json:"hashed_value,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + HashedValue string `protobuf:"bytes,2,opt,name=hashed_value,json=hashedValue,proto3" json:"hashed_value,omitempty"` } func (x *HeaderAuth) Reset() { *x = HeaderAuth{} - mi := &file_proxy_service_proto_msgTypes[5] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *HeaderAuth) String() string { @@ -537,7 +563,7 @@ func (*HeaderAuth) ProtoMessage() {} func (x *HeaderAuth) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[5] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -567,22 +593,25 @@ func (x *HeaderAuth) GetHashedValue() string { } type Authentication struct { - state protoimpl.MessageState `protogen:"open.v1"` - SessionKey string `protobuf:"bytes,1,opt,name=session_key,json=sessionKey,proto3" json:"session_key,omitempty"` - MaxSessionAgeSeconds int64 `protobuf:"varint,2,opt,name=max_session_age_seconds,json=maxSessionAgeSeconds,proto3" json:"max_session_age_seconds,omitempty"` - Password bool `protobuf:"varint,3,opt,name=password,proto3" json:"password,omitempty"` - Pin bool `protobuf:"varint,4,opt,name=pin,proto3" json:"pin,omitempty"` - Oidc bool `protobuf:"varint,5,opt,name=oidc,proto3" json:"oidc,omitempty"` - HeaderAuths []*HeaderAuth `protobuf:"bytes,6,rep,name=header_auths,json=headerAuths,proto3" json:"header_auths,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + SessionKey string `protobuf:"bytes,1,opt,name=session_key,json=sessionKey,proto3" json:"session_key,omitempty"` + MaxSessionAgeSeconds int64 `protobuf:"varint,2,opt,name=max_session_age_seconds,json=maxSessionAgeSeconds,proto3" json:"max_session_age_seconds,omitempty"` + Password bool `protobuf:"varint,3,opt,name=password,proto3" json:"password,omitempty"` + Pin bool `protobuf:"varint,4,opt,name=pin,proto3" json:"pin,omitempty"` + Oidc bool `protobuf:"varint,5,opt,name=oidc,proto3" json:"oidc,omitempty"` + HeaderAuths []*HeaderAuth `protobuf:"bytes,6,rep,name=header_auths,json=headerAuths,proto3" json:"header_auths,omitempty"` } func (x *Authentication) Reset() { *x = Authentication{} - mi := &file_proxy_service_proto_msgTypes[6] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *Authentication) String() string { @@ -593,7 +622,7 @@ func (*Authentication) ProtoMessage() {} func (x *Authentication) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[6] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -651,20 +680,25 @@ func (x *Authentication) GetHeaderAuths() []*HeaderAuth { } type AccessRestrictions struct { - state protoimpl.MessageState `protogen:"open.v1"` - AllowedCidrs []string `protobuf:"bytes,1,rep,name=allowed_cidrs,json=allowedCidrs,proto3" json:"allowed_cidrs,omitempty"` - BlockedCidrs []string `protobuf:"bytes,2,rep,name=blocked_cidrs,json=blockedCidrs,proto3" json:"blocked_cidrs,omitempty"` - AllowedCountries []string `protobuf:"bytes,3,rep,name=allowed_countries,json=allowedCountries,proto3" json:"allowed_countries,omitempty"` - BlockedCountries []string `protobuf:"bytes,4,rep,name=blocked_countries,json=blockedCountries,proto3" json:"blocked_countries,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + AllowedCidrs []string `protobuf:"bytes,1,rep,name=allowed_cidrs,json=allowedCidrs,proto3" json:"allowed_cidrs,omitempty"` + BlockedCidrs []string `protobuf:"bytes,2,rep,name=blocked_cidrs,json=blockedCidrs,proto3" json:"blocked_cidrs,omitempty"` + AllowedCountries []string `protobuf:"bytes,3,rep,name=allowed_countries,json=allowedCountries,proto3" json:"allowed_countries,omitempty"` + BlockedCountries []string `protobuf:"bytes,4,rep,name=blocked_countries,json=blockedCountries,proto3" json:"blocked_countries,omitempty"` + // CrowdSec IP reputation mode: "", "off", "enforce", or "observe". + CrowdsecMode string `protobuf:"bytes,5,opt,name=crowdsec_mode,json=crowdsecMode,proto3" json:"crowdsec_mode,omitempty"` } func (x *AccessRestrictions) Reset() { *x = AccessRestrictions{} - mi := &file_proxy_service_proto_msgTypes[7] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *AccessRestrictions) String() string { @@ -675,7 +709,7 @@ func (*AccessRestrictions) ProtoMessage() {} func (x *AccessRestrictions) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[7] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -718,8 +752,18 @@ func (x *AccessRestrictions) GetBlockedCountries() []string { return nil } +func (x *AccessRestrictions) GetCrowdsecMode() string { + if x != nil { + return x.CrowdsecMode + } + return "" +} + type ProxyMapping struct { - state protoimpl.MessageState `protogen:"open.v1"` + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + Type ProxyMappingUpdateType `protobuf:"varint,1,opt,name=type,proto3,enum=management.ProxyMappingUpdateType" json:"type,omitempty"` Id string `protobuf:"bytes,2,opt,name=id,proto3" json:"id,omitempty"` AccountId string `protobuf:"bytes,3,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` @@ -738,15 +782,15 @@ type ProxyMapping struct { // For L4/TLS: the port the proxy listens on. ListenPort int32 `protobuf:"varint,11,opt,name=listen_port,json=listenPort,proto3" json:"listen_port,omitempty"` AccessRestrictions *AccessRestrictions `protobuf:"bytes,12,opt,name=access_restrictions,json=accessRestrictions,proto3" json:"access_restrictions,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache } func (x *ProxyMapping) Reset() { *x = ProxyMapping{} - mi := &file_proxy_service_proto_msgTypes[8] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *ProxyMapping) String() string { @@ -757,7 +801,7 @@ func (*ProxyMapping) ProtoMessage() {} func (x *ProxyMapping) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[8] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -858,17 +902,20 @@ func (x *ProxyMapping) GetAccessRestrictions() *AccessRestrictions { // SendAccessLogRequest consists of one or more AccessLogs from a Proxy. type SendAccessLogRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - Log *AccessLog `protobuf:"bytes,1,opt,name=log,proto3" json:"log,omitempty"` - unknownFields protoimpl.UnknownFields + state protoimpl.MessageState sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Log *AccessLog `protobuf:"bytes,1,opt,name=log,proto3" json:"log,omitempty"` } func (x *SendAccessLogRequest) Reset() { *x = SendAccessLogRequest{} - mi := &file_proxy_service_proto_msgTypes[9] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *SendAccessLogRequest) String() string { @@ -879,7 +926,7 @@ func (*SendAccessLogRequest) ProtoMessage() {} func (x *SendAccessLogRequest) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[9] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -903,16 +950,18 @@ func (x *SendAccessLogRequest) GetLog() *AccessLog { // SendAccessLogResponse is intentionally empty to allow for future expansion. type SendAccessLogResponse struct { - state protoimpl.MessageState `protogen:"open.v1"` - unknownFields protoimpl.UnknownFields + state protoimpl.MessageState sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields } func (x *SendAccessLogResponse) Reset() { *x = SendAccessLogResponse{} - mi := &file_proxy_service_proto_msgTypes[10] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *SendAccessLogResponse) String() string { @@ -923,7 +972,7 @@ func (*SendAccessLogResponse) ProtoMessage() {} func (x *SendAccessLogResponse) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[10] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -939,7 +988,10 @@ func (*SendAccessLogResponse) Descriptor() ([]byte, []int) { } type AccessLog struct { - state protoimpl.MessageState `protogen:"open.v1"` + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + Timestamp *timestamppb.Timestamp `protobuf:"bytes,1,opt,name=timestamp,proto3" json:"timestamp,omitempty"` LogId string `protobuf:"bytes,2,opt,name=log_id,json=logId,proto3" json:"log_id,omitempty"` AccountId string `protobuf:"bytes,3,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` @@ -956,15 +1008,17 @@ type AccessLog struct { BytesUpload int64 `protobuf:"varint,14,opt,name=bytes_upload,json=bytesUpload,proto3" json:"bytes_upload,omitempty"` BytesDownload int64 `protobuf:"varint,15,opt,name=bytes_download,json=bytesDownload,proto3" json:"bytes_download,omitempty"` Protocol string `protobuf:"bytes,16,opt,name=protocol,proto3" json:"protocol,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // Extra key-value metadata for the access log entry (e.g. crowdsec_verdict, scenario). + Metadata map[string]string `protobuf:"bytes,17,rep,name=metadata,proto3" json:"metadata,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` } func (x *AccessLog) Reset() { *x = AccessLog{} - mi := &file_proxy_service_proto_msgTypes[11] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[11] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *AccessLog) String() string { @@ -975,7 +1029,7 @@ func (*AccessLog) ProtoMessage() {} func (x *AccessLog) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[11] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1102,25 +1156,35 @@ func (x *AccessLog) GetProtocol() string { return "" } +func (x *AccessLog) GetMetadata() map[string]string { + if x != nil { + return x.Metadata + } + return nil +} + type AuthenticateRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` - AccountId string `protobuf:"bytes,2,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` - // Types that are valid to be assigned to Request: + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + AccountId string `protobuf:"bytes,2,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` + // Types that are assignable to Request: // // *AuthenticateRequest_Password // *AuthenticateRequest_Pin // *AuthenticateRequest_HeaderAuth - Request isAuthenticateRequest_Request `protobuf_oneof:"request"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + Request isAuthenticateRequest_Request `protobuf_oneof:"request"` } func (x *AuthenticateRequest) Reset() { *x = AuthenticateRequest{} - mi := &file_proxy_service_proto_msgTypes[12] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[12] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *AuthenticateRequest) String() string { @@ -1131,7 +1195,7 @@ func (*AuthenticateRequest) ProtoMessage() {} func (x *AuthenticateRequest) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[12] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1160,36 +1224,30 @@ func (x *AuthenticateRequest) GetAccountId() string { return "" } -func (x *AuthenticateRequest) GetRequest() isAuthenticateRequest_Request { - if x != nil { - return x.Request +func (m *AuthenticateRequest) GetRequest() isAuthenticateRequest_Request { + if m != nil { + return m.Request } return nil } func (x *AuthenticateRequest) GetPassword() *PasswordRequest { - if x != nil { - if x, ok := x.Request.(*AuthenticateRequest_Password); ok { - return x.Password - } + if x, ok := x.GetRequest().(*AuthenticateRequest_Password); ok { + return x.Password } return nil } func (x *AuthenticateRequest) GetPin() *PinRequest { - if x != nil { - if x, ok := x.Request.(*AuthenticateRequest_Pin); ok { - return x.Pin - } + if x, ok := x.GetRequest().(*AuthenticateRequest_Pin); ok { + return x.Pin } return nil } func (x *AuthenticateRequest) GetHeaderAuth() *HeaderAuthRequest { - if x != nil { - if x, ok := x.Request.(*AuthenticateRequest_HeaderAuth); ok { - return x.HeaderAuth - } + if x, ok := x.GetRequest().(*AuthenticateRequest_HeaderAuth); ok { + return x.HeaderAuth } return nil } @@ -1217,18 +1275,21 @@ func (*AuthenticateRequest_Pin) isAuthenticateRequest_Request() {} func (*AuthenticateRequest_HeaderAuth) isAuthenticateRequest_Request() {} type HeaderAuthRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - HeaderValue string `protobuf:"bytes,1,opt,name=header_value,json=headerValue,proto3" json:"header_value,omitempty"` - HeaderName string `protobuf:"bytes,2,opt,name=header_name,json=headerName,proto3" json:"header_name,omitempty"` - unknownFields protoimpl.UnknownFields + state protoimpl.MessageState sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + HeaderValue string `protobuf:"bytes,1,opt,name=header_value,json=headerValue,proto3" json:"header_value,omitempty"` + HeaderName string `protobuf:"bytes,2,opt,name=header_name,json=headerName,proto3" json:"header_name,omitempty"` } func (x *HeaderAuthRequest) Reset() { *x = HeaderAuthRequest{} - mi := &file_proxy_service_proto_msgTypes[13] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[13] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *HeaderAuthRequest) String() string { @@ -1239,7 +1300,7 @@ func (*HeaderAuthRequest) ProtoMessage() {} func (x *HeaderAuthRequest) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[13] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1269,17 +1330,20 @@ func (x *HeaderAuthRequest) GetHeaderName() string { } type PasswordRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - Password string `protobuf:"bytes,1,opt,name=password,proto3" json:"password,omitempty"` - unknownFields protoimpl.UnknownFields + state protoimpl.MessageState sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Password string `protobuf:"bytes,1,opt,name=password,proto3" json:"password,omitempty"` } func (x *PasswordRequest) Reset() { *x = PasswordRequest{} - mi := &file_proxy_service_proto_msgTypes[14] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[14] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *PasswordRequest) String() string { @@ -1290,7 +1354,7 @@ func (*PasswordRequest) ProtoMessage() {} func (x *PasswordRequest) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[14] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1313,17 +1377,20 @@ func (x *PasswordRequest) GetPassword() string { } type PinRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - Pin string `protobuf:"bytes,1,opt,name=pin,proto3" json:"pin,omitempty"` - unknownFields protoimpl.UnknownFields + state protoimpl.MessageState sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Pin string `protobuf:"bytes,1,opt,name=pin,proto3" json:"pin,omitempty"` } func (x *PinRequest) Reset() { *x = PinRequest{} - mi := &file_proxy_service_proto_msgTypes[15] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[15] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *PinRequest) String() string { @@ -1334,7 +1401,7 @@ func (*PinRequest) ProtoMessage() {} func (x *PinRequest) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[15] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1357,18 +1424,21 @@ func (x *PinRequest) GetPin() string { } type AuthenticateResponse struct { - state protoimpl.MessageState `protogen:"open.v1"` - Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` - SessionToken string `protobuf:"bytes,2,opt,name=session_token,json=sessionToken,proto3" json:"session_token,omitempty"` - unknownFields protoimpl.UnknownFields + state protoimpl.MessageState sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + SessionToken string `protobuf:"bytes,2,opt,name=session_token,json=sessionToken,proto3" json:"session_token,omitempty"` } func (x *AuthenticateResponse) Reset() { *x = AuthenticateResponse{} - mi := &file_proxy_service_proto_msgTypes[16] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[16] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *AuthenticateResponse) String() string { @@ -1379,7 +1449,7 @@ func (*AuthenticateResponse) ProtoMessage() {} func (x *AuthenticateResponse) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[16] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1410,21 +1480,24 @@ func (x *AuthenticateResponse) GetSessionToken() string { // SendStatusUpdateRequest is sent by the proxy to update its status type SendStatusUpdateRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - ServiceId string `protobuf:"bytes,1,opt,name=service_id,json=serviceId,proto3" json:"service_id,omitempty"` - AccountId string `protobuf:"bytes,2,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` - Status ProxyStatus `protobuf:"varint,3,opt,name=status,proto3,enum=management.ProxyStatus" json:"status,omitempty"` - CertificateIssued bool `protobuf:"varint,4,opt,name=certificate_issued,json=certificateIssued,proto3" json:"certificate_issued,omitempty"` - ErrorMessage *string `protobuf:"bytes,5,opt,name=error_message,json=errorMessage,proto3,oneof" json:"error_message,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ServiceId string `protobuf:"bytes,1,opt,name=service_id,json=serviceId,proto3" json:"service_id,omitempty"` + AccountId string `protobuf:"bytes,2,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` + Status ProxyStatus `protobuf:"varint,3,opt,name=status,proto3,enum=management.ProxyStatus" json:"status,omitempty"` + CertificateIssued bool `protobuf:"varint,4,opt,name=certificate_issued,json=certificateIssued,proto3" json:"certificate_issued,omitempty"` + ErrorMessage *string `protobuf:"bytes,5,opt,name=error_message,json=errorMessage,proto3,oneof" json:"error_message,omitempty"` } func (x *SendStatusUpdateRequest) Reset() { *x = SendStatusUpdateRequest{} - mi := &file_proxy_service_proto_msgTypes[17] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[17] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *SendStatusUpdateRequest) String() string { @@ -1435,7 +1508,7 @@ func (*SendStatusUpdateRequest) ProtoMessage() {} func (x *SendStatusUpdateRequest) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[17] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1487,16 +1560,18 @@ func (x *SendStatusUpdateRequest) GetErrorMessage() string { // SendStatusUpdateResponse is intentionally empty to allow for future expansion type SendStatusUpdateResponse struct { - state protoimpl.MessageState `protogen:"open.v1"` - unknownFields protoimpl.UnknownFields + state protoimpl.MessageState sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields } func (x *SendStatusUpdateResponse) Reset() { *x = SendStatusUpdateResponse{} - mi := &file_proxy_service_proto_msgTypes[18] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[18] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *SendStatusUpdateResponse) String() string { @@ -1507,7 +1582,7 @@ func (*SendStatusUpdateResponse) ProtoMessage() {} func (x *SendStatusUpdateResponse) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[18] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1525,21 +1600,24 @@ func (*SendStatusUpdateResponse) Descriptor() ([]byte, []int) { // CreateProxyPeerRequest is sent by the proxy to create a peer connection // The token is a one-time authentication token sent via ProxyMapping type CreateProxyPeerRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - ServiceId string `protobuf:"bytes,1,opt,name=service_id,json=serviceId,proto3" json:"service_id,omitempty"` - AccountId string `protobuf:"bytes,2,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` - Token string `protobuf:"bytes,3,opt,name=token,proto3" json:"token,omitempty"` - WireguardPublicKey string `protobuf:"bytes,4,opt,name=wireguard_public_key,json=wireguardPublicKey,proto3" json:"wireguard_public_key,omitempty"` - Cluster string `protobuf:"bytes,5,opt,name=cluster,proto3" json:"cluster,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ServiceId string `protobuf:"bytes,1,opt,name=service_id,json=serviceId,proto3" json:"service_id,omitempty"` + AccountId string `protobuf:"bytes,2,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` + Token string `protobuf:"bytes,3,opt,name=token,proto3" json:"token,omitempty"` + WireguardPublicKey string `protobuf:"bytes,4,opt,name=wireguard_public_key,json=wireguardPublicKey,proto3" json:"wireguard_public_key,omitempty"` + Cluster string `protobuf:"bytes,5,opt,name=cluster,proto3" json:"cluster,omitempty"` } func (x *CreateProxyPeerRequest) Reset() { *x = CreateProxyPeerRequest{} - mi := &file_proxy_service_proto_msgTypes[19] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[19] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *CreateProxyPeerRequest) String() string { @@ -1550,7 +1628,7 @@ func (*CreateProxyPeerRequest) ProtoMessage() {} func (x *CreateProxyPeerRequest) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[19] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1602,18 +1680,21 @@ func (x *CreateProxyPeerRequest) GetCluster() string { // CreateProxyPeerResponse contains the result of peer creation type CreateProxyPeerResponse struct { - state protoimpl.MessageState `protogen:"open.v1"` - Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` - ErrorMessage *string `protobuf:"bytes,2,opt,name=error_message,json=errorMessage,proto3,oneof" json:"error_message,omitempty"` - unknownFields protoimpl.UnknownFields + state protoimpl.MessageState sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + ErrorMessage *string `protobuf:"bytes,2,opt,name=error_message,json=errorMessage,proto3,oneof" json:"error_message,omitempty"` } func (x *CreateProxyPeerResponse) Reset() { *x = CreateProxyPeerResponse{} - mi := &file_proxy_service_proto_msgTypes[20] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[20] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *CreateProxyPeerResponse) String() string { @@ -1624,7 +1705,7 @@ func (*CreateProxyPeerResponse) ProtoMessage() {} func (x *CreateProxyPeerResponse) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[20] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1654,19 +1735,22 @@ func (x *CreateProxyPeerResponse) GetErrorMessage() string { } type GetOIDCURLRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` - AccountId string `protobuf:"bytes,2,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` - RedirectUrl string `protobuf:"bytes,3,opt,name=redirect_url,json=redirectUrl,proto3" json:"redirect_url,omitempty"` - unknownFields protoimpl.UnknownFields + state protoimpl.MessageState sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + AccountId string `protobuf:"bytes,2,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` + RedirectUrl string `protobuf:"bytes,3,opt,name=redirect_url,json=redirectUrl,proto3" json:"redirect_url,omitempty"` } func (x *GetOIDCURLRequest) Reset() { *x = GetOIDCURLRequest{} - mi := &file_proxy_service_proto_msgTypes[21] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[21] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *GetOIDCURLRequest) String() string { @@ -1677,7 +1761,7 @@ func (*GetOIDCURLRequest) ProtoMessage() {} func (x *GetOIDCURLRequest) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[21] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1714,17 +1798,20 @@ func (x *GetOIDCURLRequest) GetRedirectUrl() string { } type GetOIDCURLResponse struct { - state protoimpl.MessageState `protogen:"open.v1"` - Url string `protobuf:"bytes,1,opt,name=url,proto3" json:"url,omitempty"` - unknownFields protoimpl.UnknownFields + state protoimpl.MessageState sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Url string `protobuf:"bytes,1,opt,name=url,proto3" json:"url,omitempty"` } func (x *GetOIDCURLResponse) Reset() { *x = GetOIDCURLResponse{} - mi := &file_proxy_service_proto_msgTypes[22] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[22] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *GetOIDCURLResponse) String() string { @@ -1735,7 +1822,7 @@ func (*GetOIDCURLResponse) ProtoMessage() {} func (x *GetOIDCURLResponse) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[22] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1758,18 +1845,21 @@ func (x *GetOIDCURLResponse) GetUrl() string { } type ValidateSessionRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - Domain string `protobuf:"bytes,1,opt,name=domain,proto3" json:"domain,omitempty"` - SessionToken string `protobuf:"bytes,2,opt,name=session_token,json=sessionToken,proto3" json:"session_token,omitempty"` - unknownFields protoimpl.UnknownFields + state protoimpl.MessageState sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Domain string `protobuf:"bytes,1,opt,name=domain,proto3" json:"domain,omitempty"` + SessionToken string `protobuf:"bytes,2,opt,name=session_token,json=sessionToken,proto3" json:"session_token,omitempty"` } func (x *ValidateSessionRequest) Reset() { *x = ValidateSessionRequest{} - mi := &file_proxy_service_proto_msgTypes[23] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[23] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *ValidateSessionRequest) String() string { @@ -1780,7 +1870,7 @@ func (*ValidateSessionRequest) ProtoMessage() {} func (x *ValidateSessionRequest) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[23] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1810,20 +1900,23 @@ func (x *ValidateSessionRequest) GetSessionToken() string { } type ValidateSessionResponse struct { - state protoimpl.MessageState `protogen:"open.v1"` - Valid bool `protobuf:"varint,1,opt,name=valid,proto3" json:"valid,omitempty"` - UserId string `protobuf:"bytes,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` - UserEmail string `protobuf:"bytes,3,opt,name=user_email,json=userEmail,proto3" json:"user_email,omitempty"` - DeniedReason string `protobuf:"bytes,4,opt,name=denied_reason,json=deniedReason,proto3" json:"denied_reason,omitempty"` - unknownFields protoimpl.UnknownFields + state protoimpl.MessageState sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Valid bool `protobuf:"varint,1,opt,name=valid,proto3" json:"valid,omitempty"` + UserId string `protobuf:"bytes,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` + UserEmail string `protobuf:"bytes,3,opt,name=user_email,json=userEmail,proto3" json:"user_email,omitempty"` + DeniedReason string `protobuf:"bytes,4,opt,name=denied_reason,json=deniedReason,proto3" json:"denied_reason,omitempty"` } func (x *ValidateSessionResponse) Reset() { *x = ValidateSessionResponse{} - mi := &file_proxy_service_proto_msgTypes[24] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[24] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *ValidateSessionResponse) String() string { @@ -1834,7 +1927,7 @@ func (*ValidateSessionResponse) ProtoMessage() {} func (x *ValidateSessionResponse) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[24] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1879,195 +1972,370 @@ func (x *ValidateSessionResponse) GetDeniedReason() string { var File_proxy_service_proto protoreflect.FileDescriptor -const file_proxy_service_proto_rawDesc = "" + - "\n" + - "\x13proxy_service.proto\x12\n" + - "management\x1a\x1egoogle/protobuf/duration.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xae\x01\n" + - "\x11ProxyCapabilities\x127\n" + - "\x15supports_custom_ports\x18\x01 \x01(\bH\x00R\x13supportsCustomPorts\x88\x01\x01\x120\n" + - "\x11require_subdomain\x18\x02 \x01(\bH\x01R\x10requireSubdomain\x88\x01\x01B\x18\n" + - "\x16_supports_custom_portsB\x14\n" + - "\x12_require_subdomain\"\xe6\x01\n" + - "\x17GetMappingUpdateRequest\x12\x19\n" + - "\bproxy_id\x18\x01 \x01(\tR\aproxyId\x12\x18\n" + - "\aversion\x18\x02 \x01(\tR\aversion\x129\n" + - "\n" + - "started_at\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampR\tstartedAt\x12\x18\n" + - "\aaddress\x18\x04 \x01(\tR\aaddress\x12A\n" + - "\fcapabilities\x18\x05 \x01(\v2\x1d.management.ProxyCapabilitiesR\fcapabilities\"\x82\x01\n" + - "\x18GetMappingUpdateResponse\x122\n" + - "\amapping\x18\x01 \x03(\v2\x18.management.ProxyMappingR\amapping\x122\n" + - "\x15initial_sync_complete\x18\x02 \x01(\bR\x13initialSyncComplete\"\xce\x03\n" + - "\x11PathTargetOptions\x12&\n" + - "\x0fskip_tls_verify\x18\x01 \x01(\bR\rskipTlsVerify\x12B\n" + - "\x0frequest_timeout\x18\x02 \x01(\v2\x19.google.protobuf.DurationR\x0erequestTimeout\x12>\n" + - "\fpath_rewrite\x18\x03 \x01(\x0e2\x1b.management.PathRewriteModeR\vpathRewrite\x12W\n" + - "\x0ecustom_headers\x18\x04 \x03(\v20.management.PathTargetOptions.CustomHeadersEntryR\rcustomHeaders\x12%\n" + - "\x0eproxy_protocol\x18\x05 \x01(\bR\rproxyProtocol\x12K\n" + - "\x14session_idle_timeout\x18\x06 \x01(\v2\x19.google.protobuf.DurationR\x12sessionIdleTimeout\x1a@\n" + - "\x12CustomHeadersEntry\x12\x10\n" + - "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + - "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01\"r\n" + - "\vPathMapping\x12\x12\n" + - "\x04path\x18\x01 \x01(\tR\x04path\x12\x16\n" + - "\x06target\x18\x02 \x01(\tR\x06target\x127\n" + - "\aoptions\x18\x03 \x01(\v2\x1d.management.PathTargetOptionsR\aoptions\"G\n" + - "\n" + - "HeaderAuth\x12\x16\n" + - "\x06header\x18\x01 \x01(\tR\x06header\x12!\n" + - "\fhashed_value\x18\x02 \x01(\tR\vhashedValue\"\xe5\x01\n" + - "\x0eAuthentication\x12\x1f\n" + - "\vsession_key\x18\x01 \x01(\tR\n" + - "sessionKey\x125\n" + - "\x17max_session_age_seconds\x18\x02 \x01(\x03R\x14maxSessionAgeSeconds\x12\x1a\n" + - "\bpassword\x18\x03 \x01(\bR\bpassword\x12\x10\n" + - "\x03pin\x18\x04 \x01(\bR\x03pin\x12\x12\n" + - "\x04oidc\x18\x05 \x01(\bR\x04oidc\x129\n" + - "\fheader_auths\x18\x06 \x03(\v2\x16.management.HeaderAuthR\vheaderAuths\"\xb8\x01\n" + - "\x12AccessRestrictions\x12#\n" + - "\rallowed_cidrs\x18\x01 \x03(\tR\fallowedCidrs\x12#\n" + - "\rblocked_cidrs\x18\x02 \x03(\tR\fblockedCidrs\x12+\n" + - "\x11allowed_countries\x18\x03 \x03(\tR\x10allowedCountries\x12+\n" + - "\x11blocked_countries\x18\x04 \x03(\tR\x10blockedCountries\"\xe6\x03\n" + - "\fProxyMapping\x126\n" + - "\x04type\x18\x01 \x01(\x0e2\".management.ProxyMappingUpdateTypeR\x04type\x12\x0e\n" + - "\x02id\x18\x02 \x01(\tR\x02id\x12\x1d\n" + - "\n" + - "account_id\x18\x03 \x01(\tR\taccountId\x12\x16\n" + - "\x06domain\x18\x04 \x01(\tR\x06domain\x12+\n" + - "\x04path\x18\x05 \x03(\v2\x17.management.PathMappingR\x04path\x12\x1d\n" + - "\n" + - "auth_token\x18\x06 \x01(\tR\tauthToken\x12.\n" + - "\x04auth\x18\a \x01(\v2\x1a.management.AuthenticationR\x04auth\x12(\n" + - "\x10pass_host_header\x18\b \x01(\bR\x0epassHostHeader\x12+\n" + - "\x11rewrite_redirects\x18\t \x01(\bR\x10rewriteRedirects\x12\x12\n" + - "\x04mode\x18\n" + - " \x01(\tR\x04mode\x12\x1f\n" + - "\vlisten_port\x18\v \x01(\x05R\n" + - "listenPort\x12O\n" + - "\x13access_restrictions\x18\f \x01(\v2\x1e.management.AccessRestrictionsR\x12accessRestrictions\"?\n" + - "\x14SendAccessLogRequest\x12'\n" + - "\x03log\x18\x01 \x01(\v2\x15.management.AccessLogR\x03log\"\x17\n" + - "\x15SendAccessLogResponse\"\x86\x04\n" + - "\tAccessLog\x128\n" + - "\ttimestamp\x18\x01 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12\x15\n" + - "\x06log_id\x18\x02 \x01(\tR\x05logId\x12\x1d\n" + - "\n" + - "account_id\x18\x03 \x01(\tR\taccountId\x12\x1d\n" + - "\n" + - "service_id\x18\x04 \x01(\tR\tserviceId\x12\x12\n" + - "\x04host\x18\x05 \x01(\tR\x04host\x12\x12\n" + - "\x04path\x18\x06 \x01(\tR\x04path\x12\x1f\n" + - "\vduration_ms\x18\a \x01(\x03R\n" + - "durationMs\x12\x16\n" + - "\x06method\x18\b \x01(\tR\x06method\x12#\n" + - "\rresponse_code\x18\t \x01(\x05R\fresponseCode\x12\x1b\n" + - "\tsource_ip\x18\n" + - " \x01(\tR\bsourceIp\x12%\n" + - "\x0eauth_mechanism\x18\v \x01(\tR\rauthMechanism\x12\x17\n" + - "\auser_id\x18\f \x01(\tR\x06userId\x12!\n" + - "\fauth_success\x18\r \x01(\bR\vauthSuccess\x12!\n" + - "\fbytes_upload\x18\x0e \x01(\x03R\vbytesUpload\x12%\n" + - "\x0ebytes_download\x18\x0f \x01(\x03R\rbytesDownload\x12\x1a\n" + - "\bprotocol\x18\x10 \x01(\tR\bprotocol\"\xf8\x01\n" + - "\x13AuthenticateRequest\x12\x0e\n" + - "\x02id\x18\x01 \x01(\tR\x02id\x12\x1d\n" + - "\n" + - "account_id\x18\x02 \x01(\tR\taccountId\x129\n" + - "\bpassword\x18\x03 \x01(\v2\x1b.management.PasswordRequestH\x00R\bpassword\x12*\n" + - "\x03pin\x18\x04 \x01(\v2\x16.management.PinRequestH\x00R\x03pin\x12@\n" + - "\vheader_auth\x18\x05 \x01(\v2\x1d.management.HeaderAuthRequestH\x00R\n" + - "headerAuthB\t\n" + - "\arequest\"W\n" + - "\x11HeaderAuthRequest\x12!\n" + - "\fheader_value\x18\x01 \x01(\tR\vheaderValue\x12\x1f\n" + - "\vheader_name\x18\x02 \x01(\tR\n" + - "headerName\"-\n" + - "\x0fPasswordRequest\x12\x1a\n" + - "\bpassword\x18\x01 \x01(\tR\bpassword\"\x1e\n" + - "\n" + - "PinRequest\x12\x10\n" + - "\x03pin\x18\x01 \x01(\tR\x03pin\"U\n" + - "\x14AuthenticateResponse\x12\x18\n" + - "\asuccess\x18\x01 \x01(\bR\asuccess\x12#\n" + - "\rsession_token\x18\x02 \x01(\tR\fsessionToken\"\xf3\x01\n" + - "\x17SendStatusUpdateRequest\x12\x1d\n" + - "\n" + - "service_id\x18\x01 \x01(\tR\tserviceId\x12\x1d\n" + - "\n" + - "account_id\x18\x02 \x01(\tR\taccountId\x12/\n" + - "\x06status\x18\x03 \x01(\x0e2\x17.management.ProxyStatusR\x06status\x12-\n" + - "\x12certificate_issued\x18\x04 \x01(\bR\x11certificateIssued\x12(\n" + - "\rerror_message\x18\x05 \x01(\tH\x00R\ferrorMessage\x88\x01\x01B\x10\n" + - "\x0e_error_message\"\x1a\n" + - "\x18SendStatusUpdateResponse\"\xb8\x01\n" + - "\x16CreateProxyPeerRequest\x12\x1d\n" + - "\n" + - "service_id\x18\x01 \x01(\tR\tserviceId\x12\x1d\n" + - "\n" + - "account_id\x18\x02 \x01(\tR\taccountId\x12\x14\n" + - "\x05token\x18\x03 \x01(\tR\x05token\x120\n" + - "\x14wireguard_public_key\x18\x04 \x01(\tR\x12wireguardPublicKey\x12\x18\n" + - "\acluster\x18\x05 \x01(\tR\acluster\"o\n" + - "\x17CreateProxyPeerResponse\x12\x18\n" + - "\asuccess\x18\x01 \x01(\bR\asuccess\x12(\n" + - "\rerror_message\x18\x02 \x01(\tH\x00R\ferrorMessage\x88\x01\x01B\x10\n" + - "\x0e_error_message\"e\n" + - "\x11GetOIDCURLRequest\x12\x0e\n" + - "\x02id\x18\x01 \x01(\tR\x02id\x12\x1d\n" + - "\n" + - "account_id\x18\x02 \x01(\tR\taccountId\x12!\n" + - "\fredirect_url\x18\x03 \x01(\tR\vredirectUrl\"&\n" + - "\x12GetOIDCURLResponse\x12\x10\n" + - "\x03url\x18\x01 \x01(\tR\x03url\"U\n" + - "\x16ValidateSessionRequest\x12\x16\n" + - "\x06domain\x18\x01 \x01(\tR\x06domain\x12#\n" + - "\rsession_token\x18\x02 \x01(\tR\fsessionToken\"\x8c\x01\n" + - "\x17ValidateSessionResponse\x12\x14\n" + - "\x05valid\x18\x01 \x01(\bR\x05valid\x12\x17\n" + - "\auser_id\x18\x02 \x01(\tR\x06userId\x12\x1d\n" + - "\n" + - "user_email\x18\x03 \x01(\tR\tuserEmail\x12#\n" + - "\rdenied_reason\x18\x04 \x01(\tR\fdeniedReason*d\n" + - "\x16ProxyMappingUpdateType\x12\x17\n" + - "\x13UPDATE_TYPE_CREATED\x10\x00\x12\x18\n" + - "\x14UPDATE_TYPE_MODIFIED\x10\x01\x12\x17\n" + - "\x13UPDATE_TYPE_REMOVED\x10\x02*F\n" + - "\x0fPathRewriteMode\x12\x18\n" + - "\x14PATH_REWRITE_DEFAULT\x10\x00\x12\x19\n" + - "\x15PATH_REWRITE_PRESERVE\x10\x01*\xc8\x01\n" + - "\vProxyStatus\x12\x18\n" + - "\x14PROXY_STATUS_PENDING\x10\x00\x12\x17\n" + - "\x13PROXY_STATUS_ACTIVE\x10\x01\x12#\n" + - "\x1fPROXY_STATUS_TUNNEL_NOT_CREATED\x10\x02\x12$\n" + - " PROXY_STATUS_CERTIFICATE_PENDING\x10\x03\x12#\n" + - "\x1fPROXY_STATUS_CERTIFICATE_FAILED\x10\x04\x12\x16\n" + - "\x12PROXY_STATUS_ERROR\x10\x052\xfc\x04\n" + - "\fProxyService\x12_\n" + - "\x10GetMappingUpdate\x12#.management.GetMappingUpdateRequest\x1a$.management.GetMappingUpdateResponse0\x01\x12T\n" + - "\rSendAccessLog\x12 .management.SendAccessLogRequest\x1a!.management.SendAccessLogResponse\x12Q\n" + - "\fAuthenticate\x12\x1f.management.AuthenticateRequest\x1a .management.AuthenticateResponse\x12]\n" + - "\x10SendStatusUpdate\x12#.management.SendStatusUpdateRequest\x1a$.management.SendStatusUpdateResponse\x12Z\n" + - "\x0fCreateProxyPeer\x12\".management.CreateProxyPeerRequest\x1a#.management.CreateProxyPeerResponse\x12K\n" + - "\n" + - "GetOIDCURL\x12\x1d.management.GetOIDCURLRequest\x1a\x1e.management.GetOIDCURLResponse\x12Z\n" + - "\x0fValidateSession\x12\".management.ValidateSessionRequest\x1a#.management.ValidateSessionResponseB\bZ\x06/protob\x06proto3" +var file_proxy_service_proto_rawDesc = []byte{ + 0x0a, 0x13, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0a, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x1a, 0x1e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x22, 0xf6, 0x01, 0x0a, 0x11, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x43, 0x61, 0x70, 0x61, + 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x12, 0x37, 0x0a, 0x15, 0x73, 0x75, 0x70, 0x70, + 0x6f, 0x72, 0x74, 0x73, 0x5f, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x70, 0x6f, 0x72, 0x74, + 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x13, 0x73, 0x75, 0x70, 0x70, 0x6f, + 0x72, 0x74, 0x73, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x6f, 0x72, 0x74, 0x73, 0x88, 0x01, + 0x01, 0x12, 0x30, 0x0a, 0x11, 0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x5f, 0x73, 0x75, 0x62, + 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x48, 0x01, 0x52, 0x10, + 0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x53, 0x75, 0x62, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x88, 0x01, 0x01, 0x12, 0x30, 0x0a, 0x11, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x5f, + 0x63, 0x72, 0x6f, 0x77, 0x64, 0x73, 0x65, 0x63, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x48, 0x02, + 0x52, 0x10, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x43, 0x72, 0x6f, 0x77, 0x64, 0x73, + 0x65, 0x63, 0x88, 0x01, 0x01, 0x42, 0x18, 0x0a, 0x16, 0x5f, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, + 0x74, 0x73, 0x5f, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x42, + 0x14, 0x0a, 0x12, 0x5f, 0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x5f, 0x73, 0x75, 0x62, 0x64, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x42, 0x14, 0x0a, 0x12, 0x5f, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, + 0x74, 0x73, 0x5f, 0x63, 0x72, 0x6f, 0x77, 0x64, 0x73, 0x65, 0x63, 0x22, 0xe6, 0x01, 0x0a, 0x17, + 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, + 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x70, 0x72, 0x6f, 0x78, 0x79, + 0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x39, 0x0a, 0x0a, + 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x73, 0x74, + 0x61, 0x72, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, + 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, + 0x73, 0x12, 0x41, 0x0a, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, + 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x43, 0x61, 0x70, 0x61, 0x62, 0x69, + 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x52, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, + 0x74, 0x69, 0x65, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, + 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x32, 0x0a, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x18, 0x01, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x52, 0x07, 0x6d, 0x61, + 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x32, 0x0a, 0x15, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, + 0x5f, 0x73, 0x79, 0x6e, 0x63, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x53, 0x79, 0x6e, + 0x63, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x22, 0xce, 0x03, 0x0a, 0x11, 0x50, 0x61, + 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, + 0x26, 0x0a, 0x0f, 0x73, 0x6b, 0x69, 0x70, 0x5f, 0x74, 0x6c, 0x73, 0x5f, 0x76, 0x65, 0x72, 0x69, + 0x66, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x54, 0x6c, + 0x73, 0x56, 0x65, 0x72, 0x69, 0x66, 0x79, 0x12, 0x42, 0x0a, 0x0f, 0x72, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x0e, 0x72, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x12, 0x3e, 0x0a, 0x0c, 0x70, + 0x61, 0x74, 0x68, 0x5f, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x0e, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, + 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x4d, 0x6f, 0x64, 0x65, 0x52, 0x0b, + 0x70, 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x12, 0x57, 0x0a, 0x0e, 0x63, + 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x30, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, + 0x6e, 0x73, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, + 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0d, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x48, 0x65, 0x61, + 0x64, 0x65, 0x72, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x70, 0x72, + 0x6f, 0x78, 0x79, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x4b, 0x0a, 0x14, 0x73, + 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x6c, 0x65, 0x5f, 0x74, 0x69, 0x6d, 0x65, + 0x6f, 0x75, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, + 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x12, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x6c, + 0x65, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x1a, 0x40, 0x0a, 0x12, 0x43, 0x75, 0x73, 0x74, + 0x6f, 0x6d, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, + 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, + 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x72, 0x0a, 0x0b, 0x50, 0x61, + 0x74, 0x68, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, + 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x16, 0x0a, + 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x74, + 0x61, 0x72, 0x67, 0x65, 0x74, 0x12, 0x37, 0x0a, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, + 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x22, 0x47, + 0x0a, 0x0a, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, 0x75, 0x74, 0x68, 0x12, 0x16, 0x0a, 0x06, + 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x68, 0x65, + 0x61, 0x64, 0x65, 0x72, 0x12, 0x21, 0x0a, 0x0c, 0x68, 0x61, 0x73, 0x68, 0x65, 0x64, 0x5f, 0x76, + 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x68, 0x61, 0x73, 0x68, + 0x65, 0x64, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x22, 0xe5, 0x01, 0x0a, 0x0e, 0x41, 0x75, 0x74, 0x68, + 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x65, + 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0a, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x4b, 0x65, 0x79, 0x12, 0x35, 0x0a, 0x17, 0x6d, + 0x61, 0x78, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x61, 0x67, 0x65, 0x5f, 0x73, + 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x14, 0x6d, 0x61, + 0x78, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x41, 0x67, 0x65, 0x53, 0x65, 0x63, 0x6f, 0x6e, + 0x64, 0x73, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x12, 0x10, + 0x0a, 0x03, 0x70, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x70, 0x69, 0x6e, + 0x12, 0x12, 0x0a, 0x04, 0x6f, 0x69, 0x64, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, + 0x6f, 0x69, 0x64, 0x63, 0x12, 0x39, 0x0a, 0x0c, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x5f, 0x61, + 0x75, 0x74, 0x68, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, 0x75, + 0x74, 0x68, 0x52, 0x0b, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, 0x75, 0x74, 0x68, 0x73, 0x22, + 0xdd, 0x01, 0x0a, 0x12, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x74, 0x72, 0x69, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, + 0x64, 0x5f, 0x63, 0x69, 0x64, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x61, + 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x43, 0x69, 0x64, 0x72, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x62, + 0x6c, 0x6f, 0x63, 0x6b, 0x65, 0x64, 0x5f, 0x63, 0x69, 0x64, 0x72, 0x73, 0x18, 0x02, 0x20, 0x03, + 0x28, 0x09, 0x52, 0x0c, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x65, 0x64, 0x43, 0x69, 0x64, 0x72, 0x73, + 0x12, 0x2b, 0x0a, 0x11, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x5f, 0x63, 0x6f, 0x75, 0x6e, + 0x74, 0x72, 0x69, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x10, 0x61, 0x6c, 0x6c, + 0x6f, 0x77, 0x65, 0x64, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x72, 0x69, 0x65, 0x73, 0x12, 0x2b, 0x0a, + 0x11, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x65, 0x64, 0x5f, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x72, 0x69, + 0x65, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x10, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x65, + 0x64, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x72, 0x69, 0x65, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x63, 0x72, + 0x6f, 0x77, 0x64, 0x73, 0x65, 0x63, 0x5f, 0x6d, 0x6f, 0x64, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x0c, 0x63, 0x72, 0x6f, 0x77, 0x64, 0x73, 0x65, 0x63, 0x4d, 0x6f, 0x64, 0x65, 0x22, + 0xe6, 0x03, 0x0a, 0x0c, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, + 0x12, 0x36, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x22, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, + 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, + 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, + 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, + 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, + 0x2b, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x17, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x4d, + 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1d, 0x0a, 0x0a, + 0x61, 0x75, 0x74, 0x68, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x09, 0x61, 0x75, 0x74, 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x2e, 0x0a, 0x04, 0x61, + 0x75, 0x74, 0x68, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x04, 0x61, 0x75, 0x74, 0x68, 0x12, 0x28, 0x0a, 0x10, 0x70, + 0x61, 0x73, 0x73, 0x5f, 0x68, 0x6f, 0x73, 0x74, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x18, + 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x70, 0x61, 0x73, 0x73, 0x48, 0x6f, 0x73, 0x74, 0x48, + 0x65, 0x61, 0x64, 0x65, 0x72, 0x12, 0x2b, 0x0a, 0x11, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, + 0x5f, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x73, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x10, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, + 0x74, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x6f, 0x64, 0x65, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x04, 0x6d, 0x6f, 0x64, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, + 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x6c, 0x69, 0x73, + 0x74, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x4f, 0x0a, 0x13, 0x61, 0x63, 0x63, 0x65, 0x73, + 0x73, 0x5f, 0x72, 0x65, 0x73, 0x74, 0x72, 0x69, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x0c, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x74, 0x72, 0x69, 0x63, 0x74, + 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x12, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x74, + 0x72, 0x69, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x22, 0x3f, 0x0a, 0x14, 0x53, 0x65, 0x6e, 0x64, + 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x27, 0x0a, 0x03, 0x6c, 0x6f, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, + 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x03, 0x6c, 0x6f, 0x67, 0x22, 0x17, 0x0a, 0x15, 0x53, 0x65, 0x6e, + 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x22, 0x84, 0x05, 0x0a, 0x09, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, + 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, + 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12, 0x15, 0x0a, 0x06, 0x6c, 0x6f, + 0x67, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6c, 0x6f, 0x67, 0x49, + 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, + 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, + 0x12, 0x0a, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x68, + 0x6f, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x06, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1f, 0x0a, 0x0b, 0x64, 0x75, 0x72, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x5f, 0x6d, 0x73, 0x18, 0x07, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0a, 0x64, 0x75, + 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x74, 0x68, + 0x6f, 0x64, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, + 0x12, 0x23, 0x0a, 0x0d, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5f, 0x63, 0x6f, 0x64, + 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0c, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, + 0x69, 0x70, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, + 0x49, 0x70, 0x12, 0x25, 0x0a, 0x0e, 0x61, 0x75, 0x74, 0x68, 0x5f, 0x6d, 0x65, 0x63, 0x68, 0x61, + 0x6e, 0x69, 0x73, 0x6d, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x61, 0x75, 0x74, 0x68, + 0x4d, 0x65, 0x63, 0x68, 0x61, 0x6e, 0x69, 0x73, 0x6d, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, + 0x72, 0x5f, 0x69, 0x64, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, + 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x61, 0x75, 0x74, 0x68, 0x5f, 0x73, 0x75, 0x63, 0x63, 0x65, + 0x73, 0x73, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, 0x61, 0x75, 0x74, 0x68, 0x53, 0x75, + 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x75, + 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x62, 0x79, 0x74, + 0x65, 0x73, 0x55, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x25, 0x0a, 0x0e, 0x62, 0x79, 0x74, 0x65, + 0x73, 0x5f, 0x64, 0x6f, 0x77, 0x6e, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x0d, 0x62, 0x79, 0x74, 0x65, 0x73, 0x44, 0x6f, 0x77, 0x6e, 0x6c, 0x6f, 0x61, 0x64, 0x12, + 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x10, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3f, 0x0a, 0x08, 0x6d, + 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x11, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x23, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, + 0x73, 0x4c, 0x6f, 0x67, 0x2e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, + 0x72, 0x79, 0x52, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x1a, 0x3b, 0x0a, 0x0d, + 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, + 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, + 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0xf8, 0x01, 0x0a, 0x13, 0x41, 0x75, + 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, + 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, + 0x12, 0x39, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x50, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, + 0x00, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x12, 0x2a, 0x0a, 0x03, 0x70, + 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x48, 0x00, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x12, 0x40, 0x0a, 0x0b, 0x68, 0x65, 0x61, 0x64, 0x65, + 0x72, 0x5f, 0x61, 0x75, 0x74, 0x68, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, + 0x41, 0x75, 0x74, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x0a, 0x68, + 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, 0x75, 0x74, 0x68, 0x42, 0x09, 0x0a, 0x07, 0x72, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x22, 0x57, 0x0a, 0x11, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, 0x75, + 0x74, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x21, 0x0a, 0x0c, 0x68, 0x65, 0x61, + 0x64, 0x65, 0x72, 0x5f, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0b, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x12, 0x1f, 0x0a, 0x0b, + 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x0a, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x4e, 0x61, 0x6d, 0x65, 0x22, 0x2d, 0x0a, + 0x0f, 0x50, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x1e, 0x0a, 0x0a, + 0x50, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x69, + 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x22, 0x55, 0x0a, 0x14, + 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x23, + 0x0a, 0x0d, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x54, 0x6f, + 0x6b, 0x65, 0x6e, 0x22, 0xf3, 0x01, 0x0a, 0x17, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, + 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, 0x1d, + 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x2f, 0x0a, + 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x17, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, + 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x2d, + 0x0a, 0x12, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x5f, 0x69, 0x73, + 0x73, 0x75, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x11, 0x63, 0x65, 0x72, 0x74, + 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x49, 0x73, 0x73, 0x75, 0x65, 0x64, 0x12, 0x28, 0x0a, + 0x0d, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x0c, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x4d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x88, 0x01, 0x01, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x65, 0x72, 0x72, 0x6f, + 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x1a, 0x0a, 0x18, 0x53, 0x65, 0x6e, + 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xb8, 0x01, 0x0a, 0x16, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, + 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, + 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x14, + 0x0a, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x74, + 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x30, 0x0a, 0x14, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, + 0x64, 0x5f, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x12, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x75, 0x62, + 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, + 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, + 0x22, 0x6f, 0x0a, 0x17, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, + 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, + 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, + 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x28, 0x0a, 0x0d, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x0c, + 0x65, 0x72, 0x72, 0x6f, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x88, 0x01, 0x01, 0x42, + 0x10, 0x0a, 0x0e, 0x5f, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, + 0x65, 0x22, 0x65, 0x0a, 0x11, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, + 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, + 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, + 0x74, 0x5f, 0x75, 0x72, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x72, 0x65, 0x64, + 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x72, 0x6c, 0x22, 0x26, 0x0a, 0x12, 0x47, 0x65, 0x74, 0x4f, + 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x10, + 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x6c, + 0x22, 0x55, 0x0a, 0x16, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, + 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x12, 0x23, 0x0a, 0x0d, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x6f, + 0x6b, 0x65, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, 0x73, 0x73, 0x69, + 0x6f, 0x6e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0x8c, 0x01, 0x0a, 0x17, 0x56, 0x61, 0x6c, 0x69, + 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, + 0x72, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, + 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x65, 0x6d, 0x61, 0x69, 0x6c, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x73, 0x65, 0x72, 0x45, 0x6d, 0x61, 0x69, + 0x6c, 0x12, 0x23, 0x0a, 0x0d, 0x64, 0x65, 0x6e, 0x69, 0x65, 0x64, 0x5f, 0x72, 0x65, 0x61, 0x73, + 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x64, 0x65, 0x6e, 0x69, 0x65, 0x64, + 0x52, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x2a, 0x64, 0x0a, 0x16, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, + 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, + 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, + 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, 0x00, 0x12, 0x18, 0x0a, 0x14, 0x55, 0x50, 0x44, + 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x4d, 0x4f, 0x44, 0x49, 0x46, 0x49, 0x45, + 0x44, 0x10, 0x01, 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, + 0x50, 0x45, 0x5f, 0x52, 0x45, 0x4d, 0x4f, 0x56, 0x45, 0x44, 0x10, 0x02, 0x2a, 0x46, 0x0a, 0x0f, + 0x50, 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x4d, 0x6f, 0x64, 0x65, 0x12, + 0x18, 0x0a, 0x14, 0x50, 0x41, 0x54, 0x48, 0x5f, 0x52, 0x45, 0x57, 0x52, 0x49, 0x54, 0x45, 0x5f, + 0x44, 0x45, 0x46, 0x41, 0x55, 0x4c, 0x54, 0x10, 0x00, 0x12, 0x19, 0x0a, 0x15, 0x50, 0x41, 0x54, + 0x48, 0x5f, 0x52, 0x45, 0x57, 0x52, 0x49, 0x54, 0x45, 0x5f, 0x50, 0x52, 0x45, 0x53, 0x45, 0x52, + 0x56, 0x45, 0x10, 0x01, 0x2a, 0xc8, 0x01, 0x0a, 0x0b, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, 0x74, + 0x61, 0x74, 0x75, 0x73, 0x12, 0x18, 0x0a, 0x14, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, + 0x41, 0x54, 0x55, 0x53, 0x5f, 0x50, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x00, 0x12, 0x17, + 0x0a, 0x13, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x41, + 0x43, 0x54, 0x49, 0x56, 0x45, 0x10, 0x01, 0x12, 0x23, 0x0a, 0x1f, 0x50, 0x52, 0x4f, 0x58, 0x59, + 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x54, 0x55, 0x4e, 0x4e, 0x45, 0x4c, 0x5f, 0x4e, + 0x4f, 0x54, 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, 0x02, 0x12, 0x24, 0x0a, 0x20, + 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x43, 0x45, 0x52, + 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x5f, 0x50, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, + 0x10, 0x03, 0x12, 0x23, 0x0a, 0x1f, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, + 0x55, 0x53, 0x5f, 0x43, 0x45, 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x5f, 0x46, + 0x41, 0x49, 0x4c, 0x45, 0x44, 0x10, 0x04, 0x12, 0x16, 0x0a, 0x12, 0x50, 0x52, 0x4f, 0x58, 0x59, + 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x05, 0x32, + 0xfc, 0x04, 0x0a, 0x0c, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, + 0x12, 0x5f, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, + 0x64, 0x61, 0x74, 0x65, 0x12, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, + 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, + 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x30, + 0x01, 0x12, 0x54, 0x0a, 0x0d, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, + 0x6f, 0x67, 0x12, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x21, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x51, 0x0a, 0x0c, 0x41, 0x75, 0x74, 0x68, 0x65, + 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, + 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, + 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5d, 0x0a, 0x10, 0x53, 0x65, + 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x23, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, + 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, + 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5a, 0x0a, 0x0f, 0x43, 0x72, 0x65, + 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x12, 0x22, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, + 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x72, + 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4b, 0x0a, 0x0a, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, + 0x55, 0x52, 0x4c, 0x12, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x5a, 0x0a, 0x0f, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, + 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, + 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, + 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x08, + 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} var ( file_proxy_service_proto_rawDescOnce sync.Once - file_proxy_service_proto_rawDescData []byte + file_proxy_service_proto_rawDescData = file_proxy_service_proto_rawDesc ) func file_proxy_service_proto_rawDescGZIP() []byte { file_proxy_service_proto_rawDescOnce.Do(func() { - file_proxy_service_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_proxy_service_proto_rawDesc), len(file_proxy_service_proto_rawDesc))) + file_proxy_service_proto_rawDescData = protoimpl.X.CompressGZIP(file_proxy_service_proto_rawDescData) }) return file_proxy_service_proto_rawDescData } var file_proxy_service_proto_enumTypes = make([]protoimpl.EnumInfo, 3) -var file_proxy_service_proto_msgTypes = make([]protoimpl.MessageInfo, 26) -var file_proxy_service_proto_goTypes = []any{ +var file_proxy_service_proto_msgTypes = make([]protoimpl.MessageInfo, 27) +var file_proxy_service_proto_goTypes = []interface{}{ (ProxyMappingUpdateType)(0), // 0: management.ProxyMappingUpdateType (PathRewriteMode)(0), // 1: management.PathRewriteMode (ProxyStatus)(0), // 2: management.ProxyStatus @@ -2097,17 +2365,18 @@ var file_proxy_service_proto_goTypes = []any{ (*ValidateSessionRequest)(nil), // 26: management.ValidateSessionRequest (*ValidateSessionResponse)(nil), // 27: management.ValidateSessionResponse nil, // 28: management.PathTargetOptions.CustomHeadersEntry - (*timestamppb.Timestamp)(nil), // 29: google.protobuf.Timestamp - (*durationpb.Duration)(nil), // 30: google.protobuf.Duration + nil, // 29: management.AccessLog.MetadataEntry + (*timestamppb.Timestamp)(nil), // 30: google.protobuf.Timestamp + (*durationpb.Duration)(nil), // 31: google.protobuf.Duration } var file_proxy_service_proto_depIdxs = []int32{ - 29, // 0: management.GetMappingUpdateRequest.started_at:type_name -> google.protobuf.Timestamp + 30, // 0: management.GetMappingUpdateRequest.started_at:type_name -> google.protobuf.Timestamp 3, // 1: management.GetMappingUpdateRequest.capabilities:type_name -> management.ProxyCapabilities 11, // 2: management.GetMappingUpdateResponse.mapping:type_name -> management.ProxyMapping - 30, // 3: management.PathTargetOptions.request_timeout:type_name -> google.protobuf.Duration + 31, // 3: management.PathTargetOptions.request_timeout:type_name -> google.protobuf.Duration 1, // 4: management.PathTargetOptions.path_rewrite:type_name -> management.PathRewriteMode 28, // 5: management.PathTargetOptions.custom_headers:type_name -> management.PathTargetOptions.CustomHeadersEntry - 30, // 6: management.PathTargetOptions.session_idle_timeout:type_name -> google.protobuf.Duration + 31, // 6: management.PathTargetOptions.session_idle_timeout:type_name -> google.protobuf.Duration 6, // 7: management.PathMapping.options:type_name -> management.PathTargetOptions 8, // 8: management.Authentication.header_auths:type_name -> management.HeaderAuth 0, // 9: management.ProxyMapping.type:type_name -> management.ProxyMappingUpdateType @@ -2115,30 +2384,31 @@ var file_proxy_service_proto_depIdxs = []int32{ 9, // 11: management.ProxyMapping.auth:type_name -> management.Authentication 10, // 12: management.ProxyMapping.access_restrictions:type_name -> management.AccessRestrictions 14, // 13: management.SendAccessLogRequest.log:type_name -> management.AccessLog - 29, // 14: management.AccessLog.timestamp:type_name -> google.protobuf.Timestamp - 17, // 15: management.AuthenticateRequest.password:type_name -> management.PasswordRequest - 18, // 16: management.AuthenticateRequest.pin:type_name -> management.PinRequest - 16, // 17: management.AuthenticateRequest.header_auth:type_name -> management.HeaderAuthRequest - 2, // 18: management.SendStatusUpdateRequest.status:type_name -> management.ProxyStatus - 4, // 19: management.ProxyService.GetMappingUpdate:input_type -> management.GetMappingUpdateRequest - 12, // 20: management.ProxyService.SendAccessLog:input_type -> management.SendAccessLogRequest - 15, // 21: management.ProxyService.Authenticate:input_type -> management.AuthenticateRequest - 20, // 22: management.ProxyService.SendStatusUpdate:input_type -> management.SendStatusUpdateRequest - 22, // 23: management.ProxyService.CreateProxyPeer:input_type -> management.CreateProxyPeerRequest - 24, // 24: management.ProxyService.GetOIDCURL:input_type -> management.GetOIDCURLRequest - 26, // 25: management.ProxyService.ValidateSession:input_type -> management.ValidateSessionRequest - 5, // 26: management.ProxyService.GetMappingUpdate:output_type -> management.GetMappingUpdateResponse - 13, // 27: management.ProxyService.SendAccessLog:output_type -> management.SendAccessLogResponse - 19, // 28: management.ProxyService.Authenticate:output_type -> management.AuthenticateResponse - 21, // 29: management.ProxyService.SendStatusUpdate:output_type -> management.SendStatusUpdateResponse - 23, // 30: management.ProxyService.CreateProxyPeer:output_type -> management.CreateProxyPeerResponse - 25, // 31: management.ProxyService.GetOIDCURL:output_type -> management.GetOIDCURLResponse - 27, // 32: management.ProxyService.ValidateSession:output_type -> management.ValidateSessionResponse - 26, // [26:33] is the sub-list for method output_type - 19, // [19:26] is the sub-list for method input_type - 19, // [19:19] is the sub-list for extension type_name - 19, // [19:19] is the sub-list for extension extendee - 0, // [0:19] is the sub-list for field type_name + 30, // 14: management.AccessLog.timestamp:type_name -> google.protobuf.Timestamp + 29, // 15: management.AccessLog.metadata:type_name -> management.AccessLog.MetadataEntry + 17, // 16: management.AuthenticateRequest.password:type_name -> management.PasswordRequest + 18, // 17: management.AuthenticateRequest.pin:type_name -> management.PinRequest + 16, // 18: management.AuthenticateRequest.header_auth:type_name -> management.HeaderAuthRequest + 2, // 19: management.SendStatusUpdateRequest.status:type_name -> management.ProxyStatus + 4, // 20: management.ProxyService.GetMappingUpdate:input_type -> management.GetMappingUpdateRequest + 12, // 21: management.ProxyService.SendAccessLog:input_type -> management.SendAccessLogRequest + 15, // 22: management.ProxyService.Authenticate:input_type -> management.AuthenticateRequest + 20, // 23: management.ProxyService.SendStatusUpdate:input_type -> management.SendStatusUpdateRequest + 22, // 24: management.ProxyService.CreateProxyPeer:input_type -> management.CreateProxyPeerRequest + 24, // 25: management.ProxyService.GetOIDCURL:input_type -> management.GetOIDCURLRequest + 26, // 26: management.ProxyService.ValidateSession:input_type -> management.ValidateSessionRequest + 5, // 27: management.ProxyService.GetMappingUpdate:output_type -> management.GetMappingUpdateResponse + 13, // 28: management.ProxyService.SendAccessLog:output_type -> management.SendAccessLogResponse + 19, // 29: management.ProxyService.Authenticate:output_type -> management.AuthenticateResponse + 21, // 30: management.ProxyService.SendStatusUpdate:output_type -> management.SendStatusUpdateResponse + 23, // 31: management.ProxyService.CreateProxyPeer:output_type -> management.CreateProxyPeerResponse + 25, // 32: management.ProxyService.GetOIDCURL:output_type -> management.GetOIDCURLResponse + 27, // 33: management.ProxyService.ValidateSession:output_type -> management.ValidateSessionResponse + 27, // [27:34] is the sub-list for method output_type + 20, // [20:27] is the sub-list for method input_type + 20, // [20:20] is the sub-list for extension type_name + 20, // [20:20] is the sub-list for extension extendee + 0, // [0:20] is the sub-list for field type_name } func init() { file_proxy_service_proto_init() } @@ -2146,21 +2416,323 @@ func file_proxy_service_proto_init() { if File_proxy_service_proto != nil { return } - file_proxy_service_proto_msgTypes[0].OneofWrappers = []any{} - file_proxy_service_proto_msgTypes[12].OneofWrappers = []any{ + if !protoimpl.UnsafeEnabled { + file_proxy_service_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ProxyCapabilities); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetMappingUpdateRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetMappingUpdateResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PathTargetOptions); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PathMapping); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*HeaderAuth); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Authentication); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*AccessRestrictions); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ProxyMapping); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SendAccessLogRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SendAccessLogResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*AccessLog); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[12].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*AuthenticateRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[13].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*HeaderAuthRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[14].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PasswordRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[15].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PinRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[16].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*AuthenticateResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[17].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SendStatusUpdateRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[18].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SendStatusUpdateResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CreateProxyPeerRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CreateProxyPeerResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetOIDCURLRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetOIDCURLResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ValidateSessionRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ValidateSessionResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + file_proxy_service_proto_msgTypes[0].OneofWrappers = []interface{}{} + file_proxy_service_proto_msgTypes[12].OneofWrappers = []interface{}{ (*AuthenticateRequest_Password)(nil), (*AuthenticateRequest_Pin)(nil), (*AuthenticateRequest_HeaderAuth)(nil), } - file_proxy_service_proto_msgTypes[17].OneofWrappers = []any{} - file_proxy_service_proto_msgTypes[20].OneofWrappers = []any{} + file_proxy_service_proto_msgTypes[17].OneofWrappers = []interface{}{} + file_proxy_service_proto_msgTypes[20].OneofWrappers = []interface{}{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: unsafe.Slice(unsafe.StringData(file_proxy_service_proto_rawDesc), len(file_proxy_service_proto_rawDesc)), + RawDescriptor: file_proxy_service_proto_rawDesc, NumEnums: 3, - NumMessages: 26, + NumMessages: 27, NumExtensions: 0, NumServices: 1, }, @@ -2170,6 +2742,7 @@ func file_proxy_service_proto_init() { MessageInfos: file_proxy_service_proto_msgTypes, }.Build() File_proxy_service_proto = out.File + file_proxy_service_proto_rawDesc = nil file_proxy_service_proto_goTypes = nil file_proxy_service_proto_depIdxs = nil } diff --git a/shared/management/proto/proxy_service.proto b/shared/management/proto/proxy_service.proto index f77071eb0..e359f0cbd 100644 --- a/shared/management/proto/proxy_service.proto +++ b/shared/management/proto/proxy_service.proto @@ -34,6 +34,8 @@ message ProxyCapabilities { // Whether the proxy requires a subdomain label in front of its cluster domain. // When true, accounts cannot use the cluster domain bare. optional bool require_subdomain = 2; + // Whether the proxy has CrowdSec configured and can enforce IP reputation checks. + optional bool supports_crowdsec = 3; } // GetMappingUpdateRequest is sent to initialise a mapping stream. @@ -104,6 +106,8 @@ message AccessRestrictions { repeated string blocked_cidrs = 2; repeated string allowed_countries = 3; repeated string blocked_countries = 4; + // CrowdSec IP reputation mode: "", "off", "enforce", or "observe". + string crowdsec_mode = 5; } message ProxyMapping { @@ -152,6 +156,8 @@ message AccessLog { int64 bytes_upload = 14; int64 bytes_download = 15; string protocol = 16; + // Extra key-value metadata for the access log entry (e.g. crowdsec_verdict, scenario). + map metadata = 17; } message AuthenticateRequest { diff --git a/shared/relay/client/client.go b/shared/relay/client/client.go index ed1b63435..1800bddb2 100644 --- a/shared/relay/client/client.go +++ b/shared/relay/client/client.go @@ -2,8 +2,12 @@ package client import ( "context" + "errors" "fmt" "net" + "net/netip" + "net/url" + "strings" "sync" "time" @@ -146,6 +150,7 @@ func (cc *connContainer) close() { type Client struct { log *log.Entry connectionURL string + serverIP netip.Addr authTokenStore *auth.TokenStore hashedID messages.PeerID @@ -170,13 +175,22 @@ type Client struct { } // NewClient creates a new client for the relay server. The client is not connected to the server until the Connect +// is called. func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string, mtu uint16) *Client { + return NewClientWithServerIP(serverURL, netip.Addr{}, authTokenStore, peerID, mtu) +} + +// NewClientWithServerIP creates a new client for the relay server with a known server IP. serverIP, when valid, is +// dialed directly first; the FQDN is only attempted if the IP-based dial fails. TLS verification still uses the +// FQDN from serverURL via SNI. +func NewClientWithServerIP(serverURL string, serverIP netip.Addr, authTokenStore *auth.TokenStore, peerID string, mtu uint16) *Client { hashedID := messages.HashID(peerID) relayLog := log.WithFields(log.Fields{"relay": serverURL}) c := &Client{ log: relayLog, connectionURL: serverURL, + serverIP: serverIP, authTokenStore: authTokenStore, hashedID: hashedID, mtu: mtu, @@ -304,6 +318,23 @@ func (c *Client) ServerInstanceURL() (string, error) { return c.instanceURL.String(), nil } +// ConnectedIP returns the IP address of the live relay-server connection, +// extracted from the underlying socket's RemoteAddr. Zero value if not +// connected or if the address is not an IP literal. +func (c *Client) ConnectedIP() netip.Addr { + c.mu.Lock() + conn := c.relayConn + c.mu.Unlock() + if conn == nil { + return netip.Addr{} + } + addr := conn.RemoteAddr() + if addr == nil { + return netip.Addr{} + } + return extractIPLiteral(addr.String()) +} + // SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed. func (c *Client) SetOnDisconnectListener(fn func(string)) { c.listenerMutex.Lock() @@ -332,10 +363,23 @@ func (c *Client) Close() error { func (c *Client) connect(ctx context.Context) (*RelayAddr, error) { dialers := c.getDialers() - rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...) - conn, err := rd.Dial() - if err != nil { - return nil, err + var conn net.Conn + if c.serverIP.IsValid() { + var err error + conn, err = c.dialRaceDirect(ctx, dialers) + if err != nil { + c.log.Infof("dial via server IP %s failed, falling back to FQDN: %v", c.serverIP, err) + conn = nil + } + } + + if conn == nil { + rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...) + var err error + conn, err = rd.Dial(ctx) + if err != nil { + return nil, fmt.Errorf("dial via FQDN: %w", err) + } } c.relayConn = conn @@ -351,6 +395,52 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) { return instanceURL, nil } +// dialRaceDirect dials c.serverIP, preserving the original FQDN as the TLS ServerName for SNI. +func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (net.Conn, error) { + directURL, serverName, err := substituteHost(c.connectionURL, c.serverIP) + if err != nil { + return nil, fmt.Errorf("substitute host: %w", err) + } + + c.log.Debugf("dialing via server IP %s (SNI=%s)", c.serverIP, serverName) + + rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, directURL, dialers...). + WithServerName(serverName) + return rd.Dial(ctx) +} + +// substituteHost replaces the host portion of a rel/rels URL with ip, +// preserving the scheme and port. Returns the rewritten URL and the +// original host to use as the TLS ServerName, or empty if the original +// host is itself an IP literal (SNI requires a DNS name). +func substituteHost(serverURL string, ip netip.Addr) (string, string, error) { + u, err := url.Parse(serverURL) + if err != nil { + return "", "", fmt.Errorf("parse %q: %w", serverURL, err) + } + if u.Scheme == "" || u.Host == "" { + return "", "", fmt.Errorf("invalid relay URL %q", serverURL) + } + if !ip.IsValid() { + return "", "", errors.New("invalid server IP") + } + origHost := u.Hostname() + if _, err := netip.ParseAddr(origHost); err == nil { + origHost = "" + } + ip = ip.Unmap() + newHost := ip.String() + if ip.Is6() { + newHost = "[" + newHost + "]" + } + if port := u.Port(); port != "" { + u.Host = newHost + ":" + port + } else { + u.Host = newHost + } + return u.String(), origHost, nil +} + func (c *Client) handShake(ctx context.Context) (*RelayAddr, error) { msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary()) if err != nil { @@ -716,3 +806,21 @@ func (c *Client) handlePeersWentOfflineMsg(buf []byte) { } c.stateSubscription.OnPeersWentOffline(peersID) } + +// extractIPLiteral returns the IP from address forms produced by the relay +// dialers (URL or host:port). Zero value if the host is not an IP. +func extractIPLiteral(s string) netip.Addr { + if u, err := url.Parse(s); err == nil && u.Host != "" { + s = u.Host + } + host, _, err := net.SplitHostPort(s) + if err != nil { + host = s + } + host = strings.Trim(host, "[]") + ip, err := netip.ParseAddr(host) + if err != nil { + return netip.Addr{} + } + return ip.Unmap() +} diff --git a/shared/relay/client/client_serverip_test.go b/shared/relay/client/client_serverip_test.go new file mode 100644 index 000000000..7e699e37d --- /dev/null +++ b/shared/relay/client/client_serverip_test.go @@ -0,0 +1,280 @@ +package client + +import ( + "context" + "fmt" + "net" + "net/netip" + "testing" + "time" + + "go.opentelemetry.io/otel" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/relay/server" + "github.com/netbirdio/netbird/shared/relay/auth/allow" +) + +// TestClient_ServerIPRecoversFromUnresolvableFQDN verifies that when the +// primary FQDN-based dial fails (unresolvable .invalid host), Connect +// recovers via the server IP and SNI still uses the FQDN. +func TestClient_ServerIPRecoversFromUnresolvableFQDN(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + listenAddr, port := freeAddr(t) + srvCfg := server.Config{ + Meter: otel.Meter(""), + ExposedAddress: fmt.Sprintf("rel://test-unresolvable-host.invalid:%d", port), + TLSSupport: false, + AuthValidator: &allow.Auth{}, + } + srv, err := server.NewServer(srvCfg) + if err != nil { + t.Fatalf("create server: %s", err) + } + + errChan := make(chan error, 1) + go func() { + if err := srv.Listen(server.ListenerConfig{Address: listenAddr}); err != nil { + errChan <- err + } + }() + t.Cleanup(func() { + if err := srv.Shutdown(context.Background()); err != nil { + t.Errorf("shutdown server: %s", err) + } + }) + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("server failed to start: %s", err) + } + + t.Run("no server IP, primary fails", func(t *testing.T) { + c := NewClient(srvCfg.ExposedAddress, hmacTokenStore, "alice-noip", iface.DefaultMTU) + err := c.Connect(ctx) + if err == nil { + _ = c.Close() + t.Fatalf("expected connect to fail without server IP, got nil") + } + }) + + t.Run("server IP recovers", func(t *testing.T) { + c := NewClientWithServerIP(srvCfg.ExposedAddress, netip.MustParseAddr("127.0.0.1"), hmacTokenStore, "alice-with-ip", iface.DefaultMTU) + if err := c.Connect(ctx); err != nil { + t.Fatalf("connect with server IP: %s", err) + } + t.Cleanup(func() { _ = c.Close() }) + + if !c.Ready() { + t.Fatalf("client not ready after connect") + } + if got := c.ConnectedIP(); got.String() != "127.0.0.1" { + t.Fatalf("ConnectedIP = %q, want 127.0.0.1", got) + } + }) +} + +// TestClient_ConnectedIPAfterFQDNDial verifies ConnectedIP returns the +// resolved IP after a successful FQDN-based dial. The underlying socket's +// RemoteAddr must be exposed through the dialer wrappers; if it returns +// the dial-time URL instead, ConnectedIP returns empty and the dial +// IP we advertise to peers is empty too. +func TestClient_ConnectedIPAfterFQDNDial(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + listenAddr, port := freeAddr(t) + srvCfg := server.Config{ + Meter: otel.Meter(""), + ExposedAddress: fmt.Sprintf("rel://localhost:%d", port), + TLSSupport: false, + AuthValidator: &allow.Auth{}, + } + srv, err := server.NewServer(srvCfg) + if err != nil { + t.Fatalf("create server: %s", err) + } + errChan := make(chan error, 1) + go func() { + if err := srv.Listen(server.ListenerConfig{Address: listenAddr}); err != nil { + errChan <- err + } + }() + t.Cleanup(func() { _ = srv.Shutdown(context.Background()) }) + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("server failed to start: %s", err) + } + + c := NewClient(srvCfg.ExposedAddress, hmacTokenStore, "alice-fqdn", iface.DefaultMTU) + if err := c.Connect(ctx); err != nil { + t.Fatalf("connect: %s", err) + } + t.Cleanup(func() { _ = c.Close() }) + + got := c.ConnectedIP().String() + if got != "127.0.0.1" && got != "::1" { + t.Fatalf("ConnectedIP after FQDN dial = %q, want 127.0.0.1 or ::1", got) + } +} + +func TestSubstituteHost(t *testing.T) { + tests := []struct { + name string + serverURL string + ip string + wantURL string + wantServerName string + wantErr bool + }{ + { + name: "rels with port", + serverURL: "rels://relay.netbird.io:443", + ip: "10.0.0.5", + wantURL: "rels://10.0.0.5:443", + wantServerName: "relay.netbird.io", + }, + { + name: "rel with port", + serverURL: "rel://relay.example.com:80", + ip: "192.0.2.1", + wantURL: "rel://192.0.2.1:80", + wantServerName: "relay.example.com", + }, + { + name: "ipv6 server IP bracketed", + serverURL: "rels://relay.example.com:443", + ip: "2001:db8::1", + wantURL: "rels://[2001:db8::1]:443", + wantServerName: "relay.example.com", + }, + { + name: "no port", + serverURL: "rels://relay.example.com", + ip: "10.0.0.5", + wantURL: "rels://10.0.0.5", + wantServerName: "relay.example.com", + }, + { + name: "ipv6 server with port returns empty SNI", + serverURL: "rels://[2001:db8::5]:443", + ip: "10.0.0.5", + wantURL: "rels://10.0.0.5:443", + wantServerName: "", + }, + { + name: "ipv4 server with port returns empty SNI", + serverURL: "rels://10.0.0.5:443", + ip: "10.0.0.6", + wantURL: "rels://10.0.0.6:443", + wantServerName: "", + }, + { + name: "ipv6 server IP no port", + serverURL: "rels://relay.example.com", + ip: "2001:db8::1", + wantURL: "rels://[2001:db8::1]", + wantServerName: "relay.example.com", + }, + { + name: "missing scheme", + serverURL: "relay.example.com:443", + ip: "10.0.0.5", + wantErr: true, + }, + { + name: "empty", + serverURL: "", + ip: "10.0.0.5", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var ip netip.Addr + if tt.ip != "" { + ip = netip.MustParseAddr(tt.ip) + } + gotURL, gotName, err := substituteHost(tt.serverURL, ip) + if tt.wantErr { + if err == nil { + t.Fatalf("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if gotURL != tt.wantURL { + t.Errorf("URL = %q, want %q", gotURL, tt.wantURL) + } + if gotName != tt.wantServerName { + t.Errorf("ServerName = %q, want %q", gotName, tt.wantServerName) + } + }) + } +} + +func TestClient_ConnectedIPEmptyWhenNotConnected(t *testing.T) { + c := NewClient("rel://example.invalid:80", hmacTokenStore, "x", iface.DefaultMTU) + if got := c.ConnectedIP(); got.IsValid() { + t.Fatalf("ConnectedIP on disconnected client = %q, want zero", got) + } +} + +// staticAddr is a net.Addr that returns a fixed string. Used to verify +// ConnectedIP parses RemoteAddr correctly. +type staticAddr struct{ s string } + +func (a staticAddr) Network() string { return "tcp" } +func (a staticAddr) String() string { return a.s } + +type stubConn struct { + net.Conn + remote net.Addr +} + +func (s stubConn) RemoteAddr() net.Addr { return s.remote } + +func TestClient_ConnectedIPParsesRemoteAddr(t *testing.T) { + tests := []struct { + name string + s string + want string + }{ + {"hostport ipv4", "127.0.0.1:50301", "127.0.0.1"}, + {"hostport ipv6 bracketed", "[::1]:50301", "::1"}, + {"url with ipv4", "rel://127.0.0.1:50301", "127.0.0.1"}, + {"url with ipv6", "rels://[2001:db8::1]:443", "2001:db8::1"}, + {"fqdn url returns empty", "rel://relay.example.com:50301", ""}, + {"fqdn hostport returns empty", "relay.example.com:50301", ""}, + {"plain ipv4 no port", "10.0.0.1", "10.0.0.1"}, + {"empty", "", ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Client{relayConn: stubConn{remote: staticAddr{s: tt.s}}} + got := c.ConnectedIP() + var gotStr string + if got.IsValid() { + gotStr = got.String() + } + if gotStr != tt.want { + t.Errorf("ConnectedIP(%q) = %q, want %q", tt.s, gotStr, tt.want) + } + }) + } +} + +// freeAddr returns a 127.0.0.1 address with an OS-assigned port. The +// listener is closed before returning, so the port is briefly free for +// the caller to bind. Avoids hardcoded ports that can collide. +func freeAddr(t *testing.T) (string, int) { + t.Helper() + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("get free port: %s", err) + } + addr := l.Addr().(*net.TCPAddr) + _ = l.Close() + return addr.String(), addr.Port +} diff --git a/shared/relay/client/dialer/quic/quic.go b/shared/relay/client/dialer/quic/quic.go index 78462837d..602803b19 100644 --- a/shared/relay/client/dialer/quic/quic.go +++ b/shared/relay/client/dialer/quic/quic.go @@ -23,7 +23,7 @@ func (d Dialer) Protocol() string { return Network } -func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { +func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) { quicURL, err := prepareURL(address) if err != nil { return nil, err @@ -32,11 +32,14 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { // Get the base TLS config tlsClientConfig := quictls.ClientQUICTLSConfig() - // Set ServerName to hostname if not an IP address - host, _, splitErr := net.SplitHostPort(quicURL) - if splitErr == nil && net.ParseIP(host) == nil { - // It's a hostname, not an IP - modify directly - tlsClientConfig.ServerName = host + switch { + case serverName != "" && net.ParseIP(serverName) == nil: + tlsClientConfig.ServerName = serverName + default: + host, _, splitErr := net.SplitHostPort(quicURL) + if splitErr == nil && net.ParseIP(host) == nil { + tlsClientConfig.ServerName = host + } } quicConfig := &quic.Config{ @@ -89,12 +92,12 @@ func prepareURL(address string) (string, error) { finalHost, finalPort, err := net.SplitHostPort(host) if err != nil { if strings.Contains(err.Error(), "missing port") { - return host + ":" + defaultPort, nil + return net.JoinHostPort(strings.Trim(host, "[]"), defaultPort), nil } // return any other split error as is return "", err } - return finalHost + ":" + finalPort, nil + return net.JoinHostPort(finalHost, finalPort), nil } diff --git a/shared/relay/client/dialer/race_dialer.go b/shared/relay/client/dialer/race_dialer.go index 0550fc63e..15208b858 100644 --- a/shared/relay/client/dialer/race_dialer.go +++ b/shared/relay/client/dialer/race_dialer.go @@ -14,7 +14,9 @@ const ( ) type DialeFn interface { - Dial(ctx context.Context, address string) (net.Conn, error) + // Dial connects to address. serverName, when non-empty, overrides the TLS + // ServerName used for SNI/cert validation. Empty means derive from address. + Dial(ctx context.Context, address, serverName string) (net.Conn, error) Protocol() string } @@ -27,6 +29,7 @@ type dialResult struct { type RaceDial struct { log *log.Entry serverURL string + serverName string dialerFns []DialeFn connectionTimeout time.Duration } @@ -40,10 +43,20 @@ func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL stri } } -func (r *RaceDial) Dial() (net.Conn, error) { +// WithServerName sets a TLS SNI/cert validation override. Used when serverURL +// contains an IP literal but the cert is issued for a different hostname. +// +// Mutates the receiver and is not safe for concurrent reconfiguration; a +// RaceDial is intended to be constructed per dial and discarded. +func (r *RaceDial) WithServerName(serverName string) *RaceDial { + r.serverName = serverName + return r +} + +func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) { connChan := make(chan dialResult, len(r.dialerFns)) winnerConn := make(chan net.Conn, 1) - abortCtx, abort := context.WithCancel(context.Background()) + abortCtx, abort := context.WithCancel(ctx) defer abort() for _, dfn := range r.dialerFns { @@ -64,7 +77,7 @@ func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dia defer cancel() r.log.Infof("dialing Relay server via %s", dfn.Protocol()) - conn, err := dfn.Dial(ctx, r.serverURL) + conn, err := dfn.Dial(ctx, r.serverURL, r.serverName) connChan <- dialResult{Conn: conn, Protocol: dfn.Protocol(), Err: err} } diff --git a/shared/relay/client/dialer/race_dialer_test.go b/shared/relay/client/dialer/race_dialer_test.go index d216ec5e7..a53edc00e 100644 --- a/shared/relay/client/dialer/race_dialer_test.go +++ b/shared/relay/client/dialer/race_dialer_test.go @@ -28,7 +28,7 @@ type MockDialer struct { protocolStr string } -func (m *MockDialer) Dial(ctx context.Context, address string) (net.Conn, error) { +func (m *MockDialer) Dial(ctx context.Context, address, _ string) (net.Conn, error) { return m.dialFunc(ctx, address) } @@ -78,7 +78,7 @@ func TestRaceDialEmptyDialers(t *testing.T) { serverURL := "test.server.com" rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err == nil { t.Errorf("Expected an error with empty dialers, got nil") } @@ -104,7 +104,7 @@ func TestRaceDialSingleSuccessfulDialer(t *testing.T) { } rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -137,7 +137,7 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) { } rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -160,7 +160,7 @@ func TestRaceDialTimeout(t *testing.T) { } rd := NewRaceDial(logger, 3*time.Second, serverURL, mockDialer) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err == nil { t.Errorf("Expected an error, got nil") } @@ -188,7 +188,7 @@ func TestRaceDialAllDialersFail(t *testing.T) { } rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err == nil { t.Errorf("Expected an error, got nil") } @@ -230,7 +230,7 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) { } rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err != nil { t.Errorf("Expected no error, got %v", err) } diff --git a/shared/relay/client/dialer/ws/conn.go b/shared/relay/client/dialer/ws/conn.go index d5b719f51..9497fab89 100644 --- a/shared/relay/client/dialer/ws/conn.go +++ b/shared/relay/client/dialer/ws/conn.go @@ -12,14 +12,24 @@ import ( type Conn struct { ctx context.Context *websocket.Conn - remoteAddr WebsocketAddr + remoteAddr net.Addr } -func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn { +// NewConn builds a relay ws.Conn. underlying is the raw TCP/TLS conn captured +// from the http transport's DialContext; when set, RemoteAddr returns its +// peer address (an IP literal). When nil (e.g. wasm), RemoteAddr falls back +// to the dial-time URL. +func NewConn(wsConn *websocket.Conn, serverAddress string, underlying net.Conn) net.Conn { + var addr net.Addr = WebsocketAddr{serverAddress} + if underlying != nil { + if ra := underlying.RemoteAddr(); ra != nil { + addr = ra + } + } return &Conn{ ctx: context.Background(), Conn: wsConn, - remoteAddr: WebsocketAddr{serverAddress}, + remoteAddr: addr, } } diff --git a/shared/relay/client/dialer/ws/dialopts_generic.go b/shared/relay/client/dialer/ws/dialopts_generic.go index 9dfe698d0..8008d89d3 100644 --- a/shared/relay/client/dialer/ws/dialopts_generic.go +++ b/shared/relay/client/dialer/ws/dialopts_generic.go @@ -2,10 +2,14 @@ package ws -import "github.com/coder/websocket" +import ( + "net" -func createDialOptions() *websocket.DialOptions { + "github.com/coder/websocket" +) + +func createDialOptions(serverName string, underlyingOut *net.Conn) *websocket.DialOptions { return &websocket.DialOptions{ - HTTPClient: httpClientNbDialer(), + HTTPClient: httpClientNbDialer(serverName, underlyingOut), } } diff --git a/shared/relay/client/dialer/ws/dialopts_js.go b/shared/relay/client/dialer/ws/dialopts_js.go index 7eac27531..5b11fe765 100644 --- a/shared/relay/client/dialer/ws/dialopts_js.go +++ b/shared/relay/client/dialer/ws/dialopts_js.go @@ -2,9 +2,13 @@ package ws -import "github.com/coder/websocket" +import ( + "net" -func createDialOptions() *websocket.DialOptions { - // WASM version doesn't support HTTPClient + "github.com/coder/websocket" +) + +func createDialOptions(_ string, _ *net.Conn) *websocket.DialOptions { + // WASM version doesn't support HTTPClient or custom TLS config. return &websocket.DialOptions{} } diff --git a/shared/relay/client/dialer/ws/ws.go b/shared/relay/client/dialer/ws/ws.go index 37b189e05..301486514 100644 --- a/shared/relay/client/dialer/ws/ws.go +++ b/shared/relay/client/dialer/ws/ws.go @@ -26,13 +26,14 @@ func (d Dialer) Protocol() string { return "WS" } -func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { +func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) { wsURL, err := prepareURL(address) if err != nil { return nil, err } - opts := createDialOptions() + var underlying net.Conn + opts := createDialOptions(serverName, &underlying) parsedURL, err := url.Parse(wsURL) if err != nil { @@ -52,7 +53,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { _ = resp.Body.Close() } - conn := NewConn(wsConn, address) + conn := NewConn(wsConn, address, underlying) return conn, nil } @@ -64,7 +65,10 @@ func prepareURL(address string) (string, error) { return strings.Replace(address, "rel", "ws", 1), nil } -func httpClientNbDialer() *http.Client { +// httpClientNbDialer builds the http client used by the websocket library. +// underlyingOut, when non-nil, is populated with the raw conn from the +// transport's DialContext so the caller can read its RemoteAddr. +func httpClientNbDialer(serverName string, underlyingOut *net.Conn) *http.Client { customDialer := nbnet.NewDialer() certPool, err := x509.SystemCertPool() @@ -75,10 +79,15 @@ func httpClientNbDialer() *http.Client { customTransport := &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return customDialer.DialContext(ctx, network, addr) + c, err := customDialer.DialContext(ctx, network, addr) + if err == nil && underlyingOut != nil { + *underlyingOut = c + } + return c, err }, TLSClientConfig: &tls.Config{ - RootCAs: certPool, + RootCAs: certPool, + ServerName: serverName, }, } diff --git a/shared/relay/client/guard.go b/shared/relay/client/guard.go index f4d3a8cce..d7892d0ce 100644 --- a/shared/relay/client/guard.go +++ b/shared/relay/client/guard.go @@ -8,10 +8,7 @@ import ( log "github.com/sirupsen/logrus" ) -const ( - // TODO: make it configurable, the manager should validate all configurable parameters - reconnectingTimeout = 60 * time.Second -) +const defaultMaxBackoffInterval = 60 * time.Second // Guard manage the reconnection tries to the Relay server in case of disconnection event. type Guard struct { @@ -19,14 +16,23 @@ type Guard struct { OnNewRelayClient chan *Client OnReconnected chan struct{} serverPicker *ServerPicker + + // maxBackoffInterval caps the exponential backoff between reconnect + // attempts. + maxBackoffInterval time.Duration } -// NewGuard creates a new guard for the relay client. -func NewGuard(sp *ServerPicker) *Guard { +// NewGuard creates a new guard for the relay client. A non-positive +// maxBackoffInterval falls back to defaultMaxBackoffInterval. +func NewGuard(sp *ServerPicker, maxBackoffInterval time.Duration) *Guard { + if maxBackoffInterval <= 0 { + maxBackoffInterval = defaultMaxBackoffInterval + } g := &Guard{ - OnNewRelayClient: make(chan *Client, 1), - OnReconnected: make(chan struct{}, 1), - serverPicker: sp, + OnNewRelayClient: make(chan *Client, 1), + OnReconnected: make(chan struct{}, 1), + serverPicker: sp, + maxBackoffInterval: maxBackoffInterval, } return g } @@ -49,7 +55,7 @@ func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) { } // start a ticker to pick a new server - ticker := exponentTicker(ctx) + ticker := g.exponentTicker(ctx) defer ticker.Stop() for { @@ -125,11 +131,11 @@ func (g *Guard) notifyReconnected() { } } -func exponentTicker(ctx context.Context) *backoff.Ticker { +func (g *Guard) exponentTicker(ctx context.Context) *backoff.Ticker { bo := backoff.WithContext(&backoff.ExponentialBackOff{ InitialInterval: 2 * time.Second, Multiplier: 2, - MaxInterval: reconnectingTimeout, + MaxInterval: g.maxBackoffInterval, Clock: backoff.SystemClock, }, ctx) diff --git a/shared/relay/client/manager.go b/shared/relay/client/manager.go index 6220e7f6b..3858b3c83 100644 --- a/shared/relay/client/manager.go +++ b/shared/relay/client/manager.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "net" + "net/netip" "reflect" "sync" "time" @@ -39,6 +40,15 @@ func NewRelayTrack() *RelayTrack { type OnServerCloseListener func() +// ManagerOption configures a Manager at construction time. +type ManagerOption func(*Manager) + +// WithMaxBackoffInterval caps the exponential backoff between reconnect +// attempts to the home relay. A non-positive value keeps the default. +func WithMaxBackoffInterval(d time.Duration) ManagerOption { + return func(m *Manager) { m.maxBackoffInterval = d } +} + // Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL // and automatically reconnect to them in case disconnection. // The manager also manage temporary relay connection. If a client wants to communicate with a client on a @@ -64,12 +74,16 @@ type Manager struct { onReconnectedListenerFn func() listenerLock sync.Mutex - mtu uint16 + mtu uint16 + maxBackoffInterval time.Duration + + cleanupInterval time.Duration + keepUnusedServerTime time.Duration } // NewManager creates a new manager instance. // The serverURL address can be empty. In this case, the manager will not serve. -func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16) *Manager { +func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16, opts ...ManagerOption) *Manager { tokenStore := &relayAuth.TokenStore{} m := &Manager{ @@ -85,9 +99,14 @@ func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uin }, relayClients: make(map[string]*RelayTrack), onDisconnectedListeners: make(map[string]*list.List), + cleanupInterval: relayCleanupInterval, + keepUnusedServerTime: keepUnusedServerTime, + } + for _, opt := range opts { + opt(m) } m.serverPicker.ServerURLs.Store(serverURLs) - m.reconnectGuard = NewGuard(m.serverPicker) + m.reconnectGuard = NewGuard(m.serverPicker, m.maxBackoffInterval) return m } @@ -117,7 +136,10 @@ func (m *Manager) Serve() error { // OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be // established via the relay server. If the peer is on a different relay server, the manager will establish a new // connection to the relay server. It returns back with a net.Conn what represent the remote peer connection. -func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) { +// +// serverIP, when valid and serverAddress is foreign, is used as a dial target if the FQDN-based dial fails. +// Ignored for the local home-server path. TLS verification still uses the FQDN via SNI. +func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string, serverIP netip.Addr) (net.Conn, error) { m.relayClientMu.RLock() defer m.relayClientMu.RUnlock() @@ -138,7 +160,7 @@ func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) ( netConn, err = m.relayClient.OpenConn(ctx, peerKey) } else { log.Debugf("open peer connection via foreign server: %s", serverAddress) - netConn, err = m.openConnVia(ctx, serverAddress, peerKey) + netConn, err = m.openConnVia(ctx, serverAddress, peerKey, serverIP) } if err != nil { return nil, err @@ -190,16 +212,22 @@ func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServ return nil } -// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is -// lost. This address will be sent to the target peer to choose the common relay server for the communication. -func (m *Manager) RelayInstanceAddress() (string, error) { +// RelayInstanceAddress returns the address and resolved IP of the permanent relay server. It could change if the +// network connection is lost. The address is sent to the target peer to choose the common relay server for the +// communication; the IP is sent alongside so remote peers can dial directly without their own DNS lookup. Both +// values are read under the same lock so they cannot diverge across a reconnection. +func (m *Manager) RelayInstanceAddress() (string, netip.Addr, error) { m.relayClientMu.RLock() defer m.relayClientMu.RUnlock() if m.relayClient == nil { - return "", ErrRelayClientNotConnected + return "", netip.Addr{}, ErrRelayClientNotConnected } - return m.relayClient.ServerInstanceURL() + addr, err := m.relayClient.ServerInstanceURL() + if err != nil { + return "", netip.Addr{}, err + } + return addr, m.relayClient.ConnectedIP(), nil } // ServerURLs returns the addresses of the relay servers. @@ -223,7 +251,7 @@ func (m *Manager) UpdateToken(token *relayAuth.Token) error { return m.tokenStore.UpdateToken(token) } -func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) { +func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string, serverIP netip.Addr) (net.Conn, error) { // check if already has a connection to the desired relay server m.relayClientsMutex.RLock() rt, ok := m.relayClients[serverAddress] @@ -258,7 +286,7 @@ func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string m.relayClients[serverAddress] = rt m.relayClientsMutex.Unlock() - relayClient := NewClient(serverAddress, m.tokenStore, m.peerID, m.mtu) + relayClient := NewClientWithServerIP(serverAddress, serverIP, m.tokenStore, m.peerID, m.mtu) err := relayClient.Connect(m.ctx) if err != nil { rt.err = err @@ -290,19 +318,36 @@ func (m *Manager) onServerConnected() { go m.onReconnectedListenerFn() } -// onServerDisconnected start to reconnection for home server only +// onServerDisconnected handles relay disconnect events. For the home server it +// starts the reconnect guard. For foreign servers it evicts the now-dead client +// from the cache so the next OpenConn builds a fresh one instead of reusing a +// closed client. func (m *Manager) onServerDisconnected(serverAddress string) { m.relayClientMu.Lock() - if serverAddress == m.relayClient.connectionURL { + isHome := m.relayClient != nil && serverAddress == m.relayClient.connectionURL + if isHome { go func(client *Client) { m.reconnectGuard.StartReconnectTrys(m.ctx, client) }(m.relayClient) } m.relayClientMu.Unlock() + if !isHome { + m.evictForeignRelay(serverAddress) + } + m.notifyOnDisconnectListeners(serverAddress) } +func (m *Manager) evictForeignRelay(serverAddress string) { + m.relayClientsMutex.Lock() + defer m.relayClientsMutex.Unlock() + if _, ok := m.relayClients[serverAddress]; ok { + delete(m.relayClients, serverAddress) + log.Debugf("evicted disconnected foreign relay client: %s", serverAddress) + } +} + func (m *Manager) listenGuardEvent(ctx context.Context) { for { select { @@ -334,7 +379,7 @@ func (m *Manager) isForeignServer(address string) (bool, error) { } func (m *Manager) startCleanupLoop() { - ticker := time.NewTicker(relayCleanupInterval) + ticker := time.NewTicker(m.cleanupInterval) defer ticker.Stop() for { select { @@ -359,7 +404,7 @@ func (m *Manager) cleanUpUnusedRelays() { continue } - if time.Since(rt.created) <= keepUnusedServerTime { + if time.Since(rt.created) <= m.keepUnusedServerTime { rt.Unlock() continue } diff --git a/shared/relay/client/manager_serverip_test.go b/shared/relay/client/manager_serverip_test.go new file mode 100644 index 000000000..a354beade --- /dev/null +++ b/shared/relay/client/manager_serverip_test.go @@ -0,0 +1,144 @@ +package client + +import ( + "context" + "io" + "net/netip" + "testing" + "time" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/relay/server" +) + +// TestManager_ForeignRelayServerIP exercises the foreign-relay path +// end-to-end through Manager.OpenConn. Alice and Bob register on different +// relay servers; Alice dials Bob's foreign relay using an unresolvable +// FQDN. Without a server IP the dial fails; with Bob's advertised IP it +// recovers and a payload round-trips between the peers. +func TestManager_ForeignRelayServerIP(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + // Alice's home relay + homeCfg := server.ListenerConfig{Address: "127.0.0.1:52401"} + homeSrv, err := server.NewServer(newManagerTestServerConfig(homeCfg.Address)) + if err != nil { + t.Fatalf("create home server: %s", err) + } + homeErr := make(chan error, 1) + go func() { + if err := homeSrv.Listen(homeCfg); err != nil { + homeErr <- err + } + }() + t.Cleanup(func() { _ = homeSrv.Shutdown(context.Background()) }) + if err := waitForServerToStart(homeErr); err != nil { + t.Fatalf("home server: %s", err) + } + + // Bob's foreign relay + foreignCfg := server.ListenerConfig{Address: "127.0.0.1:52402"} + foreignSrv, err := server.NewServer(newManagerTestServerConfig(foreignCfg.Address)) + if err != nil { + t.Fatalf("create foreign server: %s", err) + } + foreignErr := make(chan error, 1) + go func() { + if err := foreignSrv.Listen(foreignCfg); err != nil { + foreignErr <- err + } + }() + t.Cleanup(func() { _ = foreignSrv.Shutdown(context.Background()) }) + if err := waitForServerToStart(foreignErr); err != nil { + t.Fatalf("foreign server: %s", err) + } + + mCtx, mCancel := context.WithCancel(ctx) + t.Cleanup(mCancel) + + mgrAlice := NewManager(mCtx, toURL(homeCfg), "alice", iface.DefaultMTU) + if err := mgrAlice.Serve(); err != nil { + t.Fatalf("alice manager serve: %s", err) + } + + mgrBob := NewManager(mCtx, toURL(foreignCfg), "bob", iface.DefaultMTU) + if err := mgrBob.Serve(); err != nil { + t.Fatalf("bob manager serve: %s", err) + } + + // Bob's real relay URL and the IP that would ride along in signal as relayServerIP. + bobRealAddr, bobAdvertisedIP, err := mgrBob.RelayInstanceAddress() + if err != nil { + t.Fatalf("bob relay address: %s", err) + } + if !bobAdvertisedIP.IsValid() { + t.Fatalf("expected valid RelayInstanceIP for bob, got zero") + } + + // .invalid is reserved (RFC 2606), so DNS resolution always fails. + const brokenFQDN = "rel://relay-bob-instance.invalid:52402" + if brokenFQDN == bobRealAddr { + t.Fatalf("broken FQDN must differ from bob's real address (%s)", bobRealAddr) + } + + t.Run("no server IP, dial fails", func(t *testing.T) { + dialCtx, dialCancel := context.WithTimeout(ctx, 5*time.Second) + defer dialCancel() + _, err := mgrAlice.OpenConn(dialCtx, brokenFQDN, "bob", netip.Addr{}) + if err == nil { + t.Fatalf("expected OpenConn to fail without server IP, got success") + } + }) + + t.Run("server IP recovers", func(t *testing.T) { + // Bob waits for Alice's incoming peer connection on his side. + bobSideCh := make(chan error, 1) + go func() { + conn, err := mgrBob.OpenConn(ctx, bobRealAddr, "alice", netip.Addr{}) + if err != nil { + bobSideCh <- err + return + } + buf := make([]byte, 1024) + n, err := conn.Read(buf) + if err != nil { + bobSideCh <- err + return + } + if _, err := conn.Write(buf[:n]); err != nil { + bobSideCh <- err + return + } + bobSideCh <- nil + }() + + aliceConn, err := mgrAlice.OpenConn(ctx, brokenFQDN, "bob", bobAdvertisedIP) + if err != nil { + t.Fatalf("alice OpenConn with server IP: %s", err) + } + t.Cleanup(func() { _ = aliceConn.Close() }) + + payload := []byte("alice-to-bob") + if _, err := aliceConn.Write(payload); err != nil { + t.Fatalf("alice write: %s", err) + } + + buf := make([]byte, len(payload)) + if _, err := io.ReadFull(aliceConn, buf); err != nil { + t.Fatalf("alice read echo: %s", err) + } + if string(buf) != string(payload) { + t.Fatalf("echo mismatch: got %q want %q", buf, payload) + } + + select { + case err := <-bobSideCh: + if err != nil { + t.Fatalf("bob side: %s", err) + } + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for bob side") + } + }) +} diff --git a/shared/relay/client/manager_test.go b/shared/relay/client/manager_test.go index fb91f7682..9e964f688 100644 --- a/shared/relay/client/manager_test.go +++ b/shared/relay/client/manager_test.go @@ -2,6 +2,8 @@ package client import ( "context" + "fmt" + "net/netip" "testing" "time" @@ -100,15 +102,15 @@ func TestForeignConn(t *testing.T) { if err := clientBob.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) } - bobsSrvAddr, err := clientBob.RelayInstanceAddress() + bobsSrvAddr, _, err := clientBob.RelayInstanceAddress() if err != nil { t.Fatalf("failed to get relay address: %s", err) } - connAliceToBob, err := clientAlice.OpenConn(ctx, bobsSrvAddr, "bob") + connAliceToBob, err := clientAlice.OpenConn(ctx, bobsSrvAddr, "bob", netip.Addr{}) if err != nil { t.Fatalf("failed to bind channel: %s", err) } - connBobToAlice, err := clientBob.OpenConn(ctx, bobsSrvAddr, "alice") + connBobToAlice, err := clientBob.OpenConn(ctx, bobsSrvAddr, "alice", netip.Addr{}) if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -208,7 +210,7 @@ func TestForeginConnClose(t *testing.T) { if err != nil { t.Fatalf("failed to serve manager: %s", err) } - conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "bob") + conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "bob", netip.Addr{}) if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -300,7 +302,7 @@ func TestForeignAutoClose(t *testing.T) { } t.Log("open connection to another peer") - if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer"); err == nil { + if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer", netip.Addr{}); err == nil { t.Fatalf("should have failed to open connection to another peer") } @@ -360,16 +362,17 @@ func TestAutoReconnect(t *testing.T) { t.Fatalf("failed to serve manager: %s", err) } - clientAlice := NewManager(mCtx, toURL(srvCfg), "alice", iface.DefaultMTU) + clientAlice := NewManager(mCtx, toURL(srvCfg), "alice", iface.DefaultMTU, + WithMaxBackoffInterval(2*time.Second)) err = clientAlice.Serve() if err != nil { t.Fatalf("failed to serve manager: %s", err) } - ra, err := clientAlice.RelayInstanceAddress() + ra, _, err := clientAlice.RelayInstanceAddress() if err != nil { t.Errorf("failed to get relay address: %s", err) } - conn, err := clientAlice.OpenConn(ctx, ra, "bob") + conn, err := clientAlice.OpenConn(ctx, ra, "bob", netip.Addr{}) if err != nil { t.Errorf("failed to bind channel: %s", err) } @@ -384,15 +387,32 @@ func TestAutoReconnect(t *testing.T) { } log.Infof("waiting for reconnection") - time.Sleep(reconnectingTimeout + 1*time.Second) + if err := waitForReady(ctx, clientAlice, 15*time.Second); err != nil { + t.Fatalf("manager did not reconnect: %s", err) + } log.Infof("reopent the connection") - _, err = clientAlice.OpenConn(ctx, ra, "bob") + _, err = clientAlice.OpenConn(ctx, ra, "bob", netip.Addr{}) if err != nil { t.Errorf("failed to open channel: %s", err) } } +func waitForReady(ctx context.Context, m *Manager, timeout time.Duration) error { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if m.Ready() { + return nil + } + select { + case <-time.After(100 * time.Millisecond): + case <-ctx.Done(): + return ctx.Err() + } + } + return fmt.Errorf("manager not ready within %s", timeout) +} + func TestNotifierDoubleAdd(t *testing.T) { ctx := context.Background() @@ -434,7 +454,7 @@ func TestNotifierDoubleAdd(t *testing.T) { t.Fatalf("failed to serve manager: %s", err) } - conn1, err := clientAlice.OpenConn(ctx, clientAlice.ServerURLs()[0], "bob") + conn1, err := clientAlice.OpenConn(ctx, clientAlice.ServerURLs()[0], "bob", netip.Addr{}) if err != nil { t.Fatalf("failed to bind channel: %s", err) } diff --git a/shared/signal/client/client.go b/shared/signal/client/client.go index 5347c80e9..9dc6ccd37 100644 --- a/shared/signal/client/client.go +++ b/shared/signal/client/client.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "net/netip" "strings" "github.com/netbirdio/netbird/shared/signal/proto" @@ -14,17 +15,17 @@ import ( // A set of tools to exchange connection details (Wireguard endpoints) with the remote peer. -// Status is the status of the client -type Status string - -const StreamConnected Status = "Connected" -const StreamDisconnected Status = "Disconnected" - const ( + StreamConnected Status = "Connected" + StreamDisconnected Status = "Disconnected" + // DirectCheck indicates support to direct mode checks DirectCheck uint32 = 1 ) +// Status is the status of the client +type Status string + type Client interface { io.Closer StreamConnected() bool @@ -38,6 +39,24 @@ type Client interface { SetOnReconnectedListener(func()) } +// Credential is an instance of a GrpcClient's Credential +type Credential struct { + UFrag string + Pwd string +} + +// CredentialPayload bundles the fields of a signal Body for MarshalCredential. +type CredentialPayload struct { + Type proto.Body_Type + WgListenPort int + Credential *Credential + RosenpassPubKey []byte + RosenpassAddr string + RelaySrvAddress string + RelaySrvIP netip.Addr + SessionID []byte +} + // UnMarshalCredential parses the credentials from the message and returns a Credential instance func UnMarshalCredential(msg *proto.Message) (*Credential, error) { @@ -52,27 +71,27 @@ func UnMarshalCredential(msg *proto.Message) (*Credential, error) { } // MarshalCredential marshal a Credential instance and returns a Message object -func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey string, credential *Credential, t proto.Body_Type, rosenpassPubKey []byte, rosenpassAddr string, relaySrvAddress string, sessionID []byte) (*proto.Message, error) { +func MarshalCredential(myKey wgtypes.Key, remoteKey string, p CredentialPayload) (*proto.Message, error) { + body := &proto.Body{ + Type: p.Type, + Payload: fmt.Sprintf("%s:%s", p.Credential.UFrag, p.Credential.Pwd), + WgListenPort: uint32(p.WgListenPort), + NetBirdVersion: version.NetbirdVersion(), + RosenpassConfig: &proto.RosenpassConfig{ + RosenpassPubKey: p.RosenpassPubKey, + RosenpassServerAddr: p.RosenpassAddr, + }, + SessionId: p.SessionID, + } + if p.RelaySrvAddress != "" { + body.RelayServerAddress = &p.RelaySrvAddress + } + if p.RelaySrvIP.IsValid() { + body.RelayServerIP = p.RelaySrvIP.Unmap().AsSlice() + } return &proto.Message{ Key: myKey.PublicKey().String(), RemoteKey: remoteKey, - Body: &proto.Body{ - Type: t, - Payload: fmt.Sprintf("%s:%s", credential.UFrag, credential.Pwd), - WgListenPort: uint32(myPort), - NetBirdVersion: version.NetbirdVersion(), - RosenpassConfig: &proto.RosenpassConfig{ - RosenpassPubKey: rosenpassPubKey, - RosenpassServerAddr: rosenpassAddr, - }, - RelayServerAddress: relaySrvAddress, - SessionId: sessionID, - }, + Body: body, }, nil } - -// Credential is an instance of a GrpcClient's Credential -type Credential struct { - UFrag string - Pwd string -} diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index 5368b57a2..b245b2296 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -23,6 +23,8 @@ import ( "github.com/netbirdio/netbird/util/wsproxy" ) +const healthCheckTimeout = 5 * time.Second + // ConnStateNotifier is a wrapper interface of the status recorder type ConnStateNotifier interface { MarkSignalDisconnected(error) @@ -165,7 +167,7 @@ func (c *GrpcClient) Receive(ctx context.Context, msgHandler func(msg *proto.Mes // start receiving messages from the Signal stream (from other peers through signal) err = c.receive(stream) if err != nil { - if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled { + if ctx.Err() != nil { log.Debugf("signal connection context has been canceled, this usually indicates shutdown") return nil } @@ -263,7 +265,7 @@ func (c *GrpcClient) IsHealthy() bool { case connectivity.Ready: } - ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second) + ctx, cancel := context.WithTimeout(c.ctx, healthCheckTimeout) defer cancel() _, err := c.realClient.Send(ctx, &proto.EncryptedMessage{ Key: c.key.PublicKey().String(), diff --git a/shared/signal/proto/signalexchange.pb.go b/shared/signal/proto/signalexchange.pb.go index d9c61a846..0c80fb489 100644 --- a/shared/signal/proto/signalexchange.pb.go +++ b/shared/signal/proto/signalexchange.pb.go @@ -229,8 +229,13 @@ type Body struct { // RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to RosenpassConfig *RosenpassConfig `protobuf:"bytes,7,opt,name=rosenpassConfig,proto3" json:"rosenpassConfig,omitempty"` // relayServerAddress is url of the relay server - RelayServerAddress string `protobuf:"bytes,8,opt,name=relayServerAddress,proto3" json:"relayServerAddress,omitempty"` - SessionId []byte `protobuf:"bytes,10,opt,name=sessionId,proto3,oneof" json:"sessionId,omitempty"` + RelayServerAddress *string `protobuf:"bytes,8,opt,name=relayServerAddress,proto3,oneof" json:"relayServerAddress,omitempty"` + SessionId []byte `protobuf:"bytes,10,opt,name=sessionId,proto3,oneof" json:"sessionId,omitempty"` + // relayServerIP is the IP the sender is connected to on its relay server, + // encoded as 4 bytes (IPv4) or 16 bytes (IPv6). Receivers may use it as a + // fallback dial target when DNS resolution of relayServerAddress fails. + // SNI/TLS verification still uses relayServerAddress. + RelayServerIP []byte `protobuf:"bytes,11,opt,name=relayServerIP,proto3,oneof" json:"relayServerIP,omitempty"` } func (x *Body) Reset() { @@ -315,8 +320,8 @@ func (x *Body) GetRosenpassConfig() *RosenpassConfig { } func (x *Body) GetRelayServerAddress() string { - if x != nil { - return x.RelayServerAddress + if x != nil && x.RelayServerAddress != nil { + return *x.RelayServerAddress } return "" } @@ -328,6 +333,13 @@ func (x *Body) GetSessionId() []byte { return nil } +func (x *Body) GetRelayServerIP() []byte { + if x != nil { + return x.RelayServerIP + } + return nil +} + // Mode indicates a connection mode type Mode struct { state protoimpl.MessageState @@ -451,7 +463,7 @@ var file_signalexchange_proto_rawDesc = []byte{ 0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52, - 0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xe4, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d, + 0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xc3, 0x04, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a, @@ -471,40 +483,46 @@ var file_signalexchange_proto_rawDesc = []byte{ 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x33, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x08, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, - 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x21, 0x0a, 0x09, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, - 0x6e, 0x49, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x00, 0x52, 0x09, 0x73, 0x65, 0x73, - 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x22, 0x43, 0x0a, 0x04, 0x54, 0x79, 0x70, + 0x28, 0x09, 0x48, 0x00, 0x52, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x88, 0x01, 0x01, 0x12, 0x21, 0x0a, 0x09, 0x73, + 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x01, + 0x52, 0x09, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x12, 0x29, + 0x0a, 0x0d, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x49, 0x50, 0x18, + 0x0b, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x02, 0x52, 0x0d, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x49, 0x50, 0x88, 0x01, 0x01, 0x22, 0x43, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09, 0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53, 0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41, 0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, - 0x04, 0x12, 0x0b, 0x0a, 0x07, 0x47, 0x4f, 0x5f, 0x49, 0x44, 0x4c, 0x45, 0x10, 0x05, 0x42, 0x0c, - 0x0a, 0x0a, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x22, 0x2e, 0x0a, 0x04, - 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x88, 0x01, - 0x01, 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d, 0x0a, 0x0f, - 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, - 0x28, 0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, - 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, - 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, - 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, - 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, 0x0a, 0x0e, - 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, - 0x0a, 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, - 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, - 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, - 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, - 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, - 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, - 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, - 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, - 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, - 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x04, 0x12, 0x0b, 0x0a, 0x07, 0x47, 0x4f, 0x5f, 0x49, 0x44, 0x4c, 0x45, 0x10, 0x05, 0x42, 0x15, + 0x0a, 0x13, 0x5f, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, + 0x64, 0x72, 0x65, 0x73, 0x73, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, + 0x6e, 0x49, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x49, 0x50, 0x4a, 0x04, 0x08, 0x09, 0x10, 0x0a, 0x22, 0x2e, 0x0a, 0x04, 0x4d, + 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x88, 0x01, 0x01, + 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d, 0x0a, 0x0f, 0x52, + 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x28, + 0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, + 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, + 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, + 0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, 0x0a, 0x0e, 0x53, + 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a, + 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, + 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, + 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, 0x43, + 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, 0x73, + 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, + 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, + 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/shared/signal/proto/signalexchange.proto b/shared/signal/proto/signalexchange.proto index 0a33ad78b..96a4001e3 100644 --- a/shared/signal/proto/signalexchange.proto +++ b/shared/signal/proto/signalexchange.proto @@ -63,9 +63,17 @@ message Body { RosenpassConfig rosenpassConfig = 7; // relayServerAddress is url of the relay server - string relayServerAddress = 8; + optional string relayServerAddress = 8; + + reserved 9; optional bytes sessionId = 10; + + // relayServerIP is the IP the sender is connected to on its relay server, + // encoded as 4 bytes (IPv4) or 16 bytes (IPv6). Receivers may use it as a + // fallback dial target when DNS resolution of relayServerAddress fails. + // SNI/TLS verification still uses relayServerAddress. + optional bytes relayServerIP = 11; } // Mode indicates a connection mode diff --git a/tools/idp-migrate/DEVELOPMENT.md b/tools/idp-migrate/DEVELOPMENT.md new file mode 100644 index 000000000..5697ead40 --- /dev/null +++ b/tools/idp-migrate/DEVELOPMENT.md @@ -0,0 +1,209 @@ +# IdP Migration Tool — Developer Guide + +## Overview + +This tool migrates NetBird deployments from an external IdP (Auth0, Zitadel, Okta, etc.) to the embedded Dex IdP introduced in v0.62.0. It does two things: + +1. **DB migration** — Re-encodes every user ID from `{original_id}` to Dex's protobuf-encoded format `base64(proto{original_id, connector_id})`. +2. **Config generation** — Transforms `management.json`: removes `IdpManagerConfig`, `PKCEAuthorizationFlow`, and `DeviceAuthorizationFlow`; strips `HttpConfig` to only `CertFile`/`CertKey`; adds `EmbeddedIdP` with the static connector configuration. + +## Code Layout + +``` +tools/idp-migrate/ +├── config.go # migrationConfig struct, CLI flags, env vars, validation +├── main.go # CLI entry point, migration phases, config generation +├── main_test.go # 8 test functions (18 subtests) covering config, connector, URL builder, config generation +└── DEVELOPMENT.md # this file + +management/server/idp/migration/ +├── migration.go # Server interface, MigrateUsersToStaticConnectors(), PopulateUserInfo(), migrateUser(), reconcileActivityStore() +├── migration_test.go # 6 top-level tests (with subtests) using hand-written mocks +└── store.go # Store, EventStore interfaces, SchemaCheck, RequiredSchema, SchemaError types + +management/server/store/ +└── sql_store_idp_migration.go # CheckSchema(), ListUsers(), UpdateUserInfo(), UpdateUserID(), txDeferFKConstraints() on SqlStore + +management/server/activity/store/ +├── sql_store_idp_migration.go # UpdateUserID() on activity Store +└── sql_store_idp_migration_test.go # 5 subtests for activity UpdateUserID + +``` + +## Release / Distribution + +The tool is included in `.goreleaser.yaml` as the `netbird-idp-migrate` build target. Each NetBird release produces pre-built archives for Linux (amd64, arm64, arm) that are uploaded to GitHub Releases. The archive naming convention is: + +``` +netbird-idp-migrate__linux_.tar.gz +``` + +The build requires `CGO_ENABLED=1` because it links the SQLite driver used by `SqlStore`. The cross-compilation setup (CC env for arm64/arm) mirrors the `netbird-mgmt` build. + +## CLI Flags + +| Flag | Type | Default | Description | +|------|------|---------|-------------| +| `--config` | string | *(required)* | Path to management.json | +| `--datadir` | string | *(required)* | Data directory (containing store.db / events.db) | +| `--idp-seed-info` | string | *(required)* | Base64-encoded connector JSON | +| `--domain` | string | `""` | Sets both dashboard and API domain (convenience shorthand) | +| `--dashboard-domain` | string | *(required)* | Dashboard domain (for redirect URIs) | +| `--api-domain` | string | *(required)* | API domain (for Dex issuer and callback URLs) | +| `--dry-run` | bool | `false` | Preview changes without writing | +| `--force` | bool | `false` | Skip interactive confirmation prompt | +| `--skip-config` | bool | `false` | Skip config generation (DB-only migration) | +| `--skip-populate-user-info` | bool | `false` | Skip populating user info (user ID migration only) | +| `--log-level` | string | `"info"` | Log level (debug, info, warn, error) | + +## Environment Variables + +All flags can be overridden via environment variables. Env vars take precedence over flags. + +| Env Var | Overrides | +|---------|-----------| +| `NETBIRD_DOMAIN` | Sets both `--dashboard-domain` and `--api-domain` | +| `NETBIRD_API_URL` | `--api-domain` | +| `NETBIRD_DASHBOARD_URL` | `--dashboard-domain` | +| `NETBIRD_CONFIG_PATH` | `--config` | +| `NETBIRD_DATA_DIR` | `--datadir` | +| `NETBIRD_IDP_SEED_INFO` | `--idp-seed-info` | +| `NETBIRD_DRY_RUN` | `--dry-run` (set to `"true"`) | +| `NETBIRD_FORCE` | `--force` (set to `"true"`) | +| `NETBIRD_SKIP_CONFIG` | `--skip-config` (set to `"true"`) | +| `NETBIRD_SKIP_POPULATE_USER_INFO` | `--skip-populate-user-info` (set to `"true"`) | +| `NETBIRD_LOG_LEVEL` | `--log-level` | + +Resolution order: CLI flags are parsed first, then `--domain` sets both URLs, then `NETBIRD_DOMAIN` overrides both, then `NETBIRD_API_URL` / `NETBIRD_DASHBOARD_URL` override individually. After all resolution, `validateConfig()` ensures all required fields are set. + +## Migration Flow + +### Phase 0: Schema Validation + +`validateSchema()` opens the store and calls `CheckSchema(RequiredSchema)` to verify that all tables and columns required by the migration exist in the database. If anything is missing, the tool exits with a descriptive error instructing the operator to start the management server (v0.66.4+) at least once so that automatic GORM migrations create the required schema. + +### Phase 1: Populate User Info + +Unless `--skip-populate-user-info` is set, `populateUserInfoFromIDP()` runs before connector resolution: + +1. Creates an IDP manager from the existing `IdpManagerConfig` in management.json. +2. Calls `idpManager.GetAllAccounts()` to fetch email and name for all users from the external IDP. +3. Calls `migration.PopulateUserInfo()` which iterates over all store users, skipping service users and users that already have both email and name populated. For Dex-encoded user IDs, it decodes back to the original IDP ID for lookup. +4. Updates the store with any missing email/name values. + +This ensures user contact info is preserved before the ID migration makes the original IDP IDs inaccessible. + +### Phase 2: Connector Decoding + +`decodeConnectorConfig()` base64-decodes and JSON-unmarshals the connector JSON provided via `--idp-seed-info` (or `NETBIRD_IDP_SEED_INFO`). It validates that the connector ID is non-empty. There is no auto-detection or fallback — the operator must provide the full connector configuration. + +### Phase 3: DB Migration + +`migrateDB()` orchestrates the database migration: + +1. `openStores()` opens the main store (`SqlStore`) and activity store (non-fatal if missing). +2. Type-asserts both to `migration.Store` / `migration.EventStore`. +3. `previewUsers()` scans all users — counts pending vs already-migrated (using `DecodeDexUserID`). +4. `confirmPrompt()` asks for interactive confirmation (unless `--force` or `--dry-run`). +5. Calls `migration.MigrateUsersToStaticConnectors(srv, conn)`: + - **Reconciliation pass**: fixes activity store references for users already migrated in the main DB but whose events still reference old IDs (from a previous partial failure). + - **Main loop**: for each non-migrated user, calls `migrateUser()` which atomically updates the user ID in both the main store and activity store. + - **Dry-run**: logs what would happen, skips all writes. + +`SqlStore.UpdateUserID()` atomically updates the user's primary key and all foreign key references (peers, PATs, groups, policies, jobs, etc.) in a single transaction. + +### Phase 4: Config Generation + +Unless `--skip-config` is set, `generateConfig()` runs: + +1. **Read** — loads existing `management.json` as raw JSON to preserve unknown fields. + +2. **Strip** — removes keys that are no longer needed: + - `IdpManagerConfig` + - `PKCEAuthorizationFlow` + - `DeviceAuthorizationFlow` + - All `HttpConfig` fields except `CertFile` and `CertKey` + +3. **Add EmbeddedIdP** — inserts a minimal section with: + - `Enabled: true` + - `Issuer` built from `--api-domain` + `/oauth2` + - `DashboardRedirectURIs` built from `--dashboard-domain` + `/nb-auth` and `/nb-silent-auth` + - `StaticConnectors` containing the decoded connector, with `redirectURI` overridden to `--api-domain` + `/oauth2/callback` + +4. **Write** — backs up original as `management.json.bak`, writes new config. In dry-run mode, prints to stdout instead. + +## Interface Decoupling + +Migration methods (`ListUsers`, `UpdateUserID`) are **not** on the core `store.Store` or `activity.Store` interfaces. Instead, they're defined in `migration/store.go`: + +```go +type Store interface { + ListUsers(ctx context.Context) ([]*types.User, error) + UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error + UpdateUserInfo(ctx context.Context, userID, email, name string) error + CheckSchema(checks []SchemaCheck) []SchemaError +} + +type EventStore interface { + UpdateUserID(ctx context.Context, oldUserID, newUserID string) error +} +``` + +A `Server` interface wraps both stores for dependency injection: + +```go +type Server interface { + Store() Store + EventStore() EventStore // may return nil +} +``` + +The concrete `SqlStore` types already have these methods (in their respective `sql_store_idp_migration.go` files), so they satisfy the interfaces via Go's structural typing — zero changes needed on the core store interfaces. At runtime, the standalone tool type-asserts: + +```go +migStore, ok := mainStore.(migration.Store) +``` + +This keeps migration concerns completely separate from the core store contract. + +## Dex User ID Encoding + +`EncodeDexUserID(userID, connectorID)` produces a manually-encoded protobuf with two string fields, then base64-encodes the result (raw, no padding). `DecodeDexUserID` reverses this. The migration loop uses `DecodeDexUserID` to detect already-migrated users (decode succeeds → skip). + +See `idp/dex/provider.go` for the implementation. + +## Standalone Tool + +The standalone tool (`tools/idp-migrate/main.go`) is the primary migration entry point. It opens stores directly, runs schema validation, populates user info from the external IDP, migrates user IDs, and generates the new config — then exits. Configuration is handled entirely through `config.go` which parses CLI flags and environment variables. + +## Running Tests + +```bash +# Migration library +go test -v ./management/server/idp/migration/... + +# Standalone tool +go test -v ./tools/idp-migrate/... + +# Activity store migration tests +go test -v -run TestUpdateUserID ./management/server/activity/store/... + +# Build locally +go build ./tools/idp-migrate/ +``` + +## Clean Removal + +When migration tooling is no longer needed, delete: + +1. `tools/idp-migrate/` — entire directory +2. `management/server/idp/migration/` — entire directory +3. `management/server/store/sql_store_idp_migration.go` — migration methods on main SqlStore +4. `management/server/activity/store/sql_store_idp_migration.go` — migration method on activity Store +5. `management/server/activity/store/sql_store_idp_migration_test.go` — tests for the above +6. In `.goreleaser.yaml`: + - Remove the `netbird-idp-migrate` build entry + - Remove the `netbird-idp-migrate` archive entry +7. Run `go mod tidy` + +No core interfaces or mocks need editing — that's the point of the decoupling. diff --git a/tools/idp-migrate/LICENSE b/tools/idp-migrate/LICENSE new file mode 100644 index 000000000..be3f7b28e --- /dev/null +++ b/tools/idp-migrate/LICENSE @@ -0,0 +1,661 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/tools/idp-migrate/config.go b/tools/idp-migrate/config.go new file mode 100644 index 000000000..f4d6b9ea2 --- /dev/null +++ b/tools/idp-migrate/config.go @@ -0,0 +1,174 @@ +package main + +import ( + "flag" + "fmt" + "os" + "strconv" + + "github.com/netbirdio/netbird/util" +) + +type migrationConfig struct { + // Data + dashboardURL string + apiURL string + configPath string + dataDir string + idpSeedInfo string + + // Options + dryRun bool + force bool + skipConfig bool + skipPopulateUserInfo bool + + // Logging + logLevel string +} + +func config() (*migrationConfig, error) { + cfg, err := configFromArgs(os.Args[1:]) + if err != nil { + return nil, err + } + + if err := util.InitLog(cfg.logLevel, util.LogConsole); err != nil { + return nil, fmt.Errorf("init logger: %w", err) + } + + return cfg, nil +} + +func configFromArgs(args []string) (*migrationConfig, error) { + var cfg migrationConfig + var domain string + + fs := flag.NewFlagSet("netbird-idp-migrate", flag.ContinueOnError) + fs.StringVar(&domain, "domain", "", "domain for both dashboard and API") + fs.StringVar(&cfg.dashboardURL, "dashboard-url", "", "dashboard URL") + fs.StringVar(&cfg.apiURL, "api-url", "", "API URL") + fs.StringVar(&cfg.configPath, "config", "", "path to management.json (required)") + fs.StringVar(&cfg.dataDir, "datadir", "", "override data directory from config") + fs.StringVar(&cfg.idpSeedInfo, "idp-seed-info", "", "base64-encoded connector JSON (overrides auto-detection)") + fs.BoolVar(&cfg.dryRun, "dry-run", false, "preview changes without writing") + fs.BoolVar(&cfg.force, "force", false, "skip confirmation prompt") + fs.BoolVar(&cfg.skipConfig, "skip-config", false, "skip config generation (DB migration only)") + fs.BoolVar(&cfg.skipPopulateUserInfo, "skip-populate-user-info", false, "skip populating user info (user id migration only)") + fs.StringVar(&cfg.logLevel, "log-level", "info", "log level (debug, info, warn, error)") + + if err := fs.Parse(args); err != nil { + return nil, err + } + + applyOverrides(&cfg, domain) + + if err := validateConfig(&cfg); err != nil { + return nil, err + } + + return &cfg, nil +} + +// applyOverrides resolves domain configuration from broad to narrow sources. +// The most granular value always wins: +// +// --domain flag (broadest, only fills blanks) +// NETBIRD_DOMAIN env (overrides flags, sets both) +// --api-domain / --dashboard-domain flags (more specific than --domain) +// NETBIRD_API_URL / NETBIRD_DASHBOARD_URL env (most specific, always wins) +// +// Other env vars unconditionally override their corresponding flags. +func applyOverrides(cfg *migrationConfig, domain string) { + // --domain is a convenience shorthand: only fills in values not already + // set by the more specific --api-domain / --dashboard-domain flags. + if domain != "" { + if cfg.apiURL == "" { + cfg.apiURL = domain + } + if cfg.dashboardURL == "" { + cfg.dashboardURL = domain + } + } + + // Env vars override flags. Broad env var first, then narrow ones on top, + // so the most granular value always wins. + if val, ok := os.LookupEnv("NETBIRD_DOMAIN"); ok { + cfg.dashboardURL = val + cfg.apiURL = val + } + + if val, ok := os.LookupEnv("NETBIRD_API_URL"); ok { + cfg.apiURL = val + } + + if val, ok := os.LookupEnv("NETBIRD_DASHBOARD_URL"); ok { + cfg.dashboardURL = val + } + + if val, ok := os.LookupEnv("NETBIRD_CONFIG_PATH"); ok { + cfg.configPath = val + } + + if val, ok := os.LookupEnv("NETBIRD_DATA_DIR"); ok { + cfg.dataDir = val + } + + if val, ok := os.LookupEnv("NETBIRD_IDP_SEED_INFO"); ok { + cfg.idpSeedInfo = val + } + + // Enforce dry run if any value is provided + if sval, ok := os.LookupEnv("NETBIRD_DRY_RUN"); ok { + if val, err := strconv.ParseBool(sval); err == nil { + cfg.dryRun = val + } + } + + cfg.dryRun = parseBool("NETBIRD_DRY_RUN", cfg.dryRun) + cfg.force = parseBool("NETBIRD_FORCE", cfg.force) + cfg.skipConfig = parseBool("NETBIRD_SKIP_CONFIG", cfg.skipConfig) + cfg.skipPopulateUserInfo = parseBool("NETBIRD_SKIP_POPULATE_USER_INFO", cfg.skipPopulateUserInfo) + + if val, ok := os.LookupEnv("NETBIRD_LOG_LEVEL"); ok { + cfg.logLevel = val + } +} + +func parseBool(varName string, defaultVal bool) bool { + stringValue, ok := os.LookupEnv(varName) + if !ok { + return defaultVal + } + + boolValue, err := strconv.ParseBool(stringValue) + if err != nil { + return defaultVal + } + + return boolValue +} + +func validateConfig(cfg *migrationConfig) error { + if cfg.configPath == "" { + return fmt.Errorf("--config is required") + } + + if cfg.dataDir == "" { + return fmt.Errorf("--datadir is required") + } + + if cfg.idpSeedInfo == "" { + return fmt.Errorf("--idp-seed-info is required") + } + + if cfg.apiURL == "" { + return fmt.Errorf("--api-domain is required") + } + + if cfg.dashboardURL == "" { + return fmt.Errorf("--dashboard-domain is required") + } + + return nil +} diff --git a/tools/idp-migrate/main.go b/tools/idp-migrate/main.go new file mode 100644 index 000000000..a8cba0750 --- /dev/null +++ b/tools/idp-migrate/main.go @@ -0,0 +1,449 @@ +// Package main provides a standalone CLI tool to migrate user IDs from an +// external IdP format to the embedded Dex IdP format used by NetBird >= v0.62.0. +// +// This tool reads management.json to auto-detect the current external IdP +// configuration (issuer, clientID, clientSecret, type) and re-encodes all user +// IDs in the database to the Dex protobuf-encoded format. It works independently +// of migrate.sh and the combined server, allowing operators to migrate their +// database before switching to the combined server. +// +// Usage: +// +// netbird-idp-migrate --config /etc/netbird/management.json [--dry-run] [--force] +package main + +import ( + "bufio" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "maps" + "net/url" + "os" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/idp/dex" + nbconfig "github.com/netbirdio/netbird/management/internals/server/config" + activitystore "github.com/netbirdio/netbird/management/server/activity/store" + "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/idp/migration" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/util/crypt" +) + +// migrationServer implements migration.Server by wrapping the migration-specific interfaces. +type migrationServer struct { + store migration.Store + eventStore migration.EventStore +} + +func (s *migrationServer) Store() migration.Store { return s.store } +func (s *migrationServer) EventStore() migration.EventStore { return s.eventStore } + +func main() { + cfg, err := config() + if err != nil { + log.Fatalf("config error: %v", err) + } + + if err := run(cfg); err != nil { + log.Fatalf("migration failed: %v", err) + } + + if !cfg.dryRun { + printPostMigrationInstructions(cfg) + } +} + +func run(cfg *migrationConfig) error { + mgmtConfig := &nbconfig.Config{} + if _, err := util.ReadJsonWithEnvSub(cfg.configPath, mgmtConfig); err != nil { + return err + } + + // Validate the database schema before attempting any operations. + if err := validateSchema(mgmtConfig, cfg.dataDir); err != nil { + return err + } + + if !cfg.skipPopulateUserInfo { + err := populateUserInfoFromIDP(cfg, mgmtConfig) + if err != nil { + return fmt.Errorf("populate user info: %w", err) + } + } + + connectorConfig, err := decodeConnectorConfig(cfg.idpSeedInfo) + if err != nil { + return fmt.Errorf("resolve connector: %w", err) + } + + log.Infof( + "resolved connector: type=%s, id=%s, name=%s", + connectorConfig.Type, + connectorConfig.ID, + connectorConfig.Name, + ) + + if err := migrateDB(cfg, mgmtConfig, connectorConfig); err != nil { + return err + } + + if cfg.skipConfig { + log.Info("skipping config generation (--skip-config)") + return nil + } + + return generateConfig(cfg, connectorConfig) +} + +// validateSchema opens the store and checks that all required tables and columns +// exist. If anything is missing, it returns a descriptive error telling the user +// to upgrade their management server. +func validateSchema(mgmtConfig *nbconfig.Config, dataDir string) error { + ctx := context.Background() + migStore, migEventStore, cleanup, err := openStores(ctx, mgmtConfig, dataDir) + if err != nil { + return err + } + defer cleanup() + + errs := migStore.CheckSchema(migration.RequiredSchema) + if len(errs) > 0 { + return fmt.Errorf("%s", formatSchemaErrors(errs)) + } + + if migEventStore != nil { + eventErrs := migEventStore.CheckSchema(migration.RequiredEventSchema) + if len(eventErrs) > 0 { + return fmt.Errorf("activity store schema check failed (upgrade management server first):\n%s", formatSchemaErrors(eventErrs)) + } + } + + log.Info("database schema check passed") + return nil +} + +// formatSchemaErrors returns a user-friendly message listing all missing schema +// elements and instructing the operator to upgrade. +func formatSchemaErrors(errs []migration.SchemaError) string { + var b strings.Builder + b.WriteString("database schema is incomplete — the following tables/columns are missing:\n") + for _, e := range errs { + fmt.Fprintf(&b, " - %s\n", e.String()) + } + b.WriteString("\nPlease start the NetBird management server (v0.66.4+) at least once so that automatic database migrations create the required schema, then re-run this tool.\n") + return b.String() +} + +// populateUserInfoFromIDP creates an IDP manager from the config, fetches all +// user data (email, name) from the external IDP, and updates the store for users +// that are missing this information. +func populateUserInfoFromIDP(cfg *migrationConfig, mgmtConfig *nbconfig.Config) error { + ctx := context.Background() + + if mgmtConfig.IdpManagerConfig == nil { + return fmt.Errorf("IdpManagerConfig is not set in management.json; cannot fetch user info from IDP") + } + + idpManager, err := idp.NewManager(ctx, *mgmtConfig.IdpManagerConfig, nil) + if err != nil { + return fmt.Errorf("create IDP manager: %w", err) + } + if idpManager == nil { + return fmt.Errorf("IDP manager type is 'none' or empty; cannot fetch user info") + } + + log.Infof("created IDP manager (type: %s)", mgmtConfig.IdpManagerConfig.ManagerType) + + migStore, _, cleanup, err := openStores(ctx, mgmtConfig, cfg.dataDir) + if err != nil { + return err + } + defer cleanup() + + srv := &migrationServer{store: migStore} + return migration.PopulateUserInfo(srv, idpManager, cfg.dryRun) +} + +// openStores opens the main and activity stores, returning migration-specific interfaces. +// The caller must call the returned cleanup function to close the stores. +func openStores(ctx context.Context, cfg *nbconfig.Config, dataDir string) (migration.Store, migration.EventStore, func(), error) { + engine := cfg.StoreConfig.Engine + if engine == "" { + engine = types.SqliteStoreEngine + } + + mainStore, err := store.NewStore(ctx, engine, dataDir, nil, true) + if err != nil { + return nil, nil, nil, fmt.Errorf("open main store: %w", err) + } + + if cfg.DataStoreEncryptionKey != "" { + fieldEncrypt, err := crypt.NewFieldEncrypt(cfg.DataStoreEncryptionKey) + if err != nil { + _ = mainStore.Close(ctx) + return nil, nil, nil, fmt.Errorf("init field encryption: %w", err) + } + mainStore.SetFieldEncrypt(fieldEncrypt) + } + + migStore, ok := mainStore.(migration.Store) + if !ok { + _ = mainStore.Close(ctx) + return nil, nil, nil, fmt.Errorf("store does not support migration operations (ListUsers/UpdateUserID)") + } + + cleanup := func() { _ = mainStore.Close(ctx) } + + var migEventStore migration.EventStore + actStore, err := activitystore.NewSqlStore(ctx, dataDir, cfg.DataStoreEncryptionKey) + if err != nil { + log.Warnf("could not open activity store (events.db may not exist): %v", err) + } else { + migEventStore = actStore + prevCleanup := cleanup + cleanup = func() { _ = actStore.Close(ctx); prevCleanup() } + } + + return migStore, migEventStore, cleanup, nil +} + +// migrateDB opens the stores, previews pending users, and runs the DB migration. +func migrateDB(cfg *migrationConfig, mgmtConfig *nbconfig.Config, connectorConfig *dex.Connector) error { + ctx := context.Background() + + migStore, migEventStore, cleanup, err := openStores(ctx, mgmtConfig, cfg.dataDir) + if err != nil { + return err + } + defer cleanup() + + pending, err := previewUsers(ctx, migStore) + if err != nil { + return err + } + + if cfg.dryRun { + if err := os.Setenv("NB_IDP_MIGRATION_DRY_RUN", "true"); err != nil { + return fmt.Errorf("set dry-run env: %w", err) + } + defer os.Unsetenv("NB_IDP_MIGRATION_DRY_RUN") //nolint:errcheck + } + + if !cfg.dryRun && !cfg.force { + if !confirmPrompt(pending) { + log.Info("migration cancelled by user") + return nil + } + } + + srv := &migrationServer{store: migStore, eventStore: migEventStore} + if err := migration.MigrateUsersToStaticConnectors(srv, connectorConfig); err != nil { + return fmt.Errorf("migrate users: %w", err) + } + + if !cfg.dryRun { + log.Info("DB migration completed successfully") + } + return nil +} + +// previewUsers counts pending vs already-migrated users and logs a summary. +// Returns the number of users still needing migration. +func previewUsers(ctx context.Context, migStore migration.Store) (int, error) { + users, err := migStore.ListUsers(ctx) + if err != nil { + return 0, fmt.Errorf("list users: %w", err) + } + + var pending, alreadyMigrated int + for _, u := range users { + if _, _, decErr := dex.DecodeDexUserID(u.Id); decErr == nil { + alreadyMigrated++ + } else { + pending++ + } + } + + log.Infof("found %d total users: %d pending migration, %d already migrated", len(users), pending, alreadyMigrated) + return pending, nil +} + +// confirmPrompt asks the user for interactive confirmation. Returns true if they accept. +func confirmPrompt(pending int) bool { + log.Infof("About to migrate %d users. This cannot be easily undone. Continue? [y/N] ", pending) + reader := bufio.NewReader(os.Stdin) + answer, _ := reader.ReadString('\n') + answer = strings.TrimSpace(strings.ToLower(answer)) + return answer == "y" || answer == "yes" +} + +// decodeConnectorConfig base64-decodes and JSON-unmarshals a connector. +func decodeConnectorConfig(encoded string) (*dex.Connector, error) { + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return nil, fmt.Errorf("base64 decode: %w", err) + } + + var conn dex.Connector + if err := json.Unmarshal(decoded, &conn); err != nil { + return nil, fmt.Errorf("json unmarshal: %w", err) + } + + if conn.ID == "" { + return nil, fmt.Errorf("connector ID is empty") + } + + return &conn, nil +} + +// generateConfig reads the existing management.json as raw JSON, removes +// IdpManagerConfig, adds EmbeddedIdP, updates HttpConfig fields, and writes +// the result. In dry-run mode, it prints the new config to stdout instead. +func generateConfig(cfg *migrationConfig, connectorConfig *dex.Connector) error { + // Read existing config as raw JSON to preserve all fields + raw, err := os.ReadFile(cfg.configPath) + if err != nil { + return fmt.Errorf("read config file: %w", err) + } + + var configMap map[string]any + if err := json.Unmarshal(raw, &configMap); err != nil { + return fmt.Errorf("parse config JSON: %w", err) + } + + // Remove unused information + delete(configMap, "IdpManagerConfig") + delete(configMap, "PKCEAuthorizationFlow") + delete(configMap, "DeviceAuthorizationFlow") + + httpConfig, ok := configMap["HttpConfig"].(map[string]any) + if httpConfig != nil && ok { + certFilePath := httpConfig["CertFile"] + certKeyPath := httpConfig["CertKey"] + + delete(configMap, "HttpConfig") + + configMap["HttpConfig"] = map[string]any{ + "CertFile": certFilePath, + "CertKey": certKeyPath, + } + } + + // Ensure the connector's redirectURI points to the management server (Dex callback), + // not the external IdP. The auto-detection may have used the IdP issuer URL. + connConfig := make(map[string]any, len(connectorConfig.Config)) + maps.Copy(connConfig, connectorConfig.Config) + + redirectURI, err := buildURL(cfg.apiURL, "/oauth2/callback") + if err != nil { + return fmt.Errorf("build redirect URI: %w", err) + } + connConfig["redirectURI"] = redirectURI + + issuer, err := buildURL(cfg.apiURL, "/oauth2") + if err != nil { + return fmt.Errorf("build issuer URL: %w", err) + } + + dashboardRedirectURL, err := buildURL(cfg.dashboardURL, "/nb-auth") + if err != nil { + return fmt.Errorf("build dashboard redirect URL: %w", err) + } + + dashboardSilentRedirectURL, err := buildURL(cfg.dashboardURL, "/nb-silent-auth") + if err != nil { + return fmt.Errorf("build dashboard silent redirect URL: %w", err) + } + + // Add minimal EmbeddedIdP section + configMap["EmbeddedIdP"] = map[string]any{ + "Enabled": true, + "Issuer": issuer, + "DashboardRedirectURIs": []string{ + dashboardRedirectURL, + dashboardSilentRedirectURL, + }, + "StaticConnectors": []any{ + map[string]any{ + "type": connectorConfig.Type, + "name": connectorConfig.Name, + "id": connectorConfig.ID, + "config": connConfig, + }, + }, + } + + newJSON, err := json.MarshalIndent(configMap, "", " ") + if err != nil { + return fmt.Errorf("marshal new config: %w", err) + } + + if cfg.dryRun { + log.Info("[DRY RUN] new management.json would be:") + log.Infoln(string(newJSON)) + return nil + } + + // Backup original + backupPath := cfg.configPath + ".bak" + if err := os.WriteFile(backupPath, raw, 0o600); err != nil { + return fmt.Errorf("write backup: %w", err) + } + log.Infof("backed up original config to %s", backupPath) + + // Write new config + if err := os.WriteFile(cfg.configPath, newJSON, 0o600); err != nil { + return fmt.Errorf("write new config: %w", err) + } + log.Infof("wrote new config to %s", cfg.configPath) + + return nil +} + +func buildURL(uri, path string) (string, error) { + // Case for domain without scheme, e.g. "example.com" or "example.com:8080" + if !strings.HasPrefix(uri, "http://") && !strings.HasPrefix(uri, "https://") { + uri = "https://" + uri + } + + val, err := url.JoinPath(uri, path) + if err != nil { + return "", err + } + + return val, nil +} + +func printPostMigrationInstructions(cfg *migrationConfig) { + authAuthority, err := buildURL(cfg.apiURL, "/oauth2") + if err != nil { + authAuthority = "https:///oauth2" + } + + log.Info("Congratulations! You have successfully migrated your NetBird management server to the embedded Dex IdP.") + log.Info("Next steps:") + log.Info("1. Make sure the following environment variables are set for your dashboard server:") + log.Infof(` +AUTH_AUDIENCE=netbird-dashboard +AUTH_CLIENT_ID=netbird-dashboard +AUTH_AUTHORITY=%s +AUTH_SUPPORTED_SCOPES=openid profile email groups +AUTH_REDIRECT_URI=/nb-auth +AUTH_SILENT_REDIRECT_URI=/nb-silent-auth + `, + authAuthority, + ) + log.Info("2. Make sure you restart the dashboard & management servers to pick up the new config and environment variables.") + log.Info("eg. docker compose up -d --force-recreate management dashboard") + log.Info("3. Optional: If you have a reverse proxy configured, make sure the path `/oauth2/*` points to the management api server.") +} + +// Compile-time check that migrationServer implements migration.Server. +var _ migration.Server = (*migrationServer)(nil) diff --git a/tools/idp-migrate/main_test.go b/tools/idp-migrate/main_test.go new file mode 100644 index 000000000..75d0bd7eb --- /dev/null +++ b/tools/idp-migrate/main_test.go @@ -0,0 +1,487 @@ +package main + +import ( + "encoding/base64" + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/idp/dex" + "github.com/netbirdio/netbird/management/server/idp/migration" +) + +// TestMigrationServerInterface is a compile-time check that migrationServer +// implements the migration.Server interface. +func TestMigrationServerInterface(t *testing.T) { + var _ migration.Server = (*migrationServer)(nil) +} + +func TestDecodeConnectorConfig(t *testing.T) { + conn := dex.Connector{ + Type: "oidc", + Name: "test", + ID: "test-id", + Config: map[string]any{ + "issuer": "https://example.com", + "clientID": "cid", + "clientSecret": "csecret", + }, + } + + data, err := json.Marshal(conn) + require.NoError(t, err) + encoded := base64.StdEncoding.EncodeToString(data) + + result, err := decodeConnectorConfig(encoded) + require.NoError(t, err) + assert.Equal(t, "test-id", result.ID) + assert.Equal(t, "oidc", result.Type) + assert.Equal(t, "https://example.com", result.Config["issuer"]) +} + +func TestDecodeConnectorConfig_InvalidBase64(t *testing.T) { + _, err := decodeConnectorConfig("not-valid-base64!!!") + require.Error(t, err) + assert.Contains(t, err.Error(), "base64 decode") +} + +func TestDecodeConnectorConfig_InvalidJSON(t *testing.T) { + encoded := base64.StdEncoding.EncodeToString([]byte("not json")) + _, err := decodeConnectorConfig(encoded) + require.Error(t, err) + assert.Contains(t, err.Error(), "json unmarshal") +} + +func TestDecodeConnectorConfig_EmptyConnectorID(t *testing.T) { + conn := dex.Connector{ + Type: "oidc", + Name: "no-id", + ID: "", + } + data, err := json.Marshal(conn) + require.NoError(t, err) + + encoded := base64.StdEncoding.EncodeToString(data) + _, err = decodeConnectorConfig(encoded) + require.Error(t, err) + assert.Contains(t, err.Error(), "connector ID is empty") +} + +func TestValidateConfig(t *testing.T) { + valid := &migrationConfig{ + configPath: "/etc/netbird/management.json", + dataDir: "/var/lib/netbird", + idpSeedInfo: "some-base64", + apiURL: "https://api.example.com", + dashboardURL: "https://dash.example.com", + } + + t.Run("valid config", func(t *testing.T) { + require.NoError(t, validateConfig(valid)) + }) + + t.Run("missing configPath", func(t *testing.T) { + cfg := *valid + cfg.configPath = "" + err := validateConfig(&cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "--config") + }) + + t.Run("missing dataDir", func(t *testing.T) { + cfg := *valid + cfg.dataDir = "" + err := validateConfig(&cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "--datadir") + }) + + t.Run("missing idpSeedInfo", func(t *testing.T) { + cfg := *valid + cfg.idpSeedInfo = "" + err := validateConfig(&cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "--idp-seed-info") + }) + + t.Run("missing apiUrl", func(t *testing.T) { + cfg := *valid + cfg.apiURL = "" + err := validateConfig(&cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "--api-domain") + }) + + t.Run("missing dashboardUrl", func(t *testing.T) { + cfg := *valid + cfg.dashboardURL = "" + err := validateConfig(&cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "--dashboard-domain") + }) +} + +func TestConfigFromArgs_EnvVarsApplied(t *testing.T) { + t.Run("env vars fill in for missing flags", func(t *testing.T) { + t.Setenv("NETBIRD_CONFIG_PATH", "/env/management.json") + t.Setenv("NETBIRD_DATA_DIR", "/env/data") + t.Setenv("NETBIRD_IDP_SEED_INFO", "env-seed") + t.Setenv("NETBIRD_API_URL", "https://api.env.com") + t.Setenv("NETBIRD_DASHBOARD_URL", "https://dash.env.com") + + cfg, err := configFromArgs([]string{}) + require.NoError(t, err) + + assert.Equal(t, "/env/management.json", cfg.configPath) + assert.Equal(t, "/env/data", cfg.dataDir) + assert.Equal(t, "env-seed", cfg.idpSeedInfo) + assert.Equal(t, "https://api.env.com", cfg.apiURL) + assert.Equal(t, "https://dash.env.com", cfg.dashboardURL) + }) + + t.Run("flags work without env vars", func(t *testing.T) { + cfg, err := configFromArgs([]string{ + "--config", "/flag/management.json", + "--datadir", "/flag/data", + "--idp-seed-info", "flag-seed", + "--api-url", "https://api.flag.com", + "--dashboard-url", "https://dash.flag.com", + }) + require.NoError(t, err) + + assert.Equal(t, "/flag/management.json", cfg.configPath) + assert.Equal(t, "/flag/data", cfg.dataDir) + assert.Equal(t, "flag-seed", cfg.idpSeedInfo) + assert.Equal(t, "https://api.flag.com", cfg.apiURL) + assert.Equal(t, "https://dash.flag.com", cfg.dashboardURL) + }) + + t.Run("env vars override flags", func(t *testing.T) { + t.Setenv("NETBIRD_CONFIG_PATH", "/env/management.json") + t.Setenv("NETBIRD_API_URL", "https://api.env.com") + + cfg, err := configFromArgs([]string{ + "--config", "/flag/management.json", + "--datadir", "/flag/data", + "--idp-seed-info", "flag-seed", + "--api-url", "https://api.flag.com", + "--dashboard-url", "https://dash.flag.com", + }) + require.NoError(t, err) + + assert.Equal(t, "/env/management.json", cfg.configPath, "env should override flag") + assert.Equal(t, "https://api.env.com", cfg.apiURL, "env should override flag") + assert.Equal(t, "https://dash.flag.com", cfg.dashboardURL, "flag preserved when no env override") + }) + + t.Run("--domain flag with specific env var override", func(t *testing.T) { + t.Setenv("NETBIRD_API_URL", "https://api.env.com") + + cfg, err := configFromArgs([]string{ + "--domain", "both.flag.com", + "--config", "/path", + "--datadir", "/data", + "--idp-seed-info", "seed", + }) + require.NoError(t, err) + + assert.Equal(t, "https://api.env.com", cfg.apiURL, "specific env beats --domain") + assert.Equal(t, "both.flag.com", cfg.dashboardURL, "--domain fills dashboard") + }) +} + +func TestApplyOverrides_MostGranularWins(t *testing.T) { + t.Run("specific flags beat --domain", func(t *testing.T) { + cfg := &migrationConfig{ + apiURL: "api.specific.com", + dashboardURL: "dash.specific.com", + } + applyOverrides(cfg, "broad.com") + + assert.Equal(t, "api.specific.com", cfg.apiURL) + assert.Equal(t, "dash.specific.com", cfg.dashboardURL) + }) + + t.Run("--domain fills blanks when specific flags missing", func(t *testing.T) { + cfg := &migrationConfig{} + applyOverrides(cfg, "broad.com") + + assert.Equal(t, "broad.com", cfg.apiURL) + assert.Equal(t, "broad.com", cfg.dashboardURL) + }) + + t.Run("--domain fills only the missing specific flag", func(t *testing.T) { + cfg := &migrationConfig{ + apiURL: "api.specific.com", + } + applyOverrides(cfg, "broad.com") + + assert.Equal(t, "api.specific.com", cfg.apiURL) + assert.Equal(t, "broad.com", cfg.dashboardURL) + }) + + t.Run("NETBIRD_DOMAIN overrides flags", func(t *testing.T) { + cfg := &migrationConfig{ + apiURL: "api.flag.com", + dashboardURL: "dash.flag.com", + } + t.Setenv("NETBIRD_DOMAIN", "env-broad.com") + + applyOverrides(cfg, "") + + assert.Equal(t, "env-broad.com", cfg.apiURL) + assert.Equal(t, "env-broad.com", cfg.dashboardURL) + }) + + t.Run("specific env vars beat NETBIRD_DOMAIN", func(t *testing.T) { + cfg := &migrationConfig{} + t.Setenv("NETBIRD_DOMAIN", "env-broad.com") + t.Setenv("NETBIRD_API_URL", "api.env-specific.com") + t.Setenv("NETBIRD_DASHBOARD_URL", "dash.env-specific.com") + + applyOverrides(cfg, "") + + assert.Equal(t, "api.env-specific.com", cfg.apiURL) + assert.Equal(t, "dash.env-specific.com", cfg.dashboardURL) + }) + + t.Run("one specific env var overrides only its field", func(t *testing.T) { + cfg := &migrationConfig{} + t.Setenv("NETBIRD_DOMAIN", "env-broad.com") + t.Setenv("NETBIRD_API_URL", "api.env-specific.com") + + applyOverrides(cfg, "") + + assert.Equal(t, "api.env-specific.com", cfg.apiURL) + assert.Equal(t, "env-broad.com", cfg.dashboardURL) + }) + + t.Run("specific env vars beat all flags combined", func(t *testing.T) { + cfg := &migrationConfig{ + apiURL: "api.flag.com", + dashboardURL: "dash.flag.com", + } + t.Setenv("NETBIRD_API_URL", "api.env.com") + t.Setenv("NETBIRD_DASHBOARD_URL", "dash.env.com") + + applyOverrides(cfg, "domain-flag.com") + + assert.Equal(t, "api.env.com", cfg.apiURL) + assert.Equal(t, "dash.env.com", cfg.dashboardURL) + }) + + t.Run("env vars override all non-domain flags", func(t *testing.T) { + cfg := &migrationConfig{ + configPath: "/flag/path", + dataDir: "/flag/data", + idpSeedInfo: "flag-seed", + dryRun: false, + force: false, + skipConfig: false, + skipPopulateUserInfo: false, + logLevel: "info", + } + t.Setenv("NETBIRD_CONFIG_PATH", "/env/path") + t.Setenv("NETBIRD_DATA_DIR", "/env/data") + t.Setenv("NETBIRD_IDP_SEED_INFO", "env-seed") + t.Setenv("NETBIRD_DRY_RUN", "true") + t.Setenv("NETBIRD_FORCE", "true") + t.Setenv("NETBIRD_SKIP_CONFIG", "true") + t.Setenv("NETBIRD_SKIP_POPULATE_USER_INFO", "true") + t.Setenv("NETBIRD_LOG_LEVEL", "debug") + + applyOverrides(cfg, "") + + assert.Equal(t, "/env/path", cfg.configPath) + assert.Equal(t, "/env/data", cfg.dataDir) + assert.Equal(t, "env-seed", cfg.idpSeedInfo) + assert.True(t, cfg.dryRun) + assert.True(t, cfg.force) + assert.True(t, cfg.skipConfig) + assert.True(t, cfg.skipPopulateUserInfo) + assert.Equal(t, "debug", cfg.logLevel) + }) + + t.Run("boolean env vars properly parse false values", func(t *testing.T) { + cfg := &migrationConfig{} + t.Setenv("NETBIRD_DRY_RUN", "false") + t.Setenv("NETBIRD_FORCE", "yes") + t.Setenv("NETBIRD_SKIP_CONFIG", "0") + + applyOverrides(cfg, "") + + assert.False(t, cfg.dryRun) + assert.False(t, cfg.force) + assert.False(t, cfg.skipConfig) + }) + + t.Run("unset env vars do not override flags", func(t *testing.T) { + cfg := &migrationConfig{ + configPath: "/flag/path", + dataDir: "/flag/data", + idpSeedInfo: "flag-seed", + dryRun: true, + logLevel: "warn", + } + + applyOverrides(cfg, "") + + assert.Equal(t, "/flag/path", cfg.configPath) + assert.Equal(t, "/flag/data", cfg.dataDir) + assert.Equal(t, "flag-seed", cfg.idpSeedInfo) + assert.True(t, cfg.dryRun) + assert.Equal(t, "warn", cfg.logLevel) + }) +} + +func TestBuildUrl(t *testing.T) { + tests := []struct { + name string + uri string + path string + expected string + }{ + {"with https scheme", "https://example.com", "/oauth2", "https://example.com/oauth2"}, + {"with http scheme", "http://example.com", "/oauth2/callback", "http://example.com/oauth2/callback"}, + {"bare domain", "example.com", "/oauth2", "https://example.com/oauth2"}, + {"domain with port", "example.com:8080", "/nb-auth", "https://example.com:8080/nb-auth"}, + {"trailing slash on uri", "https://example.com/", "/oauth2", "https://example.com/oauth2"}, + {"nested path", "https://example.com", "/oauth2/callback", "https://example.com/oauth2/callback"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url, err := buildURL(tt.uri, tt.path) + assert.NoError(t, err) + assert.Equal(t, tt.expected, url) + }) + } +} + +func TestGenerateConfig(t *testing.T) { + t.Run("generates valid config", func(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "management.json") + + originalConfig := `{ + "Datadir": "/var/lib/netbird", + "HttpConfig": { + "LetsEncryptDomain": "mgmt.example.com", + "CertFile": "/etc/ssl/cert.pem", + "CertKey": "/etc/ssl/key.pem", + "AuthIssuer": "https://zitadel.example.com/oauth2", + "AuthKeysLocation": "https://zitadel.example.com/oauth2/keys", + "OIDCConfigEndpoint": "https://zitadel.example.com/.well-known/openid-configuration", + "AuthClientID": "old-client-id", + "AuthUserIDClaim": "preferred_username" + }, + "IdpManagerConfig": { + "ManagerType": "zitadel", + "ClientConfig": { + "Issuer": "https://zitadel.example.com", + "ClientID": "zit-id", + "ClientSecret": "zit-secret" + } + } +}` + require.NoError(t, os.WriteFile(configPath, []byte(originalConfig), 0o600)) + + cfg := &migrationConfig{ + configPath: configPath, + dashboardURL: "https://mgmt.example.com", + apiURL: "https://mgmt.example.com", + } + conn := &dex.Connector{ + Type: "zitadel", + Name: "zitadel", + ID: "zitadel", + Config: map[string]any{ + "issuer": "https://zitadel.example.com", + "clientID": "zit-id", + "clientSecret": "zit-secret", + }, + } + + err := generateConfig(cfg, conn) + require.NoError(t, err) + + // Check backup was created + backupPath := configPath + ".bak" + backupData, err := os.ReadFile(backupPath) + require.NoError(t, err) + assert.Equal(t, originalConfig, string(backupData)) + + // Read and parse the new config + newData, err := os.ReadFile(configPath) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(newData, &result)) + + // IdpManagerConfig should be removed + _, hasOldIdp := result["IdpManagerConfig"] + assert.False(t, hasOldIdp, "IdpManagerConfig should be removed") + + _, hasPKCE := result["PKCEAuthorizationFlow"] + assert.False(t, hasPKCE, "PKCEAuthorizationFlow should be removed") + + // EmbeddedIdP should be present with minimal fields + embeddedIDP, ok := result["EmbeddedIdP"].(map[string]any) + require.True(t, ok, "EmbeddedIdP should be present") + assert.Equal(t, true, embeddedIDP["Enabled"]) + assert.Equal(t, "https://mgmt.example.com/oauth2", embeddedIDP["Issuer"]) + assert.Nil(t, embeddedIDP["LocalAuthDisabled"], "LocalAuthDisabled should not be set") + assert.Nil(t, embeddedIDP["SignKeyRefreshEnabled"], "SignKeyRefreshEnabled should not be set") + assert.Nil(t, embeddedIDP["CLIRedirectURIs"], "CLIRedirectURIs should not be set") + + // Static connector's redirectURI should use the management domain + connectors := embeddedIDP["StaticConnectors"].([]any) + require.Len(t, connectors, 1) + firstConn := connectors[0].(map[string]any) + connCfg := firstConn["config"].(map[string]any) + assert.Equal(t, "https://mgmt.example.com/oauth2/callback", connCfg["redirectURI"], + "redirectURI should be overridden to use the management domain") + + // HttpConfig should only have CertFile and CertKey + httpConfig, ok := result["HttpConfig"].(map[string]any) + require.True(t, ok, "HttpConfig should be present") + assert.Equal(t, "/etc/ssl/cert.pem", httpConfig["CertFile"]) + assert.Equal(t, "/etc/ssl/key.pem", httpConfig["CertKey"]) + assert.Nil(t, httpConfig["AuthIssuer"], "AuthIssuer should be stripped") + + // Datadir should be preserved + assert.Equal(t, "/var/lib/netbird", result["Datadir"]) + }) + + t.Run("dry run does not write files", func(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "management.json") + + originalConfig := `{"HttpConfig": {"CertFile": "", "CertKey": ""}}` + require.NoError(t, os.WriteFile(configPath, []byte(originalConfig), 0o600)) + + cfg := &migrationConfig{ + configPath: configPath, + dashboardURL: "https://mgmt.example.com", + apiURL: "https://mgmt.example.com", + dryRun: true, + } + conn := &dex.Connector{Type: "oidc", Name: "test", ID: "test"} + + err := generateConfig(cfg, conn) + require.NoError(t, err) + + // Original should be unchanged + data, err := os.ReadFile(configPath) + require.NoError(t, err) + assert.Equal(t, originalConfig, string(data)) + + // No backup should exist + _, err = os.Stat(configPath + ".bak") + assert.True(t, os.IsNotExist(err)) + }) +} diff --git a/upload-server/server/local.go b/upload-server/server/local.go index f12c472d2..f7ca50011 100644 --- a/upload-server/server/local.go +++ b/upload-server/server/local.go @@ -7,6 +7,7 @@ import ( "net/url" "os" "path/filepath" + "strings" log "github.com/sirupsen/logrus" @@ -82,15 +83,18 @@ func (l *local) getUploadURL(objectKey string) (string, error) { return newURL.String(), nil } +const maxUploadSize = 150 << 20 + func (l *local) handlePutRequest(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPut { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } + r.Body = http.MaxBytesReader(w, r.Body, maxUploadSize) body, err := io.ReadAll(r.Body) if err != nil { - http.Error(w, fmt.Sprintf("failed to read body: %v", err), http.StatusInternalServerError) + http.Error(w, "request body too large or failed to read", http.StatusRequestEntityTooLarge) return } @@ -105,20 +109,47 @@ func (l *local) handlePutRequest(w http.ResponseWriter, r *http.Request) { return } - dirPath := filepath.Join(l.dir, uploadDir) - err = os.MkdirAll(dirPath, 0750) - if err != nil { + cleanBase := filepath.Clean(l.dir) + string(filepath.Separator) + + dirPath := filepath.Clean(filepath.Join(l.dir, uploadDir)) + if !strings.HasPrefix(dirPath, cleanBase) { + http.Error(w, "invalid path", http.StatusBadRequest) + log.Warnf("Path traversal attempt blocked (dir): %s", dirPath) + return + } + + filePath := filepath.Clean(filepath.Join(dirPath, uploadFile)) + if !strings.HasPrefix(filePath, cleanBase) { + http.Error(w, "invalid path", http.StatusBadRequest) + log.Warnf("Path traversal attempt blocked (file): %s", filePath) + return + } + + if err = os.MkdirAll(dirPath, 0750); err != nil { http.Error(w, "failed to create upload dir", http.StatusInternalServerError) log.Errorf("Failed to create upload dir: %v", err) return } - file := filepath.Join(dirPath, uploadFile) - if err := os.WriteFile(file, body, 0600); err != nil { - http.Error(w, "failed to write file", http.StatusInternalServerError) - log.Errorf("Failed to write file %s: %v", file, err) + flags := os.O_WRONLY | os.O_CREATE | os.O_EXCL + f, err := os.OpenFile(filePath, flags, 0600) + if err != nil { + if os.IsExist(err) { + http.Error(w, "file already exists", http.StatusConflict) + return + } + http.Error(w, "failed to create file", http.StatusInternalServerError) + log.Errorf("Failed to create file %s: %v", filePath, err) return } - log.Infof("Uploading file %s", file) + defer func() { _ = f.Close() }() + + if _, err = f.Write(body); err != nil { + http.Error(w, "failed to write file", http.StatusInternalServerError) + log.Errorf("Failed to write file %s: %v", filePath, err) + return + } + + log.Infof("Uploaded file %s", filePath) w.WriteHeader(http.StatusOK) } diff --git a/upload-server/server/local_test.go b/upload-server/server/local_test.go index bd8a87809..64b8fd228 100644 --- a/upload-server/server/local_test.go +++ b/upload-server/server/local_test.go @@ -63,3 +63,90 @@ func Test_LocalHandlePutRequest(t *testing.T) { require.NoError(t, err) require.Equal(t, fileContent, createdFileContent) } + +func Test_LocalHandlePutRequest_PathTraversal(t *testing.T) { + mockDir := t.TempDir() + mockURL := "http://localhost:8080" + t.Setenv("SERVER_URL", mockURL) + t.Setenv("STORE_DIR", mockDir) + + mux := http.NewServeMux() + err := configureLocalHandlers(mux) + require.NoError(t, err) + + fileContent := []byte("malicious content") + req := httptest.NewRequest(http.MethodPut, putURLPath+"/uploads/%2e%2e%2f%2e%2e%2fetc%2fpasswd", bytes.NewReader(fileContent)) + + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + + _, err = os.Stat(filepath.Join(mockDir, "..", "..", "etc", "passwd")) + require.True(t, os.IsNotExist(err), "traversal file should not exist") +} + +func Test_LocalHandlePutRequest_DirTraversal(t *testing.T) { + mockDir := t.TempDir() + t.Setenv("SERVER_URL", "http://localhost:8080") + t.Setenv("STORE_DIR", mockDir) + + l := &local{url: "http://localhost:8080", dir: mockDir} + + body := bytes.NewReader([]byte("bad")) + req := httptest.NewRequest(http.MethodPut, putURLPath+"/x/evil.txt", body) + req.SetPathValue("dir", "../../../tmp") + req.SetPathValue("file", "evil.txt") + + rec := httptest.NewRecorder() + l.handlePutRequest(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + + _, err := os.Stat(filepath.Join("/tmp", "evil.txt")) + require.True(t, os.IsNotExist(err), "traversal file should not exist outside store dir") +} + +func Test_LocalHandlePutRequest_DuplicateFile(t *testing.T) { + mockDir := t.TempDir() + t.Setenv("SERVER_URL", "http://localhost:8080") + t.Setenv("STORE_DIR", mockDir) + + mux := http.NewServeMux() + err := configureLocalHandlers(mux) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPut, putURLPath+"/dir/dup.txt", bytes.NewReader([]byte("first"))) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + req = httptest.NewRequest(http.MethodPut, putURLPath+"/dir/dup.txt", bytes.NewReader([]byte("second"))) + rec = httptest.NewRecorder() + mux.ServeHTTP(rec, req) + require.Equal(t, http.StatusConflict, rec.Code) + + content, err := os.ReadFile(filepath.Join(mockDir, "dir", "dup.txt")) + require.NoError(t, err) + require.Equal(t, []byte("first"), content) +} + +func Test_LocalHandlePutRequest_BodyTooLarge(t *testing.T) { + mockDir := t.TempDir() + t.Setenv("SERVER_URL", "http://localhost:8080") + t.Setenv("STORE_DIR", mockDir) + + mux := http.NewServeMux() + err := configureLocalHandlers(mux) + require.NoError(t, err) + + largeBody := make([]byte, maxUploadSize+1) + req := httptest.NewRequest(http.MethodPut, putURLPath+"/dir/big.txt", bytes.NewReader(largeBody)) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + require.Equal(t, http.StatusRequestEntityTooLarge, rec.Code) + + _, err = os.Stat(filepath.Join(mockDir, "dir", "big.txt")) + require.True(t, os.IsNotExist(err)) +} diff --git a/util/capture/afpacket_linux.go b/util/capture/afpacket_linux.go new file mode 100644 index 000000000..bf59e806a --- /dev/null +++ b/util/capture/afpacket_linux.go @@ -0,0 +1,199 @@ +package capture + +import ( + "encoding/binary" + "errors" + "fmt" + "net" + "sync" + "sync/atomic" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" +) + +// htons converts a uint16 from host to network (big-endian) byte order. +func htons(v uint16) uint16 { + var buf [2]byte + binary.BigEndian.PutUint16(buf[:], v) + return binary.NativeEndian.Uint16(buf[:]) +} + +// AFPacketCapture reads raw packets from a network interface using an +// AF_PACKET socket. This is the kernel-mode fallback when FilteredDevice is +// not available (kernel WireGuard). Linux only. +// +// It implements device.PacketCapture so it can be set on a Session, but it +// drives its own read loop rather than being called from FilteredDevice. +// Call Start to begin and Stop to end. +type AFPacketCapture struct { + ifaceName string + sess *Session + fd int + mu sync.Mutex + stopped chan struct{} + started atomic.Bool + closed atomic.Bool +} + +// NewAFPacketCapture creates a capture bound to the given interface. +// The session receives packets via Offer. +func NewAFPacketCapture(ifaceName string, sess *Session) *AFPacketCapture { + return &AFPacketCapture{ + ifaceName: ifaceName, + sess: sess, + fd: -1, + stopped: make(chan struct{}), + } +} + +// Start opens the AF_PACKET socket and begins reading packets. +// Packets are fed to the session via Offer. Returns immediately; +// the read loop runs in a goroutine. +func (c *AFPacketCapture) Start() error { + if c.sess == nil { + return errors.New("nil capture session") + } + if !c.started.CompareAndSwap(false, true) { + return errors.New("capture already started") + } + if c.closed.Load() { + c.started.Store(false) + return errors.New("cannot restart stopped capture") + } + + iface, err := net.InterfaceByName(c.ifaceName) + if err != nil { + c.started.Store(false) + return fmt.Errorf("interface %s: %w", c.ifaceName, err) + } + + fd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_DGRAM|unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC, int(htons(unix.ETH_P_ALL))) + if err != nil { + c.started.Store(false) + return fmt.Errorf("create AF_PACKET socket: %w", err) + } + + addr := &unix.SockaddrLinklayer{ + Protocol: htons(unix.ETH_P_ALL), + Ifindex: iface.Index, + } + if err := unix.Bind(fd, addr); err != nil { + unix.Close(fd) + c.started.Store(false) + return fmt.Errorf("bind to %s: %w", c.ifaceName, err) + } + + c.mu.Lock() + c.fd = fd + c.mu.Unlock() + + go c.readLoop(fd) + return nil +} + +// Stop closes the socket and waits for the read loop to exit. Idempotent. +func (c *AFPacketCapture) Stop() { + if !c.closed.CompareAndSwap(false, true) { + if c.started.Load() { + <-c.stopped + } + return + } + + c.mu.Lock() + fd := c.fd + c.fd = -1 + c.mu.Unlock() + + if fd >= 0 { + unix.Close(fd) + } + + if c.started.Load() { + <-c.stopped + } +} + +func (c *AFPacketCapture) readLoop(fd int) { + defer close(c.stopped) + + buf := make([]byte, 65536) + pollFds := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLIN}} + + for { + if c.closed.Load() { + return + } + + ok, err := c.pollOnce(pollFds) + if err != nil { + return + } + if !ok { + continue + } + + c.recvAndOffer(fd, buf) + } +} + +// pollOnce waits for data on the fd. Returns true if data is ready, false for timeout/retry. +// Returns an error to signal the loop should exit. +func (c *AFPacketCapture) pollOnce(pollFds []unix.PollFd) (bool, error) { + n, err := unix.Poll(pollFds, 200) + if err != nil { + if errors.Is(err, unix.EINTR) { + return false, nil + } + if c.closed.Load() { + return false, errors.New("closed") + } + log.Debugf("af_packet poll: %v", err) + return false, err + } + if n == 0 { + return false, nil + } + if pollFds[0].Revents&(unix.POLLERR|unix.POLLHUP|unix.POLLNVAL) != 0 { + return false, errors.New("fd error") + } + return true, nil +} + +func (c *AFPacketCapture) recvAndOffer(fd int, buf []byte) { + nr, from, err := unix.Recvfrom(fd, buf, 0) + if err != nil { + if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EINTR) { + return + } + if !c.closed.Load() { + log.Debugf("af_packet recvfrom: %v", err) + } + return + } + if nr < 1 { + return + } + + ver := buf[0] >> 4 + if ver != 4 && ver != 6 { + return + } + + // The kernel sets Pkttype on AF_PACKET sockets: + // PACKET_HOST(0) = addressed to us (inbound) + // PACKET_OUTGOING(4) = sent by us (outbound) + outbound := false + if sa, ok := from.(*unix.SockaddrLinklayer); ok { + outbound = sa.Pkttype == unix.PACKET_OUTGOING + } + c.sess.Offer(buf[:nr], outbound) +} + +// Offer satisfies device.PacketCapture but is unused: the AFPacketCapture +// drives its own read loop. This exists only so the type signature is +// compatible if someone tries to set it as a PacketCapture. +func (c *AFPacketCapture) Offer([]byte, bool) { + // unused: AFPacketCapture drives its own read loop +} diff --git a/util/capture/afpacket_stub.go b/util/capture/afpacket_stub.go new file mode 100644 index 000000000..bde368e88 --- /dev/null +++ b/util/capture/afpacket_stub.go @@ -0,0 +1,26 @@ +//go:build !linux + +package capture + +import "errors" + +// AFPacketCapture is not available on this platform. +type AFPacketCapture struct{} + +// NewAFPacketCapture returns nil on non-Linux platforms. +func NewAFPacketCapture(string, *Session) *AFPacketCapture { return nil } + +// Start returns an error on non-Linux platforms. +func (c *AFPacketCapture) Start() error { + return errors.New("AF_PACKET capture is only supported on Linux") +} + +// Stop is a no-op on non-Linux platforms. +func (c *AFPacketCapture) Stop() { + // no-op on non-Linux platforms +} + +// Offer is a no-op on non-Linux platforms. +func (c *AFPacketCapture) Offer([]byte, bool) { + // no-op on non-Linux platforms +} diff --git a/util/capture/capture.go b/util/capture/capture.go new file mode 100644 index 000000000..0d92a4311 --- /dev/null +++ b/util/capture/capture.go @@ -0,0 +1,59 @@ +// Package capture provides userspace packet capture in pcap format. +// +// It taps decrypted WireGuard packets flowing through the FilteredDevice and +// writes them as pcap (readable by tcpdump, tshark, Wireshark) or as +// human-readable one-line-per-packet text. +package capture + +import "io" + +// Direction indicates whether a packet is entering or leaving the host. +type Direction uint8 + +const ( + // Inbound is a packet arriving from the network (FilteredDevice.Write path). + Inbound Direction = iota + // Outbound is a packet leaving the host (FilteredDevice.Read path). + Outbound +) + +// String returns "IN" or "OUT". +func (d Direction) String() string { + if d == Outbound { + return "OUT" + } + return "IN" +} + +const ( + protoICMP = 1 + protoTCP = 6 + protoUDP = 17 + protoICMPv6 = 58 +) + +// Options configures a capture session. +type Options struct { + // Output receives pcap-formatted data. Nil disables pcap output. + Output io.Writer + // TextOutput receives human-readable packet summaries. Nil disables text output. + TextOutput io.Writer + // Matcher selects which packets to capture. Nil captures all. + // Use ParseFilter("host 10.0.0.1 and tcp") or &Filter{...}. + Matcher Matcher + // Verbose adds seq/ack, TTL, window, total length to text output. + Verbose bool + // ASCII dumps transport payload as printable ASCII after each packet line. + ASCII bool + // SnapLen is the maximum bytes captured per packet. 0 means 65535. + SnapLen uint32 + // BufSize is the internal channel buffer size. 0 means 256. + BufSize int +} + +// Stats reports capture session counters. +type Stats struct { + Packets int64 + Bytes int64 + Dropped int64 +} diff --git a/util/capture/filter.go b/util/capture/filter.go new file mode 100644 index 000000000..d463450b8 --- /dev/null +++ b/util/capture/filter.go @@ -0,0 +1,528 @@ +package capture + +import ( + "encoding/binary" + "fmt" + "net/netip" + "strconv" + "strings" +) + +// Matcher tests whether a raw packet should be captured. +type Matcher interface { + Match(data []byte) bool +} + +// Filter selects packets by flat AND'd criteria. Useful for structured APIs +// (query params, proto fields). Implements Matcher. +type Filter struct { + SrcIP netip.Addr + DstIP netip.Addr + Host netip.Addr + SrcPort uint16 + DstPort uint16 + Port uint16 + Proto uint8 +} + +// IsEmpty returns true if the filter has no criteria set. +func (f *Filter) IsEmpty() bool { + return !f.SrcIP.IsValid() && !f.DstIP.IsValid() && !f.Host.IsValid() && + f.SrcPort == 0 && f.DstPort == 0 && f.Port == 0 && f.Proto == 0 +} + +// Match implements Matcher. All non-zero fields must match (AND). +func (f *Filter) Match(data []byte) bool { + if f.IsEmpty() { + return true + } + info, ok := parsePacketInfo(data) + if !ok { + return false + } + if f.Host.IsValid() && info.srcIP != f.Host && info.dstIP != f.Host { + return false + } + if f.SrcIP.IsValid() && info.srcIP != f.SrcIP { + return false + } + if f.DstIP.IsValid() && info.dstIP != f.DstIP { + return false + } + if f.Proto != 0 && info.proto != f.Proto { + return false + } + if f.Port != 0 && info.srcPort != f.Port && info.dstPort != f.Port { + return false + } + if f.SrcPort != 0 && info.srcPort != f.SrcPort { + return false + } + if f.DstPort != 0 && info.dstPort != f.DstPort { + return false + } + return true +} + +// exprNode evaluates a filter condition against pre-parsed packet info. +type exprNode func(info *packetInfo) bool + +// exprMatcher wraps an expression tree. Parses the packet once, then walks the tree. +type exprMatcher struct { + root exprNode +} + +func (m *exprMatcher) Match(data []byte) bool { + info, ok := parsePacketInfo(data) + if !ok { + return false + } + return m.root(&info) +} + +func nodeAnd(a, b exprNode) exprNode { + return func(info *packetInfo) bool { return a(info) && b(info) } +} + +func nodeOr(a, b exprNode) exprNode { + return func(info *packetInfo) bool { return a(info) || b(info) } +} + +func nodeNot(n exprNode) exprNode { + return func(info *packetInfo) bool { return !n(info) } +} + +func nodeHost(addr netip.Addr) exprNode { + return func(info *packetInfo) bool { return info.srcIP == addr || info.dstIP == addr } +} + +func nodeSrcHost(addr netip.Addr) exprNode { + return func(info *packetInfo) bool { return info.srcIP == addr } +} + +func nodeDstHost(addr netip.Addr) exprNode { + return func(info *packetInfo) bool { return info.dstIP == addr } +} + +func nodePort(port uint16) exprNode { + return func(info *packetInfo) bool { return info.srcPort == port || info.dstPort == port } +} + +func nodeSrcPort(port uint16) exprNode { + return func(info *packetInfo) bool { return info.srcPort == port } +} + +func nodeDstPort(port uint16) exprNode { + return func(info *packetInfo) bool { return info.dstPort == port } +} + +func nodeProto(proto uint8) exprNode { + return func(info *packetInfo) bool { return info.proto == proto } +} + +func nodeFamily(family uint8) exprNode { + return func(info *packetInfo) bool { return info.family == family } +} + +func nodeNet(prefix netip.Prefix) exprNode { + return func(info *packetInfo) bool { return prefix.Contains(info.srcIP) || prefix.Contains(info.dstIP) } +} + +func nodeSrcNet(prefix netip.Prefix) exprNode { + return func(info *packetInfo) bool { return prefix.Contains(info.srcIP) } +} + +func nodeDstNet(prefix netip.Prefix) exprNode { + return func(info *packetInfo) bool { return prefix.Contains(info.dstIP) } +} + +// packetInfo holds parsed header fields for filtering and display. +type packetInfo struct { + family uint8 + srcIP netip.Addr + dstIP netip.Addr + proto uint8 + srcPort uint16 + dstPort uint16 + hdrLen int +} + +func parsePacketInfo(data []byte) (packetInfo, bool) { + if len(data) < 1 { + return packetInfo{}, false + } + switch data[0] >> 4 { + case 4: + return parseIPv4Info(data) + case 6: + return parseIPv6Info(data) + default: + return packetInfo{}, false + } +} + +func parseIPv4Info(data []byte) (packetInfo, bool) { + if len(data) < 20 { + return packetInfo{}, false + } + ihl := int(data[0]&0x0f) * 4 + if ihl < 20 || len(data) < ihl { + return packetInfo{}, false + } + info := packetInfo{ + family: 4, + srcIP: netip.AddrFrom4([4]byte{data[12], data[13], data[14], data[15]}), + dstIP: netip.AddrFrom4([4]byte{data[16], data[17], data[18], data[19]}), + proto: data[9], + hdrLen: ihl, + } + if (info.proto == protoTCP || info.proto == protoUDP) && len(data) >= ihl+4 { + info.srcPort = binary.BigEndian.Uint16(data[ihl:]) + info.dstPort = binary.BigEndian.Uint16(data[ihl+2:]) + } + return info, true +} + +// parseIPv6Info parses the fixed IPv6 header. It reads the Next Header field +// directly, so packets with extension headers (hop-by-hop, routing, fragment, +// etc.) will report the extension type as the protocol rather than the final +// transport protocol. This is acceptable for a debug capture tool. +func parseIPv6Info(data []byte) (packetInfo, bool) { + if len(data) < 40 { + return packetInfo{}, false + } + var src, dst [16]byte + copy(src[:], data[8:24]) + copy(dst[:], data[24:40]) + info := packetInfo{ + family: 6, + srcIP: netip.AddrFrom16(src), + dstIP: netip.AddrFrom16(dst), + proto: data[6], + hdrLen: 40, + } + if (info.proto == protoTCP || info.proto == protoUDP) && len(data) >= 44 { + info.srcPort = binary.BigEndian.Uint16(data[40:]) + info.dstPort = binary.BigEndian.Uint16(data[42:]) + } + return info, true +} + +// ParseFilter parses a BPF-like filter expression and returns a Matcher. +// Returns nil Matcher for an empty expression (match all). +// +// Grammar (mirrors common tcpdump BPF syntax): +// +// orExpr = andExpr ("or" andExpr)* +// andExpr = unary ("and" unary)* +// unary = "not" unary | "(" orExpr ")" | term +// +// term = "host" IP | "src" target | "dst" target +// | "port" NUM | "net" PREFIX +// | "tcp" | "udp" | "icmp" | "icmp6" +// | "ip" | "ip6" | "proto" NUM +// target = "host" IP | "port" NUM | "net" PREFIX | IP +// +// Examples: +// +// host 10.0.0.1 and tcp port 443 +// not port 22 +// (host 10.0.0.1 or host 10.0.0.2) and tcp +// ip6 and icmp6 +// net 10.0.0.0/24 +// src host 10.0.0.1 or dst port 80 +func ParseFilter(expr string) (Matcher, error) { + tokens := tokenize(expr) + if len(tokens) == 0 { + return nil, nil //nolint:nilnil // nil Matcher means "match all" + } + + p := &parser{tokens: tokens} + node, err := p.parseOr() + if err != nil { + return nil, err + } + if p.pos < len(p.tokens) { + return nil, fmt.Errorf("unexpected token %q at position %d", p.tokens[p.pos], p.pos) + } + return &exprMatcher{root: node}, nil +} + +func tokenize(expr string) []string { + expr = strings.TrimSpace(expr) + if expr == "" { + return nil + } + // Split on whitespace but keep parens as separate tokens. + var tokens []string + for _, field := range strings.Fields(expr) { + tokens = append(tokens, splitParens(field)...) + } + return tokens +} + +// splitParens splits "(foo)" into "(", "foo", ")". +func splitParens(s string) []string { + var out []string + for strings.HasPrefix(s, "(") { + out = append(out, "(") + s = s[1:] + } + var trail []string + for strings.HasSuffix(s, ")") { + trail = append(trail, ")") + s = s[:len(s)-1] + } + if s != "" { + out = append(out, s) + } + out = append(out, trail...) + return out +} + +type parser struct { + tokens []string + pos int +} + +func (p *parser) peek() string { + if p.pos >= len(p.tokens) { + return "" + } + return strings.ToLower(p.tokens[p.pos]) +} + +func (p *parser) next() string { + tok := p.peek() + if tok != "" { + p.pos++ + } + return tok +} + +func (p *parser) expect(tok string) error { + got := p.next() + if got != tok { + return fmt.Errorf("expected %q, got %q", tok, got) + } + return nil +} + +func (p *parser) parseOr() (exprNode, error) { + left, err := p.parseAnd() + if err != nil { + return nil, err + } + for p.peek() == "or" { + p.next() + right, err := p.parseAnd() + if err != nil { + return nil, err + } + left = nodeOr(left, right) + } + return left, nil +} + +func (p *parser) parseAnd() (exprNode, error) { + left, err := p.parseUnary() + if err != nil { + return nil, err + } + for { + tok := p.peek() + if tok == "and" { + p.next() + right, err := p.parseUnary() + if err != nil { + return nil, err + } + left = nodeAnd(left, right) + continue + } + // Implicit AND: two atoms without "and" between them. + // Only if the next token starts an atom (not "or", ")", or EOF). + if tok != "" && tok != "or" && tok != ")" { + right, err := p.parseUnary() + if err != nil { + return nil, err + } + left = nodeAnd(left, right) + continue + } + break + } + return left, nil +} + +func (p *parser) parseUnary() (exprNode, error) { + switch p.peek() { + case "not": + p.next() + inner, err := p.parseUnary() + if err != nil { + return nil, err + } + return nodeNot(inner), nil + case "(": + p.next() + inner, err := p.parseOr() + if err != nil { + return nil, err + } + if err := p.expect(")"); err != nil { + return nil, fmt.Errorf("unclosed parenthesis") + } + return inner, nil + default: + return p.parseAtom() + } +} + +func (p *parser) parseAtom() (exprNode, error) { + tok := p.next() + if tok == "" { + return nil, fmt.Errorf("unexpected end of expression") + } + + switch tok { + case "host": + addr, err := p.parseAddr() + if err != nil { + return nil, fmt.Errorf("host: %w", err) + } + return nodeHost(addr), nil + + case "port": + port, err := p.parsePort() + if err != nil { + return nil, fmt.Errorf("port: %w", err) + } + return nodePort(port), nil + + case "net": + prefix, err := p.parsePrefix() + if err != nil { + return nil, fmt.Errorf("net: %w", err) + } + return nodeNet(prefix), nil + + case "src": + return p.parseDirTarget(true) + + case "dst": + return p.parseDirTarget(false) + + case "tcp": + return nodeProto(protoTCP), nil + case "udp": + return nodeProto(protoUDP), nil + case "icmp": + return nodeProto(protoICMP), nil + case "icmp6": + return nodeProto(protoICMPv6), nil + case "ip": + return nodeFamily(4), nil + case "ip6": + return nodeFamily(6), nil + + case "proto": + raw := p.next() + if raw == "" { + return nil, fmt.Errorf("proto: missing number") + } + n, err := strconv.Atoi(raw) + if err != nil || n < 0 || n > 255 { + return nil, fmt.Errorf("proto: invalid number %q", raw) + } + return nodeProto(uint8(n)), nil + + default: + return nil, fmt.Errorf("unknown filter keyword %q", tok) + } +} + +func (p *parser) parseDirTarget(isSrc bool) (exprNode, error) { + tok := p.peek() + switch tok { + case "host": + p.next() + addr, err := p.parseAddr() + if err != nil { + return nil, err + } + if isSrc { + return nodeSrcHost(addr), nil + } + return nodeDstHost(addr), nil + + case "port": + p.next() + port, err := p.parsePort() + if err != nil { + return nil, err + } + if isSrc { + return nodeSrcPort(port), nil + } + return nodeDstPort(port), nil + + case "net": + p.next() + prefix, err := p.parsePrefix() + if err != nil { + return nil, err + } + if isSrc { + return nodeSrcNet(prefix), nil + } + return nodeDstNet(prefix), nil + + default: + // Try as bare IP: "src 10.0.0.1" + addr, err := p.parseAddr() + if err != nil { + return nil, fmt.Errorf("expected host, port, net, or IP after src/dst, got %q", tok) + } + if isSrc { + return nodeSrcHost(addr), nil + } + return nodeDstHost(addr), nil + } +} + +func (p *parser) parseAddr() (netip.Addr, error) { + raw := p.next() + if raw == "" { + return netip.Addr{}, fmt.Errorf("missing IP address") + } + addr, err := netip.ParseAddr(raw) + if err != nil { + return netip.Addr{}, fmt.Errorf("invalid IP %q", raw) + } + return addr.Unmap(), nil +} + +func (p *parser) parsePort() (uint16, error) { + raw := p.next() + if raw == "" { + return 0, fmt.Errorf("missing port number") + } + n, err := strconv.Atoi(raw) + if err != nil || n < 1 || n > 65535 { + return 0, fmt.Errorf("invalid port %q", raw) + } + return uint16(n), nil +} + +func (p *parser) parsePrefix() (netip.Prefix, error) { + raw := p.next() + if raw == "" { + return netip.Prefix{}, fmt.Errorf("missing network prefix") + } + prefix, err := netip.ParsePrefix(raw) + if err != nil { + return netip.Prefix{}, fmt.Errorf("invalid prefix %q", raw) + } + return prefix, nil +} diff --git a/util/capture/filter_test.go b/util/capture/filter_test.go new file mode 100644 index 000000000..d5fd17566 --- /dev/null +++ b/util/capture/filter_test.go @@ -0,0 +1,263 @@ +package capture + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// buildIPv4Packet creates a minimal IPv4+TCP/UDP packet for filter testing. +func buildIPv4Packet(t *testing.T, srcIP, dstIP netip.Addr, proto uint8, srcPort, dstPort uint16) []byte { + t.Helper() + + hdrLen := 20 + pkt := make([]byte, hdrLen+20) + pkt[0] = 0x45 + pkt[9] = proto + + src := srcIP.As4() + dst := dstIP.As4() + copy(pkt[12:16], src[:]) + copy(pkt[16:20], dst[:]) + + pkt[20] = byte(srcPort >> 8) + pkt[21] = byte(srcPort) + pkt[22] = byte(dstPort >> 8) + pkt[23] = byte(dstPort) + + return pkt +} + +// buildIPv6Packet creates a minimal IPv6+TCP/UDP packet for filter testing. +func buildIPv6Packet(t *testing.T, srcIP, dstIP netip.Addr, proto uint8, srcPort, dstPort uint16) []byte { + t.Helper() + + pkt := make([]byte, 44) // 40 header + 4 ports + pkt[0] = 0x60 // version 6 + pkt[6] = proto // next header + + src := srcIP.As16() + dst := dstIP.As16() + copy(pkt[8:24], src[:]) + copy(pkt[24:40], dst[:]) + + pkt[40] = byte(srcPort >> 8) + pkt[41] = byte(srcPort) + pkt[42] = byte(dstPort >> 8) + pkt[43] = byte(dstPort) + + return pkt +} + +// ---- Filter struct tests ---- + +func TestFilter_Empty(t *testing.T) { + f := Filter{} + assert.True(t, f.IsEmpty()) + assert.True(t, f.Match(buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 443))) +} + +func TestFilter_Host(t *testing.T) { + f := Filter{Host: netip.MustParseAddr("10.0.0.1")} + assert.True(t, f.Match(buildIPv4Packet(t, netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2"), protoTCP, 1234, 80))) + assert.True(t, f.Match(buildIPv4Packet(t, netip.MustParseAddr("10.0.0.2"), netip.MustParseAddr("10.0.0.1"), protoTCP, 1234, 80))) + assert.False(t, f.Match(buildIPv4Packet(t, netip.MustParseAddr("10.0.0.2"), netip.MustParseAddr("10.0.0.3"), protoTCP, 1234, 80))) +} + +func TestFilter_InvalidPacket(t *testing.T) { + f := Filter{Host: netip.MustParseAddr("10.0.0.1")} + assert.False(t, f.Match(nil)) + assert.False(t, f.Match([]byte{})) + assert.False(t, f.Match([]byte{0x00})) +} + +func TestParsePacketInfo_IPv4(t *testing.T) { + pkt := buildIPv4Packet(t, netip.MustParseAddr("192.168.1.1"), netip.MustParseAddr("10.0.0.1"), protoTCP, 54321, 80) + info, ok := parsePacketInfo(pkt) + require.True(t, ok) + assert.Equal(t, uint8(4), info.family) + assert.Equal(t, netip.MustParseAddr("192.168.1.1"), info.srcIP) + assert.Equal(t, netip.MustParseAddr("10.0.0.1"), info.dstIP) + assert.Equal(t, uint8(protoTCP), info.proto) + assert.Equal(t, uint16(54321), info.srcPort) + assert.Equal(t, uint16(80), info.dstPort) +} + +func TestParsePacketInfo_IPv6(t *testing.T) { + pkt := buildIPv6Packet(t, netip.MustParseAddr("fd00::1"), netip.MustParseAddr("fd00::2"), protoUDP, 1234, 53) + info, ok := parsePacketInfo(pkt) + require.True(t, ok) + assert.Equal(t, uint8(6), info.family) + assert.Equal(t, netip.MustParseAddr("fd00::1"), info.srcIP) + assert.Equal(t, netip.MustParseAddr("fd00::2"), info.dstIP) + assert.Equal(t, uint8(protoUDP), info.proto) + assert.Equal(t, uint16(1234), info.srcPort) + assert.Equal(t, uint16(53), info.dstPort) +} + +// ---- ParseFilter expression tests ---- + +func matchV4(t *testing.T, m Matcher, srcIP, dstIP string, proto uint8, srcPort, dstPort uint16) bool { + t.Helper() + return m.Match(buildIPv4Packet(t, netip.MustParseAddr(srcIP), netip.MustParseAddr(dstIP), proto, srcPort, dstPort)) +} + +func matchV6(t *testing.T, m Matcher, srcIP, dstIP string, proto uint8, srcPort, dstPort uint16) bool { + t.Helper() + return m.Match(buildIPv6Packet(t, netip.MustParseAddr(srcIP), netip.MustParseAddr(dstIP), proto, srcPort, dstPort)) +} + +func TestParseFilter_Empty(t *testing.T) { + m, err := ParseFilter("") + require.NoError(t, err) + assert.Nil(t, m, "empty expression should return nil matcher") +} + +func TestParseFilter_Atoms(t *testing.T) { + tests := []struct { + expr string + match bool + }{ + {"tcp", true}, + {"udp", false}, + {"host 10.0.0.1", true}, + {"host 10.0.0.99", false}, + {"port 443", true}, + {"port 80", false}, + {"src host 10.0.0.1", true}, + {"dst host 10.0.0.2", true}, + {"dst host 10.0.0.1", false}, + {"src port 12345", true}, + {"dst port 443", true}, + {"dst port 80", false}, + {"proto 6", true}, + {"proto 17", false}, + } + + pkt := buildIPv4Packet(t, netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2"), protoTCP, 12345, 443) + + for _, tt := range tests { + t.Run(tt.expr, func(t *testing.T) { + m, err := ParseFilter(tt.expr) + require.NoError(t, err) + assert.Equal(t, tt.match, m.Match(pkt)) + }) + } +} + +func TestParseFilter_And(t *testing.T) { + m, err := ParseFilter("host 10.0.0.1 and tcp port 443") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 55555, 443)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoUDP, 55555, 443), "wrong proto") + assert.False(t, matchV4(t, m, "10.0.0.3", "10.0.0.2", protoTCP, 55555, 443), "wrong host") + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 55555, 80), "wrong port") +} + +func TestParseFilter_ImplicitAnd(t *testing.T) { + // "tcp port 443" = implicit AND between tcp and port 443 + m, err := ParseFilter("tcp port 443") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 443)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoUDP, 1, 443)) +} + +func TestParseFilter_Or(t *testing.T) { + m, err := ParseFilter("port 80 or port 443") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 80)) + assert.True(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 443)) + assert.False(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 8080)) +} + +func TestParseFilter_Not(t *testing.T) { + m, err := ParseFilter("not port 22") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 443)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 22)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 22, 80)) +} + +func TestParseFilter_Parens(t *testing.T) { + m, err := ParseFilter("(port 80 or port 443) and tcp") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 443)) + assert.False(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoUDP, 1, 443), "wrong proto") + assert.False(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 8080), "wrong port") +} + +func TestParseFilter_Family(t *testing.T) { + mV4, err := ParseFilter("ip") + require.NoError(t, err) + assert.True(t, matchV4(t, mV4, "10.0.0.1", "10.0.0.2", protoTCP, 1, 80)) + assert.False(t, matchV6(t, mV4, "fd00::1", "fd00::2", protoTCP, 1, 80)) + + mV6, err := ParseFilter("ip6") + require.NoError(t, err) + assert.False(t, matchV4(t, mV6, "10.0.0.1", "10.0.0.2", protoTCP, 1, 80)) + assert.True(t, matchV6(t, mV6, "fd00::1", "fd00::2", protoTCP, 1, 80)) +} + +func TestParseFilter_Net(t *testing.T) { + m, err := ParseFilter("net 10.0.0.0/24") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "192.168.1.1", protoTCP, 1, 80), "src in net") + assert.True(t, matchV4(t, m, "192.168.1.1", "10.0.0.200", protoTCP, 1, 80), "dst in net") + assert.False(t, matchV4(t, m, "10.0.1.1", "192.168.1.1", protoTCP, 1, 80), "neither in net") +} + +func TestParseFilter_SrcDstNet(t *testing.T) { + m, err := ParseFilter("src net 10.0.0.0/8 and dst net 192.168.0.0/16") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.1.2.3", "192.168.1.1", protoTCP, 1, 80)) + assert.False(t, matchV4(t, m, "192.168.1.1", "10.1.2.3", protoTCP, 1, 80), "reversed") +} + +func TestParseFilter_Complex(t *testing.T) { + // Real-world: capture HTTP(S) traffic to/from specific host, excluding SSH + m, err := ParseFilter("host 10.0.0.1 and (port 80 or port 443) and not port 22") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 55555, 443)) + assert.True(t, matchV4(t, m, "10.0.0.2", "10.0.0.1", protoTCP, 55555, 80)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 22, 443), "port 22 excluded") + assert.False(t, matchV4(t, m, "10.0.0.3", "10.0.0.2", protoTCP, 55555, 443), "wrong host") +} + +func TestParseFilter_IPv6Combined(t *testing.T) { + m, err := ParseFilter("ip6 and icmp6") + require.NoError(t, err) + assert.True(t, matchV6(t, m, "fd00::1", "fd00::2", protoICMPv6, 0, 0)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoICMP, 0, 0), "wrong family") + assert.False(t, matchV6(t, m, "fd00::1", "fd00::2", protoTCP, 1, 80), "wrong proto") +} + +func TestParseFilter_CaseInsensitive(t *testing.T) { + m, err := ParseFilter("HOST 10.0.0.1 AND TCP PORT 443") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 443)) +} + +func TestParseFilter_Errors(t *testing.T) { + bad := []string{ + "badkeyword", + "host", + "port abc", + "port 99999", + "net invalid", + "(", + "(port 80", + "not", + "src", + } + for _, expr := range bad { + t.Run(expr, func(t *testing.T) { + _, err := ParseFilter(expr) + assert.Error(t, err, "should fail for %q", expr) + }) + } +} diff --git a/util/capture/pcap.go b/util/capture/pcap.go new file mode 100644 index 000000000..0a9057045 --- /dev/null +++ b/util/capture/pcap.go @@ -0,0 +1,85 @@ +package capture + +import ( + "encoding/binary" + "io" + "time" +) + +const ( + pcapMagic = 0xa1b2c3d4 + pcapVersionMaj = 2 + pcapVersionMin = 4 + // linkTypeRaw is LINKTYPE_RAW: raw IPv4/IPv6 packets without link-layer header. + linkTypeRaw = 101 + defaultSnapLen = 65535 +) + +// PcapWriter writes packets in pcap format to an underlying writer. +// The global header is written lazily on the first WritePacket call so that +// the writer can be used with unbuffered io.Pipes without deadlocking. +// It is not safe for concurrent use; callers must serialize access. +type PcapWriter struct { + w io.Writer + snapLen uint32 + headerWritten bool +} + +// NewPcapWriter creates a pcap writer. The global header is deferred until the +// first WritePacket call. +func NewPcapWriter(w io.Writer, snapLen uint32) *PcapWriter { + if snapLen == 0 { + snapLen = defaultSnapLen + } + return &PcapWriter{w: w, snapLen: snapLen} +} + +// writeGlobalHeader writes the 24-byte pcap file header. +func (pw *PcapWriter) writeGlobalHeader() error { + var hdr [24]byte + binary.LittleEndian.PutUint32(hdr[0:4], pcapMagic) + binary.LittleEndian.PutUint16(hdr[4:6], pcapVersionMaj) + binary.LittleEndian.PutUint16(hdr[6:8], pcapVersionMin) + binary.LittleEndian.PutUint32(hdr[16:20], pw.snapLen) + binary.LittleEndian.PutUint32(hdr[20:24], linkTypeRaw) + + _, err := pw.w.Write(hdr[:]) + return err +} + +// WriteHeader writes the pcap global header. Safe to call multiple times. +func (pw *PcapWriter) WriteHeader() error { + if pw.headerWritten { + return nil + } + if err := pw.writeGlobalHeader(); err != nil { + return err + } + pw.headerWritten = true + return nil +} + +// WritePacket writes a single packet record, preceded by the global header +// on the first call. +func (pw *PcapWriter) WritePacket(ts time.Time, data []byte) error { + if err := pw.WriteHeader(); err != nil { + return err + } + + origLen := uint32(len(data)) + if origLen > pw.snapLen { + data = data[:pw.snapLen] + } + + var hdr [16]byte + binary.LittleEndian.PutUint32(hdr[0:4], uint32(ts.Unix())) + binary.LittleEndian.PutUint32(hdr[4:8], uint32(ts.Nanosecond()/1000)) + binary.LittleEndian.PutUint32(hdr[8:12], uint32(len(data))) + binary.LittleEndian.PutUint32(hdr[12:16], origLen) + + if _, err := pw.w.Write(hdr[:]); err != nil { + return err + } + _, err := pw.w.Write(data) + return err +} diff --git a/util/capture/pcap_test.go b/util/capture/pcap_test.go new file mode 100644 index 000000000..c3d21ef4a --- /dev/null +++ b/util/capture/pcap_test.go @@ -0,0 +1,68 @@ +package capture + +import ( + "bytes" + "encoding/binary" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPcapWriter_GlobalHeader(t *testing.T) { + var buf bytes.Buffer + pw := NewPcapWriter(&buf, 0) + + // Header is lazy, so write a dummy packet to trigger it. + err := pw.WritePacket(time.Now(), []byte{0x45, 0, 0, 20, 0, 0, 0, 0, 64, 1, 0, 0, 10, 0, 0, 1, 10, 0, 0, 2}) + require.NoError(t, err) + + data := buf.Bytes() + require.GreaterOrEqual(t, len(data), 24, "should contain global header") + + assert.Equal(t, uint32(pcapMagic), binary.LittleEndian.Uint32(data[0:4]), "magic number") + assert.Equal(t, uint16(pcapVersionMaj), binary.LittleEndian.Uint16(data[4:6]), "version major") + assert.Equal(t, uint16(pcapVersionMin), binary.LittleEndian.Uint16(data[6:8]), "version minor") + assert.Equal(t, uint32(defaultSnapLen), binary.LittleEndian.Uint32(data[16:20]), "snap length") + assert.Equal(t, uint32(linkTypeRaw), binary.LittleEndian.Uint32(data[20:24]), "link type") +} + +func TestPcapWriter_WritePacket(t *testing.T) { + var buf bytes.Buffer + pw := NewPcapWriter(&buf, 100) + + ts := time.Date(2025, 6, 15, 12, 30, 45, 123456000, time.UTC) + payload := make([]byte, 50) + for i := range payload { + payload[i] = byte(i) + } + + err := pw.WritePacket(ts, payload) + require.NoError(t, err) + + data := buf.Bytes()[24:] // skip global header + require.Len(t, data, 16+50, "packet header + payload") + + assert.Equal(t, uint32(ts.Unix()), binary.LittleEndian.Uint32(data[0:4]), "timestamp seconds") + assert.Equal(t, uint32(123456), binary.LittleEndian.Uint32(data[4:8]), "timestamp microseconds") + assert.Equal(t, uint32(50), binary.LittleEndian.Uint32(data[8:12]), "included length") + assert.Equal(t, uint32(50), binary.LittleEndian.Uint32(data[12:16]), "original length") + assert.Equal(t, payload, data[16:], "packet data") +} + +func TestPcapWriter_SnapLen(t *testing.T) { + var buf bytes.Buffer + pw := NewPcapWriter(&buf, 10) + + ts := time.Now() + payload := make([]byte, 50) + + err := pw.WritePacket(ts, payload) + require.NoError(t, err) + + data := buf.Bytes()[24:] + assert.Equal(t, uint32(10), binary.LittleEndian.Uint32(data[8:12]), "included length should be truncated") + assert.Equal(t, uint32(50), binary.LittleEndian.Uint32(data[12:16]), "original length preserved") + assert.Len(t, data[16:], 10, "only snap_len bytes written") +} diff --git a/util/capture/session.go b/util/capture/session.go new file mode 100644 index 000000000..09806e10c --- /dev/null +++ b/util/capture/session.go @@ -0,0 +1,213 @@ +package capture + +import ( + "fmt" + "sync" + "sync/atomic" + "time" +) + +const defaultBufSize = 256 + +type packetEntry struct { + ts time.Time + data []byte + dir Direction +} + +// Session manages an active packet capture. Packets are offered via Offer, +// buffered in a channel, and written to configured sinks by a background +// goroutine. This keeps the hot path (FilteredDevice.Read/Write) non-blocking. +// +// The caller must call Stop when done to flush remaining packets and release +// resources. +type Session struct { + pcapW *PcapWriter + textW *TextWriter + matcher Matcher + snapLen uint32 + flushFn func() + + ch chan packetEntry + done chan struct{} + stopped chan struct{} + + closeOnce sync.Once + closed atomic.Bool + packets atomic.Int64 + bytes atomic.Int64 + dropped atomic.Int64 + started time.Time +} + +// NewSession creates and starts a capture session. At least one of +// Options.Output or Options.TextOutput must be non-nil. +func NewSession(opts Options) (*Session, error) { + if opts.Output == nil && opts.TextOutput == nil { + return nil, fmt.Errorf("at least one output sink required") + } + + snapLen := opts.SnapLen + if snapLen == 0 { + snapLen = defaultSnapLen + } + + bufSize := opts.BufSize + if bufSize <= 0 { + bufSize = defaultBufSize + } + + s := &Session{ + matcher: opts.Matcher, + snapLen: snapLen, + ch: make(chan packetEntry, bufSize), + done: make(chan struct{}), + stopped: make(chan struct{}), + started: time.Now(), + } + + if opts.Output != nil { + s.pcapW = NewPcapWriter(opts.Output, snapLen) + } + if opts.TextOutput != nil { + s.textW = NewTextWriter(opts.TextOutput, opts.Verbose, opts.ASCII) + } + + s.flushFn = buildFlushFn(opts.Output, opts.TextOutput) + + go s.run() + return s, nil +} + +// Offer submits a packet for capture. It returns immediately and never blocks +// the caller. If the internal buffer is full the packet is dropped silently. +// +// outbound should be true for packets leaving the host (FilteredDevice.Read +// path) and false for packets arriving (FilteredDevice.Write path). +// +// Offer satisfies the device.PacketCapture interface. +func (s *Session) Offer(data []byte, outbound bool) { + if s.closed.Load() { + return + } + + if s.matcher != nil && !s.matcher.Match(data) { + return + } + + captureLen := len(data) + if s.snapLen > 0 && uint32(captureLen) > s.snapLen { + captureLen = int(s.snapLen) + } + + copied := make([]byte, captureLen) + copy(copied, data) + + dir := Inbound + if outbound { + dir = Outbound + } + + select { + case s.ch <- packetEntry{ts: time.Now(), data: copied, dir: dir}: + s.packets.Add(1) + s.bytes.Add(int64(len(data))) + default: + s.dropped.Add(1) + } +} + +// Stop signals the session to stop accepting packets, drains any buffered +// packets to the sinks, and waits for the writer goroutine to exit. +// It is safe to call multiple times. +func (s *Session) Stop() { + s.closeOnce.Do(func() { + s.closed.Store(true) + close(s.done) + }) + <-s.stopped +} + +// Done returns a channel that is closed when the session's writer goroutine +// has fully exited and all buffered packets have been flushed. +func (s *Session) Done() <-chan struct{} { + return s.stopped +} + +// Stats returns current capture counters. +func (s *Session) Stats() Stats { + return Stats{ + Packets: s.packets.Load(), + Bytes: s.bytes.Load(), + Dropped: s.dropped.Load(), + } +} + +func (s *Session) run() { + defer close(s.stopped) + + for { + select { + case pkt := <-s.ch: + s.write(pkt) + case <-s.done: + s.drain() + return + } + } +} + +func (s *Session) drain() { + for { + select { + case pkt := <-s.ch: + s.write(pkt) + default: + return + } + } +} + +func (s *Session) write(pkt packetEntry) { + if s.pcapW != nil { + // Best-effort: if the writer fails (broken pipe etc.), discard silently. + _ = s.pcapW.WritePacket(pkt.ts, pkt.data) + } + if s.textW != nil { + _ = s.textW.WritePacket(pkt.ts, pkt.data, pkt.dir) + } + s.flushFn() +} + +// buildFlushFn returns a function that flushes all writers that support it. +// This covers http.Flusher and similar streaming writers. +func buildFlushFn(writers ...any) func() { + type flusher interface { + Flush() + } + + var fns []func() + for _, w := range writers { + if w == nil { + continue + } + if f, ok := w.(flusher); ok { + fns = append(fns, f.Flush) + } + } + + switch len(fns) { + case 0: + return func() { + // no writers to flush + } + case 1: + return fns[0] + default: + return func() { + for _, fn := range fns { + fn() + } + } + } +} diff --git a/util/capture/session_test.go b/util/capture/session_test.go new file mode 100644 index 000000000..ab27686c6 --- /dev/null +++ b/util/capture/session_test.go @@ -0,0 +1,144 @@ +package capture + +import ( + "bytes" + "encoding/binary" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSession_PcapOutput(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{ + Output: &buf, + BufSize: 16, + }) + require.NoError(t, err) + + pkt := buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 443) + + sess.Offer(pkt, true) + sess.Stop() + + data := buf.Bytes() + require.Greater(t, len(data), 24, "should have global header + at least one packet") + + // Verify global header + assert.Equal(t, uint32(pcapMagic), binary.LittleEndian.Uint32(data[0:4])) + assert.Equal(t, uint32(linkTypeRaw), binary.LittleEndian.Uint32(data[20:24])) + + // Verify packet record + pktData := data[24:] + inclLen := binary.LittleEndian.Uint32(pktData[8:12]) + assert.Equal(t, uint32(len(pkt)), inclLen) + + stats := sess.Stats() + assert.Equal(t, int64(1), stats.Packets) + assert.Equal(t, int64(len(pkt)), stats.Bytes) + assert.Equal(t, int64(0), stats.Dropped) +} + +func TestSession_TextOutput(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{ + TextOutput: &buf, + BufSize: 16, + }) + require.NoError(t, err) + + pkt := buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 443) + + sess.Offer(pkt, false) + sess.Stop() + + output := buf.String() + assert.Contains(t, output, "TCP") + assert.Contains(t, output, "10.0.0.1") + assert.Contains(t, output, "10.0.0.2") + assert.Contains(t, output, "443") + assert.Contains(t, output, "[IN TCP]") +} + +func TestSession_Filter(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{ + Output: &buf, + Matcher: &Filter{Port: 443}, + }) + require.NoError(t, err) + + pktMatch := buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 443) + pktNoMatch := buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 80) + + sess.Offer(pktMatch, true) + sess.Offer(pktNoMatch, true) + sess.Stop() + + stats := sess.Stats() + assert.Equal(t, int64(1), stats.Packets, "only matching packet should be captured") +} + +func TestSession_StopIdempotent(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{Output: &buf}) + require.NoError(t, err) + + sess.Stop() + sess.Stop() // should not panic or deadlock +} + +func TestSession_OfferAfterStop(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{Output: &buf}) + require.NoError(t, err) + sess.Stop() + + pkt := buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 443) + sess.Offer(pkt, true) // should not panic + + assert.Equal(t, int64(0), sess.Stats().Packets) +} + +func TestSession_Done(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{Output: &buf}) + require.NoError(t, err) + + select { + case <-sess.Done(): + t.Fatal("Done should not be closed before Stop") + default: + } + + sess.Stop() + + select { + case <-sess.Done(): + case <-time.After(time.Second): + t.Fatal("Done should be closed after Stop") + } +} + +func TestSession_RequiresOutput(t *testing.T) { + _, err := NewSession(Options{}) + assert.Error(t, err) +} diff --git a/util/capture/text.go b/util/capture/text.go new file mode 100644 index 000000000..b44bd0cad --- /dev/null +++ b/util/capture/text.go @@ -0,0 +1,638 @@ +package capture + +import ( + "encoding/binary" + "fmt" + "io" + "net/netip" + "strings" + "time" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" +) + +// TextWriter writes human-readable one-line-per-packet summaries. +// It is not safe for concurrent use; callers must serialize access. +type TextWriter struct { + w io.Writer + verbose bool + ascii bool + flows map[dirKey]uint32 +} + +type dirKey struct { + src netip.AddrPort + dst netip.AddrPort +} + +// NewTextWriter creates a text formatter that writes to w. +func NewTextWriter(w io.Writer, verbose, ascii bool) *TextWriter { + return &TextWriter{ + w: w, + verbose: verbose, + ascii: ascii, + flows: make(map[dirKey]uint32), + } +} + +// tag formats the fixed-width "[DIR PROTO]" prefix with right-aligned protocol. +func tag(dir Direction, proto string) string { + return fmt.Sprintf("[%-3s %4s]", dir, proto) +} + +// WritePacket formats and writes a single packet line. +func (tw *TextWriter) WritePacket(ts time.Time, data []byte, dir Direction) error { + ts = ts.Local() + info, ok := parsePacketInfo(data) + if !ok { + _, err := fmt.Fprintf(tw.w, "%s [%-3s ?] ??? len=%d\n", + ts.Format("15:04:05.000000"), dir, len(data)) + return err + } + + timeStr := ts.Format("15:04:05.000000") + + var err error + switch info.proto { + case protoTCP: + err = tw.writeTCP(timeStr, dir, &info, data) + case protoUDP: + err = tw.writeUDP(timeStr, dir, &info, data) + case protoICMP: + err = tw.writeICMPv4(timeStr, dir, &info, data) + case protoICMPv6: + err = tw.writeICMPv6(timeStr, dir, &info, data) + default: + var verbose string + if tw.verbose { + verbose = tw.verboseIP(data, info.family) + } + _, err = fmt.Fprintf(tw.w, "%s %s %s > %s length %d%s\n", + timeStr, tag(dir, fmt.Sprintf("P%d", info.proto)), + info.srcIP, info.dstIP, len(data)-info.hdrLen, verbose) + } + return err +} + +func (tw *TextWriter) writeTCP(timeStr string, dir Direction, info *packetInfo, data []byte) error { + tcp := &layers.TCP{} + if err := tcp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil { + return tw.writeFallback(timeStr, dir, "TCP", info, data) + } + + flags := tcpFlagsStr(tcp) + plen := len(tcp.Payload) + + // Protocol annotation + var annotation string + if plen > 0 { + annotation = annotatePayload(tcp.Payload) + } + + if !tw.verbose { + _, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d [%s] length %d%s\n", + timeStr, tag(dir, "TCP"), + info.srcIP, info.srcPort, info.dstIP, info.dstPort, + flags, plen, annotation) + if err != nil { + return err + } + if tw.ascii && plen > 0 { + return tw.writeASCII(tcp.Payload) + } + return nil + } + + relSeq, relAck := tw.relativeSeqAck(info, tcp.Seq, tcp.Ack) + + var seqStr string + if plen > 0 { + seqStr = fmt.Sprintf(", seq %d:%d", relSeq, relSeq+uint32(plen)) + } else { + seqStr = fmt.Sprintf(", seq %d", relSeq) + } + + var ackStr string + if tcp.ACK { + ackStr = fmt.Sprintf(", ack %d", relAck) + } + + var opts string + if s := formatTCPOptions(tcp.Options); s != "" { + opts = ", options [" + s + "]" + } + + verbose := tw.verboseIP(data, info.family) + + _, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d [%s]%s%s, win %d%s, length %d%s%s\n", + timeStr, tag(dir, "TCP"), + info.srcIP, info.srcPort, info.dstIP, info.dstPort, + flags, seqStr, ackStr, tcp.Window, opts, plen, annotation, verbose) + if err != nil { + return err + } + if tw.ascii && plen > 0 { + return tw.writeASCII(tcp.Payload) + } + return nil +} + +func (tw *TextWriter) writeUDP(timeStr string, dir Direction, info *packetInfo, data []byte) error { + udp := &layers.UDP{} + if err := udp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil { + return tw.writeFallback(timeStr, dir, "UDP", info, data) + } + + plen := len(udp.Payload) + + // DNS replaces the entire line format + if plen > 0 && isDNSPort(info.srcPort, info.dstPort) { + if s := formatDNSPayload(udp.Payload); s != "" { + var verbose string + if tw.verbose { + verbose = tw.verboseIP(data, info.family) + } + _, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d %s%s\n", + timeStr, tag(dir, "UDP"), + info.srcIP, info.srcPort, info.dstIP, info.dstPort, + s, verbose) + return err + } + } + + var verbose string + if tw.verbose { + verbose = tw.verboseIP(data, info.family) + } + _, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d length %d%s\n", + timeStr, tag(dir, "UDP"), + info.srcIP, info.srcPort, info.dstIP, info.dstPort, + plen, verbose) + if err != nil { + return err + } + if tw.ascii && plen > 0 { + return tw.writeASCII(udp.Payload) + } + return nil +} + +func (tw *TextWriter) writeICMPv4(timeStr string, dir Direction, info *packetInfo, data []byte) error { + icmp := &layers.ICMPv4{} + if err := icmp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil { + return tw.writeFallback(timeStr, dir, "ICMP", info, data) + } + + var detail string + if icmp.TypeCode.Type() == layers.ICMPv4TypeEchoRequest || icmp.TypeCode.Type() == layers.ICMPv4TypeEchoReply { + detail = fmt.Sprintf("%s, id %d, seq %d", icmp.TypeCode.String(), icmp.Id, icmp.Seq) + } else { + detail = icmp.TypeCode.String() + } + + var verbose string + if tw.verbose { + verbose = tw.verboseIP(data, info.family) + } + _, err := fmt.Fprintf(tw.w, "%s %s %s > %s %s, length %d%s\n", + timeStr, tag(dir, "ICMP"), info.srcIP, info.dstIP, detail, len(data)-info.hdrLen, verbose) + return err +} + +func (tw *TextWriter) writeICMPv6(timeStr string, dir Direction, info *packetInfo, data []byte) error { + icmp := &layers.ICMPv6{} + if err := icmp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil { + return tw.writeFallback(timeStr, dir, "ICMP", info, data) + } + + var verbose string + if tw.verbose { + verbose = tw.verboseIP(data, info.family) + } + _, err := fmt.Fprintf(tw.w, "%s %s %s > %s %s, length %d%s\n", + timeStr, tag(dir, "ICMP"), info.srcIP, info.dstIP, icmp.TypeCode.String(), len(data)-info.hdrLen, verbose) + return err +} + +func (tw *TextWriter) writeFallback(timeStr string, dir Direction, proto string, info *packetInfo, data []byte) error { + _, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d length %d\n", + timeStr, tag(dir, proto), + info.srcIP, info.srcPort, info.dstIP, info.dstPort, + len(data)-info.hdrLen) + return err +} + +func (tw *TextWriter) verboseIP(data []byte, family uint8) string { + return fmt.Sprintf(", ttl %d, id %d, iplen %d", + ipTTL(data, family), ipID(data, family), len(data)) +} + +// relativeSeqAck returns seq/ack relative to the first seen value per direction. +func (tw *TextWriter) relativeSeqAck(info *packetInfo, seq, ack uint32) (relSeq, relAck uint32) { + fwd := dirKey{ + src: netip.AddrPortFrom(info.srcIP, info.srcPort), + dst: netip.AddrPortFrom(info.dstIP, info.dstPort), + } + rev := dirKey{ + src: netip.AddrPortFrom(info.dstIP, info.dstPort), + dst: netip.AddrPortFrom(info.srcIP, info.srcPort), + } + + if isn, ok := tw.flows[fwd]; ok { + relSeq = seq - isn + } else { + tw.flows[fwd] = seq + } + + if isn, ok := tw.flows[rev]; ok { + relAck = ack - isn + } else { + relAck = ack + } + + return relSeq, relAck +} + +// writeASCII prints payload bytes as printable ASCII. +func (tw *TextWriter) writeASCII(payload []byte) error { + if len(payload) == 0 { + return nil + } + buf := make([]byte, len(payload)) + for i, b := range payload { + switch { + case b >= 0x20 && b < 0x7f: + buf[i] = b + case b == '\n' || b == '\r' || b == '\t': + buf[i] = b + default: + buf[i] = '.' + } + } + _, err := fmt.Fprintf(tw.w, "%s\n", buf) + return err +} + +// --- TCP helpers --- + +func ipTTL(data []byte, family uint8) uint8 { + if family == 4 && len(data) > 8 { + return data[8] + } + if family == 6 && len(data) > 7 { + return data[7] + } + return 0 +} + +func ipID(data []byte, family uint8) uint16 { + if family == 4 && len(data) >= 6 { + return binary.BigEndian.Uint16(data[4:6]) + } + return 0 +} + +func tcpFlagsStr(tcp *layers.TCP) string { + var buf [6]byte + n := 0 + if tcp.SYN { + buf[n] = 'S' + n++ + } + if tcp.FIN { + buf[n] = 'F' + n++ + } + if tcp.RST { + buf[n] = 'R' + n++ + } + if tcp.PSH { + buf[n] = 'P' + n++ + } + if tcp.ACK { + buf[n] = '.' + n++ + } + if tcp.URG { + buf[n] = 'U' + n++ + } + if n == 0 { + return "none" + } + return string(buf[:n]) +} + +func formatTCPOptions(opts []layers.TCPOption) string { + var parts []string + for _, opt := range opts { + switch opt.OptionType { + case layers.TCPOptionKindEndList: + return strings.Join(parts, ",") + case layers.TCPOptionKindNop: + parts = append(parts, "nop") + case layers.TCPOptionKindMSS: + if len(opt.OptionData) == 2 { + parts = append(parts, fmt.Sprintf("mss %d", binary.BigEndian.Uint16(opt.OptionData))) + } + case layers.TCPOptionKindWindowScale: + if len(opt.OptionData) == 1 { + parts = append(parts, fmt.Sprintf("wscale %d", opt.OptionData[0])) + } + case layers.TCPOptionKindSACKPermitted: + parts = append(parts, "sackOK") + case layers.TCPOptionKindSACK: + blocks := len(opt.OptionData) / 8 + parts = append(parts, fmt.Sprintf("sack %d", blocks)) + case layers.TCPOptionKindTimestamps: + if len(opt.OptionData) == 8 { + tsval := binary.BigEndian.Uint32(opt.OptionData[0:4]) + tsecr := binary.BigEndian.Uint32(opt.OptionData[4:8]) + parts = append(parts, fmt.Sprintf("TS val %d ecr %d", tsval, tsecr)) + } + } + } + return strings.Join(parts, ",") +} + +// --- Protocol annotation --- + +// annotatePayload returns a protocol annotation string for known application protocols. +func annotatePayload(payload []byte) string { + if len(payload) < 4 { + return "" + } + + s := string(payload) + + // SSH banner: "SSH-2.0-OpenSSH_9.6\r\n" + if strings.HasPrefix(s, "SSH-") { + if end := strings.IndexByte(s, '\r'); end > 0 && end < 256 { + return ": " + s[:end] + } + } + + // TLS records + if ann := annotateTLS(payload); ann != "" { + return ": " + ann + } + + // HTTP request or response + for _, method := range [...]string{"GET ", "POST ", "PUT ", "DELETE ", "HEAD ", "PATCH ", "OPTIONS ", "CONNECT "} { + if strings.HasPrefix(s, method) { + if end := strings.IndexByte(s, '\r'); end > 0 && end < 200 { + return ": " + s[:end] + } + } + } + if strings.HasPrefix(s, "HTTP/") { + if end := strings.IndexByte(s, '\r'); end > 0 && end < 200 { + return ": " + s[:end] + } + } + + return "" +} + +// annotateTLS returns a description for TLS handshake and alert records. +func annotateTLS(data []byte) string { + if len(data) < 6 { + return "" + } + + switch data[0] { + case 0x16: + return annotateTLSHandshake(data) + case 0x15: + return annotateTLSAlert(data) + } + return "" +} + +func annotateTLSHandshake(data []byte) string { + if len(data) < 10 { + return "" + } + switch data[5] { + case 0x01: + if sni := extractSNI(data); sni != "" { + return "TLS ClientHello SNI=" + sni + } + return "TLS ClientHello" + case 0x02: + return "TLS ServerHello" + } + return "" +} + +func annotateTLSAlert(data []byte) string { + if len(data) < 7 { + return "" + } + severity := "warning" + if data[5] == 2 { + severity = "fatal" + } + return fmt.Sprintf("TLS Alert %s %s", severity, tlsAlertDesc(data[6])) +} + +func tlsAlertDesc(code byte) string { + switch code { + case 0: + return "close_notify" + case 10: + return "unexpected_message" + case 40: + return "handshake_failure" + case 42: + return "bad_certificate" + case 43: + return "unsupported_certificate" + case 44: + return "certificate_revoked" + case 45: + return "certificate_expired" + case 48: + return "unknown_ca" + case 49: + return "access_denied" + case 50: + return "decode_error" + case 70: + return "protocol_version" + case 80: + return "internal_error" + case 86: + return "inappropriate_fallback" + case 90: + return "user_canceled" + case 112: + return "unrecognized_name" + default: + return fmt.Sprintf("alert(%d)", code) + } +} + +// extractSNI parses a TLS ClientHello and returns the SNI server name. +func extractSNI(data []byte) string { + if len(data) < 6 || data[0] != 0x16 { + return "" + } + recordLen := int(binary.BigEndian.Uint16(data[3:5])) + handshake := data[5:] + if len(handshake) > recordLen { + handshake = handshake[:recordLen] + } + + if len(handshake) < 4 || handshake[0] != 0x01 { + return "" + } + hsLen := int(handshake[1])<<16 | int(handshake[2])<<8 | int(handshake[3]) + body := handshake[4:] + if len(body) > hsLen { + body = body[:hsLen] + } + + extPos := clientHelloExtensionsOffset(body) + if extPos < 0 { + return "" + } + return findSNIExtension(body, extPos) +} + +// clientHelloExtensionsOffset returns the byte offset where extensions begin +// within the ClientHello body, or -1 if the body is too short. +func clientHelloExtensionsOffset(body []byte) int { + if len(body) < 38 { + return -1 + } + pos := 34 + + if pos >= len(body) { + return -1 + } + pos += 1 + int(body[pos]) // session ID + + if pos+2 > len(body) { + return -1 + } + pos += 2 + int(binary.BigEndian.Uint16(body[pos:pos+2])) // cipher suites + + if pos >= len(body) { + return -1 + } + pos += 1 + int(body[pos]) // compression methods + + if pos+2 > len(body) { + return -1 + } + return pos +} + +func findSNIExtension(body []byte, pos int) string { + extLen := int(binary.BigEndian.Uint16(body[pos : pos+2])) + pos += 2 + extEnd := pos + extLen + if extEnd > len(body) { + extEnd = len(body) + } + + for pos+4 <= extEnd { + extType := binary.BigEndian.Uint16(body[pos : pos+2]) + eLen := int(binary.BigEndian.Uint16(body[pos+2 : pos+4])) + pos += 4 + if pos+eLen > extEnd { + break + } + if extType == 0 && eLen >= 5 { + nameLen := int(binary.BigEndian.Uint16(body[pos+3 : pos+5])) + if pos+5+nameLen <= extEnd { + return string(body[pos+5 : pos+5+nameLen]) + } + } + pos += eLen + } + return "" +} + +func isDNSPort(src, dst uint16) bool { + return src == 53 || dst == 53 || src == 5353 || dst == 5353 +} + +// formatDNSPayload parses DNS and returns a tcpdump-style summary. +func formatDNSPayload(payload []byte) string { + d := &layers.DNS{} + if err := d.DecodeFromBytes(payload, gopacket.NilDecodeFeedback); err != nil { + return "" + } + + rd := "" + if d.RD { + rd = "+" + } + + if !d.QR { + return formatDNSQuery(d, rd, len(payload)) + } + return formatDNSResponse(d, rd, len(payload)) +} + +func formatDNSQuery(d *layers.DNS, rd string, plen int) string { + if len(d.Questions) == 0 { + return fmt.Sprintf("%04x%s (%d)", d.ID, rd, plen) + } + q := d.Questions[0] + return fmt.Sprintf("%04x%s %s? %s. (%d)", d.ID, rd, q.Type, q.Name, plen) +} + +func formatDNSResponse(d *layers.DNS, rd string, plen int) string { + anCount := d.ANCount + nsCount := d.NSCount + arCount := d.ARCount + + if d.ResponseCode != layers.DNSResponseCodeNoErr { + return fmt.Sprintf("%04x %d/%d/%d %s (%d)", d.ID, anCount, nsCount, arCount, d.ResponseCode, plen) + } + + if anCount > 0 && len(d.Answers) > 0 { + rr := d.Answers[0] + if rdata := shortRData(&rr); rdata != "" { + return fmt.Sprintf("%04x %d/%d/%d %s %s (%d)", d.ID, anCount, nsCount, arCount, rr.Type, rdata, plen) + } + } + + return fmt.Sprintf("%04x %d/%d/%d (%d)", d.ID, anCount, nsCount, arCount, plen) +} + +func shortRData(rr *layers.DNSResourceRecord) string { + switch rr.Type { + case layers.DNSTypeA, layers.DNSTypeAAAA: + if rr.IP != nil { + return rr.IP.String() + } + case layers.DNSTypeCNAME: + if len(rr.CNAME) > 0 { + return string(rr.CNAME) + "." + } + case layers.DNSTypePTR: + if len(rr.PTR) > 0 { + return string(rr.PTR) + "." + } + case layers.DNSTypeNS: + if len(rr.NS) > 0 { + return string(rr.NS) + "." + } + case layers.DNSTypeMX: + return fmt.Sprintf("%d %s.", rr.MX.Preference, rr.MX.Name) + case layers.DNSTypeTXT: + if len(rr.TXTs) > 0 { + return fmt.Sprintf("%q", string(rr.TXTs[0])) + } + case layers.DNSTypeSRV: + return fmt.Sprintf("%d %d %d %s.", rr.SRV.Priority, rr.SRV.Weight, rr.SRV.Port, rr.SRV.Name) + } + return "" +}