diff --git a/.github/DISCUSSION_TEMPLATE/ideas-feature-requests.yml b/.github/DISCUSSION_TEMPLATE/ideas-feature-requests.yml new file mode 100644 index 000000000..f57a62107 --- /dev/null +++ b/.github/DISCUSSION_TEMPLATE/ideas-feature-requests.yml @@ -0,0 +1,130 @@ +body: + - type: markdown + attributes: + value: | + ## Ideas & Feature Requests + + Use this category for feature requests, enhancements, integrations, and product ideas. + + NetBird uses community traction in discussions — upvotes, replies, affected users, and use-case detail — as an input when deciding what should become a maintainer-curated issue or roadmap item. A clear problem statement is more useful than a solution-only request. + + Please search first and add your use case to an existing discussion when one already exists. + + - type: checkboxes + id: preflight + attributes: + label: Before posting + options: + - label: I searched existing discussions and issues for similar requests. + required: true + - label: I checked the documentation to confirm this is not already supported. + required: true + - label: This is a product idea or enhancement request, not a support question. + required: true + - label: I removed or anonymized sensitive details from examples and screenshots. + required: true + + - type: dropdown + id: area + attributes: + label: Product area + description: Select every area this request touches. + multiple: true + options: + - Client / Agent + - CLI + - Desktop UI + - Mobile app + - Dashboard / Admin UI + - Management service / API + - Signal service + - Relay + - DNS + - Routes / Exit nodes + - NetBird SSH + - Access control policies + - Posture checks + - Identity provider / SSO + - Self-hosting / Deployment + - Kubernetes / Operator + - Terraform / Automation + - Documentation + - Other / not sure + validations: + required: true + + - type: textarea + id: problem + attributes: + label: Problem or use case + description: What are you trying to accomplish, and what is difficult or impossible today? + placeholder: | + As a ... + I want to ... + Because ... + validations: + required: true + + - type: textarea + id: proposal + attributes: + label: Proposed solution + description: Describe the behavior, workflow, API, UI, or integration you would like to see. + validations: + required: true + + - type: textarea + id: alternatives + attributes: + label: Alternatives or workarounds considered + description: What have you tried today? Why is the current workaround not enough? + + - type: textarea + id: impact + attributes: + label: Community impact and priority + description: Help us understand who benefits and how urgent this is. + placeholder: | + - Number of users/teams/peers affected: + - Deployment type: Cloud / self-hosted / both + - Frequency: daily / weekly / occasional + - Blocking production adoption? yes/no + - Related comments, discussions, or customer requests: + validations: + required: true + + - type: textarea + id: examples + attributes: + label: Examples from other tools or products + description: If another tool solves this well, link or describe the behavior. + + - type: textarea + id: security + attributes: + label: Security, privacy, and compatibility considerations + description: Note any access-control, audit, data retention, network, platform, or backward-compatibility concerns. + + - type: textarea + id: implementation + attributes: + label: Implementation ideas + description: Optional. If you are familiar with the codebase or API, share possible implementation notes. + + - type: dropdown + id: contribution + attributes: + label: Are you willing to help? + options: + - Yes, I can submit a PR if the approach is accepted. + - Yes, I can test or validate a proposed implementation. + - Yes, I can provide more use-case details. + - Not at this time. + validations: + required: true + + - type: textarea + id: additional-context + attributes: + label: Additional context + description: Add screenshots, diagrams, links, or anything else that helps explain the request. diff --git a/.github/DISCUSSION_TEMPLATE/issue-triage.yml b/.github/DISCUSSION_TEMPLATE/issue-triage.yml new file mode 100644 index 000000000..b13ec066e --- /dev/null +++ b/.github/DISCUSSION_TEMPLATE/issue-triage.yml @@ -0,0 +1,237 @@ +body: + - type: markdown + attributes: + value: | + ## Issue Triage + + Use this category for reproducible bugs and regressions in NetBird. + + The more context you include, the faster we can validate and act on your report. If you're not sure whether something is a bug, **Q&A / Support** is a good starting point — we can always move the conversation here once we've confirmed it's a product issue. + + Intermittent issues are useful too. Include the trigger, frequency, timing, and any logs or debug evidence you have, and we'll work from there. + + Please don't include secrets, tokens, private keys, internal hostnames, or public IPs. Security vulnerabilities should be reported through the repository security policy rather than a public discussion. + + - type: checkboxes + id: preflight + attributes: + label: Before posting + options: + - label: I searched existing discussions and issues, including closed ones, and checked the relevant docs. + required: true + - label: I believe this is a product bug rather than a configuration or setup question. + required: true + - label: I can reproduce this issue, or for intermittent issues I've included trigger, frequency, and timing details below. + required: true + - label: I removed or anonymized sensitive data from logs, screenshots, and configuration. + required: true + + - type: dropdown + id: area + attributes: + label: Affected area + description: Select every area this report touches. + multiple: true + options: + - Client / Agent + - Reverse Proxy + - CLI + - Desktop UI + - Mobile app + - Peer connectivity + - DNS + - Routes / Exit nodes + - NetBird SSH + - Relay / Signal / NAT traversal + - Login / Authentication / IdP + - Dashboard / Admin UI + - Management service / API + - Access control policies / Posture checks + - Self-hosting / Deployment + - Kubernetes / Operator + - Documentation + - Other / not sure + validations: + required: true + + - type: dropdown + id: deployment + attributes: + label: Deployment type + options: + - NetBird Cloud + - Self-hosted - quickstart script + - Self-hosted - advanced/custom deployment + - Local development build + - Not sure / environment I do not fully control + validations: + required: true + + - type: dropdown + id: platform + attributes: + label: Operating system or environment + description: Select every environment involved in the reproduction. + multiple: true + options: + - Linux + - macOS + - Windows + - Android + - iOS + - FreeBSD + - OpenWRT + - Docker + - Kubernetes + - Synology + - Browser + - Other / not sure + validations: + required: true + + - type: textarea + id: version + attributes: + label: NetBird version and upgrade status + description: Run `netbird version` where applicable. For self-hosted deployments, include management, signal, relay, and dashboard versions if available. If you cannot test on a current/supported version, explain why. + placeholder: | + Example: + - Client: 0.30.2 + - Management: 0.30.2 + - Signal: 0.30.2 + - Relay: 0.30.2 + - Dashboard: 0.30.2 + - Upgrade status: reproduced on current version / cannot upgrade because ... + validations: + required: true + + - type: dropdown + id: regression + attributes: + label: Did this work before? + options: + - Yes, this worked before + - No, this never worked + - Not sure + validations: + required: true + + - type: textarea + id: regression-details + attributes: + label: Regression details + description: If this worked before, include the last known working version, first known broken version, and any recent upgrade, configuration, network, or IdP changes. + placeholder: | + - Last known working version: + - First known broken version: + - Recent changes: + + - type: textarea + id: summary + attributes: + label: Summary + description: Briefly describe the reproducible bug. + placeholder: What is broken? + validations: + required: true + + - type: textarea + id: current-behavior + attributes: + label: Current behavior + description: What happens now? Include exact errors, timeouts, UI messages, or failed commands when possible. + validations: + required: true + + - type: textarea + id: expected-behavior + attributes: + label: Expected behavior + description: What did you expect to happen instead? + validations: + required: true + + - type: textarea + id: reproduction + attributes: + label: Steps to reproduce + description: Provide the smallest set of steps that reliably reproduces the bug. If the issue is intermittent, include the trigger, frequency, timing, and relevant timestamps. + placeholder: | + 1. Configure ... + 2. Run ... + 3. Observe ... + + For intermittent issues: + - Trigger: + - Frequency: + - Timing/timestamps: + validations: + required: true + + - type: textarea + id: environment + attributes: + label: Environment and topology + description: Include the relevant topology and software involved in the reproduction. For UI/docs-only reports, write `N/A` if this does not apply. Use `None`, `Unknown`, or `N/A` where appropriate. + placeholder: | + - Peer A: + - Peer B: + - Same LAN or different networks: + - NAT/CGNAT/corporate firewall/mobile network: + - Other VPN software: + - Firewall, DNS, or endpoint security software: + - Routes, DNS, policies, posture checks, or SSH rules involved: + - IdP, reverse proxy, or browser involved: + validations: + required: true + + - type: textarea + id: self-hosted-details + attributes: + label: Self-hosted details, if available + description: Optional. If you use self-hosting and have access to these details, include them. If you do not administer the environment, provide what you know and say what you cannot access. + placeholder: | + - Deployment method: quickstart / Docker Compose / Helm / operator / custom + - Management/signal/relay/dashboard versions: + - Reverse proxy: + - IdP/provider: + - STUN/TURN/coturn/relay details: + - Relevant component logs: + + - type: textarea + id: logs + attributes: + label: Logs, status output, or debug evidence + description: | + For client, connectivity, DNS, route, relay/signal, or self-hosted reports, logs are essential — please include anonymized output from `netbird status -dA`, or a debug bundle via `netbird debug for 1m -AS -U`. Debug bundles are automatically deleted after 30 days. + + For UI, dashboard, or documentation reports, leave the pre-filled `N/A`. + value: "N/A" + render: shell + validations: + required: true + + - type: textarea + id: related-reports + attributes: + label: Related issues or discussions + description: Optional. Link similar reports you found while searching, if any. + placeholder: | + - Related issue/discussion: + - Why this may be the same or different: + + - type: textarea + id: impact + attributes: + label: Impact + description: Optional. Help us understand priority. How many users, peers, environments, or workflows are affected? Is there a workaround? + placeholder: | + - Affected users/peers: + - Business or production impact: + - Workaround available: + + - type: textarea + id: additional-context + attributes: + label: Additional context + description: Add links to related discussions, issues, docs, screenshots, recordings, or anything else that may help validation. diff --git a/.github/DISCUSSION_TEMPLATE/q-a-support.yml b/.github/DISCUSSION_TEMPLATE/q-a-support.yml new file mode 100644 index 000000000..725f8737c --- /dev/null +++ b/.github/DISCUSSION_TEMPLATE/q-a-support.yml @@ -0,0 +1,146 @@ +body: + - type: markdown + attributes: + value: | + ## Q&A / Support + + Use this category for questions about configuration, setup, self-hosted deployments, troubleshooting, and general NetBird usage. + + This is community support and does not provide an SLA. For NetBird Cloud support, use the official support channel linked from the issue creation page. Please do not post secrets, tokens, private keys, internal hostnames, or public IPs unless you intentionally want them public. + + If your question turns into a reproducible product defect, DevRel or a maintainer may ask you to open or move the conversation to Issue Triage. + + - type: checkboxes + id: preflight + attributes: + label: Before posting + options: + - label: I searched existing discussions and issues for similar questions. + required: true + - label: I reviewed the relevant NetBird documentation or troubleshooting guide. + required: true + - label: I removed or anonymized sensitive data from logs, screenshots, and configuration. + required: true + + - type: dropdown + id: topic + attributes: + label: Topic + multiple: true + options: + - Getting started + - Self-hosting + - Client / Agent + - CLI + - Desktop UI + - Mobile app + - Dashboard / Admin UI + - DNS + - Routes / Exit nodes + - NetBird SSH + - Relay + - Access control policies + - Posture checks + - Identity provider / SSO + - API + - Kubernetes / Operator + - Terraform / Automation + - Documentation + - Other / not sure + validations: + required: true + + - type: dropdown + id: deployment + attributes: + label: Deployment type + options: + - NetBird Cloud + - Self-hosted - quickstart script + - Self-hosted - advanced/custom deployment + - Local development build + - Not sure + validations: + required: true + + - type: dropdown + id: platform + attributes: + label: Operating system or environment + multiple: true + options: + - Linux + - macOS + - Windows + - Android + - iOS + - FreeBSD + - OpenWRT + - Docker + - Kubernetes + - Synology + - Browser + - Other / not sure + validations: + required: true + + - type: input + id: version + attributes: + label: NetBird version + description: Run `netbird version` where applicable. For self-hosted deployments, include component versions if relevant. + placeholder: "Example: client 0.30.2, management 0.30.2" + + - type: textarea + id: question + attributes: + label: Question + description: What are you trying to understand or accomplish? + placeholder: Describe your question clearly. + validations: + required: true + + - type: textarea + id: goal + attributes: + label: Desired outcome + description: What would a successful answer help you do? + placeholder: | + I want to configure ... + I expected ... + I need help deciding ... + + - type: textarea + id: attempted + attributes: + label: What have you tried? + description: Include commands, documentation links, configuration attempts, or troubleshooting steps already tried. + placeholder: | + - Read ... + - Ran ... + - Changed ... + - Observed ... + + - type: textarea + id: environment + attributes: + label: Relevant environment details + description: Include redacted topology, IdP/provider, reverse proxy, firewall, DNS, route, policy, or self-hosted setup details that may affect the answer. + placeholder: | + - Deployment: + - Components involved: + - Network/topology: + - Related config: + + - type: textarea + id: logs + attributes: + label: Logs or output + description: Optional. Include anonymized logs, command output, screenshots, or `netbird status -dA` if relevant. + render: shell + + - type: textarea + id: additional-context + attributes: + label: Additional context + description: Add links, diagrams, screenshots, or other details that may help the community answer. diff --git a/.github/ISSUE_TEMPLATE/bug-issue-report.md b/.github/ISSUE_TEMPLATE/bug-issue-report.md deleted file mode 100644 index df670db06..000000000 --- a/.github/ISSUE_TEMPLATE/bug-issue-report.md +++ /dev/null @@ -1,71 +0,0 @@ ---- -name: Bug/Issue report -about: Create a report to help us improve -title: '' -labels: ['triage-needed'] -assignees: '' - ---- - -**Describe the problem** - -A clear and concise description of what the problem is. - -**To Reproduce** - -Steps to reproduce the behavior: -1. Go to '...' -2. Click on '....' -3. Scroll down to '....' -4. See error - -**Expected behavior** - -A clear and concise description of what you expected to happen. - -**Are you using NetBird Cloud?** - -Please specify whether you use NetBird Cloud or self-host NetBird's control plane. - -**NetBird version** - -`netbird version` - -**Is any other VPN software installed?** - -If yes, which one? - -**Debug output** - -To help us resolve the problem, please attach the following anonymized status output - - netbird status -dA - -Create and upload a debug bundle, and share the returned file key: - - netbird debug for 1m -AS -U - -*Uploaded files are automatically deleted after 30 days.* - - -Alternatively, create the file only and attach it here manually: - - netbird debug for 1m -AS - - -**Screenshots** - -If applicable, add screenshots to help explain your problem. - -**Additional context** - -Add any other context about the problem here. - -**Have you tried these troubleshooting steps?** -- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable) -- [ ] Checked for newer NetBird versions -- [ ] Searched for similar issues on GitHub (including closed ones) -- [ ] Restarted the NetBird client -- [ ] Disabled other VPN software -- [ ] Checked firewall settings - diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index e9ffaf8a3..ee3e84df6 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,14 +1,26 @@ -blank_issues_enabled: true +blank_issues_enabled: false contact_links: - - name: Community Support + - name: Start an Issue Triage discussion + url: https://github.com/netbirdio/netbird/discussions/new?category=issue-triage + about: Report a bug, regression, or unexpected behavior so DevRel can validate it before it becomes an issue. + - name: Propose an idea or feature request + url: https://github.com/netbirdio/netbird/discussions/new?category=ideas-feature-requests + about: Share feature requests, enhancements, and integration ideas for community feedback and prioritization. + - name: Ask a Q&A / Support question + url: https://github.com/netbirdio/netbird/discussions/new?category=q-a-support + about: Get help with setup, configuration, self-hosting, troubleshooting, and general usage. + - name: Security vulnerability disclosure + url: https://github.com/netbirdio/netbird/security/policy + about: Please do not report security vulnerabilities in public issues or discussions. + - name: Community Support Forum url: https://forum.netbird.io/ - about: Community support forum + about: Community support forum. - name: Cloud Support url: https://docs.netbird.io/help/report-bug-issues - about: Contact us for support - - name: Client/Connection Troubleshooting + about: Contact NetBird for Cloud support. + - name: Client / Connection Troubleshooting url: https://docs.netbird.io/help/troubleshooting-client - about: See our client troubleshooting guide for help addressing common issues + about: See the client troubleshooting guide for common connectivity issues. - name: Self-host Troubleshooting url: https://docs.netbird.io/selfhosted/troubleshooting - about: See our self-host troubleshooting guide for help addressing common issues + about: See the self-host troubleshooting guide for common deployment issues. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md deleted file mode 100644 index 4a3e5782c..000000000 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ /dev/null @@ -1,20 +0,0 @@ ---- -name: Feature request -about: Suggest an idea for this project -title: '' -labels: ['feature-request'] -assignees: '' - ---- - -**Is your feature request related to a problem? Please describe.** -A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] - -**Describe the solution you'd like** -A clear and concise description of what you want to happen. - -**Describe alternatives you've considered** -A clear and concise description of any alternative solutions or features you've considered. - -**Additional context** -Add any other context or screenshots about the feature request here. diff --git a/.github/ISSUE_TEMPLATE/validated_issue.yml b/.github/ISSUE_TEMPLATE/validated_issue.yml new file mode 100644 index 000000000..2a21b73b2 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/validated_issue.yml @@ -0,0 +1,128 @@ +name: Validated issue +description: Maintainer/DevRel only. Create an issue after a discussion has been validated or for internally validated work. +title: "[Validated]: " +body: + - type: markdown + attributes: + value: | + ## Discussion-first issue policy + + Issues are maintainer-curated work items. Community reports and feature requests should start in [Discussions](https://github.com/netbirdio/netbird/discussions) so DevRel can validate, reproduce, and route them before engineering time is committed. + + Use this form when: + - A discussion has been validated and should become actionable work. + - A maintainer is opening internally validated work that can bypass the discussion-first flow. + + Issues opened without a relevant validated discussion or maintainer context may be closed and redirected to Discussions. + + - type: checkboxes + id: validation-checks + attributes: + label: Validation checklist + options: + - label: This issue is linked to a validated discussion, or it is being opened directly by a maintainer. + required: true + - label: The report has enough context for engineering to act on it without re-triaging from scratch. + required: true + - label: Sensitive data, secrets, private keys, internal hostnames, and public IPs have been removed or intentionally disclosed. + required: true + + - type: dropdown + id: issue-type + attributes: + label: Issue type + options: + - Bug / Regression + - Feature / Enhancement + - Documentation + - Maintenance / Refactor + - Cross-repository coordination + - Other + validations: + required: true + + - type: input + id: source-discussion + attributes: + label: Source discussion + description: Link the GitHub Discussion that was validated. Maintainers bypassing the flow can write "Maintainer-created" and explain why below. + placeholder: https://github.com/netbirdio/netbird/discussions/1234 + validations: + required: true + + - type: input + id: validation-owner + attributes: + label: Validation owner + description: GitHub handle of the DevRel team member or maintainer who validated this work. + placeholder: "@username" + validations: + required: true + + - type: dropdown + id: target-repository + attributes: + label: Target repository + description: Where should the implementation work happen? + options: + - netbirdio/netbird + - netbirdio/dashboard + - netbirdio/kubernetes-operator + - netbirdio/docs + - Multiple repositories + - Unknown / needs routing + validations: + required: true + + - type: textarea + id: summary + attributes: + label: Summary + description: Concise description of the validated work. + placeholder: What needs to be fixed, changed, documented, or built? + validations: + required: true + + - type: textarea + id: evidence + attributes: + label: Validation evidence + description: For bugs, include reproduction status, affected versions, logs, and environment. For features, include community traction, affected users, and alignment notes. + placeholder: | + - Reproduced by: + - Affected versions / platforms: + - Community signal: + - Related logs or screenshots: + validations: + required: true + + - type: textarea + id: scope + attributes: + label: Proposed scope + description: Describe what is in scope and, if helpful, what is explicitly out of scope. + placeholder: | + In scope: + - ... + + Out of scope: + - ... + validations: + required: true + + - type: textarea + id: acceptance-criteria + attributes: + label: Acceptance criteria + description: What must be true for this issue to be closed? + placeholder: | + - [ ] ... + - [ ] ... + validations: + required: true + + - type: textarea + id: additional-context + attributes: + label: Additional context + description: Links to related PRs, docs, issues in other repositories, roadmap items, or implementation notes. diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5ada1033d..c1ae01a98 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.1.2" + SIGN_PIPE_VER: "v0.1.4" GORELEASER_VER: "v2.14.3" PRODUCT_NAME: "NetBird" COPYRIGHT: "NetBird GmbH" @@ -114,7 +114,13 @@ jobs: retention-days: 30 release: - runs-on: ubuntu-latest-m + runs-on: ubuntu-24.04-8-core + outputs: + release_artifact_url: ${{ steps.upload_release.outputs.artifact-url }} + linux_packages_artifact_url: ${{ steps.upload_linux_packages.outputs.artifact-url }} + windows_packages_artifact_url: ${{ steps.upload_windows_packages.outputs.artifact-url }} + macos_packages_artifact_url: ${{ steps.upload_macos_packages.outputs.artifact-url }} + ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }} env: flags: "" steps: @@ -213,10 +219,13 @@ jobs: if: always() run: rm -f /tmp/gpg-rpm-signing-key.asc - name: Tag and push images (amd64 only) + id: tag_and_push_images if: | (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) || (github.event_name == 'push' && github.ref == 'refs/heads/main') run: | + set -euo pipefail + resolve_tags() { if [[ "${{ github.event_name }}" == "pull_request" ]]; then echo "pr-${{ github.event.pull_request.number }}" @@ -225,6 +234,17 @@ jobs: fi } + ghcr_package_url() { + local image="$1" package encoded_package + package="${image#ghcr.io/}" + package="${package#*/}" + package="${package%%:*}" + encoded_package="${package//\//%2F}" + echo "https://github.com/orgs/netbirdio/packages/container/package/${encoded_package}" + } + + image_refs=() + tag_and_push() { local src="$1" img_name tag dst img_name="${src%%:*}" @@ -233,35 +253,56 @@ jobs: echo "Tagging ${src} -> ${dst}" docker tag "$src" "$dst" docker push "$dst" + image_refs+=("$dst") done } - export -f tag_and_push resolve_tags + cat > /tmp/goreleaser-artifacts.json <<'JSON' + ${{ steps.goreleaser.outputs.artifacts }} + JSON - echo '${{ steps.goreleaser.outputs.artifacts }}' | \ - jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \ - grep '^ghcr.io/' | while read -r SRC; do - tag_and_push "$SRC" - done + mapfile -t src_images < <( + jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name | select(startswith("ghcr.io/"))' /tmp/goreleaser-artifacts.json + ) + + for src in "${src_images[@]}"; do + tag_and_push "$src" + done + + { + echo "images_markdown<> "$GITHUB_OUTPUT" - name: upload non tags for debug purposes + id: upload_release uses: actions/upload-artifact@v4 with: name: release path: dist/ retention-days: 7 - name: upload linux packages + id: upload_linux_packages uses: actions/upload-artifact@v4 with: name: linux-packages path: dist/netbird_linux** retention-days: 7 - name: upload windows packages + id: upload_windows_packages uses: actions/upload-artifact@v4 with: name: windows-packages path: dist/netbird_windows** retention-days: 7 - name: upload macos packages + id: upload_macos_packages uses: actions/upload-artifact@v4 with: name: macos-packages @@ -270,6 +311,8 @@ jobs: release_ui: runs-on: ubuntu-latest + outputs: + release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }} steps: - name: Parse semver string id: semver_parser @@ -360,6 +403,7 @@ jobs: if: always() run: rm -f /tmp/gpg-rpm-signing-key.asc - name: upload non tags for debug purposes + id: upload_release_ui uses: actions/upload-artifact@v4 with: name: release-ui @@ -368,6 +412,8 @@ jobs: release_ui_darwin: runs-on: macos-latest + outputs: + release_ui_darwin_artifact_url: ${{ steps.upload_release_ui_darwin.outputs.artifact-url }} steps: - if: ${{ !startsWith(github.ref, 'refs/tags/v') }} run: echo "flags=--snapshot" >> $GITHUB_ENV @@ -402,15 +448,258 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: upload non tags for debug purposes + id: upload_release_ui_darwin uses: actions/upload-artifact@v4 with: name: release-ui-darwin path: dist/ retention-days: 3 - trigger_signer: + test_windows_installer: + name: "Windows Installer / Build Test" + runs-on: windows-2022 + needs: [release, release_ui] + strategy: + fail-fast: false + matrix: + include: + - arch: amd64 + wintun_arch: amd64 + - arch: arm64 + wintun_arch: arm64 + defaults: + run: + shell: powershell + env: + PackageWorkdir: netbird_windows_${{ matrix.arch }} + downloadPath: '${{ github.workspace }}\temp' + steps: + - name: Parse semver string + id: semver_parser + uses: booxmedialtd/ws-action-parse-semver@v1 + with: + input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }} + version_extractor_regex: '\/v(.*)$' + + - name: Checkout + uses: actions/checkout@v4 + + - name: Add 7-Zip to PATH + run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + + - name: Download release artifacts + uses: actions/download-artifact@v4 + with: + name: release + path: release + + - name: Download UI release artifacts + uses: actions/download-artifact@v4 + with: + name: release-ui + path: release-ui + + - name: Stage binaries into dist + run: | + $workdir = "dist\${{ env.PackageWorkdir }}" + New-Item -ItemType Directory -Force -Path $workdir | Out-Null + $client = Get-ChildItem -Recurse -Path release -Filter "netbird_*_windows_${{ matrix.arch }}.tar.gz" | Select-Object -First 1 + $ui = Get-ChildItem -Recurse -Path release-ui -Filter "netbird-ui-windows_*_windows_${{ matrix.arch }}.tar.gz" | Select-Object -First 1 + if (-not $client) { Write-Host "::error::client tarball not found for ${{ matrix.arch }}"; exit 1 } + if (-not $ui) { Write-Host "::error::ui tarball not found for ${{ matrix.arch }}"; exit 1 } + Write-Host "Client: $($client.FullName)" + Write-Host "UI: $($ui.FullName)" + tar -zvxf $client.FullName -C $workdir + tar -zvxf $ui.FullName -C $workdir + Get-ChildItem $workdir + + - name: Download wintun + uses: carlosperate/download-file-action@v2 + id: download-wintun + with: + file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip + file-name: wintun.zip + location: ${{ env.downloadPath }} + sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51' + + - name: Decompress wintun files + run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }} + + - name: Move wintun.dll into dist + run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\ + + - name: Download Mesa3D (amd64 only) + uses: carlosperate/download-file-action@v2 + id: download-mesa3d + if: matrix.arch == 'amd64' + with: + file-url: https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z + file-name: mesa3d.7z + location: ${{ env.downloadPath }} + sha256: '71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9' + + - name: Extract Mesa3D driver (amd64 only) + if: matrix.arch == 'amd64' + run: 7z x -o"${{ env.downloadPath }}" "${{ env.downloadPath }}/mesa3d.7z" + + - name: Move opengl32.dll into dist (amd64 only) + if: matrix.arch == 'amd64' + run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\ + + - name: Download EnVar plugin for NSIS + uses: carlosperate/download-file-action@v2 + with: + file-url: https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip + file-name: envar_plugin.zip + location: ${{ github.workspace }} + + - name: Extract EnVar plugin + run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/envar_plugin.zip" + + - name: Download ShellExecAsUser plugin for NSIS (amd64 only) + uses: carlosperate/download-file-action@v2 + if: matrix.arch == 'amd64' + with: + file-url: https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z + file-name: ShellExecAsUser_amd64-Unicode.7z + location: ${{ github.workspace }} + + - name: Extract ShellExecAsUser plugin (amd64 only) + if: matrix.arch == 'amd64' + run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z" + + - name: Build NSIS installer + uses: joncloud/makensis-action@v3.3 + with: + additional-plugin-paths: ${{ github.workspace }}/NSIS_Plugins/Plugins + script-file: client/installer.nsis + arguments: "/V4 /DARCH=${{ matrix.arch }}" + env: + APPVER: ${{ steps.semver_parser.outputs.major }}.${{ steps.semver_parser.outputs.minor }}.${{ steps.semver_parser.outputs.patch }}.${{ github.run_id }} + + - name: Rename NSIS installer + run: mv netbird-installer.exe netbird_installer_test_windows_${{ matrix.arch }}.exe + + - name: Install WiX + run: | + dotnet tool install --global wix --version 6.0.2 + wix extension add WixToolset.Util.wixext/6.0.2 + + - name: Build MSI installer + env: + NETBIRD_VERSION: "${{ steps.semver_parser.outputs.fullversion }}" + run: wix build -arch ${{ matrix.arch == 'amd64' && 'x64' || 'arm64' }} -ext WixToolset.Util.wixext -o netbird_installer_test_windows_${{ matrix.arch }}.msi .\client\netbird.wxs -d ProcessorArchitecture=${{ matrix.arch == 'amd64' && 'x64' || 'arm64' }} -d ArchSuffix=${{ matrix.arch }} + + - name: Upload installer artifacts + if: always() + uses: actions/upload-artifact@v4 + with: + name: windows-installer-test-${{ matrix.arch }} + path: | + netbird_installer_test_windows_${{ matrix.arch }}.exe + netbird_installer_test_windows_${{ matrix.arch }}.msi + retention-days: 3 + + comment_release_artifacts: + name: Comment release artifacts runs-on: ubuntu-latest needs: [release, release_ui, release_ui_darwin] + if: ${{ always() && github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository }} + permissions: + contents: read + issues: write + pull-requests: write + steps: + - name: Create or update PR comment + uses: actions/github-script@v7 + env: + RELEASE_RESULT: ${{ needs.release.result }} + RELEASE_UI_RESULT: ${{ needs.release_ui.result }} + RELEASE_UI_DARWIN_RESULT: ${{ needs.release_ui_darwin.result }} + RELEASE_ARTIFACT_URL: ${{ needs.release.outputs.release_artifact_url }} + LINUX_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.linux_packages_artifact_url }} + WINDOWS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.windows_packages_artifact_url }} + MACOS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.macos_packages_artifact_url }} + RELEASE_UI_ARTIFACT_URL: ${{ needs.release_ui.outputs.release_ui_artifact_url }} + RELEASE_UI_DARWIN_ARTIFACT_URL: ${{ needs.release_ui_darwin.outputs.release_ui_darwin_artifact_url }} + GHCR_IMAGES_MARKDOWN: ${{ needs.release.outputs.ghcr_images }} + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const marker = ''; + const { owner, repo } = context.repo; + const issue_number = context.payload.pull_request.number; + const runUrl = `${context.serverUrl}/${owner}/${repo}/actions/runs/${context.runId}`; + const shortSha = context.payload.pull_request.head.sha.slice(0, 7); + + const artifactCell = (url, result) => { + if (url) return `[Download](${url})`; + return result && result !== 'success' ? `_Not available (${result})_` : '_Not available_'; + }; + + const artifacts = [ + ['All release artifacts', process.env.RELEASE_ARTIFACT_URL, process.env.RELEASE_RESULT], + ['Linux packages', process.env.LINUX_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT], + ['Windows packages', process.env.WINDOWS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT], + ['macOS packages', process.env.MACOS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT], + ['UI artifacts', process.env.RELEASE_UI_ARTIFACT_URL, process.env.RELEASE_UI_RESULT], + ['UI macOS artifacts', process.env.RELEASE_UI_DARWIN_ARTIFACT_URL, process.env.RELEASE_UI_DARWIN_RESULT], + ]; + + const artifactRows = artifacts + .map(([name, url, result]) => `| ${name} | ${artifactCell(url, result)} |`) + .join('\n'); + + const ghcrImages = (process.env.GHCR_IMAGES_MARKDOWN || '').trim() || '_No GHCR images were pushed._'; + + const body = [ + marker, + '## Release artifacts', + '', + `Built for PR head \`${shortSha}\` in [workflow run #${process.env.GITHUB_RUN_NUMBER}](${runUrl}).`, + '', + '| Artifact | Link |', + '| --- | --- |', + artifactRows, + '', + '### GHCR images (amd64)', + ghcrImages, + '', + '_This comment is updated by the Release workflow. Artifact links expire according to the workflow retention policy._', + ].join('\n'); + + const comments = await github.paginate(github.rest.issues.listComments, { + owner, + repo, + issue_number, + per_page: 100, + }); + + const previous = comments.find(comment => + comment.user?.type === 'Bot' && comment.body?.includes(marker) + ); + + if (previous) { + await github.rest.issues.updateComment({ + owner, + repo, + comment_id: previous.id, + body, + }); + core.info(`Updated release artifacts comment ${previous.id}`); + } else { + const { data } = await github.rest.issues.createComment({ + owner, + repo, + issue_number, + body, + }); + core.info(`Created release artifacts comment ${data.id}`); + } + + trigger_signer: + runs-on: ubuntu-latest + needs: [release, release_ui, release_ui_darwin, test_windows_installer] if: startsWith(github.ref, 'refs/tags/') steps: - name: Trigger binaries sign pipelines diff --git a/.github/workflows/sync-tag.yml b/.github/workflows/sync-tag.yml index 1cc553b12..a75d9a9d5 100644 --- a/.github/workflows/sync-tag.yml +++ b/.github/workflows/sync-tag.yml @@ -9,6 +9,8 @@ concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} cancel-in-progress: true +# Receiving workflows (cloud sync-tag, mobile bump-netbird) expect the short +# tag form (e.g. v0.30.0), not refs/tags/v0.30.0 — github.ref_name, not github.ref. jobs: trigger_sync_tag: runs-on: ubuntu-latest @@ -20,4 +22,30 @@ jobs: ref: main repo: ${{ secrets.UPSTREAM_REPO }} token: ${{ secrets.NC_GITHUB_TOKEN }} + inputs: '{ "tag": "${{ github.ref_name }}" }' + + trigger_android_bump: + runs-on: ubuntu-latest + if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-') + steps: + - name: Trigger android-client submodule bump + uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1 + with: + workflow: bump-netbird.yml + ref: main + repo: netbirdio/android-client + token: ${{ secrets.NC_GITHUB_TOKEN }} + inputs: '{ "tag": "${{ github.ref_name }}" }' + + trigger_ios_bump: + runs-on: ubuntu-latest + if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-') + steps: + - name: Trigger ios-client submodule bump + uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1 + with: + workflow: bump-netbird.yml + ref: main + repo: netbirdio/ios-client + token: ${{ secrets.NC_GITHUB_TOKEN }} inputs: '{ "tag": "${{ github.ref_name }}" }' \ No newline at end of file diff --git a/.gitignore b/.gitignore index a0f128933..783fe77f3 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,4 @@ infrastructure_files/setup-*.env vendor/ /netbird client/netbird-electron/ +management/server/types/testdata/ diff --git a/.golangci.yaml b/.golangci.yaml index d81ad1377..900af4ac0 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -58,6 +58,11 @@ linters: govet: enable: - nilness + disable: + # The inline analyzer flags x/exp/maps Clone/Clear with //go:fix inline + # directives but cannot perform the rewrite due to generic type + # parameter inference limitations in the Go inliner. + - inline enable-all: false revive: rules: diff --git a/client/Dockerfile b/client/Dockerfile index 64d5ba04f..53e4555ef 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -17,6 +17,7 @@ ENV \ NETBIRD_BIN="/usr/local/bin/netbird" \ NB_LOG_FILE="console,/var/log/netbird/client.log" \ NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \ + NB_ENABLE_CAPTURE="false" \ NB_ENTRYPOINT_SERVICE_TIMEOUT="30" ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] diff --git a/client/Dockerfile-rootless b/client/Dockerfile-rootless index 69d00aaf2..706bf40de 100644 --- a/client/Dockerfile-rootless +++ b/client/Dockerfile-rootless @@ -23,6 +23,7 @@ ENV \ NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \ NB_LOG_FILE="console,/var/lib/netbird/client.log" \ NB_DISABLE_DNS="true" \ + NB_ENABLE_CAPTURE="false" \ NB_ENTRYPOINT_SERVICE_TIMEOUT="30" ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] diff --git a/client/android/client.go b/client/android/client.go index 37e17a363..99ccdf393 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -301,10 +301,11 @@ func (c *Client) PeersList() *PeerInfoArray { peerInfos := make([]PeerInfo, len(fullStatus.Peers)) for n, p := range fullStatus.Peers { pi := PeerInfo{ - p.IP, - p.FQDN, - int(p.ConnStatus), - PeerRoutes{routes: maps.Keys(p.GetRoutes())}, + IP: p.IP, + IPv6: p.IPv6, + FQDN: p.FQDN, + ConnStatus: int(p.ConnStatus), + Routes: PeerRoutes{routes: maps.Keys(p.GetRoutes())}, } peerInfos[n] = pi } @@ -336,43 +337,84 @@ func (c *Client) Networks() *NetworkArray { return nil } + routesMap := routeManager.GetClientRoutesWithNetID() + v6Merged := route.V6ExitMergeSet(routesMap) + resolvedDomains := c.recorder.GetResolvedDomainsStates() + networkArray := &NetworkArray{ items: make([]Network, 0), } - resolvedDomains := c.recorder.GetResolvedDomainsStates() - - for id, routes := range routeManager.GetClientRoutesWithNetID() { + for id, routes := range routesMap { if len(routes) == 0 { continue } - - r := routes[0] - domains := c.getNetworkDomainsFromRoute(r, resolvedDomains) - netStr := r.Network.String() - - if r.IsDynamic() { - netStr = r.Domains.SafeString() - } - - routePeer, err := c.recorder.GetPeer(routes[0].Peer) - if err != nil { - log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err) + if _, skip := v6Merged[id]; skip { continue } - network := Network{ - Name: string(id), - Network: netStr, - Peer: routePeer.FQDN, - Status: routePeer.ConnStatus.String(), - IsSelected: routeSelector.IsSelected(id), - Domains: domains, + + network := c.buildNetwork(id, routes, routeSelector.IsSelected(id), resolvedDomains, v6Merged) + if network == nil { + continue } - networkArray.Add(network) + networkArray.Add(*network) } return networkArray } +func (c *Client) buildNetwork(id route.NetID, routes []*route.Route, selected bool, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo, v6Merged map[route.NetID]struct{}) *Network { + r := routes[0] + netStr := r.Network.String() + if r.IsDynamic() { + netStr = r.Domains.SafeString() + } + + routePeer, err := c.findBestRoutePeer(routes) + if err != nil { + log.Errorf("could not get peer info for route %s: %v", id, err) + return nil + } + + network := &Network{ + Name: string(id), + Network: netStr, + Peer: routePeer.FQDN, + Status: routePeer.ConnStatus.String(), + IsSelected: selected, + Domains: c.getNetworkDomainsFromRoute(r, resolvedDomains), + } + + if route.IsV4DefaultRoute(r.Network) && route.HasV6ExitPair(id, v6Merged) { + network.Network = "0.0.0.0/0, ::/0" + } + + return network +} + +// findBestRoutePeer returns the peer actively routing traffic for the given +// HA route group. Falls back to the first connected peer, then the first peer. +func (c *Client) findBestRoutePeer(routes []*route.Route) (peer.State, error) { + netStr := routes[0].Network.String() + + fullStatus := c.recorder.GetFullStatus() + for _, p := range fullStatus.Peers { + if _, ok := p.GetRoutes()[netStr]; ok { + return p, nil + } + } + + for _, r := range routes { + p, err := c.recorder.GetPeer(r.Peer) + if err != nil { + continue + } + if p.ConnStatus == peer.StatusConnected { + return p, nil + } + } + return c.recorder.GetPeer(routes[0].Peer) +} + // OnUpdatedHostDNS update the DNS servers addresses for root zones func (c *Client) OnUpdatedHostDNS(list *DNSList) error { dnsServer, err := dns.GetServerDns() diff --git a/client/android/peer_notifier.go b/client/android/peer_notifier.go index 4ec22f3ab..c2595e574 100644 --- a/client/android/peer_notifier.go +++ b/client/android/peer_notifier.go @@ -14,6 +14,7 @@ const ( // PeerInfo describe information about the peers. It designed for the UI usage type PeerInfo struct { IP string + IPv6 string FQDN string ConnStatus int Routes PeerRoutes diff --git a/client/android/preferences.go b/client/android/preferences.go index c3c8eb3fb..066477293 100644 --- a/client/android/preferences.go +++ b/client/android/preferences.go @@ -307,6 +307,24 @@ func (p *Preferences) SetBlockInbound(block bool) { p.configInput.BlockInbound = &block } +// GetDisableIPv6 reads disable IPv6 setting from config file +func (p *Preferences) GetDisableIPv6() (bool, error) { + if p.configInput.DisableIPv6 != nil { + return *p.configInput.DisableIPv6, nil + } + + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + return cfg.DisableIPv6, err +} + +// SetDisableIPv6 stores the given value and waits for commit +func (p *Preferences) SetDisableIPv6(disable bool) { + p.configInput.DisableIPv6 = &disable +} + // Commit writes out the changes to the config file func (p *Preferences) Commit() error { _, err := profilemanager.UpdateOrCreateConfig(p.configInput) diff --git a/client/android/route_command.go b/client/android/route_command.go index b47d5ca6c..5e7357335 100644 --- a/client/android/route_command.go +++ b/client/android/route_command.go @@ -18,9 +18,12 @@ func executeRouteToggle(id string, manager routemanager.Manager, netID := route.NetID(id) routes := []route.NetID{netID} - log.Debugf("%s with id: %s", operationName, id) + routesMap := manager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) - if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil { + log.Debugf("%s with ids: %v", operationName, routes) + + if err := routeOperation(routes, maps.Keys(routesMap)); err != nil { log.Debugf("error when %s: %s", operationName, err) return fmt.Errorf("error %s: %w", operationName, err) } diff --git a/client/anonymize/anonymize.go b/client/anonymize/anonymize.go index 89e653300..c140cef89 100644 --- a/client/anonymize/anonymize.go +++ b/client/anonymize/anonymize.go @@ -9,6 +9,7 @@ import ( "net/url" "regexp" "slices" + "strconv" "strings" ) @@ -26,8 +27,9 @@ type Anonymizer struct { } func DefaultAddresses() (netip.Addr, netip.Addr) { - // 198.51.100.0, 100:: - return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01}) + // 198.51.100.0 (RFC 5737 TEST-NET-2), 2001:db8:ffff:: (RFC 3849 documentation, last /48) + // The old start 100:: (discard, RFC 6666) is now used for fake IPs on Android. + return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.MustParseAddr("2001:db8:ffff::") } func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer { @@ -48,7 +50,7 @@ func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr { ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() || - ip.IsPrivate() || + (ip.Is4() && ip.IsPrivate()) || ip.IsUnspecified() || ip.IsMulticast() || isWellKnown(ip) || @@ -96,6 +98,11 @@ func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool { } func (a *Anonymizer) AnonymizeIPString(ip string) string { + // Handle CIDR notation (e.g. "2001:db8::/32") + if prefix, err := netip.ParsePrefix(ip); err == nil { + return a.AnonymizeIP(prefix.Addr()).String() + "/" + strconv.Itoa(prefix.Bits()) + } + addr, err := netip.ParseAddr(ip) if err != nil { return ip @@ -150,7 +157,7 @@ func (a *Anonymizer) AnonymizeURI(uri string) string { if u.Opaque != "" { host, port, err := net.SplitHostPort(u.Opaque) if err == nil { - anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port) + anonymizedHost = net.JoinHostPort(a.AnonymizeDomain(host), port) } else { anonymizedHost = a.AnonymizeDomain(u.Opaque) } @@ -158,7 +165,7 @@ func (a *Anonymizer) AnonymizeURI(uri string) string { } else if u.Host != "" { host, port, err := net.SplitHostPort(u.Host) if err == nil { - anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port) + anonymizedHost = net.JoinHostPort(a.AnonymizeDomain(host), port) } else { anonymizedHost = a.AnonymizeDomain(u.Host) } diff --git a/client/anonymize/anonymize_test.go b/client/anonymize/anonymize_test.go index ff2e48869..852315fa1 100644 --- a/client/anonymize/anonymize_test.go +++ b/client/anonymize/anonymize_test.go @@ -13,7 +13,7 @@ import ( func TestAnonymizeIP(t *testing.T) { startIPv4 := netip.MustParseAddr("198.51.100.0") - startIPv6 := netip.MustParseAddr("100::") + startIPv6 := netip.MustParseAddr("2001:db8:ffff::") anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6) tests := []struct { @@ -26,9 +26,9 @@ func TestAnonymizeIP(t *testing.T) { {"Second Public IPv4", "4.3.2.1", "198.51.100.1"}, {"Repeated IPv4", "1.2.3.4", "198.51.100.0"}, {"Private IPv4", "192.168.1.1", "192.168.1.1"}, - {"First Public IPv6", "2607:f8b0:4005:805::200e", "100::"}, - {"Second Public IPv6", "a::b", "100::1"}, - {"Repeated IPv6", "2607:f8b0:4005:805::200e", "100::"}, + {"First Public IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"}, + {"Second Public IPv6", "a::b", "2001:db8:ffff::1"}, + {"Repeated IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"}, {"Private IPv6", "fe80::1", "fe80::1"}, {"In Range IPv4", "198.51.100.2", "198.51.100.2"}, } @@ -274,17 +274,27 @@ func TestAnonymizeString_IPAddresses(t *testing.T) { { name: "IPv6 Address", input: "Access attempted from 2001:db8::ff00:42", - expect: "Access attempted from 100::", + expect: "Access attempted from 2001:db8:ffff::", }, { name: "IPv6 Address with Port", input: "Access attempted from [2001:db8::ff00:42]:8080", - expect: "Access attempted from [100::]:8080", + expect: "Access attempted from [2001:db8:ffff::]:8080", }, { name: "Both IPv4 and IPv6", input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43", - expect: "IPv4: 198.51.100.1 and IPv6: 100::1", + expect: "IPv4: 198.51.100.1 and IPv6: 2001:db8:ffff::1", + }, + { + name: "STUN URI with IPv6", + input: "Connecting to stun:[2001:db8::ff00:42]:3478", + expect: "Connecting to stun:[2001:db8:ffff::]:3478", + }, + { + name: "HTTPS URI with IPv6", + input: "Visit https://[2001:db8::ff00:42]:443/path", + expect: "Visit https://[2001:db8:ffff::]:443/path", }, } diff --git a/client/cmd/capture.go b/client/cmd/capture.go new file mode 100644 index 000000000..95caaa5cd --- /dev/null +++ b/client/cmd/capture.go @@ -0,0 +1,196 @@ +package cmd + +import ( + "context" + "fmt" + "io" + "os" + "os/signal" + "path/filepath" + "strings" + "syscall" + + "github.com/hashicorp/go-multierror" + "github.com/spf13/cobra" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/durationpb" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/util/capture" +) + +var captureCmd = &cobra.Command{ + Use: "capture", + Short: "Capture packets on the WireGuard interface", + Long: `Captures decrypted packets flowing through the WireGuard interface. + +Default output is human-readable text. Use --pcap or --output for pcap binary. +Requires --enable-capture to be set at service install or reconfigure time. + +Examples: + netbird debug capture + netbird debug capture host 100.64.0.1 and port 443 + netbird debug capture tcp + netbird debug capture icmp + netbird debug capture src host 10.0.0.1 and dst port 80 + netbird debug capture -o capture.pcap + netbird debug capture --pcap | tshark -r - + netbird debug capture --pcap | tcpdump -r - -n`, + Args: cobra.ArbitraryArgs, + RunE: runCapture, +} + +func init() { + debugCmd.AddCommand(captureCmd) + + captureCmd.Flags().Bool("pcap", false, "Force pcap binary output (default when --output is set)") + captureCmd.Flags().BoolP("verbose", "v", false, "Show seq/ack, TTL, window, total length") + captureCmd.Flags().Bool("ascii", false, "Print payload as ASCII after each packet (useful for HTTP)") + captureCmd.Flags().Uint32("snap-len", 0, "Max bytes per packet (0 = full)") + captureCmd.Flags().DurationP("duration", "d", 0, "Capture duration (0 = until interrupted)") + captureCmd.Flags().StringP("output", "o", "", "Write pcap to file instead of stdout") +} + +func runCapture(cmd *cobra.Command, args []string) error { + conn, err := getClient(cmd) + if err != nil { + return err + } + defer func() { + if err := conn.Close(); err != nil { + cmd.PrintErrf(errCloseConnection, err) + } + }() + + client := proto.NewDaemonServiceClient(conn) + + req, err := buildCaptureRequest(cmd, args) + if err != nil { + return err + } + + ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + stream, err := client.StartCapture(ctx, req) + if err != nil { + return handleCaptureError(err) + } + + // First Recv is the empty acceptance message from the server. If the + // device is unavailable (kernel WG, not connected, capture disabled), + // the server returns an error instead. + if _, err := stream.Recv(); err != nil { + return handleCaptureError(err) + } + + out, cleanup, err := captureOutput(cmd) + if err != nil { + return err + } + + if req.TextOutput { + cmd.PrintErrf("Capturing packets... Press Ctrl+C to stop.\n") + } else { + cmd.PrintErrf("Capturing packets (pcap)... Press Ctrl+C to stop.\n") + } + + streamErr := streamCapture(ctx, cmd, stream, out) + cleanupErr := cleanup() + if streamErr != nil { + return streamErr + } + return cleanupErr +} + +func buildCaptureRequest(cmd *cobra.Command, args []string) (*proto.StartCaptureRequest, error) { + req := &proto.StartCaptureRequest{} + + if len(args) > 0 { + expr := strings.Join(args, " ") + if _, err := capture.ParseFilter(expr); err != nil { + return nil, fmt.Errorf("invalid filter: %w", err) + } + req.FilterExpr = expr + } + + if snap, _ := cmd.Flags().GetUint32("snap-len"); snap > 0 { + req.SnapLen = snap + } + if d, _ := cmd.Flags().GetDuration("duration"); d != 0 { + if d < 0 { + return nil, fmt.Errorf("duration must not be negative") + } + req.Duration = durationpb.New(d) + } + req.Verbose, _ = cmd.Flags().GetBool("verbose") + req.Ascii, _ = cmd.Flags().GetBool("ascii") + + outPath, _ := cmd.Flags().GetString("output") + forcePcap, _ := cmd.Flags().GetBool("pcap") + req.TextOutput = !forcePcap && outPath == "" + + return req, nil +} + +func streamCapture(ctx context.Context, cmd *cobra.Command, stream proto.DaemonService_StartCaptureClient, out io.Writer) error { + for { + pkt, err := stream.Recv() + if err != nil { + if ctx.Err() != nil { + cmd.PrintErrf("\nCapture stopped.\n") + return nil //nolint:nilerr // user interrupted + } + if err == io.EOF { + cmd.PrintErrf("\nCapture finished.\n") + return nil + } + return handleCaptureError(err) + } + if _, err := out.Write(pkt.GetData()); err != nil { + return fmt.Errorf("write output: %w", err) + } + } +} + +// captureOutput returns the writer for capture data and a cleanup function +// that finalizes the file. Errors from the cleanup must be propagated. +func captureOutput(cmd *cobra.Command) (io.Writer, func() error, error) { + outPath, _ := cmd.Flags().GetString("output") + if outPath == "" { + return os.Stdout, func() error { return nil }, nil + } + + f, err := os.CreateTemp(filepath.Dir(outPath), filepath.Base(outPath)+".*.tmp") + if err != nil { + return nil, nil, fmt.Errorf("create output file: %w", err) + } + tmpPath := f.Name() + return f, func() error { + var merr *multierror.Error + if err := f.Close(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("close output file: %w", err)) + } + fi, statErr := os.Stat(tmpPath) + if statErr != nil || fi.Size() == 0 { + if rmErr := os.Remove(tmpPath); rmErr != nil && !os.IsNotExist(rmErr) { + merr = multierror.Append(merr, fmt.Errorf("remove empty output file: %w", rmErr)) + } + return nberrors.FormatErrorOrNil(merr) + } + if err := os.Rename(tmpPath, outPath); err != nil { + merr = multierror.Append(merr, fmt.Errorf("rename output file: %w", err)) + return nberrors.FormatErrorOrNil(merr) + } + cmd.PrintErrf("Wrote %s\n", outPath) + return nberrors.FormatErrorOrNil(merr) + }, nil +} + +func handleCaptureError(err error) error { + if s, ok := status.FromError(err); ok { + return fmt.Errorf("%s", s.Message()) + } + return err +} diff --git a/client/cmd/debug.go b/client/cmd/debug.go index e3d3afe5f..2a8cdc887 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -9,6 +9,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/durationpb" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/debug" @@ -239,11 +240,50 @@ func runForDuration(cmd *cobra.Command, args []string) error { }() } + captureStarted := false + if wantCapture, _ := cmd.Flags().GetBool("capture"); wantCapture { + captureTimeout := duration + 30*time.Second + const maxBundleCapture = 10 * time.Minute + if captureTimeout > maxBundleCapture { + captureTimeout = maxBundleCapture + } + _, err := client.StartBundleCapture(cmd.Context(), &proto.StartBundleCaptureRequest{ + Timeout: durationpb.New(captureTimeout), + }) + if err != nil { + cmd.PrintErrf("Failed to start packet capture: %v\n", status.Convert(err).Message()) + } else { + captureStarted = true + cmd.Println("Packet capture started.") + // Safety: always stop on exit, even if the normal stop below runs too. + defer func() { + if captureStarted { + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil { + cmd.PrintErrf("Failed to stop packet capture: %v\n", err) + } + } + }() + } + } + if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil { return waitErr } cmd.Println("\nDuration completed") + if captureStarted { + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil { + cmd.PrintErrf("Failed to stop packet capture: %v\n", err) + } else { + captureStarted = false + cmd.Println("Packet capture stopped.") + } + } + if cpuProfilingStarted { if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil { cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err) @@ -416,4 +456,5 @@ func init() { forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle") forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server") forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle") + forCmd.Flags().Bool("capture", false, "Capture packets during the debug duration and include in bundle") } diff --git a/client/cmd/login.go b/client/cmd/login.go index 4521a67c9..bd37e30f1 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -10,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "golang.org/x/term" "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" @@ -23,6 +24,7 @@ import ( func init() { loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc) + loginCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc) loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc) loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location") } @@ -256,7 +258,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, } func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error { - openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser) + openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser, showQR) resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName}) if err != nil { @@ -324,7 +326,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err) } - openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser) + openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser, showQR) tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo) if err != nil { @@ -334,7 +336,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro return &tokenInfo, nil } -func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) { +func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser, showQR bool) { var codeMsg string if userCode != "" && !strings.Contains(verificationURIComplete, userCode) { codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode) @@ -348,6 +350,12 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro verificationURIComplete + " " + codeMsg) } + if showQR { + if f, ok := cmd.OutOrStdout().(*os.File); ok && term.IsTerminal(int(f.Fd())) { + printQRCode(f, verificationURIComplete) + } + } + cmd.Println("") if !noBrowser { diff --git a/client/cmd/qr.go b/client/cmd/qr.go new file mode 100644 index 000000000..8b2c489ff --- /dev/null +++ b/client/cmd/qr.go @@ -0,0 +1,25 @@ +package cmd + +import ( + "io" + + "github.com/mdp/qrterminal/v3" +) + +// printQRCode prints a QR code for the given URL to the writer. +// Called only when the user explicitly requests QR output via --qr. +func printQRCode(w io.Writer, url string) { + if url == "" { + return + } + qrterminal.GenerateWithConfig(url, qrterminal.Config{ + Level: qrterminal.M, + Writer: w, + HalfBlocks: true, + BlackChar: qrterminal.BLACK_BLACK, + WhiteChar: qrterminal.WHITE_WHITE, + BlackWhiteChar: qrterminal.BLACK_WHITE, + WhiteBlackChar: qrterminal.WHITE_BLACK, + QuietZone: qrterminal.QUIET_ZONE, + }) +} diff --git a/client/cmd/qr_test.go b/client/cmd/qr_test.go new file mode 100644 index 000000000..d12705b9e --- /dev/null +++ b/client/cmd/qr_test.go @@ -0,0 +1,26 @@ +package cmd + +import ( + "bytes" + "testing" +) + +func TestPrintQRCode_EmptyURL(t *testing.T) { + var buf bytes.Buffer + + printQRCode(&buf, "") + + if buf.Len() != 0 { + t.Error("expected no output for empty URL") + } +} + +func TestPrintQRCode_WritesOutput(t *testing.T) { + var buf bytes.Buffer + + printQRCode(&buf, "https://example.com/auth") + + if buf.Len() == 0 { + t.Error("expected QR code output for non-empty URL") + } +} diff --git a/client/cmd/root.go b/client/cmd/root.go index c872fe9f6..29d4328a1 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -75,6 +75,7 @@ var ( mtu uint16 profilesDisabled bool updateSettingsDisabled bool + captureEnabled bool networksDisabled bool rootCmd = &cobra.Command{ diff --git a/client/cmd/service.go b/client/cmd/service.go index f1123ce8c..56d8a8726 100644 --- a/client/cmd/service.go +++ b/client/cmd/service.go @@ -44,6 +44,7 @@ func init() { serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd) serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles") serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings") + serviceCmd.PersistentFlags().BoolVar(&captureEnabled, "enable-capture", false, "Enables packet capture via 'netbird debug capture'. To persist, use: netbird service install --enable-capture") serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks") rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go index 0943b6184..88121c067 100644 --- a/client/cmd/service_controller.go +++ b/client/cmd/service_controller.go @@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error { } } - serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, networksDisabled) + serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, captureEnabled, networksDisabled) if err := serverInstance.Start(); err != nil { log.Fatalf("failed to start daemon: %v", err) } diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index 5ada6f633..2d45fa063 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -59,6 +59,10 @@ func buildServiceArguments() []string { args = append(args, "--disable-update-settings") } + if captureEnabled { + args = append(args, "--enable-capture") + } + if networksDisabled { args = append(args, "--disable-networks") } diff --git a/client/cmd/service_params.go b/client/cmd/service_params.go index 5a86aebc6..192e0ac60 100644 --- a/client/cmd/service_params.go +++ b/client/cmd/service_params.go @@ -28,6 +28,7 @@ type serviceParams struct { LogFiles []string `json:"log_files,omitempty"` DisableProfiles bool `json:"disable_profiles,omitempty"` DisableUpdateSettings bool `json:"disable_update_settings,omitempty"` + EnableCapture bool `json:"enable_capture,omitempty"` DisableNetworks bool `json:"disable_networks,omitempty"` ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"` } @@ -79,6 +80,7 @@ func currentServiceParams() *serviceParams { LogFiles: logFiles, DisableProfiles: profilesDisabled, DisableUpdateSettings: updateSettingsDisabled, + EnableCapture: captureEnabled, DisableNetworks: networksDisabled, } @@ -144,6 +146,10 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) { updateSettingsDisabled = params.DisableUpdateSettings } + if !serviceCmd.PersistentFlags().Changed("enable-capture") { + captureEnabled = params.EnableCapture + } + if !serviceCmd.PersistentFlags().Changed("disable-networks") { networksDisabled = params.DisableNetworks } diff --git a/client/cmd/service_params_test.go b/client/cmd/service_params_test.go index 7e04e5abe..f338c12f4 100644 --- a/client/cmd/service_params_test.go +++ b/client/cmd/service_params_test.go @@ -535,6 +535,7 @@ func fieldToGlobalVar(field string) string { "LogFiles": "logFiles", "DisableProfiles": "profilesDisabled", "DisableUpdateSettings": "updateSettingsDisabled", + "EnableCapture": "captureEnabled", "DisableNetworks": "networksDisabled", "ServiceEnvVars": "serviceEnvVars", } diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index 0acf0b133..d6e052e08 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -523,7 +523,7 @@ func parseHostnameAndCommand(args []string) error { } func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error { - target := fmt.Sprintf("%s:%d", addr, port) + target := net.JoinHostPort(strings.Trim(addr, "[]"), strconv.Itoa(port)) c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{ KnownHostsFile: knownHostsFile, IdentityFile: identityFile, @@ -787,10 +787,10 @@ func isUnixSocket(path string) bool { return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./") } -// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces. +// normalizeLocalHost converts "*" to "" for binding to all interfaces (dual-stack). func normalizeLocalHost(host string) string { if host == "*" { - return "0.0.0.0" + return "" } return host } diff --git a/client/cmd/ssh_test.go b/client/cmd/ssh_test.go index 43291fa87..16ffadb90 100644 --- a/client/cmd/ssh_test.go +++ b/client/cmd/ssh_test.go @@ -527,10 +527,10 @@ func TestParsePortForward(t *testing.T) { { name: "wildcard bind all interfaces", spec: "*:8080:localhost:80", - expectedLocal: "0.0.0.0:8080", + expectedLocal: ":8080", expectedRemote: "localhost:80", expectError: false, - description: "Wildcard * should bind to all interfaces (0.0.0.0)", + description: "Wildcard * should bind to all interfaces (dual-stack)", }, { name: "wildcard for port only", diff --git a/client/cmd/status.go b/client/cmd/status.go index c35a06eb3..dae30e854 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -20,6 +20,7 @@ import ( var ( detailFlag bool ipv4Flag bool + ipv6Flag bool jsonFlag bool yamlFlag bool ipsFilter []string @@ -45,8 +46,9 @@ func init() { statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format") statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format") statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33") - statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4") - statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200") + statusCmd.PersistentFlags().BoolVar(&ipv6Flag, "ipv6", false, "display only NetBird IPv6 of this peer") + statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4", "ipv6") + statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1") statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected") statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P") @@ -101,6 +103,14 @@ func statusFunc(cmd *cobra.Command, args []string) error { return nil } + if ipv6Flag { + ipv6 := resp.GetFullStatus().GetLocalPeerState().GetIpv6() + if ipv6 != "" { + cmd.Print(parseInterfaceIP(ipv6)) + } + return nil + } + pm := profilemanager.NewProfileManager() var profName string if activeProf, err := pm.GetActiveProfile(); err == nil { diff --git a/client/cmd/system.go b/client/cmd/system.go index f63432401..b386fe4ae 100644 --- a/client/cmd/system.go +++ b/client/cmd/system.go @@ -8,6 +8,7 @@ const ( disableFirewallFlag = "disable-firewall" blockLANAccessFlag = "block-lan-access" blockInboundFlag = "block-inbound" + disableIPv6Flag = "disable-ipv6" ) var ( @@ -17,6 +18,7 @@ var ( disableFirewall bool blockLANAccess bool blockInbound bool + disableIPv6 bool ) func init() { @@ -39,4 +41,7 @@ func init() { upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false, "Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+ "This overrides any policies received from the management service.") + + upCmd.PersistentFlags().BoolVar(&disableIPv6, disableIPv6Flag, false, + "Disable IPv6 overlay. If enabled, the client won't request or use an IPv6 overlay address.") } diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index d7564c353..c24965e8d 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -135,7 +135,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp if err != nil { t.Fatal(err) } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil) if err != nil { t.Fatal(err) } @@ -160,7 +160,7 @@ func startClientDaemon( s := grpc.NewServer() server := client.New(ctx, - "", "", false, false, false) + "", "", false, false, false, false) if err := server.Start(); err != nil { t.Fatal(err) } diff --git a/client/cmd/up.go b/client/cmd/up.go index f5766522a..cabd0aacf 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -39,6 +39,9 @@ const ( noBrowserFlag = "no-browser" noBrowserDesc = "do not open the browser for SSO login" + showQRFlag = "qr" + showQRDesc = "show QR code for the SSO login URL (useful for headless machines without browser access)" + profileNameFlag = "profile" profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used." ) @@ -48,6 +51,7 @@ var ( dnsLabels []string dnsLabelsValidated domain.List noBrowser bool + showQR bool profileName string configPath string @@ -80,6 +84,7 @@ func init() { ) upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc) + upCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc) upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc) upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ") @@ -430,6 +435,10 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro req.BlockInbound = &blockInbound } + if cmd.Flag(disableIPv6Flag).Changed { + req.DisableIpv6 = &disableIPv6 + } + if cmd.Flag(enableLazyConnectionFlag).Changed { req.LazyConnectionEnabled = &lazyConnEnabled } @@ -547,6 +556,10 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil ic.BlockInbound = &blockInbound } + if cmd.Flag(disableIPv6Flag).Changed { + ic.DisableIPv6 = &disableIPv6 + } + if cmd.Flag(enableLazyConnectionFlag).Changed { ic.LazyConnectionEnabled = &lazyConnEnabled } @@ -661,6 +674,10 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte loginRequest.BlockInbound = &blockInbound } + if cmd.Flag(disableIPv6Flag).Changed { + loginRequest.DisableIpv6 = &disableIPv6 + } + if cmd.Flag(enableLazyConnectionFlag).Changed { loginRequest.LazyConnectionEnabled = &lazyConnEnabled } diff --git a/client/embed/capture.go b/client/embed/capture.go new file mode 100644 index 000000000..30f9b496f --- /dev/null +++ b/client/embed/capture.go @@ -0,0 +1,65 @@ +package embed + +import ( + "io" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/util/capture" +) + +// CaptureOptions configures a packet capture session. +type CaptureOptions struct { + // Output receives pcap-formatted data. Nil disables pcap output. + Output io.Writer + // TextOutput receives human-readable packet summaries. Nil disables text output. + TextOutput io.Writer + // Filter is a BPF-like filter expression (e.g. "host 10.0.0.1 and tcp port 443"). + // Empty captures all packets. + Filter string + // Verbose adds seq/ack, TTL, window, and total length to text output. + Verbose bool + // ASCII dumps transport payload as printable ASCII after each packet line. + ASCII bool +} + +// CaptureStats reports capture session counters. +type CaptureStats struct { + Packets int64 + Bytes int64 + Dropped int64 +} + +// CaptureSession represents an active packet capture. Call Stop to end the +// capture and flush buffered packets. +type CaptureSession struct { + sess *capture.Session + engine *internal.Engine +} + +// Stop ends the capture, flushes remaining packets, and detaches from the device. +// Safe to call multiple times. +func (cs *CaptureSession) Stop() { + if cs.engine != nil { + _ = cs.engine.SetCapture(nil) + cs.engine = nil + } + if cs.sess != nil { + cs.sess.Stop() + } +} + +// Stats returns current capture counters. +func (cs *CaptureSession) Stats() CaptureStats { + s := cs.sess.Stats() + return CaptureStats{ + Packets: s.Packets, + Bytes: s.Bytes, + Dropped: s.Dropped, + } +} + +// Done returns a channel that is closed when the capture's writer goroutine +// has fully exited and all buffered packets have been flushed. +func (cs *CaptureSession) Done() <-chan struct{} { + return cs.sess.Done() +} diff --git a/client/embed/embed.go b/client/embed/embed.go index 88f7e541c..4b9445b97 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -24,6 +24,7 @@ import ( "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/shared/management/domain" mgmProto "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/util/capture" ) var ( @@ -65,7 +66,7 @@ type Options struct { PrivateKey string // ManagementURL overrides the default management server URL ManagementURL string - // PreSharedKey is the pre-shared key for the WireGuard interface + // PreSharedKey is the pre-shared key for the tunnel interface PreSharedKey string // LogOutput is the output destination for logs (defaults to os.Stderr if nil) LogOutput io.Writer @@ -79,11 +80,13 @@ type Options struct { StatePath string // DisableClientRoutes disables the client routes DisableClientRoutes bool + // DisableIPv6 disables IPv6 overlay addressing + DisableIPv6 bool // BlockInbound blocks all inbound connections from peers BlockInbound bool - // WireguardPort is the port for the WireGuard interface. Use 0 for a random port. + // WireguardPort is the port for the tunnel interface. Use 0 for a random port. WireguardPort *int - // MTU is the MTU for the WireGuard interface. + // MTU is the MTU for the tunnel interface. // Valid values are in the range 576..8192 bytes. // If non-nil, this value overrides any value stored in the config file. // If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280. @@ -170,6 +173,7 @@ func New(opts Options) (*Client, error) { PreSharedKey: &opts.PreSharedKey, DisableServerRoutes: &t, DisableClientRoutes: &opts.DisableClientRoutes, + DisableIPv6: &opts.DisableIPv6, BlockInbound: &opts.BlockInbound, WireguardPort: opts.WireguardPort, MTU: opts.MTU, @@ -469,6 +473,52 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error { return sshcommon.VerifyHostKey(storedKey, key, peerAddress) } +// StartCapture begins capturing packets on this client's tunnel device. +// Only one capture can be active at a time; starting a new one stops the previous. +// Call StopCapture (or CaptureSession.Stop) to end it. +func (c *Client) StartCapture(opts CaptureOptions) (*CaptureSession, error) { + engine, err := c.getEngine() + if err != nil { + return nil, err + } + + var matcher capture.Matcher + if opts.Filter != "" { + m, err := capture.ParseFilter(opts.Filter) + if err != nil { + return nil, fmt.Errorf("parse filter: %w", err) + } + matcher = m + } + + sess, err := capture.NewSession(capture.Options{ + Output: opts.Output, + TextOutput: opts.TextOutput, + Matcher: matcher, + Verbose: opts.Verbose, + ASCII: opts.ASCII, + }) + if err != nil { + return nil, fmt.Errorf("create capture session: %w", err) + } + + if err := engine.SetCapture(sess); err != nil { + sess.Stop() + return nil, fmt.Errorf("set capture: %w", err) + } + + return &CaptureSession{sess: sess, engine: engine}, nil +} + +// StopCapture stops the active capture session if one is running. +func (c *Client) StopCapture() error { + engine, err := c.getEngine() + if err != nil { + return err + } + return engine.SetCapture(nil) +} + // getEngine safely retrieves the engine from the client with proper locking. // Returns ErrClientNotStarted if the client is not started. // Returns ErrEngineNotStarted if the engine is not available. diff --git a/client/firewall/firewalld/firewalld.go b/client/firewall/firewalld/firewalld.go new file mode 100644 index 000000000..188ea61dd --- /dev/null +++ b/client/firewall/firewalld/firewalld.go @@ -0,0 +1,11 @@ +// Package firewalld integrates with the firewalld daemon so NetBird can place +// its wg interface into firewalld's "trusted" zone. This is required because +// firewalld's nftables chains are created with NFT_CHAIN_OWNER on recent +// versions, which returns EPERM to any other process that tries to insert +// rules into them. The workaround mirrors what Tailscale does: let firewalld +// itself add the accept rules to its own chains by trusting the interface. +package firewalld + +// TrustedZone is the firewalld zone name used for interfaces whose traffic +// should bypass firewalld filtering. +const TrustedZone = "trusted" diff --git a/client/firewall/firewalld/firewalld_linux.go b/client/firewall/firewalld/firewalld_linux.go new file mode 100644 index 000000000..924a04b0a --- /dev/null +++ b/client/firewall/firewalld/firewalld_linux.go @@ -0,0 +1,260 @@ +//go:build linux + +package firewalld + +import ( + "context" + "errors" + "fmt" + "os/exec" + "strings" + "sync" + "time" + + "github.com/godbus/dbus/v5" + log "github.com/sirupsen/logrus" +) + +const ( + dbusDest = "org.fedoraproject.FirewallD1" + dbusPath = "/org/fedoraproject/FirewallD1" + dbusRootIface = "org.fedoraproject.FirewallD1" + dbusZoneIface = "org.fedoraproject.FirewallD1.zone" + + errZoneAlreadySet = "ZONE_ALREADY_SET" + errAlreadyEnabled = "ALREADY_ENABLED" + errUnknownIface = "UNKNOWN_INTERFACE" + errNotEnabled = "NOT_ENABLED" + + // callTimeout bounds each individual DBus or firewall-cmd invocation. + // A fresh context is created for each call so a slow DBus probe can't + // exhaust the deadline before the firewall-cmd fallback gets to run. + callTimeout = 3 * time.Second +) + +var ( + errDBusUnavailable = errors.New("firewalld dbus unavailable") + + // trustLogOnce ensures the "added to trusted zone" message is logged at + // Info level only for the first successful add per process; repeat adds + // from other init paths are quieter. + trustLogOnce sync.Once + + parentCtxMu sync.RWMutex + parentCtx context.Context = context.Background() +) + +// SetParentContext installs a parent context whose cancellation aborts any +// in-flight TrustInterface call. It does not affect UntrustInterface, which +// always uses a fresh Background-rooted timeout so cleanup can still run +// during engine shutdown when the engine context is already cancelled. +func SetParentContext(ctx context.Context) { + parentCtxMu.Lock() + parentCtx = ctx + parentCtxMu.Unlock() +} + +func getParentContext() context.Context { + parentCtxMu.RLock() + defer parentCtxMu.RUnlock() + return parentCtx +} + +// TrustInterface places iface into firewalld's trusted zone if firewalld is +// running. It is idempotent and best-effort: errors are returned so callers +// can log, but a non-running firewalld is not an error. Only the first +// successful call per process logs at Info. Respects the parent context set +// via SetParentContext so startup-time cancellation unblocks it. +func TrustInterface(iface string) error { + parent := getParentContext() + if !isRunning(parent) { + return nil + } + if err := addTrusted(parent, iface); err != nil { + return fmt.Errorf("add %s to firewalld trusted zone: %w", iface, err) + } + trustLogOnce.Do(func() { + log.Infof("added %s to firewalld trusted zone", iface) + }) + log.Debugf("firewalld: ensured %s is in trusted zone", iface) + return nil +} + +// UntrustInterface removes iface from firewalld's trusted zone if firewalld +// is running. Idempotent. Uses a Background-rooted timeout so it still runs +// during shutdown after the engine context has been cancelled. +func UntrustInterface(iface string) error { + if !isRunning(context.Background()) { + return nil + } + if err := removeTrusted(context.Background(), iface); err != nil { + return fmt.Errorf("remove %s from firewalld trusted zone: %w", iface, err) + } + return nil +} + +func newCallContext(parent context.Context) (context.Context, context.CancelFunc) { + return context.WithTimeout(parent, callTimeout) +} + +func isRunning(parent context.Context) bool { + ctx, cancel := newCallContext(parent) + ok, err := isRunningDBus(ctx) + cancel() + if err == nil { + return ok + } + if errors.Is(err, errDBusUnavailable) || errors.Is(err, context.DeadlineExceeded) { + ctx, cancel = newCallContext(parent) + defer cancel() + return isRunningCLI(ctx) + } + return false +} + +func addTrusted(parent context.Context, iface string) error { + ctx, cancel := newCallContext(parent) + err := addDBus(ctx, iface) + cancel() + if err == nil { + return nil + } + if !errors.Is(err, errDBusUnavailable) { + log.Debugf("firewalld: dbus add failed, falling back to firewall-cmd: %v", err) + } + ctx, cancel = newCallContext(parent) + defer cancel() + return addCLI(ctx, iface) +} + +func removeTrusted(parent context.Context, iface string) error { + ctx, cancel := newCallContext(parent) + err := removeDBus(ctx, iface) + cancel() + if err == nil { + return nil + } + if !errors.Is(err, errDBusUnavailable) { + log.Debugf("firewalld: dbus remove failed, falling back to firewall-cmd: %v", err) + } + ctx, cancel = newCallContext(parent) + defer cancel() + return removeCLI(ctx, iface) +} + +func isRunningDBus(ctx context.Context) (bool, error) { + conn, err := dbus.SystemBus() + if err != nil { + return false, fmt.Errorf("%w: %v", errDBusUnavailable, err) + } + obj := conn.Object(dbusDest, dbusPath) + + var zone string + if err := obj.CallWithContext(ctx, dbusRootIface+".getDefaultZone", 0).Store(&zone); err != nil { + return false, fmt.Errorf("firewalld getDefaultZone: %w", err) + } + return true, nil +} + +func isRunningCLI(ctx context.Context) bool { + if _, err := exec.LookPath("firewall-cmd"); err != nil { + return false + } + return exec.CommandContext(ctx, "firewall-cmd", "--state").Run() == nil +} + +func addDBus(ctx context.Context, iface string) error { + conn, err := dbus.SystemBus() + if err != nil { + return fmt.Errorf("%w: %v", errDBusUnavailable, err) + } + obj := conn.Object(dbusDest, dbusPath) + + call := obj.CallWithContext(ctx, dbusZoneIface+".addInterface", 0, TrustedZone, iface) + if call.Err == nil { + return nil + } + + if dbusErrContains(call.Err, errAlreadyEnabled) { + return nil + } + + if dbusErrContains(call.Err, errZoneAlreadySet) { + move := obj.CallWithContext(ctx, dbusZoneIface+".changeZoneOfInterface", 0, TrustedZone, iface) + if move.Err != nil { + return fmt.Errorf("firewalld changeZoneOfInterface: %w", move.Err) + } + return nil + } + + return fmt.Errorf("firewalld addInterface: %w", call.Err) +} + +func removeDBus(ctx context.Context, iface string) error { + conn, err := dbus.SystemBus() + if err != nil { + return fmt.Errorf("%w: %v", errDBusUnavailable, err) + } + obj := conn.Object(dbusDest, dbusPath) + + call := obj.CallWithContext(ctx, dbusZoneIface+".removeInterface", 0, TrustedZone, iface) + if call.Err == nil { + return nil + } + + if dbusErrContains(call.Err, errUnknownIface) || dbusErrContains(call.Err, errNotEnabled) { + return nil + } + + return fmt.Errorf("firewalld removeInterface: %w", call.Err) +} + +func addCLI(ctx context.Context, iface string) error { + if _, err := exec.LookPath("firewall-cmd"); err != nil { + return fmt.Errorf("firewall-cmd not available: %w", err) + } + + // --change-interface (no --permanent) binds the interface for the + // current runtime only; we do not want membership to persist across + // reboots because netbird re-asserts it on every startup. + out, err := exec.CommandContext(ctx, + "firewall-cmd", "--zone="+TrustedZone, "--change-interface="+iface, + ).CombinedOutput() + if err != nil { + return fmt.Errorf("firewall-cmd change-interface: %w: %s", err, strings.TrimSpace(string(out))) + } + return nil +} + +func removeCLI(ctx context.Context, iface string) error { + if _, err := exec.LookPath("firewall-cmd"); err != nil { + return fmt.Errorf("firewall-cmd not available: %w", err) + } + + out, err := exec.CommandContext(ctx, + "firewall-cmd", "--zone="+TrustedZone, "--remove-interface="+iface, + ).CombinedOutput() + if err != nil { + msg := strings.TrimSpace(string(out)) + if strings.Contains(msg, errUnknownIface) || strings.Contains(msg, errNotEnabled) { + return nil + } + return fmt.Errorf("firewall-cmd remove-interface: %w: %s", err, msg) + } + return nil +} + +func dbusErrContains(err error, code string) bool { + if err == nil { + return false + } + var de dbus.Error + if errors.As(err, &de) { + for _, b := range de.Body { + if s, ok := b.(string); ok && strings.Contains(s, code) { + return true + } + } + } + return strings.Contains(err.Error(), code) +} diff --git a/client/firewall/firewalld/firewalld_linux_test.go b/client/firewall/firewalld/firewalld_linux_test.go new file mode 100644 index 000000000..d812745fc --- /dev/null +++ b/client/firewall/firewalld/firewalld_linux_test.go @@ -0,0 +1,49 @@ +//go:build linux + +package firewalld + +import ( + "errors" + "testing" + + "github.com/godbus/dbus/v5" +) + +func TestDBusErrContains(t *testing.T) { + tests := []struct { + name string + err error + code string + want bool + }{ + {"nil error", nil, errZoneAlreadySet, false}, + {"plain error match", errors.New("ZONE_ALREADY_SET: wt0"), errZoneAlreadySet, true}, + {"plain error miss", errors.New("something else"), errZoneAlreadySet, false}, + { + "dbus.Error body match", + dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"ZONE_ALREADY_SET: wt0"}}, + errZoneAlreadySet, + true, + }, + { + "dbus.Error body miss", + dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"INVALID_INTERFACE"}}, + errAlreadyEnabled, + false, + }, + { + "dbus.Error non-string body falls back to Error()", + dbus.Error{Name: "x", Body: []any{123}}, + "x", + true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := dbusErrContains(tc.err, tc.code) + if got != tc.want { + t.Fatalf("dbusErrContains(%v, %q) = %v; want %v", tc.err, tc.code, got, tc.want) + } + }) + } +} diff --git a/client/firewall/firewalld/firewalld_other.go b/client/firewall/firewalld/firewalld_other.go new file mode 100644 index 000000000..cfa28221d --- /dev/null +++ b/client/firewall/firewalld/firewalld_other.go @@ -0,0 +1,25 @@ +//go:build !linux + +package firewalld + +import "context" + +// SetParentContext is a no-op on non-Linux platforms because firewalld only +// runs on Linux. +func SetParentContext(context.Context) { + // intentionally empty: firewalld is a Linux-only daemon +} + +// TrustInterface is a no-op on non-Linux platforms because firewalld only +// runs on Linux. +func TrustInterface(string) error { + // intentionally empty: firewalld is a Linux-only daemon + return nil +} + +// UntrustInterface is a no-op on non-Linux platforms because firewalld only +// runs on Linux. +func UntrustInterface(string) error { + // intentionally empty: firewalld is a Linux-only daemon + return nil +} diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index e629f7881..e5e19cec9 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -40,6 +40,7 @@ type aclManager struct { entries aclEntries optionalEntries map[string][]entry ipsetStore *ipsetStore + v6 bool stateManager *statemanager.Manager } @@ -51,6 +52,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*acl entries: make(map[string][][]string), optionalEntries: make(map[string][]entry), ipsetStore: newIpsetStore(), + v6: iptablesClient.Proto() == iptables.ProtocolIPv6, }, nil } @@ -85,7 +87,11 @@ func (m *aclManager) AddPeerFiltering( chain := chainNameInputRules ipsetName = transformIPsetName(ipsetName, sPort, dPort, action) - specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName) + if m.v6 && ipsetName != "" { + ipsetName += "-v6" + } + proto := protoForFamily(protocol, m.v6) + specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName) mangleSpecs := slices.Clone(specs) mangleSpecs = append(mangleSpecs, @@ -109,6 +115,7 @@ func (m *aclManager) AddPeerFiltering( ip: ip.String(), chain: chain, specs: specs, + v6: m.v6, }}, nil } @@ -161,6 +168,7 @@ func (m *aclManager) AddPeerFiltering( ipsetName: ipsetName, ip: ip.String(), chain: chain, + v6: m.v6, } m.updateState() @@ -413,8 +421,13 @@ func (m *aclManager) updateState() { currentState.Lock() defer currentState.Unlock() - currentState.ACLEntries = m.entries - currentState.ACLIPsetStore = m.ipsetStore + if m.v6 { + currentState.ACLEntries6 = m.entries + currentState.ACLIPsetStore6 = m.ipsetStore + } else { + currentState.ACLEntries = m.entries + currentState.ACLIPsetStore = m.ipsetStore + } if err := m.stateManager.UpdateState(currentState); err != nil { log.Errorf("failed to update state: %v", err) @@ -422,13 +435,22 @@ func (m *aclManager) updateState() { } // filterRuleSpecs returns the specs of a filtering rule +// protoForFamily translates ICMP to ICMPv6 for ip6tables. +// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp". +func protoForFamily(protocol firewall.Protocol, v6 bool) string { + if v6 && protocol == firewall.ProtocolICMP { + return "ipv6-icmp" + } + return string(protocol) +} + func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) { // don't use IP matching if IP is 0.0.0.0 matchByIP := !ip.IsUnspecified() if matchByIP { if ipsetName != "" { - specs = append(specs, "-m", "set", "--set", ipsetName, "src") + specs = append(specs, "-m", "set", "--match-set", ipsetName, "src") } else { specs = append(specs, "-s", ip.String()) } @@ -474,6 +496,9 @@ func (m *aclManager) createIPSet(name string) error { opts := ipset.CreateOptions{ Replace: true, } + if m.v6 { + opts.Family = ipset.FamilyIPV6 + } if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil { return fmt.Errorf("create ipset %s: %w", name, err) diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index a1d4467d5..696537dd8 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -12,11 +12,16 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/firewall/firewalld" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/statemanager" ) +type resetter interface { + Reset() error +} + // Manager of iptables firewall type Manager struct { mutex sync.Mutex @@ -27,6 +32,11 @@ type Manager struct { aclMgr *aclManager router *router rawSupported bool + + // IPv6 counterparts, nil when no v6 overlay + ipv6Client *iptables.IPTables + aclMgr6 *aclManager + router6 *router } // iFaceMapper defines subset methods of interface required for manager @@ -57,9 +67,43 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { return nil, fmt.Errorf("create acl manager: %w", err) } + if wgIface.Address().HasIPv6() { + if err := m.createIPv6Components(wgIface, mtu); err != nil { + return nil, fmt.Errorf("create IPv6 firewall: %w", err) + } + } + return m, nil } +func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error { + ip6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6) + if err != nil { + return fmt.Errorf("init ip6tables: %w", err) + } + m.ipv6Client = ip6Client + + m.router6, err = newRouter(ip6Client, wgIface, mtu) + if err != nil { + return fmt.Errorf("create v6 router: %w", err) + } + + // Share the same IP forwarding state with the v4 router, since + // EnableIPForwarding controls both v4 and v6 sysctls. + m.router6.ipFwdState = m.router.ipFwdState + + m.aclMgr6, err = newAclManager(ip6Client, wgIface) + if err != nil { + return fmt.Errorf("create v6 acl manager: %w", err) + } + + return nil +} + +func (m *Manager) hasIPv6() bool { + return m.ipv6Client != nil +} + func (m *Manager) Init(stateManager *statemanager.Manager) error { state := &ShutdownState{ InterfaceState: &InterfaceState{ @@ -73,19 +117,20 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { log.Errorf("failed to update state: %v", err) } - if err := m.router.init(stateManager); err != nil { - return fmt.Errorf("router init: %w", err) - } - - if err := m.aclMgr.init(stateManager); err != nil { - // TODO: cleanup router - return fmt.Errorf("acl manager init: %w", err) + if err := m.initChains(stateManager); err != nil { + return err } if err := m.initNoTrackChain(); err != nil { log.Warnf("raw table not available, notrack rules will be disabled: %v", err) } + // Trust after all fatal init steps so a later failure doesn't leave the + // interface in firewalld's trusted zone without a corresponding Close. + if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil { + log.Warnf("failed to trust interface in firewalld: %v", err) + } + // persist early to ensure cleanup of chains go func() { if err := stateManager.PersistState(context.Background()); err != nil { @@ -96,6 +141,41 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { return nil } +// initChains initializes router and ACL chains for both address families, +// rolling back on failure. +func (m *Manager) initChains(stateManager *statemanager.Manager) error { + type initStep struct { + name string + init func(*statemanager.Manager) error + mgr resetter + } + + steps := []initStep{ + {"router", m.router.init, m.router}, + {"acl manager", m.aclMgr.init, m.aclMgr}, + } + if m.hasIPv6() { + steps = append(steps, + initStep{"v6 router", m.router6.init, m.router6}, + initStep{"v6 acl manager", m.aclMgr6.init, m.aclMgr6}, + ) + } + + var initialized []initStep + for _, s := range steps { + if err := s.init(stateManager); err != nil { + for i := len(initialized) - 1; i >= 0; i-- { + if rerr := initialized[i].mgr.Reset(); rerr != nil { + log.Warnf("rollback %s: %v", initialized[i].name, rerr) + } + } + return fmt.Errorf("%s init: %w", s.name, err) + } + initialized = append(initialized, s) + } + return nil +} + // AddPeerFiltering adds a rule to the firewall // // Comment will be ignored because some system this feature is not supported @@ -111,7 +191,13 @@ func (m *Manager) AddPeerFiltering( m.mutex.Lock() defer m.mutex.Unlock() - return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) + if ip.To4() != nil { + return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) + } + if !m.hasIPv6() { + return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized) + } + return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) } func (m *Manager) AddRouteFiltering( @@ -125,25 +211,48 @@ func (m *Manager) AddRouteFiltering( m.mutex.Lock() defer m.mutex.Unlock() - if destination.IsPrefix() && !destination.Prefix.Addr().Is4() { - return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String()) + if isIPv6RouteRule(sources, destination) { + if !m.hasIPv6() { + return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) } return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) } +func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool { + if destination.IsPrefix() { + return destination.Prefix.Addr().Is6() + } + return len(sources) > 0 && sources[0].Addr().Is6() +} + // DeletePeerRule from the firewall by rule definition func (m *Manager) DeletePeerRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && isIPv6IptRule(rule) { + return m.aclMgr6.DeletePeerRule(rule) + } return m.aclMgr.DeletePeerRule(rule) } +func isIPv6IptRule(rule firewall.Rule) bool { + r, ok := rule.(*Rule) + return ok && r.v6 +} + +// DeleteRouteRule deletes a routing rule. +// Route rules are keyed by content hash. Check v4 first, try v6 if not found. func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && !m.router.hasRule(rule.ID()) { + return m.router6.DeleteRouteRule(rule) + } return m.router.DeleteRouteRule(rule) } @@ -159,18 +268,65 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.AddNatRule(pair) + if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { + if !m.hasIPv6() { + return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.AddNatRule(pair) + } + + if err := m.router.AddNatRule(pair); err != nil { + return err + } + + // Dynamic routes need NAT in both tables since resolved IPs can be + // either v4 or v6. This covers both DomainSet (modern) and the legacy + // wildcard 0.0.0.0/0 destination where the client resolves DNS. + if m.hasIPv6() && pair.Dynamic { + v6Pair := firewall.ToV6NatPair(pair) + if err := m.router6.AddNatRule(v6Pair); err != nil { + return fmt.Errorf("add v6 NAT rule: %w", err) + } + } + + return nil } func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveNatRule(pair) + if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { + if !m.hasIPv6() { + return nil + } + return m.router6.RemoveNatRule(pair) + } + + var merr *multierror.Error + + if err := m.router.RemoveNatRule(pair); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err)) + } + + if m.hasIPv6() && pair.Dynamic { + v6Pair := firewall.ToV6NatPair(pair) + if err := m.router6.RemoveNatRule(v6Pair); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err)) + } + } + + return nberrors.FormatErrorOrNil(merr) } func (m *Manager) SetLegacyManagement(isLegacy bool) error { - return firewall.SetLegacyManagement(m.router, isLegacy) + if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil { + return err + } + if m.hasIPv6() { + return firewall.SetLegacyManagement(m.router6, isLegacy) + } + return nil } // Reset firewall to the default state @@ -184,6 +340,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err)) } + if m.hasIPv6() { + if err := m.aclMgr6.Reset(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err)) + } + if err := m.router6.Reset(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err)) + } + } + if err := m.aclMgr.Reset(); err != nil { merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err)) } @@ -191,6 +356,12 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err)) } + // Appending to merr intentionally blocks DeleteState below so ShutdownState + // stays persisted and the crash-recovery path retries firewalld cleanup. + if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil { + merr = multierror.Append(merr, err) + } + // attempt to delete state only if all other operations succeeded if merr == nil { if err := stateManager.DeleteState(&ShutdownState{}); err != nil { @@ -205,19 +376,21 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { // This is called when USPFilter wraps the native firewall, adding blanket accept // rules so that packet filtering is handled in userspace instead of by netfilter. func (m *Manager) AllowNetbird() error { - _, err := m.AddPeerFiltering( - nil, - net.IP{0, 0, 0, 0}, - firewall.ProtocolALL, - nil, - nil, - firewall.ActionAccept, - "", - ) - if err != nil { - return fmt.Errorf("allow netbird interface traffic: %w", err) + var merr *multierror.Error + if _, err := m.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil { + merr = multierror.Append(merr, fmt.Errorf("allow netbird v4 interface traffic: %w", err)) } - return nil + if m.hasIPv6() { + if _, err := m.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil { + merr = multierror.Append(merr, fmt.Errorf("allow netbird v6 interface traffic: %w", err)) + } + } + + if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil { + log.Warnf("failed to trust interface in firewalld: %v", err) + } + + return nberrors.FormatErrorOrNil(merr) } // Flush doesn't need to be implemented for this manager @@ -247,6 +420,12 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) m.mutex.Lock() defer m.mutex.Unlock() + if rule.TranslatedAddress.Is6() { + if !m.hasIPv6() { + return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.AddDNATRule(rule) + } return m.router.AddDNATRule(rule) } @@ -255,6 +434,9 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) { + return m.router6.DeleteDNATRule(rule) + } return m.router.DeleteDNATRule(rule) } @@ -263,39 +445,82 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.UpdateSet(set, prefixes) + var v4Prefixes, v6Prefixes []netip.Prefix + for _, p := range prefixes { + if p.Addr().Is6() { + v6Prefixes = append(v6Prefixes, p) + } else { + v4Prefixes = append(v4Prefixes, p) + } + } + + if err := m.router.UpdateSet(set, v4Prefixes); err != nil { + return err + } + + if m.hasIPv6() && len(v6Prefixes) > 0 { + if err := m.router6.UpdateSet(set, v6Prefixes); err != nil { + return fmt.Errorf("update v6 set: %w", err) + } + } + + return nil } // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. -func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) + if localAddr.Is6() { + if !m.hasIPv6() { + return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort) + } + return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort) } // RemoveInboundDNAT removes an inbound DNAT rule. -func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) + if localAddr.Is6() { + if !m.hasIPv6() { + return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort) + } + return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort) } // AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. -func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort) + if localAddr.Is6() { + if !m.hasIPv6() { + return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort) + } + return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort) } // RemoveOutputDNAT removes an OUTPUT chain DNAT rule. -func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort) + if localAddr.Is6() { + if !m.hasIPv6() { + return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort) + } + return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort) } const ( diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index a7c4f67dd..290e5da1e 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -54,8 +54,10 @@ const ( snatSuffix = "_snat" fwdSuffix = "_fwd" - // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation - ipTCPHeaderMinSize = 40 + // ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation. + ipv4TCPHeaderSize = 40 + // ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation. + ipv6TCPHeaderSize = 60 ) type ruleInfo struct { @@ -86,6 +88,7 @@ type router struct { wgIface iFaceMapper legacyManagement bool mtu uint16 + v6 bool stateManager *statemanager.Manager ipFwdState *ipfwdstate.IPForwardingState @@ -97,6 +100,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1 rules: make(map[string][]string), wgIface: wgIface, mtu: mtu, + v6: iptablesClient.Proto() == iptables.ProtocolIPv6, ipFwdState: ipfwdstate.NewIPForwardingState(), } @@ -186,6 +190,11 @@ func (r *router) AddRouteFiltering( return ruleKey, nil } +func (r *router) hasRule(id string) bool { + _, ok := r.rules[id] + return ok +} + func (r *router) DeleteRouteRule(rule firewall.Rule) error { ruleKey := rule.ID() @@ -392,9 +401,13 @@ func (r *router) cleanUpDefaultForwardRules() error { // Remove jump rules from built-in chains before deleting custom chains, // otherwise the chain deletion fails with "device or resource busy". - jumpRule := []string{"-j", chainNATOutput} - if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil { - log.Debugf("clean OUTPUT jump rule: %v", err) + if ok, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput); err != nil { + return fmt.Errorf("check chain %s: %w", chainNATOutput, err) + } else if ok { + jumpRule := []string{"-j", chainNATOutput} + if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil { + log.Debugf("clean OUTPUT jump rule: %v", err) + } } for _, chainInfo := range []struct { @@ -434,6 +447,12 @@ func (r *router) createContainers() error { {chainRTRDR, tableNat}, {chainRTMSSCLAMP, tableMangle}, } { + // Fallback: clear chains that survived an unclean shutdown. + if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok { + if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil { + log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err) + } + } if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil { return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) } @@ -540,9 +559,12 @@ func (r *router) addPostroutingRules() error { } // addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. -// TODO: Add IPv6 support func (r *router) addMSSClampingRules() error { - mss := r.mtu - ipTCPHeaderMinSize + overhead := uint16(ipv4TCPHeaderSize) + if r.v6 { + overhead = ipv6TCPHeaderSize + } + mss := r.mtu - overhead // Add jump rule from FORWARD chain in mangle table to our custom chain jumpRule := []string{ @@ -727,8 +749,13 @@ func (r *router) updateState() { currentState.Lock() defer currentState.Unlock() - currentState.RouteRules = r.rules - currentState.RouteIPsetCounter = r.ipsetCounter + if r.v6 { + currentState.RouteRules6 = r.rules + currentState.RouteIPsetCounter6 = r.ipsetCounter + } else { + currentState.RouteRules = r.rules + currentState.RouteIPsetCounter = r.ipsetCounter + } if err := r.stateManager.UpdateState(currentState); err != nil { log.Errorf("failed to update state: %v", err) @@ -856,7 +883,7 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error { } if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists { - if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil { + if err := r.iptablesClient.Delete(tableFilter, chainRTFWDOUT, fwdRule...); err != nil { merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err)) } delete(r.rules, ruleKey+fwdSuffix) @@ -883,7 +910,7 @@ func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []net rule = append(rule, destExp...) if params.Proto != firewall.ProtocolALL { - rule = append(rule, "-p", strings.ToLower(string(params.Proto))) + rule = append(rule, "-p", strings.ToLower(protoForFamily(params.Proto, r.v6))) rule = append(rule, applyPort("--sport", params.SPort)...) rule = append(rule, applyPort("--dport", params.DPort)...) } @@ -900,11 +927,12 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes [] } if network.IsSet() { - if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil { + name := r.ipsetName(network.Set.HashedName()) + if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil { return nil, fmt.Errorf("create or get ipset: %w", err) } - return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil + return []string{"-m", "set", matchSet, name, direction}, nil } if network.IsPrefix() { return []string{flag, network.Prefix.String()}, nil @@ -915,27 +943,23 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes [] } func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { + name := r.ipsetName(set.HashedName()) var merr *multierror.Error for _, prefix := range prefixes { - // TODO: Implement IPv6 support - if prefix.Addr().Is6() { - log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) - continue - } - if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil { + if err := r.addPrefixToIPSet(name, prefix); err != nil { merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err)) } } if merr == nil { - log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes) + log.Debugf("updated set %s with prefixes %v", name, prefixes) } return nberrors.FormatErrorOrNil(merr) } // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. -func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { - ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) +func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort) if _, exists := r.rules[ruleID]; exists { return nil @@ -943,12 +967,12 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol dnatRule := []string{ "-i", r.wgIface.Name(), - "-p", strings.ToLower(string(protocol)), - "--dport", strconv.Itoa(int(sourcePort)), + "-p", strings.ToLower(protoForFamily(protocol, r.v6)), + "--dport", strconv.Itoa(int(originalPort)), "-d", localAddr.String(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "DNAT", - "--to-destination", ":" + strconv.Itoa(int(targetPort)), + "--to-destination", ":" + strconv.Itoa(int(translatedPort)), } ruleInfo := ruleInfo{ @@ -967,8 +991,8 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol } // RemoveInboundDNAT removes an inbound DNAT rule. -func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { - ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) +func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort) if dnatRule, exists := r.rules[ruleID]; exists { if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil { @@ -1013,8 +1037,8 @@ func (r *router) ensureNATOutputChain() error { } // AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. -func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { - ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) +func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { + ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort) if _, exists := r.rules[ruleID]; exists { return nil @@ -1025,11 +1049,11 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, } dnatRule := []string{ - "-p", strings.ToLower(string(protocol)), - "--dport", strconv.Itoa(int(sourcePort)), + "-p", strings.ToLower(protoForFamily(protocol, localAddr.Is6())), + "--dport", strconv.Itoa(int(originalPort)), "-d", localAddr.String(), "-j", "DNAT", - "--to-destination", ":" + strconv.Itoa(int(targetPort)), + "--to-destination", ":" + strconv.Itoa(int(translatedPort)), } if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil { @@ -1042,8 +1066,8 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, } // RemoveOutputDNAT removes an OUTPUT chain DNAT rule. -func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { - ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) +func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { + ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort) if dnatRule, exists := r.rules[ruleID]; exists { if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil { @@ -1076,10 +1100,22 @@ func applyPort(flag string, port *firewall.Port) []string { return []string{flag, strconv.Itoa(int(port.Values[0]))} } +// ipsetName returns the ipset name, suffixed with "-v6" for the v6 router +// to avoid collisions since ipsets are global in the kernel. +func (r *router) ipsetName(name string) string { + if r.v6 { + return name + "-v6" + } + return name +} + func (r *router) createIPSet(name string) error { opts := ipset.CreateOptions{ Replace: true, } + if r.v6 { + opts.Family = ipset.FamilyIPV6 + } if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil { return fmt.Errorf("create ipset %s: %w", name, err) diff --git a/client/firewall/iptables/rule.go b/client/firewall/iptables/rule.go index aa4d2d079..4f4eab167 100644 --- a/client/firewall/iptables/rule.go +++ b/client/firewall/iptables/rule.go @@ -9,6 +9,7 @@ type Rule struct { mangleSpecs []string ip string chain string + v6 bool } // GetRuleID returns the rule id diff --git a/client/firewall/iptables/state_linux.go b/client/firewall/iptables/state_linux.go index 121c755e9..f4be37d01 100644 --- a/client/firewall/iptables/state_linux.go +++ b/client/firewall/iptables/state_linux.go @@ -4,6 +4,8 @@ import ( "fmt" "sync" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -32,6 +34,12 @@ type ShutdownState struct { ACLEntries aclEntries `json:"acl_entries,omitempty"` ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"` + + // IPv6 counterparts + RouteRules6 routeRules `json:"route_rules_v6,omitempty"` + RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"` + ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"` + ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"` } func (s *ShutdownState) Name() string { @@ -62,6 +70,28 @@ func (s *ShutdownState) Cleanup() error { ipt.aclMgr.ipsetStore = s.ACLIPsetStore } + // Clean up v6 state even if the current run has no IPv6. + // The previous run may have left ip6tables rules behind. + if !ipt.hasIPv6() { + if err := ipt.createIPv6Components(s.InterfaceState, mtu); err != nil { + log.Warnf("failed to create v6 components for cleanup: %v", err) + } + } + if ipt.hasIPv6() { + if s.RouteRules6 != nil { + ipt.router6.rules = s.RouteRules6 + } + if s.RouteIPsetCounter6 != nil { + ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6) + } + if s.ACLEntries6 != nil { + ipt.aclMgr6.entries = s.ACLEntries6 + } + if s.ACLIPsetStore6 != nil { + ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6 + } + } + if err := ipt.Close(nil); err != nil { return fmt.Errorf("reset iptables manager: %w", err) } diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index d65d717b3..149c6db83 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -1,6 +1,7 @@ package manager import ( + "errors" "fmt" "net" "net/netip" @@ -11,6 +12,10 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) +// ErrIPv6NotInitialized is returned when an IPv6 address is passed to a firewall +// method but the IPv6 firewall components were not initialized. +var ErrIPv6NotInitialized = errors.New("IPv6 firewall not initialized") + const ( ForwardingFormatPrefix = "netbird-fwd-" ForwardingFormat = "netbird-fwd-%s-%t" @@ -164,18 +169,16 @@ type Manager interface { UpdateSet(hash Set, prefixes []netip.Prefix) error // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services - AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + AddInboundDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error // RemoveInboundDNAT removes inbound DNAT rule - RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error // AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. - // localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only. - AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + AddOutputDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error // RemoveOutputDNAT removes an OUTPUT chain DNAT rule. - // localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only. - RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error // SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic. // This prevents conntrack from interfering with WireGuard proxy communication. diff --git a/client/firewall/manager/routerpair.go b/client/firewall/manager/routerpair.go index 079c051d9..096f8b9bb 100644 --- a/client/firewall/manager/routerpair.go +++ b/client/firewall/manager/routerpair.go @@ -1,6 +1,8 @@ package manager import ( + "net/netip" + "github.com/netbirdio/netbird/route" ) @@ -10,6 +12,10 @@ type RouterPair struct { Destination Network Masquerade bool Inverse bool + // Dynamic indicates the route is domain-based. NAT rules for dynamic + // routes are duplicated to the v6 table so that resolved AAAA records + // are masqueraded correctly. + Dynamic bool } func GetInversePair(pair RouterPair) RouterPair { @@ -20,5 +26,17 @@ func GetInversePair(pair RouterPair) RouterPair { Destination: pair.Source, Masquerade: pair.Masquerade, Inverse: true, + Dynamic: pair.Dynamic, } } + +// ToV6NatPair creates a v6 counterpart of a v4 NAT pair with `::/0` source +// and, for prefix destinations, `::/0` destination. +func ToV6NatPair(pair RouterPair) RouterPair { + v6 := pair + v6.Source = Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)} + if v6.Destination.IsPrefix() { + v6.Destination = Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)} + } + return v6 +} diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index a9d066e2f..9d2ea7264 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -33,15 +33,12 @@ const ( const flushError = "flush: %w" -var ( - anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} -) - type AclManager struct { rConn *nftables.Conn sConn *nftables.Conn wgIface iFaceMapper routingFwChainName string + af addrFamily workTable *nftables.Table chainInputRules *nftables.Chain @@ -67,6 +64,7 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam wgIface: wgIface, workTable: table, routingFwChainName: routingFwChainName, + af: familyForAddr(table.Family == nftables.TableFamilyIPv4), ipsetStore: newIpsetStore(), rules: make(map[string]*Rule), @@ -145,7 +143,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { } if _, ok := ips[r.ip.String()]; ok { - err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}}) + err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}}) if err != nil { log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err) } @@ -254,11 +252,11 @@ func (m *AclManager) addIOFiltering( expressions = append(expressions, &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, - Offset: uint32(9), + Offset: m.af.protoOffset, Len: uint32(1), }) - protoData, err := protoToInt(proto) + protoData, err := m.af.protoNum(proto) if err != nil { return nil, fmt.Errorf("convert protocol to number: %v", err) } @@ -270,19 +268,16 @@ func (m *AclManager) addIOFiltering( }) } - rawIP := ip.To4() + rawIP := ipToBytes(ip, m.af) // check if rawIP contains zeroed IPv4 0.0.0.0 value // in that case not add IP match expression into the rule definition - if !bytes.HasPrefix(anyIP, rawIP) { - // source address position - addrOffset := uint32(12) - + if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) { expressions = append(expressions, &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, - Offset: addrOffset, - Len: 4, + Offset: m.af.srcAddrOffset, + Len: m.af.addrLen, }, ) // add individual IP for match if no ipset defined @@ -587,7 +582,7 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) { ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName) - rawIP := ip.To4() + rawIP := ipToBytes(ip, m.af) if err != nil { if ipset, err = m.createSet(m.workTable, ipsetName); err != nil { return nil, fmt.Errorf("get set name: %v", err) @@ -619,7 +614,7 @@ func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Se Name: name, Table: table, Dynamic: true, - KeyType: nftables.TypeIPAddr, + KeyType: m.af.setKeyType, } if err := m.rConn.AddSet(ipset, nil); err != nil { @@ -707,15 +702,12 @@ func ifname(n string) []byte { return b } -func protoToInt(protocol firewall.Protocol) (uint8, error) { - switch protocol { - case firewall.ProtocolTCP: - return unix.IPPROTO_TCP, nil - case firewall.ProtocolUDP: - return unix.IPPROTO_UDP, nil - case firewall.ProtocolICMP: - return unix.IPPROTO_ICMP, nil - } - return 0, fmt.Errorf("unsupported protocol: %s", protocol) +// ipToBytes converts net.IP to the correct byte length for the address family. +func ipToBytes(ip net.IP, af addrFamily) []byte { + if af.addrLen == 4 { + return ip.To4() + } + return ip.To16() } + diff --git a/client/firewall/nftables/addr_family_linux.go b/client/firewall/nftables/addr_family_linux.go new file mode 100644 index 000000000..0c90d704a --- /dev/null +++ b/client/firewall/nftables/addr_family_linux.go @@ -0,0 +1,81 @@ +package nftables + +import ( + "fmt" + "net" + + "github.com/google/nftables" + "golang.org/x/sys/unix" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" +) + +var ( + // afIPv4 defines IPv4 header layout and nftables types. + afIPv4 = addrFamily{ + protoOffset: 9, + srcAddrOffset: 12, + dstAddrOffset: 16, + addrLen: net.IPv4len, + totalBits: 8 * net.IPv4len, + setKeyType: nftables.TypeIPAddr, + tableFamily: nftables.TableFamilyIPv4, + icmpProto: unix.IPPROTO_ICMP, + } + // afIPv6 defines IPv6 header layout and nftables types. + afIPv6 = addrFamily{ + protoOffset: 6, + srcAddrOffset: 8, + dstAddrOffset: 24, + addrLen: net.IPv6len, + totalBits: 8 * net.IPv6len, + setKeyType: nftables.TypeIP6Addr, + tableFamily: nftables.TableFamilyIPv6, + icmpProto: unix.IPPROTO_ICMPV6, + } +) + +// addrFamily holds protocol-specific constants for nftables expression building. +type addrFamily struct { + // protoOffset is the IP header offset for the protocol/next-header field (9 for v4, 6 for v6) + protoOffset uint32 + // srcAddrOffset is the IP header offset for the source address (12 for v4, 8 for v6) + srcAddrOffset uint32 + // dstAddrOffset is the IP header offset for the destination address (16 for v4, 24 for v6) + dstAddrOffset uint32 + // addrLen is the byte length of addresses (4 for v4, 16 for v6) + addrLen uint32 + // totalBits is the address size in bits (32 for v4, 128 for v6) + totalBits int + // setKeyType is the nftables set data type for addresses + setKeyType nftables.SetDatatype + // tableFamily is the nftables table family + tableFamily nftables.TableFamily + // icmpProto is the ICMP protocol number for this family (1 for v4, 58 for v6) + icmpProto uint8 +} + +// familyForAddr returns the address family for the given IP. +func familyForAddr(is4 bool) addrFamily { + if is4 { + return afIPv4 + } + return afIPv6 +} + +// protoNum converts a firewall protocol to the IP protocol number, +// using the correct ICMP variant for the address family. +func (af addrFamily) protoNum(protocol firewall.Protocol) (uint8, error) { + switch protocol { + case firewall.ProtocolTCP: + return unix.IPPROTO_TCP, nil + case firewall.ProtocolUDP: + return unix.IPPROTO_UDP, nil + case firewall.ProtocolICMP: + return af.icmpProto, nil + case firewall.ProtocolALL: + return 0, nil + default: + return 0, fmt.Errorf("unsupported protocol: %s", protocol) + } +} diff --git a/client/firewall/nftables/external_chain_monitor_integration_linux_test.go b/client/firewall/nftables/external_chain_monitor_integration_linux_test.go new file mode 100644 index 000000000..3c4e3f44d --- /dev/null +++ b/client/firewall/nftables/external_chain_monitor_integration_linux_test.go @@ -0,0 +1,76 @@ +//go:build linux + +package nftables + +import ( + "os" + "sync/atomic" + "testing" + "time" + + "github.com/google/nftables" + "github.com/stretchr/testify/require" +) + +// TestExternalChainMonitorRootIntegration verifies that adding a new chain +// in an external (non-netbird) filter table triggers the reconciler. +// Requires CAP_NET_ADMIN; skip otherwise. +func TestExternalChainMonitorRootIntegration(t *testing.T) { + if os.Geteuid() != 0 { + t.Skip("root required") + } + + calls := make(chan struct{}, 8) + var count atomic.Int32 + rec := &countingReconciler{calls: calls, count: &count} + + m := newExternalChainMonitor(rec) + m.start() + t.Cleanup(m.stop) + + // Give the netlink subscription a moment to register. + time.Sleep(200 * time.Millisecond) + + conn := &nftables.Conn{} + table := conn.AddTable(&nftables.Table{ + Name: "nbmon_integration_test", + Family: nftables.TableFamilyINet, + }) + t.Cleanup(func() { + cleanup := &nftables.Conn{} + cleanup.DelTable(table) + _ = cleanup.Flush() + }) + + chain := conn.AddChain(&nftables.Chain{ + Name: "filter_INPUT", + Table: table, + Hooknum: nftables.ChainHookInput, + Priority: nftables.ChainPriorityFilter, + Type: nftables.ChainTypeFilter, + }) + _ = chain + require.NoError(t, conn.Flush(), "create external test chain") + + select { + case <-calls: + // success + case <-time.After(3 * time.Second): + t.Fatalf("reconcile was not invoked after creating an external chain") + } + require.GreaterOrEqual(t, count.Load(), int32(1)) +} + +type countingReconciler struct { + calls chan struct{} + count *atomic.Int32 +} + +func (c *countingReconciler) reconcileExternalChains() error { + c.count.Add(1) + select { + case c.calls <- struct{}{}: + default: + } + return nil +} diff --git a/client/firewall/nftables/external_chain_monitor_linux.go b/client/firewall/nftables/external_chain_monitor_linux.go new file mode 100644 index 000000000..2a2e04c09 --- /dev/null +++ b/client/firewall/nftables/external_chain_monitor_linux.go @@ -0,0 +1,199 @@ +package nftables + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/google/nftables" + log "github.com/sirupsen/logrus" +) + +const ( + externalMonitorReconcileDelay = 500 * time.Millisecond + externalMonitorInitInterval = 5 * time.Second + externalMonitorMaxInterval = 5 * time.Minute + externalMonitorRandomization = 0.5 +) + +// externalChainReconciler re-applies passthrough accept rules to external +// nftables chains. Implementations must be safe to call from the monitor +// goroutine; the Manager locks its mutex internally. +type externalChainReconciler interface { + reconcileExternalChains() error +} + +// externalChainMonitor watches nftables netlink events and triggers a +// reconcile when a new table or chain appears (e.g. after +// `firewall-cmd --reload`). Netlink errors trigger exponential-backoff +// reconnect. +type externalChainMonitor struct { + reconciler externalChainReconciler + + mu sync.Mutex + cancel context.CancelFunc + done chan struct{} +} + +func newExternalChainMonitor(r externalChainReconciler) *externalChainMonitor { + return &externalChainMonitor{reconciler: r} +} + +func (m *externalChainMonitor) start() { + m.mu.Lock() + defer m.mu.Unlock() + + if m.cancel != nil { + return + } + + ctx, cancel := context.WithCancel(context.Background()) + m.cancel = cancel + m.done = make(chan struct{}) + + go m.run(ctx) +} + +func (m *externalChainMonitor) stop() { + m.mu.Lock() + cancel := m.cancel + done := m.done + m.cancel = nil + m.done = nil + m.mu.Unlock() + + if cancel == nil { + return + } + cancel() + <-done +} + +func (m *externalChainMonitor) run(ctx context.Context) { + defer close(m.done) + + bo := &backoff.ExponentialBackOff{ + InitialInterval: externalMonitorInitInterval, + RandomizationFactor: externalMonitorRandomization, + Multiplier: backoff.DefaultMultiplier, + MaxInterval: externalMonitorMaxInterval, + MaxElapsedTime: 0, + Clock: backoff.SystemClock, + } + bo.Reset() + + for ctx.Err() == nil { + err := m.watch(ctx) + if ctx.Err() != nil { + return + } + + delay := bo.NextBackOff() + log.Warnf("external chain monitor: %v, reconnecting in %s", err, delay) + select { + case <-ctx.Done(): + return + case <-time.After(delay): + } + } +} + +func (m *externalChainMonitor) watch(ctx context.Context) error { + events, closeMon, err := m.subscribe() + if err != nil { + return err + } + defer closeMon() + + debounce := time.NewTimer(time.Hour) + if !debounce.Stop() { + <-debounce.C + } + defer debounce.Stop() + + pending := false + for { + select { + case <-ctx.Done(): + return nil + case <-debounce.C: + pending = false + m.reconcile() + case ev, ok := <-events: + if !ok { + return errors.New("monitor channel closed") + } + if ev.Error != nil { + return fmt.Errorf("monitor event: %w", ev.Error) + } + if !isRelevantMonitorEvent(ev) { + continue + } + resetDebounce(debounce, pending) + pending = true + } + } +} + +func (m *externalChainMonitor) subscribe() (chan *nftables.MonitorEvent, func(), error) { + conn := &nftables.Conn{} + mon := nftables.NewMonitor( + nftables.WithMonitorAction(nftables.MonitorActionNew), + nftables.WithMonitorObject(nftables.MonitorObjectChains|nftables.MonitorObjectTables), + ) + events, err := conn.AddMonitor(mon) + if err != nil { + return nil, nil, fmt.Errorf("add netlink monitor: %w", err) + } + return events, func() { _ = mon.Close() }, nil +} + +// resetDebounce reschedules a pending debounce timer without leaking a stale +// fire on its channel. pending must reflect whether the timer is armed. +func resetDebounce(t *time.Timer, pending bool) { + if pending && !t.Stop() { + select { + case <-t.C: + default: + } + } + t.Reset(externalMonitorReconcileDelay) +} + +func (m *externalChainMonitor) reconcile() { + if err := m.reconciler.reconcileExternalChains(); err != nil { + log.Warnf("reconcile external chain rules: %v", err) + } +} + +// isRelevantMonitorEvent returns true for table/chain creation events on +// families we care about. The reconciler filters to actual external filter +// chains. +func isRelevantMonitorEvent(ev *nftables.MonitorEvent) bool { + switch ev.Type { + case nftables.MonitorEventTypeNewChain: + chain, ok := ev.Data.(*nftables.Chain) + if !ok || chain == nil || chain.Table == nil { + return false + } + return isMonitoredFamily(chain.Table.Family) + case nftables.MonitorEventTypeNewTable: + table, ok := ev.Data.(*nftables.Table) + if !ok || table == nil { + return false + } + return isMonitoredFamily(table.Family) + } + return false +} + +func isMonitoredFamily(family nftables.TableFamily) bool { + switch family { + case nftables.TableFamilyIPv4, nftables.TableFamilyIPv6, nftables.TableFamilyINet: + return true + } + return false +} diff --git a/client/firewall/nftables/external_chain_monitor_linux_test.go b/client/firewall/nftables/external_chain_monitor_linux_test.go new file mode 100644 index 000000000..1a37faca2 --- /dev/null +++ b/client/firewall/nftables/external_chain_monitor_linux_test.go @@ -0,0 +1,137 @@ +package nftables + +import ( + "testing" + + "github.com/google/nftables" + "github.com/stretchr/testify/assert" +) + +func TestIsMonitoredFamily(t *testing.T) { + tests := []struct { + family nftables.TableFamily + want bool + }{ + {nftables.TableFamilyIPv4, true}, + {nftables.TableFamilyIPv6, true}, + {nftables.TableFamilyINet, true}, + {nftables.TableFamilyARP, false}, + {nftables.TableFamilyBridge, false}, + {nftables.TableFamilyNetdev, false}, + {nftables.TableFamilyUnspecified, false}, + } + for _, tc := range tests { + assert.Equal(t, tc.want, isMonitoredFamily(tc.family), "family=%d", tc.family) + } +} + +func TestIsRelevantMonitorEvent(t *testing.T) { + inetTable := &nftables.Table{Name: "firewalld", Family: nftables.TableFamilyINet} + ipTable := &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4} + arpTable := &nftables.Table{Name: "arp", Family: nftables.TableFamilyARP} + + tests := []struct { + name string + ev *nftables.MonitorEvent + want bool + }{ + { + name: "new chain in inet firewalld", + ev: &nftables.MonitorEvent{ + Type: nftables.MonitorEventTypeNewChain, + Data: &nftables.Chain{Name: "filter_INPUT", Table: inetTable}, + }, + want: true, + }, + { + name: "new chain in ip filter", + ev: &nftables.MonitorEvent{ + Type: nftables.MonitorEventTypeNewChain, + Data: &nftables.Chain{Name: "INPUT", Table: ipTable}, + }, + want: true, + }, + { + name: "new chain in unwatched arp family", + ev: &nftables.MonitorEvent{ + Type: nftables.MonitorEventTypeNewChain, + Data: &nftables.Chain{Name: "x", Table: arpTable}, + }, + want: false, + }, + { + name: "new table inet", + ev: &nftables.MonitorEvent{ + Type: nftables.MonitorEventTypeNewTable, + Data: inetTable, + }, + want: true, + }, + { + name: "del chain (we only act on new)", + ev: &nftables.MonitorEvent{ + Type: nftables.MonitorEventTypeDelChain, + Data: &nftables.Chain{Name: "filter_INPUT", Table: inetTable}, + }, + want: false, + }, + { + name: "chain with nil table", + ev: &nftables.MonitorEvent{ + Type: nftables.MonitorEventTypeNewChain, + Data: &nftables.Chain{Name: "x"}, + }, + want: false, + }, + { + name: "nil data", + ev: &nftables.MonitorEvent{ + Type: nftables.MonitorEventTypeNewChain, + Data: (*nftables.Chain)(nil), + }, + want: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, isRelevantMonitorEvent(tc.ev)) + }) + } +} + +// fakeReconciler records reconcile invocations for debounce tests. +type fakeReconciler struct { + calls chan struct{} +} + +func (f *fakeReconciler) reconcileExternalChains() error { + f.calls <- struct{}{} + return nil +} + +func TestExternalChainMonitorStopWithoutStart(t *testing.T) { + m := newExternalChainMonitor(&fakeReconciler{calls: make(chan struct{}, 1)}) + // Must not panic or block. + m.stop() +} + +func TestExternalChainMonitorDoubleStart(t *testing.T) { + // start() twice should be a no-op; stop() cleans up once. + // We avoid exercising the netlink watch loop here because it needs root. + m := newExternalChainMonitor(&fakeReconciler{calls: make(chan struct{}, 1)}) + + // Replace run with a stub that just waits for cancel, so start() stays + // deterministic without opening a netlink socket. + origDone := make(chan struct{}) + m.done = origDone + m.cancel = func() { close(origDone) } + + // Second start should be a no-op (cancel already set). + m.start() + assert.NotNil(t, m.cancel) + + m.stop() + assert.Nil(t, m.cancel) + assert.Nil(t, m.done) +} diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 0b5b61e04..fdc7c2f3c 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -11,9 +11,12 @@ import ( "github.com/google/nftables" "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/firewall/firewalld" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -48,10 +51,17 @@ type Manager struct { rConn *nftables.Conn wgIface iFaceMapper - router *router - aclManager *AclManager + router *router + aclManager *AclManager + + // IPv6 counterparts, nil when no v6 overlay + router6 *router + aclManager6 *AclManager + notrackOutputChain *nftables.Chain notrackPreroutingChain *nftables.Chain + + extMonitor *externalChainMonitor } // Create nftables firewall manager @@ -61,7 +71,8 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { wgIface: wgIface, } - workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4} + tableName := getTableName() + workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4} var err error m.router, err = newRouter(workTable, wgIface, mtu) @@ -74,35 +85,137 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { return nil, fmt.Errorf("create acl manager: %w", err) } + if wgIface.Address().HasIPv6() { + if err := m.createIPv6Components(tableName, wgIface, mtu); err != nil { + return nil, fmt.Errorf("create IPv6 firewall: %w", err) + } + } + + m.extMonitor = newExternalChainMonitor(m) + return m, nil } +func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mtu uint16) error { + workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6} + + var err error + m.router6, err = newRouter(workTable6, wgIface, mtu) + if err != nil { + return fmt.Errorf("create v6 router: %w", err) + } + + // Share the same IP forwarding state with the v4 router, since + // EnableIPForwarding controls both v4 and v6 sysctls. + m.router6.ipFwdState = m.router.ipFwdState + + m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw) + if err != nil { + return fmt.Errorf("create v6 acl manager: %w", err) + } + + return nil +} + +// hasIPv6 reports whether the manager has IPv6 components initialized. +func (m *Manager) hasIPv6() bool { + return m.router6 != nil +} + +func (m *Manager) initIPv6() error { + workTable6, err := m.createWorkTableFamily(nftables.TableFamilyIPv6) + if err != nil { + return fmt.Errorf("create v6 work table: %w", err) + } + + if err := m.router6.init(workTable6); err != nil { + return fmt.Errorf("v6 router init: %w", err) + } + + if err := m.aclManager6.init(workTable6); err != nil { + return fmt.Errorf("v6 acl manager init: %w", err) + } + + return nil +} + // Init nftables firewall manager func (m *Manager) Init(stateManager *statemanager.Manager) error { + if err := m.initFirewall(); err != nil { + return err + } + + m.persistState(stateManager) + + // Start after initFirewall has installed the baseline external-chain + // accept rules. start() is idempotent across Init/Close/Init cycles. + m.extMonitor.start() + + return nil +} + +// reconcileExternalChains re-applies passthrough accept rules to external +// filter chains for both IPv4 and IPv6 routers. Called by the monitor when +// tables or chains appear (e.g. after firewalld reloads). +func (m *Manager) reconcileExternalChains() error { + m.mutex.Lock() + defer m.mutex.Unlock() + + var merr *multierror.Error + if m.router != nil { + if err := m.router.acceptExternalChainsRules(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("v4: %w", err)) + } + } + if m.hasIPv6() { + if err := m.router6.acceptExternalChainsRules(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("v6: %w", err)) + } + } + return nberrors.FormatErrorOrNil(merr) +} + +func (m *Manager) initFirewall() (err error) { workTable, err := m.createWorkTable() if err != nil { return fmt.Errorf("create work table: %w", err) } + defer func() { + if err != nil { + m.rollbackInit() + } + }() + if err := m.router.init(workTable); err != nil { return fmt.Errorf("router init: %w", err) } if err := m.aclManager.init(workTable); err != nil { - // TODO: cleanup router return fmt.Errorf("acl manager init: %w", err) } + if m.hasIPv6() { + if err := m.initIPv6(); err != nil { + // Peer has a v6 address: v6 firewall MUST work or we risk fail-open. + return fmt.Errorf("init IPv6 firewall (required because peer has IPv6 address): %w", err) + } + } + if err := m.initNoTrackChains(workTable); err != nil { log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err) } + return nil +} + +// persistState saves the current interface state for potential recreation on restart. +// Unlike iptables, which requires tracking individual rules, nftables maintains +// a known state (our netbird table plus a few static rules). This allows for easy +// cleanup using Close() without needing to store specific rules. +func (m *Manager) persistState(stateManager *statemanager.Manager) { stateManager.RegisterState(&ShutdownState{}) - // We only need to record minimal interface state for potential recreation. - // Unlike iptables, which requires tracking individual rules, nftables maintains - // a known state (our netbird table plus a few static rules). This allows for easy - // cleanup using Close() without needing to store specific rules. if err := stateManager.UpdateState(&ShutdownState{ InterfaceState: &InterfaceState{ NameStr: m.wgIface.Name(), @@ -113,14 +226,29 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { log.Errorf("failed to update state: %v", err) } - // persist early go func() { if err := stateManager.PersistState(context.Background()); err != nil { log.Errorf("failed to persist state: %v", err) } }() +} - return nil +// rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through. +func (m *Manager) rollbackInit() { + if err := m.router.Reset(); err != nil { + log.Warnf("rollback router: %v", err) + } + if m.hasIPv6() { + if err := m.router6.Reset(); err != nil { + log.Warnf("rollback v6 router: %v", err) + } + } + if err := m.cleanupNetbirdTables(); err != nil { + log.Warnf("cleanup tables: %v", err) + } + if err := m.rConn.Flush(); err != nil { + log.Warnf("flush: %v", err) + } } // AddPeerFiltering rule to the firewall @@ -139,12 +267,14 @@ func (m *Manager) AddPeerFiltering( m.mutex.Lock() defer m.mutex.Unlock() - rawIP := ip.To4() - if rawIP == nil { - return nil, fmt.Errorf("unsupported IP version: %s", ip.String()) + if ip.To4() != nil { + return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) } - return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) + if !m.hasIPv6() { + return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized) + } + return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) } func (m *Manager) AddRouteFiltering( @@ -158,8 +288,11 @@ func (m *Manager) AddRouteFiltering( m.mutex.Lock() defer m.mutex.Unlock() - if destination.IsPrefix() && !destination.Prefix.Addr().Is4() { - return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String()) + if isIPv6RouteRule(sources, destination) { + if !m.hasIPv6() { + return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) } return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) @@ -170,15 +303,66 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && isIPv6Rule(rule) { + return m.aclManager6.DeletePeerRule(rule) + } return m.aclManager.DeletePeerRule(rule) } -// DeleteRouteRule deletes a routing rule +func isIPv6Rule(rule firewall.Rule) bool { + r, ok := rule.(*Rule) + return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6 +} + +// isIPv6RouteRule determines whether a route rule belongs to the v6 table. +// For static routes, the destination prefix determines the family. For dynamic +// routes (DomainSet), the sources determine the family since management +// duplicates dynamic rules per family. +func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool { + if destination.IsPrefix() { + return destination.Prefix.Addr().Is6() + } + return len(sources) > 0 && sources[0].Addr().Is6() +} + +// DeleteRouteRule deletes a routing rule. Route rules live in exactly one +// router; the cached maps are normally authoritative, so the kernel is only +// consulted when neither map knows about the rule. func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.DeleteRouteRule(rule) + id := rule.ID() + r, err := m.routerForRuleID(id, (*router).hasRule) + if err != nil { + return err + } + return r.DeleteRouteRule(rule) +} + +// routerForRuleID picks the router holding the rule with the given id, using +// the supplied lookup. If the cached maps disagree (or both miss), it refreshes +// from the kernel once and re-checks before falling back to the v4 router. +func (m *Manager) routerForRuleID(id string, has func(*router, string) bool) (*router, error) { + if has(m.router, id) { + return m.router, nil + } + if m.hasIPv6() && has(m.router6, id) { + return m.router6, nil + } + if !m.hasIPv6() { + return m.router, nil + } + if err := m.router.refreshRulesMap(); err != nil { + return nil, fmt.Errorf("refresh v4 rules: %w", err) + } + if err := m.router6.refreshRulesMap(); err != nil { + return nil, fmt.Errorf("refresh v6 rules: %w", err) + } + if has(m.router6, id) && !has(m.router, id) { + return m.router6, nil + } + return m.router, nil } func (m *Manager) IsServerRouteSupported() bool { @@ -193,19 +377,70 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.AddNatRule(pair) + if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { + if !m.hasIPv6() { + return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.AddNatRule(pair) + } + + if err := m.router.AddNatRule(pair); err != nil { + return err + } + + // Dynamic routes need NAT in both tables since resolved IPs can be + // either v4 or v6. This covers both DomainSet (modern) and the legacy + // wildcard 0.0.0.0/0 destination where the client resolves DNS. + // On v6 failure we keep the v4 NAT rule rather than rolling back: half + // connectivity is better than none, and RemoveNatRule is content-keyed + // so the eventual cleanup still works. + if m.hasIPv6() && pair.Dynamic { + v6Pair := firewall.ToV6NatPair(pair) + if err := m.router6.AddNatRule(v6Pair); err != nil { + return fmt.Errorf("add v6 NAT rule: %w", err) + } + } + + return nil } func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveNatRule(pair) + if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { + if !m.hasIPv6() { + return nil + } + return m.router6.RemoveNatRule(pair) + } + + var merr *multierror.Error + + if err := m.router.RemoveNatRule(pair); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err)) + } + + if m.hasIPv6() && pair.Dynamic { + v6Pair := firewall.ToV6NatPair(pair) + if err := m.router6.RemoveNatRule(v6Pair); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err)) + } + } + + return nberrors.FormatErrorOrNil(merr) } // AllowNetbird allows netbird interface traffic. // This is called when USPFilter wraps the native firewall, adding blanket accept // rules so that packet filtering is handled in userspace instead of by netfilter. +// +// TODO: In USP mode this only adds ACCEPT to the netbird table's own chains, +// which doesn't override DROP rules in external tables (e.g. firewalld). +// Should add passthrough rules to external chains (like the native mode router's +// addExternalChainsRules does) for both the netbird table family and inet tables. +// The netbird table itself is fine (routing chains already exist there), but +// non-netbird tables with INPUT/FORWARD hooks can still DROP our WG traffic. func (m *Manager) AllowNetbird() error { m.mutex.Lock() defer m.mutex.Unlock() @@ -213,40 +448,65 @@ func (m *Manager) AllowNetbird() error { if err := m.aclManager.createDefaultAllowRules(); err != nil { return fmt.Errorf("create default allow rules: %w", err) } + if m.hasIPv6() { + if err := m.aclManager6.createDefaultAllowRules(); err != nil { + return fmt.Errorf("create v6 default allow rules: %w", err) + } + } if err := m.rConn.Flush(); err != nil { return fmt.Errorf("flush allow input netbird rules: %w", err) } + if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil { + log.Warnf("failed to trust interface in firewalld: %v", err) + } + return nil } // SetLegacyManagement sets the route manager to use legacy management func (m *Manager) SetLegacyManagement(isLegacy bool) error { - return firewall.SetLegacyManagement(m.router, isLegacy) + if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil { + return err + } + if m.hasIPv6() { + return firewall.SetLegacyManagement(m.router6, isLegacy) + } + return nil } // Close closes the firewall manager func (m *Manager) Close(stateManager *statemanager.Manager) error { + m.extMonitor.stop() + m.mutex.Lock() defer m.mutex.Unlock() + var merr *multierror.Error + if err := m.router.Reset(); err != nil { - return fmt.Errorf("reset router: %v", err) + merr = multierror.Append(merr, fmt.Errorf("reset router: %v", err)) + } + + if m.hasIPv6() { + if err := m.router6.Reset(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %v", err)) + } } if err := m.cleanupNetbirdTables(); err != nil { - return fmt.Errorf("cleanup netbird tables: %v", err) + merr = multierror.Append(merr, fmt.Errorf("cleanup netbird tables: %v", err)) } if err := m.rConn.Flush(); err != nil { - return fmt.Errorf(flushError, err) + merr = multierror.Append(merr, fmt.Errorf(flushError, err)) } if err := stateManager.DeleteState(&ShutdownState{}); err != nil { - return fmt.Errorf("delete state: %v", err) + merr = multierror.Append(merr, fmt.Errorf("delete state: %v", err)) } - return nil + return nberrors.FormatErrorOrNil(merr) } func (m *Manager) cleanupNetbirdTables() error { @@ -295,6 +555,12 @@ func (m *Manager) Flush() error { return err } + if m.hasIPv6() { + if err := m.aclManager6.Flush(); err != nil { + return fmt.Errorf("flush v6 acl: %w", err) + } + } + if err := m.refreshNoTrackChains(); err != nil { log.Errorf("failed to refresh notrack chains: %v", err) } @@ -307,6 +573,12 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) m.mutex.Lock() defer m.mutex.Unlock() + if rule.TranslatedAddress.Is6() { + if !m.hasIPv6() { + return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.AddDNATRule(rule) + } return m.router.AddDNATRule(rule) } @@ -315,7 +587,11 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.DeleteDNATRule(rule) + r, err := m.routerForRuleID(rule.ID(), (*router).hasDNATRule) + if err != nil { + return err + } + return r.DeleteDNATRule(rule) } // UpdateSet updates the set with the given prefixes @@ -323,39 +599,82 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.UpdateSet(set, prefixes) + var v4Prefixes, v6Prefixes []netip.Prefix + for _, p := range prefixes { + if p.Addr().Is6() { + v6Prefixes = append(v6Prefixes, p) + } else { + v4Prefixes = append(v4Prefixes, p) + } + } + + if err := m.router.UpdateSet(set, v4Prefixes); err != nil { + return err + } + + if m.hasIPv6() && len(v6Prefixes) > 0 { + if err := m.router6.UpdateSet(set, v6Prefixes); err != nil { + return fmt.Errorf("update v6 set: %w", err) + } + } + + return nil } // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. -func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) + if localAddr.Is6() { + if !m.hasIPv6() { + return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort) + } + return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort) } // RemoveInboundDNAT removes an inbound DNAT rule. -func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) + if localAddr.Is6() { + if !m.hasIPv6() { + return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort) + } + return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort) } // AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. -func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort) + if localAddr.Is6() { + if !m.hasIPv6() { + return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort) + } + return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort) } // RemoveOutputDNAT removes an OUTPUT chain DNAT rule. -func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort) + if localAddr.Is6() { + if !m.hasIPv6() { + return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort) + } + return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort) } const ( @@ -529,7 +848,11 @@ func (m *Manager) refreshNoTrackChains() error { } func (m *Manager) createWorkTable() (*nftables.Table, error) { - tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) + return m.createWorkTableFamily(nftables.TableFamilyIPv4) +} + +func (m *Manager) createWorkTableFamily(family nftables.TableFamily) (*nftables.Table, error) { + tables, err := m.rConn.ListTablesOfFamily(family) if err != nil { return nil, fmt.Errorf("list of tables: %w", err) } @@ -541,7 +864,7 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) { } } - table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}) + table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: family}) err = m.rConn.Flush() return table, err } diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index d48e4ba88..be4f65881 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -383,10 +383,138 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { err = manager.AddNatRule(pair) require.NoError(t, err, "failed to add NAT rule") + dnatRule, err := manager.AddDNATRule(fw.ForwardRule{ + Protocol: fw.ProtocolTCP, + DestinationPort: fw.Port{Values: []uint16{8080}}, + TranslatedAddress: netip.MustParseAddr("100.96.0.2"), + TranslatedPort: fw.Port{Values: []uint16{80}}, + }) + require.NoError(t, err, "failed to add DNAT rule") + + t.Cleanup(func() { + require.NoError(t, manager.DeleteDNATRule(dnatRule), "failed to delete DNAT rule") + }) + stdout, stderr = runIptablesSave(t) verifyIptablesOutput(t, stdout, stderr) } +func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + for _, bin := range []string{"ip6tables", "ip6tables-save", "iptables-save"} { + if _, err := exec.LookPath(bin); err != nil { + t.Skipf("%s not available on this system: %v", bin, err) + } + } + + // Seed ip6 tables in the nft backend. Docker may not create them. + seedIp6tables(t) + + ifaceMockV6 := &iFaceMock{ + NameFunc: func() string { return "wt-test" }, + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.96.0.1"), + Network: netip.MustParsePrefix("100.96.0.0/16"), + IPv6: netip.MustParseAddr("fd00::1"), + IPv6Net: netip.MustParsePrefix("fd00::/64"), + } + }, + } + + manager, err := Create(ifaceMockV6, iface.DefaultMTU) + require.NoError(t, err, "create manager") + require.NoError(t, manager.Init(nil)) + + t.Cleanup(func() { + require.NoError(t, manager.Close(nil), "close manager") + + stdout, stderr := runIp6tablesSave(t) + verifyIp6tablesOutput(t, stdout, stderr) + }) + + ip := netip.MustParseAddr("fd00::2") + _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "") + require.NoError(t, err, "add v6 peer filtering rule") + + _, err = manager.AddRouteFiltering( + nil, + []netip.Prefix{netip.MustParsePrefix("fd00:1::/64")}, + fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")}, + fw.ProtocolTCP, + nil, + &fw.Port{Values: []uint16{443}}, + fw.ActionAccept, + ) + require.NoError(t, err, "add v6 route filtering rule") + + err = manager.AddNatRule(fw.RouterPair{ + Source: fw.Network{Prefix: netip.MustParsePrefix("fd00::/64")}, + Destination: fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")}, + Masquerade: true, + }) + require.NoError(t, err, "add v6 NAT rule") + + dnatRule, err := manager.AddDNATRule(fw.ForwardRule{ + Protocol: fw.ProtocolTCP, + DestinationPort: fw.Port{Values: []uint16{8080}}, + TranslatedAddress: netip.MustParseAddr("fd00::2"), + TranslatedPort: fw.Port{Values: []uint16{80}}, + }) + require.NoError(t, err, "add v6 DNAT rule") + + t.Cleanup(func() { + require.NoError(t, manager.DeleteDNATRule(dnatRule), "delete v6 DNAT rule") + }) + + stdout, stderr := runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) + + stdout, stderr = runIp6tablesSave(t) + verifyIp6tablesOutput(t, stdout, stderr) +} + +func seedIp6tables(t *testing.T) { + t.Helper() + for _, tc := range []struct{ table, chain string }{ + {"filter", "FORWARD"}, + {"nat", "POSTROUTING"}, + {"mangle", "FORWARD"}, + } { + add := exec.Command("ip6tables", "-t", tc.table, "-A", tc.chain, "-j", "ACCEPT") + require.NoError(t, add.Run(), "seed ip6tables -t %s", tc.table) + del := exec.Command("ip6tables", "-t", tc.table, "-D", tc.chain, "-j", "ACCEPT") + require.NoError(t, del.Run(), "unseed ip6tables -t %s", tc.table) + } +} + +func runIp6tablesSave(t *testing.T) (string, string) { + t.Helper() + var stdout, stderr bytes.Buffer + cmd := exec.Command("ip6tables-save") + cmd.Stdout = &stdout + cmd.Stderr = &stderr + require.NoError(t, cmd.Run(), "ip6tables-save failed") + return stdout.String(), stderr.String() +} + +func verifyIp6tablesOutput(t *testing.T, stdout, stderr string) { + t.Helper() + for _, msg := range []string{ + "Table `nat' is incompatible", + "Table `mangle' is incompatible", + "Table `filter' is incompatible", + } { + require.NotContains(t, stdout, msg, + "ip6tables-save stdout reports incompatibility: %s", stdout) + require.NotContains(t, stderr, msg, + "ip6tables-save stderr reports incompatibility: %s", stderr) + } +} + func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) { if check() != NFTABLES { t.Skip("nftables not supported on this system") diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 904daf7cb..4214455a9 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -19,6 +19,7 @@ import ( "golang.org/x/sys/unix" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/firewall/firewalld" firewall "github.com/netbirdio/netbird/client/firewall/manager" nbid "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" @@ -40,6 +41,8 @@ const ( chainNameForward = "FORWARD" chainNameMangleForward = "netbird-mangle-forward" + firewalldTableName = "firewalld" + userDataAcceptForwardRuleIif = "frwacceptiif" userDataAcceptForwardRuleOif = "frwacceptoif" userDataAcceptInputRule = "inputaccept" @@ -47,8 +50,10 @@ const ( dnatSuffix = "_dnat" snatSuffix = "_snat" - // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation - ipTCPHeaderMinSize = 40 + // ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation. + ipv4TCPHeaderSize = 40 + // ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation. + ipv6TCPHeaderSize = 60 // maxPrefixesSet 1638 prefixes start to fail, taking some margin maxPrefixesSet = 1500 @@ -73,6 +78,7 @@ type router struct { rules map[string]*nftables.Rule ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set] + af addrFamily wgIface iFaceMapper ipFwdState *ipfwdstate.IPForwardingState legacyManagement bool @@ -85,6 +91,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou workTable: workTable, chains: make(map[string]*nftables.Chain), rules: make(map[string]*nftables.Rule), + af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4), wgIface: wgIface, ipFwdState: ipfwdstate.NewIPForwardingState(), mtu: mtu, @@ -133,6 +140,10 @@ func (r *router) Reset() error { merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err)) } + if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil { + merr = multierror.Append(merr, err) + } + if err := r.removeNatPreroutingRules(); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err)) } @@ -143,7 +154,7 @@ func (r *router) Reset() error { func (r *router) removeNatPreroutingRules() error { table := &nftables.Table{ Name: tableNat, - Family: nftables.TableFamilyIPv4, + Family: r.af.tableFamily, } chain := &nftables.Chain{ Name: chainNameNatPrerouting, @@ -176,7 +187,7 @@ func (r *router) removeNatPreroutingRules() error { } func (r *router) loadFilterTable() (*nftables.Table, error) { - tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) + tables, err := r.conn.ListTablesOfFamily(r.af.tableFamily) if err != nil { return nil, fmt.Errorf("list tables: %w", err) } @@ -280,6 +291,10 @@ func (r *router) createContainers() error { log.Errorf("failed to add accept rules for the forward chain: %s", err) } + if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil { + log.Warnf("failed to trust interface in firewalld: %v", err) + } + if err := r.refreshRulesMap(); err != nil { log.Errorf("failed to refresh rules: %s", err) } @@ -408,7 +423,7 @@ func (r *router) AddRouteFiltering( // Handle protocol if proto != firewall.ProtocolALL { - protoNum, err := protoToInt(proto) + protoNum, err := r.af.protoNum(proto) if err != nil { return nil, fmt.Errorf("convert protocol to number: %w", err) } @@ -468,7 +483,24 @@ func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bo return nil, fmt.Errorf("create or get ipset: %w", err) } - return getIpSetExprs(ref, isSource) + return r.getIpSetExprs(ref, isSource) +} + +func (r *router) iptablesProto() iptables.Protocol { + if r.af.tableFamily == nftables.TableFamilyIPv6 { + return iptables.ProtocolIPv6 + } + return iptables.ProtocolIPv4 +} + +func (r *router) hasRule(id string) bool { + _, ok := r.rules[id] + return ok +} + +func (r *router) hasDNATRule(id string) bool { + _, ok := r.rules[id+dnatSuffix] + return ok } func (r *router) DeleteRouteRule(rule firewall.Rule) error { @@ -517,10 +549,10 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err Table: r.workTable, // required for prefixes Interval: true, - KeyType: nftables.TypeIPAddr, + KeyType: r.af.setKeyType, } - elements := convertPrefixesToSet(prefixes) + elements := r.convertPrefixesToSet(prefixes) nElements := len(elements) maxElements := maxPrefixesSet * 2 @@ -553,23 +585,17 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err return nfset, nil } -func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement { +func (r *router) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement { var elements []nftables.SetElement for _, prefix := range prefixes { - // TODO: Implement IPv6 support - if prefix.Addr().Is6() { - log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) - continue - } - // nftables needs half-open intervals [firstIP, lastIP) for prefixes // e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc firstIP := prefix.Addr() lastIP := calculateLastIP(prefix).Next() elements = append(elements, - // the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247 - // nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true}, + // the nft tool also adds a zero-address IntervalEnd element, see https://github.com/google/nftables/issues/247 + // nftables.SetElement{Key: make([]byte, r.af.addrLen), IntervalEnd: true}, nftables.SetElement{Key: firstIP.AsSlice()}, nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true}, ) @@ -579,10 +605,20 @@ func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement { // calculateLastIP determines the last IP in a given prefix. func calculateLastIP(prefix netip.Prefix) netip.Addr { - hostMask := ^uint32(0) >> prefix.Masked().Bits() - lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask + masked := prefix.Masked() + if masked.Addr().Is4() { + hostMask := ^uint32(0) >> masked.Bits() + lastIP := uint32FromNetipAddr(masked.Addr()) | hostMask + return netip.AddrFrom4(uint32ToBytes(lastIP)) + } - return netip.AddrFrom4(uint32ToBytes(lastIP)) + // IPv6: set host bits to all 1s + b := masked.Addr().As16() + bits := masked.Bits() + for i := bits; i < 128; i++ { + b[i/8] |= 1 << (7 - i%8) + } + return netip.AddrFrom16(b) } // Utility function to convert netip.Addr to uint32. @@ -834,9 +870,16 @@ func (r *router) addPostroutingRules() { } // addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. -// TODO: Add IPv6 support func (r *router) addMSSClampingRules() error { - mss := r.mtu - ipTCPHeaderMinSize + overhead := uint16(ipv4TCPHeaderSize) + if r.af.tableFamily == nftables.TableFamilyIPv6 { + overhead = ipv6TCPHeaderSize + } + if r.mtu <= overhead { + log.Debugf("MTU %d too small for MSS clamping (overhead %d), skipping", r.mtu, overhead) + return nil + } + mss := r.mtu - overhead exprsOut := []expr.Any{ &expr.Meta{ @@ -1043,17 +1086,22 @@ func (r *router) acceptFilterTableRules() error { log.Debugf("Used %s to add accept forward and input rules", fw) }() - // Try iptables first and fallback to nftables if iptables is not available - ipt, err := iptables.New() + // Try iptables first and fallback to nftables if iptables is not available. + // Use the correct protocol (iptables vs ip6tables) for the address family. + ipt, err := iptables.NewWithProtocol(r.iptablesProto()) if err != nil { - // iptables is not available but the filter table exists log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) fw = "nftables" return r.acceptFilterRulesNftables(r.filterTable) } - return r.acceptFilterRulesIptables(ipt) + if err := r.acceptFilterRulesIptables(ipt); err != nil { + log.Warnf("iptables failed (table may be incompatible), falling back to nftables: %v", err) + fw = "nftables" + return r.acceptFilterRulesNftables(r.filterTable) + } + return nil } func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error { @@ -1124,83 +1172,122 @@ func (r *router) acceptExternalChainsRules() error { } intf := ifname(r.wgIface.Name()) - for _, chain := range chains { - if chain.Hooknum == nil { - log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name) - continue - } - - log.Debugf("adding accept rules to external %s chain: %s %s/%s", - hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name) - - switch *chain.Hooknum { - case *nftables.ChainHookForward: - r.insertForwardAcceptRules(chain, intf) - case *nftables.ChainHookInput: - r.insertInputAcceptRule(chain, intf) - } + r.applyExternalChainAccept(chain, intf) } if err := r.conn.Flush(); err != nil { return fmt.Errorf("flush external chain rules: %w", err) } - return nil } +func (r *router) applyExternalChainAccept(chain *nftables.Chain, intf []byte) { + if chain.Hooknum == nil { + log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name) + return + } + + log.Debugf("adding accept rules to external %s chain: %s %s/%s", + hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name) + + switch *chain.Hooknum { + case *nftables.ChainHookForward: + r.insertForwardAcceptRules(chain, intf) + case *nftables.ChainHookInput: + r.insertInputAcceptRule(chain, intf) + } +} + func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) { - iifRule := &nftables.Rule{ + existing, err := r.existingNetbirdRulesInChain(chain) + if err != nil { + log.Warnf("skip forward accept rules in %s/%s: %v", chain.Table.Name, chain.Name, err) + return + } + r.insertForwardIifRule(chain, intf, existing) + r.insertForwardOifEstablishedRule(chain, intf, existing) +} + +func (r *router) insertForwardIifRule(chain *nftables.Chain, intf []byte, existing map[string]bool) { + if existing[userDataAcceptForwardRuleIif] { + return + } + r.conn.InsertRule(&nftables.Rule{ Table: chain.Table, Chain: chain, Exprs: []expr.Any{ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: intf, - }, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf}, &expr.Counter{}, &expr.Verdict{Kind: expr.VerdictAccept}, }, UserData: []byte(userDataAcceptForwardRuleIif), - } - r.conn.InsertRule(iifRule) + }) +} - oifExprs := []expr.Any{ - &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: intf, - }, +func (r *router) insertForwardOifEstablishedRule(chain *nftables.Chain, intf []byte, existing map[string]bool) { + if existing[userDataAcceptForwardRuleOif] { + return } - oifRule := &nftables.Rule{ + exprs := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf}, + } + r.conn.InsertRule(&nftables.Rule{ Table: chain.Table, Chain: chain, - Exprs: append(oifExprs, getEstablishedExprs(2)...), + Exprs: append(exprs, getEstablishedExprs(2)...), UserData: []byte(userDataAcceptForwardRuleOif), - } - r.conn.InsertRule(oifRule) + }) } func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) { - inputRule := &nftables.Rule{ + existing, err := r.existingNetbirdRulesInChain(chain) + if err != nil { + log.Warnf("skip input accept rule in %s/%s: %v", chain.Table.Name, chain.Name, err) + return + } + if existing[userDataAcceptInputRule] { + return + } + r.conn.InsertRule(&nftables.Rule{ Table: chain.Table, Chain: chain, Exprs: []expr.Any{ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: intf, - }, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf}, &expr.Counter{}, &expr.Verdict{Kind: expr.VerdictAccept}, }, UserData: []byte(userDataAcceptInputRule), + }) +} + +// existingNetbirdRulesInChain returns the set of netbird-owned UserData tags present in a chain; callers must bail on error since InsertRule is additive. +func (r *router) existingNetbirdRulesInChain(chain *nftables.Chain) (map[string]bool, error) { + rules, err := r.conn.GetRules(chain.Table, chain) + if err != nil { + return nil, fmt.Errorf("list rules: %w", err) } - r.conn.InsertRule(inputRule) + present := map[string]bool{} + for _, rule := range rules { + if !isNetbirdAcceptRuleTag(rule.UserData) { + continue + } + present[string(rule.UserData)] = true + } + return present, nil +} + +func isNetbirdAcceptRuleTag(userData []byte) bool { + switch string(userData) { + case userDataAcceptForwardRuleIif, + userDataAcceptForwardRuleOif, + userDataAcceptInputRule: + return true + } + return false } func (r *router) removeAcceptFilterRules() error { @@ -1222,13 +1309,17 @@ func (r *router) removeFilterTableRules() error { return nil } - ipt, err := iptables.New() + ipt, err := iptables.NewWithProtocol(r.iptablesProto()) if err != nil { log.Debugf("iptables not available, using nftables to remove filter rules: %v", err) return r.removeAcceptRulesFromTable(r.filterTable) } - return r.removeAcceptFilterRulesIptables(ipt) + if err := r.removeAcceptFilterRulesIptables(ipt); err != nil { + log.Debugf("iptables removal failed (table may be incompatible), falling back to nftables: %v", err) + return r.removeAcceptRulesFromTable(r.filterTable) + } + return nil } func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error { @@ -1295,7 +1386,7 @@ func (r *router) removeExternalChainsRules() error { func (r *router) findExternalChains() []*nftables.Chain { var chains []*nftables.Chain - families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet} + families := []nftables.TableFamily{r.af.tableFamily, nftables.TableFamilyINet} for _, family := range families { allChains, err := r.conn.ListChainsOfTableFamily(family) @@ -1319,8 +1410,15 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool { return false } - // Skip all iptables-managed tables in the ip family - if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) { + // Skip firewalld-owned chains. Firewalld creates its chains with the + // NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM. + // We delegate acceptance to firewalld by trusting the interface instead. + if chain.Table.Name == firewalldTableName { + return false + } + + // Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat) + if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) { return false } @@ -1461,7 +1559,7 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { return rule, nil } - protoNum, err := protoToInt(rule.Protocol) + protoNum, err := r.af.protoNum(rule.Protocol) if err != nil { return nil, fmt.Errorf("convert protocol to number: %w", err) } @@ -1524,7 +1622,7 @@ func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, rule dnatExprs = append(dnatExprs, &expr.NAT{ Type: expr.NATTypeDestNAT, - Family: uint32(nftables.TableFamilyIPv4), + Family: uint32(r.af.tableFamily), RegAddrMin: 1, RegProtoMin: regProtoMin, RegProtoMax: regProtoMax, @@ -1617,14 +1715,15 @@ func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule f }, ) + natTable := &nftables.Table{ + Name: tableNat, + Family: r.af.tableFamily, + } dnatRule := &nftables.Rule{ - Table: &nftables.Table{ - Name: tableNat, - Family: nftables.TableFamilyIPv4, - }, + Table: natTable, Chain: &nftables.Chain{ Name: chainNameNatPrerouting, - Table: r.filterTable, + Table: natTable, Type: nftables.ChainTypeNAT, Hooknum: nftables.ChainHookPrerouting, Priority: nftables.ChainPriorityNATDest, @@ -1655,8 +1754,8 @@ func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, + Offset: r.af.dstAddrOffset, + Len: r.af.addrLen, }, &expr.Cmp{ Op: expr.CmpOpEq, @@ -1734,7 +1833,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return fmt.Errorf("get set %s: %w", set.HashedName(), err) } - elements := convertPrefixesToSet(prefixes) + elements := r.convertPrefixesToSet(prefixes) if err := r.conn.SetAddElements(nfset, elements); err != nil { return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err) } @@ -1749,14 +1848,14 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { } // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. -func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { - ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) +func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort) if _, exists := r.rules[ruleID]; exists { return nil } - protoNum, err := protoToInt(protocol) + protoNum, err := r.af.protoNum(protocol) if err != nil { return fmt.Errorf("convert protocol to number: %w", err) } @@ -1783,11 +1882,15 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol &expr.Cmp{ Op: expr.CmpOpEq, Register: 3, - Data: binaryutil.BigEndian.PutUint16(sourcePort), + Data: binaryutil.BigEndian.PutUint16(originalPort), }, } - exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...) + bits := 32 + if localAddr.Is6() { + bits = 128 + } + exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...) exprs = append(exprs, &expr.Immediate{ @@ -1796,11 +1899,11 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol }, &expr.Immediate{ Register: 2, - Data: binaryutil.BigEndian.PutUint16(targetPort), + Data: binaryutil.BigEndian.PutUint16(translatedPort), }, &expr.NAT{ Type: expr.NATTypeDestNAT, - Family: uint32(nftables.TableFamilyIPv4), + Family: uint32(r.af.tableFamily), RegAddrMin: 1, RegProtoMin: 2, RegProtoMax: 0, @@ -1825,12 +1928,12 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol } // RemoveInboundDNAT removes an inbound DNAT rule. -func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { if err := r.refreshRulesMap(); err != nil { return fmt.Errorf(refreshRulesMapError, err) } - ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort) rule, exists := r.rules[ruleID] if !exists { @@ -1876,8 +1979,8 @@ func (r *router) ensureNATOutputChain() error { } // AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. -func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { - ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) +func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { + ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort) if _, exists := r.rules[ruleID]; exists { return nil @@ -1887,7 +1990,7 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, return err } - protoNum, err := protoToInt(protocol) + protoNum, err := r.af.protoNum(protocol) if err != nil { return fmt.Errorf("convert protocol to number: %w", err) } @@ -1908,11 +2011,15 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, &expr.Cmp{ Op: expr.CmpOpEq, Register: 2, - Data: binaryutil.BigEndian.PutUint16(sourcePort), + Data: binaryutil.BigEndian.PutUint16(originalPort), }, } - exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...) + bits := 32 + if localAddr.Is6() { + bits = 128 + } + exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...) exprs = append(exprs, &expr.Immediate{ @@ -1921,11 +2028,11 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, }, &expr.Immediate{ Register: 2, - Data: binaryutil.BigEndian.PutUint16(targetPort), + Data: binaryutil.BigEndian.PutUint16(translatedPort), }, &expr.NAT{ Type: expr.NATTypeDestNAT, - Family: uint32(nftables.TableFamilyIPv4), + Family: uint32(r.af.tableFamily), RegAddrMin: 1, RegProtoMin: 2, }, @@ -1949,12 +2056,12 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, } // RemoveOutputDNAT removes an OUTPUT chain DNAT rule. -func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { if err := r.refreshRulesMap(); err != nil { return fmt.Errorf(refreshRulesMapError, err) } - ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort) rule, exists := r.rules[ruleID] if !exists { @@ -1993,45 +2100,44 @@ func (r *router) applyNetwork( } if network.IsPrefix() { - return applyPrefix(network.Prefix, isSource), nil + return r.applyPrefix(network.Prefix, isSource), nil } return nil, nil } // applyPrefix generates nftables expressions for a CIDR prefix -func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any { - // dst offset - offset := uint32(16) +func (r *router) applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any { + // dst offset by default + offset := r.af.dstAddrOffset if isSource { // src offset - offset = 12 + offset = r.af.srcAddrOffset } ones := prefix.Bits() - // 0.0.0.0/0 doesn't need extra expressions + // unspecified address (/0) doesn't need extra expressions if ones == 0 { return nil } - mask := net.CIDRMask(ones, 32) + mask := net.CIDRMask(ones, r.af.totalBits) + xor := make([]byte, r.af.addrLen) return []expr.Any{ &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: offset, - Len: 4, + Len: r.af.addrLen, }, - // netmask &expr.Bitwise{ DestRegister: 1, SourceRegister: 1, - Len: 4, + Len: r.af.addrLen, Mask: mask, - Xor: []byte{0, 0, 0, 0}, + Xor: xor, }, - // net address &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, @@ -2114,13 +2220,12 @@ func getCtNewExprs() []expr.Any { } } -func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) { - - // dst offset - offset := uint32(16) +func (r *router) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) { + // dst offset by default + offset := r.af.dstAddrOffset if isSource { // src offset - offset = 12 + offset = r.af.srcAddrOffset } return []expr.Any{ @@ -2128,7 +2233,7 @@ func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: offset, - Len: 4, + Len: r.af.addrLen, }, &expr.Lookup{ SourceRegister: 1, diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index f0e34d211..c5d6729d9 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -90,8 +90,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) { } // Build CIDR matching expressions - sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true) - destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false) + testRouter := &router{af: afIPv4} + sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true) + destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false) // Combine all expressions in the correct order // nolint:gocritic @@ -508,6 +509,136 @@ func TestNftablesCreateIpSet(t *testing.T) { } } +func TestNftablesCreateIpSet_IPv6(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + workTable, err := createWorkTableIPv6() + require.NoError(t, err, "Failed to create v6 work table") + defer deleteWorkTableIPv6() + + r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU) + require.NoError(t, err, "Failed to create router") + require.NoError(t, r.init(workTable)) + defer func() { + require.NoError(t, r.Reset(), "Failed to reset router") + }() + + tests := []struct { + name string + sources []netip.Prefix + expected []netip.Prefix + }{ + { + name: "Single IPv6", + sources: []netip.Prefix{netip.MustParsePrefix("2001:db8::1/128")}, + }, + { + name: "Multiple IPv6 Subnets", + sources: []netip.Prefix{ + netip.MustParsePrefix("fd00::/64"), + netip.MustParsePrefix("2001:db8::/48"), + netip.MustParsePrefix("fe80::/10"), + }, + }, + { + name: "Overlapping IPv6", + sources: []netip.Prefix{ + netip.MustParsePrefix("fd00::/48"), + netip.MustParsePrefix("fd00::/64"), + netip.MustParsePrefix("fd00::1/128"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("fd00::/48"), + }, + }, + { + name: "Mixed prefix lengths", + sources: []netip.Prefix{ + netip.MustParsePrefix("2001:db8:1::/48"), + netip.MustParsePrefix("2001:db8:2::1/128"), + netip.MustParsePrefix("fd00:abcd::/32"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setName := firewall.NewPrefixSet(tt.sources).HashedName() + set, err := r.createIpSet(setName, setInput{prefixes: tt.sources}) + require.NoError(t, err, "Failed to create IPv6 set") + require.NotNil(t, set) + + assert.Equal(t, setName, set.Name) + assert.True(t, set.Interval) + assert.Equal(t, nftables.TypeIP6Addr, set.KeyType) + + fetchedSet, err := r.conn.GetSetByName(r.workTable, setName) + require.NoError(t, err, "Failed to fetch created set") + + elements, err := r.conn.GetSetElements(fetchedSet) + require.NoError(t, err, "Failed to get set elements") + + uniquePrefixes := make(map[string]bool) + for _, elem := range elements { + if !elem.IntervalEnd && len(elem.Key) == 16 { + ip := netip.AddrFrom16([16]byte(elem.Key)) + uniquePrefixes[ip.String()] = true + } + } + + expectedCount := len(tt.expected) + if expectedCount == 0 { + expectedCount = len(tt.sources) + } + assert.Equal(t, expectedCount, len(uniquePrefixes), "unique prefix count mismatch") + + r.conn.DelSet(set) + require.NoError(t, r.conn.Flush()) + }) + } +} + +func createWorkTableIPv6() (*nftables.Table, error) { + sConn, err := nftables.New(nftables.AsLasting()) + if err != nil { + return nil, err + } + + tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6) + if err != nil { + return nil, err + } + for _, t := range tables { + if t.Name == tableNameNetbird { + sConn.DelTable(t) + } + } + + table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv6}) + err = sConn.Flush() + return table, err +} + +func deleteWorkTableIPv6() { + sConn, err := nftables.New(nftables.AsLasting()) + if err != nil { + return + } + + tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6) + if err != nil { + return + } + for _, t := range tables { + if t.Name == tableNameNetbird { + sConn.DelTable(t) + _ = sConn.Flush() + } + } +} + func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) { t.Helper() @@ -627,7 +758,7 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool { func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool { var metaFound, cmpFound bool - expectedProto, _ := protoToInt(proto) + expectedProto, _ := afIPv4.protoNum(proto) for _, e := range exprs { switch ex := e.(type) { case *expr.Meta: @@ -854,3 +985,55 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) { } assert.Equal(t, 1, found, "NAT rule should exist in kernel") } + +func TestCalculateLastIP(t *testing.T) { + tests := []struct { + prefix string + want string + }{ + {"10.0.0.0/24", "10.0.0.255"}, + {"10.0.0.0/32", "10.0.0.0"}, + {"0.0.0.0/0", "255.255.255.255"}, + {"192.168.1.0/28", "192.168.1.15"}, + {"fd00::/64", "fd00::ffff:ffff:ffff:ffff"}, + {"fd00::/128", "fd00::"}, + {"2001:db8::/48", "2001:db8:0:ffff:ffff:ffff:ffff:ffff"}, + {"::/0", "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}, + } + for _, tt := range tests { + t.Run(tt.prefix, func(t *testing.T) { + prefix := netip.MustParsePrefix(tt.prefix) + got := calculateLastIP(prefix) + assert.Equal(t, tt.want, got.String()) + }) + } +} + +func TestConvertPrefixesToSet_IPv6(t *testing.T) { + r := &router{af: afIPv6} + prefixes := []netip.Prefix{ + netip.MustParsePrefix("fd00::/64"), + netip.MustParsePrefix("2001:db8::1/128"), + } + + elements := r.convertPrefixesToSet(prefixes) + + // Each prefix produces 2 elements (start + end) + require.Len(t, elements, 4) + + // fd00::/64 start + assert.Equal(t, netip.MustParseAddr("fd00::").As16(), [16]byte(elements[0].Key)) + assert.False(t, elements[0].IntervalEnd) + + // fd00::/64 end (fd00:0:0:1::, one past the last) + assert.Equal(t, netip.MustParseAddr("fd00:0:0:1::").As16(), [16]byte(elements[1].Key)) + assert.True(t, elements[1].IntervalEnd) + + // 2001:db8::1/128 start + assert.Equal(t, netip.MustParseAddr("2001:db8::1").As16(), [16]byte(elements[2].Key)) + assert.False(t, elements[2].IntervalEnd) + + // 2001:db8::1/128 end (2001:db8::2) + assert.Equal(t, netip.MustParseAddr("2001:db8::2").As16(), [16]byte(elements[3].Key)) + assert.True(t, elements[3].IntervalEnd) +} diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index 6a6533344..b120cdf12 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -3,6 +3,9 @@ package uspfilter import ( + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/firewall/firewalld" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -16,6 +19,9 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { if m.nativeFirewall != nil { return m.nativeFirewall.Close(stateManager) } + if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil { + log.Warnf("failed to untrust interface in firewalld: %v", err) + } return nil } @@ -24,5 +30,8 @@ func (m *Manager) AllowNetbird() error { if m.nativeFirewall != nil { return m.nativeFirewall.AllowNetbird() } + if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil { + log.Warnf("failed to trust interface in firewalld: %v", err) + } return nil } diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index 6aef2ecfd..10a2b9116 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -5,8 +5,10 @@ import ( "os/exec" "syscall" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -29,15 +31,20 @@ func (m *Manager) Close(*statemanager.Manager) error { return nil } - if !isFirewallRuleActive(firewallRuleName) { - return nil + var merr *multierror.Error + if isFirewallRuleActive(firewallRuleName) { + if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove windows firewall rule: %w", err)) + } } - if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil { - return fmt.Errorf("couldn't remove windows firewall: %w", err) + if isFirewallRuleActive(firewallRuleName + "-v6") { + if err := manageFirewallRule(firewallRuleName+"-v6", deleteRule); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove windows v6 firewall rule: %w", err)) + } } - return nil + return nberrors.FormatErrorOrNil(merr) } // AllowNetbird allows netbird interface traffic @@ -46,17 +53,33 @@ func (m *Manager) AllowNetbird() error { return nil } - if isFirewallRuleActive(firewallRuleName) { - return nil + if !isFirewallRuleActive(firewallRuleName) { + if err := manageFirewallRule(firewallRuleName, + addRule, + "dir=in", + "enable=yes", + "action=allow", + "profile=any", + "localip="+m.wgIface.Address().IP.String(), + ); err != nil { + return err + } } - return manageFirewallRule(firewallRuleName, - addRule, - "dir=in", - "enable=yes", - "action=allow", - "profile=any", - "localip="+m.wgIface.Address().IP.String(), - ) + + if v6 := m.wgIface.Address().IPv6; v6.IsValid() && !isFirewallRuleActive(firewallRuleName+"-v6") { + if err := manageFirewallRule(firewallRuleName+"-v6", + addRule, + "dir=in", + "enable=yes", + "action=allow", + "profile=any", + "localip="+v6.String(), + ); err != nil { + return err + } + } + + return nil } func manageFirewallRule(ruleName string, action action, extraArgs ...string) error { diff --git a/client/firewall/uspfilter/common/iface.go b/client/firewall/uspfilter/common/iface.go index 7296953db..9c06eb3f7 100644 --- a/client/firewall/uspfilter/common/iface.go +++ b/client/firewall/uspfilter/common/iface.go @@ -9,6 +9,7 @@ import ( // IFaceMapper defines subset methods of interface required for manager type IFaceMapper interface { + Name() string SetFilter(device.PacketFilter) error Address() wgaddr.Address GetWGDevice() *wgdevice.Device diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index 5ddfedccf..e497a0bff 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -1,7 +1,7 @@ package conntrack import ( - "fmt" + "net" "net/netip" "os" "strconv" @@ -111,5 +111,7 @@ type ConnKey struct { } func (c ConnKey) String() string { - return fmt.Sprintf("%s:%d → %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort) + return net.JoinHostPort(c.SrcIP.Unmap().String(), strconv.Itoa(int(c.SrcPort))) + + " → " + + net.JoinHostPort(c.DstIP.Unmap().String(), strconv.Itoa(int(c.DstPort))) } diff --git a/client/firewall/uspfilter/conntrack/common_test.go b/client/firewall/uspfilter/conntrack/common_test.go index d868dd1fb..7e67b98fa 100644 --- a/client/firewall/uspfilter/conntrack/common_test.go +++ b/client/firewall/uspfilter/conntrack/common_test.go @@ -13,6 +13,54 @@ import ( var logger = log.NewFromLogrus(logrus.StandardLogger()) var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger() +func TestConnKey_String(t *testing.T) { + tests := []struct { + name string + key ConnKey + expect string + }{ + { + name: "IPv4", + key: ConnKey{ + SrcIP: netip.MustParseAddr("192.168.1.1"), + DstIP: netip.MustParseAddr("10.0.0.1"), + SrcPort: 12345, + DstPort: 80, + }, + expect: "192.168.1.1:12345 → 10.0.0.1:80", + }, + { + name: "IPv6", + key: ConnKey{ + SrcIP: netip.MustParseAddr("2001:db8::1"), + DstIP: netip.MustParseAddr("2001:db8::2"), + SrcPort: 54321, + DstPort: 443, + }, + expect: "[2001:db8::1]:54321 → [2001:db8::2]:443", + }, + { + name: "IPv4-mapped IPv6 unmaps", + key: ConnKey{ + SrcIP: netip.MustParseAddr("::ffff:10.0.0.1"), + DstIP: netip.MustParseAddr("::ffff:10.0.0.2"), + SrcPort: 1000, + DstPort: 2000, + }, + expect: "10.0.0.1:1000 → 10.0.0.2:2000", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := tc.key.String() + if got != tc.expect { + t.Errorf("got %q, want %q", got, tc.expect) + } + }) + } +} + // Memory pressure tests func BenchmarkMemoryPressure(b *testing.B) { b.Run("TCPHighLoad", func(b *testing.B) { diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index 2fd37145a..3c96548b5 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/netip" + "strconv" "sync" "time" @@ -21,9 +22,14 @@ const ( // ICMPCleanupInterval is how often we check for stale ICMP connections ICMPCleanupInterval = 15 * time.Second - // MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info, - // which includes the IP header (20 bytes) and transport header (8 bytes) - MaxICMPPayloadLength = 28 + // MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info. + // IPv4: 20-byte header + 8-byte transport = 28 bytes. + // IPv6: 40-byte header + 8-byte transport = 48 bytes. + MaxICMPPayloadLength = 48 + // minICMPPayloadIPv4 is the minimum embedded packet length for IPv4 ICMP errors. + minICMPPayloadIPv4 = 28 + // minICMPPayloadIPv6 is the minimum embedded packet length for IPv6 ICMP errors. + minICMPPayloadIPv6 = 48 ) // ICMPConnKey uniquely identifies an ICMP connection @@ -69,7 +75,7 @@ type ICMPInfo struct { // String implements fmt.Stringer for lazy evaluation in log messages func (info ICMPInfo) String() string { - if info.isErrorMessage() && info.PayloadLen >= MaxICMPPayloadLength { + if info.isErrorMessage() && info.PayloadLen >= minICMPPayloadIPv4 { if origInfo := info.parseOriginalPacket(); origInfo != "" { return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo) } @@ -78,42 +84,72 @@ func (info ICMPInfo) String() string { return info.TypeCode.String() } -// isErrorMessage returns true if this ICMP type carries original packet info +// isErrorMessage returns true if this ICMP type carries original packet info. +// Covers both ICMPv4 and ICMPv6 error types. Without a family field we match +// both sets; type 3 overlaps (v4 DestUnreachable / v6 TimeExceeded) so it's +// kept as a literal. func (info ICMPInfo) isErrorMessage() bool { typ := info.TypeCode.Type() - return typ == 3 || // Destination Unreachable - typ == 5 || // Redirect - typ == 11 || // Time Exceeded - typ == 12 // Parameter Problem + // ICMPv4 error types + if typ == layers.ICMPv4TypeDestinationUnreachable || + typ == layers.ICMPv4TypeRedirect || + typ == layers.ICMPv4TypeTimeExceeded || + typ == layers.ICMPv4TypeParameterProblem { + return true + } + // ICMPv6 error types (type 3 already matched above as v4 DestUnreachable) + if typ == layers.ICMPv6TypeDestinationUnreachable || + typ == layers.ICMPv6TypePacketTooBig || + typ == layers.ICMPv6TypeParameterProblem { + return true + } + return false } // parseOriginalPacket extracts info about the original packet from ICMP payload func (info ICMPInfo) parseOriginalPacket() string { - if info.PayloadLen < MaxICMPPayloadLength { + if info.PayloadLen == 0 { return "" } - // TODO: handle IPv6 - if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 { + version := (info.PayloadData[0] >> 4) & 0xF + + var protocol uint8 + var srcIP, dstIP net.IP + var transportData []byte + + switch version { + case 4: + if info.PayloadLen < minICMPPayloadIPv4 { + return "" + } + protocol = info.PayloadData[9] + srcIP = net.IP(info.PayloadData[12:16]) + dstIP = net.IP(info.PayloadData[16:20]) + transportData = info.PayloadData[20:] + case 6: + if info.PayloadLen < minICMPPayloadIPv6 { + return "" + } + // Next Header field in IPv6 header + protocol = info.PayloadData[6] + srcIP = net.IP(info.PayloadData[8:24]) + dstIP = net.IP(info.PayloadData[24:40]) + transportData = info.PayloadData[40:] + default: return "" } - protocol := info.PayloadData[9] - srcIP := net.IP(info.PayloadData[12:16]) - dstIP := net.IP(info.PayloadData[16:20]) - - transportData := info.PayloadData[20:] - switch nftypes.Protocol(protocol) { case nftypes.TCP: srcPort := uint16(transportData[0])<<8 | uint16(transportData[1]) dstPort := uint16(transportData[2])<<8 | uint16(transportData[3]) - return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort) + return "TCP " + net.JoinHostPort(srcIP.String(), strconv.Itoa(int(srcPort))) + " → " + net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort))) case nftypes.UDP: srcPort := uint16(transportData[0])<<8 | uint16(transportData[1]) dstPort := uint16(transportData[2])<<8 | uint16(transportData[3]) - return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort) + return "UDP " + net.JoinHostPort(srcIP.String(), strconv.Itoa(int(srcPort))) + " → " + net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort))) case nftypes.ICMP: icmpType := transportData[0] @@ -259,9 +295,10 @@ func (t *ICMPTracker) track( t.sendEvent(nftypes.TypeStart, conn, ruleId) } -// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request +// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request. +// Accepts both ICMPv4 (type 0) and ICMPv6 (type 129) echo replies. func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool { - if icmpType != uint8(layers.ICMPv4TypeEchoReply) { + if icmpType != uint8(layers.ICMPv4TypeEchoReply) && icmpType != uint8(layers.ICMPv6TypeEchoReply) { return false } @@ -343,6 +380,13 @@ func (t *ICMPTracker) cleanup() { } } +func icmpProtocolForAddr(ip netip.Addr) nftypes.Protocol { + if ip.Is6() { + return nftypes.ICMPv6 + } + return nftypes.ICMP +} + // Close stops the cleanup routine and releases resources func (t *ICMPTracker) Close() { t.tickerCancel() @@ -358,7 +402,7 @@ func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID [] Type: typ, RuleID: ruleID, Direction: conn.Direction, - Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6 + Protocol: icmpProtocolForAddr(conn.SourceIP), SourceIP: conn.SourceIP, DestIP: conn.DestIP, ICMPType: conn.ICMPType, @@ -376,7 +420,7 @@ func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Ad Type: nftypes.TypeStart, RuleID: ruleID, Direction: direction, - Protocol: nftypes.ICMP, + Protocol: icmpProtocolForAddr(srcIP), SourceIP: srcIP, DestIP: dstIP, ICMPType: typ, diff --git a/client/firewall/uspfilter/conntrack/icmp_test.go b/client/firewall/uspfilter/conntrack/icmp_test.go index b15b42cf0..6d1f87162 100644 --- a/client/firewall/uspfilter/conntrack/icmp_test.go +++ b/client/firewall/uspfilter/conntrack/icmp_test.go @@ -5,6 +5,42 @@ import ( "testing" ) +func TestICMPConnKey_String(t *testing.T) { + tests := []struct { + name string + key ICMPConnKey + expect string + }{ + { + name: "IPv4", + key: ICMPConnKey{ + SrcIP: netip.MustParseAddr("192.168.1.1"), + DstIP: netip.MustParseAddr("10.0.0.1"), + ID: 1234, + }, + expect: "192.168.1.1 → 10.0.0.1 (id 1234)", + }, + { + name: "IPv6", + key: ICMPConnKey{ + SrcIP: netip.MustParseAddr("2001:db8::1"), + DstIP: netip.MustParseAddr("2001:db8::2"), + ID: 5678, + }, + expect: "2001:db8::1 → 2001:db8::2 (id 5678)", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := tc.key.String() + if got != tc.expect { + t.Errorf("got %q, want %q", got, tc.expect) + } + }) + } +} + func BenchmarkICMPTracker(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) { tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger) diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 1d4dcb1e5..91866dcab 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -18,9 +18,10 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/google/uuid" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" + nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/common" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" @@ -35,8 +36,10 @@ import ( const ( layerTypeAll = 255 - // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation - ipTCPHeaderMinSize = 40 + // ipv4TCPHeaderMinSize represents minimum IPv4 (20) + TCP (20) header size for MSS calculation + ipv4TCPHeaderMinSize = 40 + // ipv6TCPHeaderMinSize represents minimum IPv6 (40) + TCP (20) header size for MSS calculation + ipv6TCPHeaderMinSize = 60 ) // serviceKey represents a protocol/port combination for netstack service registry @@ -115,14 +118,15 @@ type Manager struct { localipmanager *localIPManager - udpTracker *conntrack.UDPTracker - icmpTracker *conntrack.ICMPTracker - tcpTracker *conntrack.TCPTracker - forwarder atomic.Pointer[forwarder.Forwarder] - logger *nblog.Logger - flowLogger nftypes.FlowLogger + udpTracker *conntrack.UDPTracker + icmpTracker *conntrack.ICMPTracker + tcpTracker *conntrack.TCPTracker + forwarder atomic.Pointer[forwarder.Forwarder] + pendingCapture atomic.Pointer[forwarder.PacketCapture] + logger *nblog.Logger + flowLogger nftypes.FlowLogger - blockRule firewall.Rule + blockRules []firewall.Rule // Internal 1:1 DNAT dnatEnabled atomic.Bool @@ -137,9 +141,10 @@ type Manager struct { netstackServices map[serviceKey]struct{} netstackServiceMutex sync.RWMutex - mtu uint16 - mssClampValue uint16 - mssClampEnabled bool + mtu uint16 + mssClampValueIPv4 uint16 + mssClampValueIPv6 uint16 + mssClampEnabled bool // Only one hook per protocol is supported. Outbound direction only. udpHookOut atomic.Pointer[common.PacketHook] @@ -156,11 +161,28 @@ type decoder struct { icmp4 layers.ICMPv4 icmp6 layers.ICMPv6 decoded []gopacket.LayerType - parser *gopacket.DecodingLayerParser + parser4 *gopacket.DecodingLayerParser + parser6 *gopacket.DecodingLayerParser dnatOrigPort uint16 } +// decodePacket decodes packet data using the appropriate parser based on IP version. +func (d *decoder) decodePacket(data []byte) error { + if len(data) == 0 { + return errors.New("empty packet") + } + version := data[0] >> 4 + switch version { + case 4: + return d.parser4.DecodeLayers(data, &d.decoded) + case 6: + return d.parser6.DecodeLayers(data, &d.decoded) + default: + return fmt.Errorf("unknown IP version %d", version) + } +} + // Create userspace firewall manager constructor func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) { return create(iface, nil, disableServerRoutes, flowLogger, mtu) @@ -218,11 +240,17 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe d := &decoder{ decoded: []gopacket.LayerType{}, } - d.parser = gopacket.NewDecodingLayerParser( + d.parser4 = gopacket.NewDecodingLayerParser( layers.LayerTypeIPv4, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, ) - d.parser.IgnoreUnsupported = true + d.parser4.IgnoreUnsupported = true + + d.parser6 = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv6, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser6.IgnoreUnsupported = true return d }, }, @@ -248,7 +276,12 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe if !disableMSSClamping { m.mssClampEnabled = true - m.mssClampValue = mtu - ipTCPHeaderMinSize + if mtu > ipv4TCPHeaderMinSize { + m.mssClampValueIPv4 = mtu - ipv4TCPHeaderMinSize + } + if mtu > ipv6TCPHeaderMinSize { + m.mssClampValueIPv6 = mtu - ipv6TCPHeaderMinSize + } } if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { return nil, fmt.Errorf("update local IPs: %w", err) @@ -271,13 +304,25 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe return m, nil } -func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) { +// blockInvalidRouted installs drop rules for traffic to the wg overlay that +// arrives via the routing path. v4 and v6 are independent: a v6 install +// failure leaves v4 protection in place (and vice versa) so the returned +// slice always contains whatever was successfully installed, even on error. +// Callers must persist the slice so DisableRouting can clean partial state. +func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule, error) { wgPrefix := iface.Address().Network log.Debugf("blocking invalid routed traffic for %s", wgPrefix) - rule, err := m.addRouteFiltering( + sources := []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)} + v6Net := iface.Address().IPv6Net + if v6Net.IsValid() { + sources = append(sources, netip.PrefixFrom(netip.IPv6Unspecified(), 0)) + } + + var rules []firewall.Rule + v4Rule, err := m.addRouteFiltering( nil, - []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}, + sources, firewall.Network{Prefix: wgPrefix}, firewall.ProtocolALL, nil, @@ -285,12 +330,30 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, e firewall.ActionDrop, ) if err != nil { - return nil, fmt.Errorf("block wg nte : %w", err) + return rules, fmt.Errorf("block wg v4 net: %w", err) + } + rules = append(rules, v4Rule) + + if v6Net.IsValid() { + log.Debugf("blocking invalid routed traffic for %s", v6Net) + v6Rule, err := m.addRouteFiltering( + nil, + sources, + firewall.Network{Prefix: v6Net}, + firewall.ProtocolALL, + nil, + nil, + firewall.ActionDrop, + ) + if err != nil { + return rules, fmt.Errorf("block wg v6 net: %w", err) + } + rules = append(rules, v6Rule) } // TODO: Block networks that we're a client of - return rule, nil + return rules, nil } func (m *Manager) determineRouting() error { @@ -351,6 +414,19 @@ func (m *Manager) determineRouting() error { return nil } +// SetPacketCapture sets or clears packet capture on the forwarder endpoint. +// This captures outbound response packets that bypass the FilteredDevice in netstack mode. +func (m *Manager) SetPacketCapture(pc forwarder.PacketCapture) { + if pc == nil { + m.pendingCapture.Store(nil) + } else { + m.pendingCapture.Store(&pc) + } + if fwder := m.forwarder.Load(); fwder != nil { + fwder.SetCapture(pc) + } +} + // initForwarder initializes the forwarder, it disables routing on errors func (m *Manager) initForwarder() error { if m.forwarder.Load() != nil { @@ -372,6 +448,11 @@ func (m *Manager) initForwarder() error { m.forwarder.Store(forwarder) + // Re-load after store: a concurrent SetPacketCapture may have seen forwarder as nil and only updated pendingCapture. + if pc := m.pendingCapture.Load(); pc != nil { + forwarder.SetCapture(*pc) + } + log.Debug("forwarder initialized") return nil @@ -502,7 +583,7 @@ func (m *Manager) addRouteFiltering( mgmtId: id, sources: sources, dstSet: destination.Set, - protoLayer: protoToLayer(proto, layers.LayerTypeIPv4), + protoLayer: protoToLayer(proto, ipLayerFromPrefix(destination.Prefix)), srcPort: sPort, dstPort: dPort, action: action, @@ -593,10 +674,10 @@ func (m *Manager) Flush() error { return nil } // resetState clears all firewall rules and closes connection trackers. // Must be called with m.mutex held. func (m *Manager) resetState() { - maps.Clear(m.outgoingRules) - maps.Clear(m.incomingDenyRules) - maps.Clear(m.incomingRules) - maps.Clear(m.routeRulesMap) + clear(m.outgoingRules) + clear(m.incomingDenyRules) + clear(m.incomingRules) + clear(m.routeRulesMap) m.routeRules = m.routeRules[:0] m.udpHookOut.Store(nil) m.tcpHookOut.Store(nil) @@ -614,6 +695,7 @@ func (m *Manager) resetState() { } if fwder := m.forwarder.Load(); fwder != nil { + fwder.SetCapture(nil) fwder.Stop() } @@ -656,11 +738,7 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { } destinations := matches[0].destinations - for _, prefix := range prefixes { - if prefix.Addr().Is4() { - destinations = append(destinations, prefix) - } - } + destinations = append(destinations, prefixes...) slices.SortFunc(destinations, func(a, b netip.Prefix) int { cmp := a.Addr().Compare(b.Addr()) @@ -699,7 +777,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool { d := m.decoders.Get().(*decoder) defer m.decoders.Put(d) - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { return false } @@ -785,12 +863,32 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool { return false } + var mssClampValue uint16 + var ipHeaderSize int + switch d.decoded[0] { + case layers.LayerTypeIPv4: + mssClampValue = m.mssClampValueIPv4 + ipHeaderSize = int(d.ip4.IHL) * 4 + if ipHeaderSize < 20 { + return false + } + case layers.LayerTypeIPv6: + mssClampValue = m.mssClampValueIPv6 + ipHeaderSize = 40 + default: + return false + } + + if mssClampValue == 0 { + return false + } + mssOptionIndex := -1 var currentMSS uint16 for i, opt := range d.tcp.Options { if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 { currentMSS = binary.BigEndian.Uint16(opt.OptionData) - if currentMSS > m.mssClampValue { + if currentMSS > mssClampValue { mssOptionIndex = i break } @@ -801,22 +899,17 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool { return false } - ipHeaderSize := int(d.ip4.IHL) * 4 - if ipHeaderSize < 20 { - return false - } - - if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) { + if !m.updateMSSOption(packetData, d, mssOptionIndex, mssClampValue, ipHeaderSize) { return false } if m.logger.Enabled(nblog.LevelTrace) { - m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue) + m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue) } return true } -func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool { +func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex int, mssClampValue uint16, ipHeaderSize int) bool { tcpHeaderStart := ipHeaderSize tcpOptionsStart := tcpHeaderStart + 20 @@ -831,7 +924,7 @@ func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, } mssValueOffset := optOffset + 2 - binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue) + binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], mssClampValue) m.recalculateTCPChecksum(packetData, d, tcpHeaderStart) return true @@ -841,18 +934,32 @@ func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeade tcpLayer := packetData[tcpHeaderStart:] tcpLength := len(packetData) - tcpHeaderStart + // Zero out existing checksum tcpLayer[16] = 0 tcpLayer[17] = 0 + // Build pseudo-header checksum based on IP version var pseudoSum uint32 - pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1]) - pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3]) - pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1]) - pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3]) - pseudoSum += uint32(d.ip4.Protocol) - pseudoSum += uint32(tcpLength) + switch d.decoded[0] { + case layers.LayerTypeIPv4: + pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1]) + pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3]) + pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1]) + pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3]) + pseudoSum += uint32(d.ip4.Protocol) + pseudoSum += uint32(tcpLength) + case layers.LayerTypeIPv6: + for i := 0; i < 16; i += 2 { + pseudoSum += uint32(d.ip6.SrcIP[i])<<8 | uint32(d.ip6.SrcIP[i+1]) + } + for i := 0; i < 16; i += 2 { + pseudoSum += uint32(d.ip6.DstIP[i])<<8 | uint32(d.ip6.DstIP[i+1]) + } + pseudoSum += uint32(tcpLength) + pseudoSum += uint32(layers.IPProtocolTCP) + } - var sum = pseudoSum + sum := pseudoSum for i := 0; i < tcpLength-1; i += 2 { sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1]) } @@ -890,6 +997,9 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData } case layers.LayerTypeICMPv4: m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size) + case layers.LayerTypeICMPv6: + id, tc := icmpv6EchoFields(d) + m.icmpTracker.TrackOutbound(srcIP, dstIP, id, tc, d.icmp6.Payload, size) } } @@ -903,6 +1013,9 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort) case layers.LayerTypeICMPv4: m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size) + case layers.LayerTypeICMPv6: + id, tc := icmpv6EchoFields(d) + m.icmpTracker.TrackInbound(srcIP, dstIP, id, tc, ruleID, d.icmp6.Payload, size) } d.dnatOrigPort = 0 @@ -936,8 +1049,12 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { // TODO: pass fragments of routed packets to forwarder if fragment { if m.logger.Enabled(nblog.LevelTrace) { - m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v", - srcIP, dstIP, d.ip4.Id, d.ip4.Flags) + if d.decoded[0] == layers.LayerTypeIPv4 { + m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v", + srcIP, dstIP, d.ip4.Id, d.ip4.Flags) + } else { + m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP) + } } return false } @@ -945,7 +1062,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { // TODO: optimize port DNAT by caching matched rules in conntrack if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated { // Re-decode after port DNAT translation to update port information - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { m.logger.Error1("failed to re-decode packet after port DNAT: %v", err) return true } @@ -954,7 +1071,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { if translated := m.translateInboundReverse(packetData, d); translated { // Re-decode after translation to get original addresses - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err) return true } @@ -1092,6 +1209,48 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe return true } +// icmpv6EchoFields extracts the echo identifier from an ICMPv6 packet and maps +// the ICMPv6 type code to an ICMPv4TypeCode so the ICMP conntrack can handle +// both families uniformly. The echo ID is in the first two payload bytes. +func icmpv6EchoFields(d *decoder) (id uint16, tc layers.ICMPv4TypeCode) { + if len(d.icmp6.Payload) >= 2 { + id = uint16(d.icmp6.Payload[0])<<8 | uint16(d.icmp6.Payload[1]) + } + // Map ICMPv6 echo types to ICMPv4 equivalents for unified tracking. + switch d.icmp6.TypeCode.Type() { + case layers.ICMPv6TypeEchoRequest: + tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0) + case layers.ICMPv6TypeEchoReply: + tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoReply, 0) + default: + tc = layers.CreateICMPv4TypeCode(d.icmp6.TypeCode.Type(), d.icmp6.TypeCode.Code()) + } + return id, tc +} + +// protoLayerMatches checks if a packet's protocol layer matches a rule's expected +// protocol layer. ICMPv4 and ICMPv6 are treated as equivalent when matching +// ICMP rules since management sends a single ICMP rule for both families. +func protoLayerMatches(ruleLayer, packetLayer gopacket.LayerType) bool { + if ruleLayer == packetLayer { + return true + } + if ruleLayer == layers.LayerTypeICMPv4 && packetLayer == layers.LayerTypeICMPv6 { + return true + } + if ruleLayer == layers.LayerTypeICMPv6 && packetLayer == layers.LayerTypeICMPv4 { + return true + } + return false +} + +func ipLayerFromPrefix(p netip.Prefix) gopacket.LayerType { + if p.Addr().Is6() { + return layers.LayerTypeIPv6 + } + return layers.LayerTypeIPv4 +} + func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType { switch proto { case firewall.ProtocolTCP: @@ -1115,8 +1274,10 @@ func getProtocolFromPacket(d *decoder) nftypes.Protocol { return nftypes.TCP case layers.LayerTypeUDP: return nftypes.UDP - case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: + case layers.LayerTypeICMPv4: return nftypes.ICMP + case layers.LayerTypeICMPv6: + return nftypes.ICMPv6 default: return nftypes.ProtocolUnknown } @@ -1137,7 +1298,7 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) { // It returns true, false if the packet is valid and not a fragment. // It returns true, true if the packet is a fragment and valid. func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) { - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { if m.logger.Enabled(nblog.LevelTrace) { m.logger.Trace1("couldn't decode packet, err: %s", err) } @@ -1152,10 +1313,21 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) { } // Fragments are also valid - if l == 1 && d.decoded[0] == layers.LayerTypeIPv4 { - ip4 := d.ip4 - if ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0 { - return true, true + if l == 1 { + switch d.decoded[0] { + case layers.LayerTypeIPv4: + if d.ip4.Flags&layers.IPv4MoreFragments != 0 || d.ip4.FragOffset != 0 { + return true, true + } + case layers.LayerTypeIPv6: + // IPv6 uses Fragment extension header (NextHeader=44). If gopacket + // only decoded the IPv6 layer, the transport is in a fragment. + // TODO: handle non-Fragment extension headers (HopByHop, Routing, + // DestOpts) by walking the chain. gopacket's parser does not + // support them as DecodingLayers; today we drop such packets. + if d.ip6.NextHeader == layers.IPProtocolIPv6Fragment { + return true, true + } } } @@ -1193,21 +1365,35 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size, ) - // TODO: ICMPv6 + case layers.LayerTypeICMPv6: + id, _ := icmpv6EchoFields(d) + return m.icmpTracker.IsValidInbound( + srcIP, + dstIP, + id, + d.icmp6.TypeCode.Type(), + size, + ) } return false } -// isSpecialICMP returns true if the packet is a special ICMP packet that should be allowed +// isSpecialICMP returns true if the packet is a special ICMP error packet that should be allowed. func (m *Manager) isSpecialICMP(d *decoder) bool { - if d.decoded[1] != layers.LayerTypeICMPv4 { - return false + switch d.decoded[1] { + case layers.LayerTypeICMPv4: + icmpType := d.icmp4.TypeCode.Type() + return icmpType == layers.ICMPv4TypeDestinationUnreachable || + icmpType == layers.ICMPv4TypeTimeExceeded + case layers.LayerTypeICMPv6: + icmpType := d.icmp6.TypeCode.Type() + return icmpType == layers.ICMPv6TypeDestinationUnreachable || + icmpType == layers.ICMPv6TypePacketTooBig || + icmpType == layers.ICMPv6TypeTimeExceeded || + icmpType == layers.ICMPv6TypeParameterProblem } - - icmpType := d.icmp4.TypeCode.Type() - return icmpType == layers.ICMPv4TypeDestinationUnreachable || - icmpType == layers.ICMPv4TypeTimeExceeded + return false } func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte) ([]byte, bool) { @@ -1264,7 +1450,7 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d return rule.mgmtId, rule.drop, true } - if payloadLayer != rule.protoLayer { + if !protoLayerMatches(rule.protoLayer, payloadLayer) { continue } @@ -1299,8 +1485,7 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.Lay } func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool { - // TODO: handle ipv6 vs ipv4 icmp rules - if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer { + if rule.protoLayer != layerTypeAll && !protoLayerMatches(rule.protoLayer, protoLayer) { return false } @@ -1361,13 +1546,14 @@ func (m *Manager) EnableRouting() error { return nil } - rule, err := m.blockInvalidRouted(m.wgIface) + rules, err := m.blockInvalidRouted(m.wgIface) + // Persist whatever was installed even on partial failure, so DisableRouting + // can clean it up later. + m.blockRules = rules if err != nil { return fmt.Errorf("block invalid routed: %w", err) } - m.blockRule = rule - return nil } @@ -1383,9 +1569,16 @@ func (m *Manager) DisableRouting() error { m.routingEnabled.Store(false) m.nativeRouter.Store(false) - // don't stop forwarder if in use by netstack + var merr *multierror.Error + for _, rule := range m.blockRules { + if err := m.deleteRouteRule(rule); err != nil { + merr = multierror.Append(merr, fmt.Errorf("delete block rule: %w", err)) + } + } + m.blockRules = nil + if m.netstack && m.localForwarding { - return nil + return nberrors.FormatErrorOrNil(merr) } fwder.Stop() @@ -1393,14 +1586,7 @@ func (m *Manager) DisableRouting() error { log.Debug("forwarder stopped") - if m.blockRule != nil { - if err := m.deleteRouteRule(m.blockRule); err != nil { - return fmt.Errorf("delete block rule: %w", err) - } - m.blockRule = nil - } - - return nil + return nberrors.FormatErrorOrNil(merr) } // RegisterNetstackService registers a service as listening on the netstack for the given protocol and port @@ -1454,7 +1640,8 @@ func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool { } // traffic to our other local interfaces (not NetBird IP) - always forward - if dstIP != m.wgIface.Address().IP { + addr := m.wgIface.Address() + if dstIP != addr.IP && (!addr.IPv6.IsValid() || dstIP != addr.IPv6) { return true } diff --git a/client/firewall/uspfilter/filter_bench_test.go b/client/firewall/uspfilter/filter_bench_test.go index 10ff62ed3..4dccb0f65 100644 --- a/client/firewall/uspfilter/filter_bench_test.go +++ b/client/firewall/uspfilter/filter_bench_test.go @@ -1023,7 +1023,8 @@ func BenchmarkMSSClamping(b *testing.B) { }() manager.mssClampEnabled = true - manager.mssClampValue = 1240 + manager.mssClampValueIPv4 = 1240 + manager.mssClampValueIPv6 = 1220 srcIP := net.ParseIP("100.64.0.2") dstIP := net.ParseIP("8.8.8.8") @@ -1088,7 +1089,8 @@ func BenchmarkMSSClampingOverhead(b *testing.B) { manager.mssClampEnabled = sc.enabled if sc.enabled { - manager.mssClampValue = 1240 + manager.mssClampValueIPv4 = 1240 + manager.mssClampValueIPv6 = 1220 } srcIP := net.ParseIP("100.64.0.2") @@ -1141,7 +1143,8 @@ func BenchmarkMSSClampingMemory(b *testing.B) { }() manager.mssClampEnabled = true - manager.mssClampValue = 1240 + manager.mssClampValueIPv4 = 1240 + manager.mssClampValueIPv6 = 1220 srcIP := net.ParseIP("100.64.0.2") dstIP := net.ParseIP("8.8.8.8") diff --git a/client/firewall/uspfilter/filter_filter_test.go b/client/firewall/uspfilter/filter_filter_test.go index a8efbac1c..a64c83138 100644 --- a/client/firewall/uspfilter/filter_filter_test.go +++ b/client/firewall/uspfilter/filter_filter_test.go @@ -539,53 +539,236 @@ func TestPeerACLFiltering(t *testing.T) { } } +func TestPeerACLFilteringIPv6(t *testing.T) { + localIP := netip.MustParseAddr("100.10.0.100") + localIPv6 := netip.MustParseAddr("fd00::100") + wgNet := netip.MustParsePrefix("100.10.0.0/16") + wgNetV6 := netip.MustParsePrefix("fd00::/64") + + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ + IP: localIP, + Network: wgNet, + IPv6: localIPv6, + IPv6Net: wgNetV6, + } + }, + } + + manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, manager.Close(nil)) }) + + err = manager.UpdateLocalIPs() + require.NoError(t, err) + + testCases := []struct { + name string + srcIP string + dstIP string + proto fw.Protocol + srcPort uint16 + dstPort uint16 + ruleIP string + ruleProto fw.Protocol + ruleDstPort *fw.Port + ruleAction fw.Action + shouldBeBlocked bool + }{ + { + name: "IPv6: allow TCP from peer", + srcIP: "fd00::1", + dstIP: "fd00::100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "fd00::1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{443}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "IPv6: allow UDP from peer", + srcIP: "fd00::1", + dstIP: "fd00::100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 53, + ruleIP: "fd00::1", + ruleProto: fw.ProtocolUDP, + ruleDstPort: &fw.Port{Values: []uint16{53}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "IPv6: allow ICMPv6 from peer", + srcIP: "fd00::1", + dstIP: "fd00::100", + proto: fw.ProtocolICMP, + ruleIP: "fd00::1", + ruleProto: fw.ProtocolICMP, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "IPv6: block TCP without rule", + srcIP: "fd00::2", + dstIP: "fd00::100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "fd00::1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{443}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: true, + }, + { + name: "IPv6: drop rule", + srcIP: "fd00::1", + dstIP: "fd00::100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 22, + ruleIP: "fd00::1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{22}}, + ruleAction: fw.ActionDrop, + shouldBeBlocked: true, + }, + { + name: "IPv6: allow all protocols", + srcIP: "fd00::1", + dstIP: "fd00::100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 9999, + ruleIP: "fd00::1", + ruleProto: fw.ProtocolALL, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "IPv6: v4 wildcard ICMP rule matches ICMPv6 via protoLayerMatches", + srcIP: "fd00::1", + dstIP: "fd00::100", + proto: fw.ProtocolICMP, + ruleIP: "0.0.0.0", + ruleProto: fw.ProtocolICMP, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + } + + t.Run("IPv6 implicit DROP (no rules)", func(t *testing.T) { + packet := createTestPacket(t, "fd00::1", "fd00::100", fw.ProtocolTCP, 12345, 443) + isDropped := manager.FilterInbound(packet, 0) + require.True(t, isDropped, "IPv6 packet should be dropped when no rules exist") + }) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.ruleAction == fw.ActionDrop { + rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), fw.ProtocolALL, nil, nil, fw.ActionAccept, "") + require.NoError(t, err) + t.Cleanup(func() { + for _, rule := range rules { + require.NoError(t, manager.DeletePeerRule(rule)) + } + }) + } + + rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction, "") + require.NoError(t, err) + require.NotEmpty(t, rules) + t.Cleanup(func() { + for _, rule := range rules { + require.NoError(t, manager.DeletePeerRule(rule)) + } + }) + + packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort) + isDropped := manager.FilterInbound(packet, 0) + require.Equal(t, tc.shouldBeBlocked, isDropped, "packet filter result mismatch") + }) + } +} + func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcPort, dstPort uint16) []byte { t.Helper() + src := net.ParseIP(srcIP) + dst := net.ParseIP(dstIP) + buf := gopacket.NewSerializeBuffer() opts := gopacket.SerializeOptions{ ComputeChecksums: true, FixLengths: true, } - ipLayer := &layers.IPv4{ - Version: 4, - TTL: 64, - SrcIP: net.ParseIP(srcIP), - DstIP: net.ParseIP(dstIP), - } + // Detect address family + isV6 := src.To4() == nil var err error - switch proto { - case fw.ProtocolTCP: - ipLayer.Protocol = layers.IPProtocolTCP - tcp := &layers.TCP{ - SrcPort: layers.TCPPort(srcPort), - DstPort: layers.TCPPort(dstPort), - } - err = tcp.SetNetworkLayerForChecksum(ipLayer) - require.NoError(t, err) - err = gopacket.SerializeLayers(buf, opts, ipLayer, tcp) - case fw.ProtocolUDP: - ipLayer.Protocol = layers.IPProtocolUDP - udp := &layers.UDP{ - SrcPort: layers.UDPPort(srcPort), - DstPort: layers.UDPPort(dstPort), + if isV6 { + ip6 := &layers.IPv6{ + Version: 6, + HopLimit: 64, + SrcIP: src, + DstIP: dst, } - err = udp.SetNetworkLayerForChecksum(ipLayer) - require.NoError(t, err) - err = gopacket.SerializeLayers(buf, opts, ipLayer, udp) - case fw.ProtocolICMP: - ipLayer.Protocol = layers.IPProtocolICMPv4 - icmp := &layers.ICMPv4{ - TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0), + switch proto { + case fw.ProtocolTCP: + ip6.NextHeader = layers.IPProtocolTCP + tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)} + _ = tcp.SetNetworkLayerForChecksum(ip6) + err = gopacket.SerializeLayers(buf, opts, ip6, tcp) + case fw.ProtocolUDP: + ip6.NextHeader = layers.IPProtocolUDP + udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)} + _ = udp.SetNetworkLayerForChecksum(ip6) + err = gopacket.SerializeLayers(buf, opts, ip6, udp) + case fw.ProtocolICMP: + ip6.NextHeader = layers.IPProtocolICMPv6 + icmp := &layers.ICMPv6{ + TypeCode: layers.CreateICMPv6TypeCode(layers.ICMPv6TypeEchoRequest, 0), + } + _ = icmp.SetNetworkLayerForChecksum(ip6) + err = gopacket.SerializeLayers(buf, opts, ip6, icmp) + default: + err = gopacket.SerializeLayers(buf, opts, ip6) + } + } else { + ip4 := &layers.IPv4{ + Version: 4, + TTL: 64, + SrcIP: src, + DstIP: dst, } - err = gopacket.SerializeLayers(buf, opts, ipLayer, icmp) - default: - err = gopacket.SerializeLayers(buf, opts, ipLayer) + switch proto { + case fw.ProtocolTCP: + ip4.Protocol = layers.IPProtocolTCP + tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)} + _ = tcp.SetNetworkLayerForChecksum(ip4) + err = gopacket.SerializeLayers(buf, opts, ip4, tcp) + case fw.ProtocolUDP: + ip4.Protocol = layers.IPProtocolUDP + udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)} + _ = udp.SetNetworkLayerForChecksum(ip4) + err = gopacket.SerializeLayers(buf, opts, ip4, udp) + case fw.ProtocolICMP: + ip4.Protocol = layers.IPProtocolICMPv4 + icmp := &layers.ICMPv4{TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0)} + err = gopacket.SerializeLayers(buf, opts, ip4, icmp) + default: + err = gopacket.SerializeLayers(buf, opts, ip4) + } } require.NoError(t, err) @@ -1498,3 +1681,103 @@ func TestRouteACLSet(t *testing.T) { _, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) require.True(t, isAllowed, "After set update, traffic to the added network should be allowed") } + +// TestRouteACLFilteringIPv6 tests IPv6 route ACL matching directly via routeACLsPass. +// Note: full FilterInbound for routed IPv6 traffic drops at the forwarder stage (IPv4-only) +// but the ACL decision itself is correct. +func TestRouteACLFilteringIPv6(t *testing.T) { + manager := setupRoutedManager(t, "10.10.0.100/16") + + v6Dst := netip.MustParsePrefix("fd00:dead:beef::/48") + _, err := manager.AddRouteFiltering( + nil, + []netip.Prefix{netip.MustParsePrefix("fd00::/16")}, + fw.Network{Prefix: v6Dst}, + fw.ProtocolTCP, + nil, + &fw.Port{Values: []uint16{80}}, + fw.ActionAccept, + ) + require.NoError(t, err) + + _, err = manager.AddRouteFiltering( + nil, + []netip.Prefix{netip.MustParsePrefix("fd00::/16")}, + fw.Network{Prefix: netip.MustParsePrefix("fd00:dead:beef:1::/64")}, + fw.ProtocolALL, + nil, + nil, + fw.ActionDrop, + ) + require.NoError(t, err) + + tests := []struct { + name string + srcIP netip.Addr + dstIP netip.Addr + proto gopacket.LayerType + srcPort uint16 + dstPort uint16 + allowed bool + }{ + { + name: "IPv6 TCP to allowed dest", + srcIP: netip.MustParseAddr("fd00::1"), + dstIP: netip.MustParseAddr("fd00:dead:beef::80"), + proto: layers.LayerTypeTCP, + srcPort: 12345, + dstPort: 80, + allowed: true, + }, + { + name: "IPv6 TCP wrong port", + srcIP: netip.MustParseAddr("fd00::1"), + dstIP: netip.MustParseAddr("fd00:dead:beef::80"), + proto: layers.LayerTypeTCP, + srcPort: 12345, + dstPort: 443, + allowed: false, + }, + { + name: "IPv6 UDP not matched by TCP rule", + srcIP: netip.MustParseAddr("fd00::1"), + dstIP: netip.MustParseAddr("fd00:dead:beef::80"), + proto: layers.LayerTypeUDP, + srcPort: 12345, + dstPort: 80, + allowed: false, + }, + { + name: "IPv6 ICMPv6 matches ICMP rule via protoLayerMatches", + srcIP: netip.MustParseAddr("fd00::1"), + dstIP: netip.MustParseAddr("fd00:dead:beef::80"), + proto: layers.LayerTypeICMPv6, + allowed: false, + }, + { + name: "IPv6 to denied subnet", + srcIP: netip.MustParseAddr("fd00::1"), + dstIP: netip.MustParseAddr("fd00:dead:beef:1::1"), + proto: layers.LayerTypeTCP, + srcPort: 12345, + dstPort: 80, + allowed: false, + }, + { + name: "IPv6 source outside allowed range", + srcIP: netip.MustParseAddr("fe80::1"), + dstIP: netip.MustParseAddr("fd00:dead:beef::80"), + proto: layers.LayerTypeTCP, + srcPort: 12345, + dstPort: 80, + allowed: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, pass := manager.routeACLsPass(tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort) + require.Equal(t, tc.allowed, pass, "route ACL result mismatch") + }) + } +} diff --git a/client/firewall/uspfilter/filter_routeacl_test.go b/client/firewall/uspfilter/filter_routeacl_test.go index 68572a01c..449554d8b 100644 --- a/client/firewall/uspfilter/filter_routeacl_test.go +++ b/client/firewall/uspfilter/filter_routeacl_test.go @@ -189,21 +189,21 @@ func TestBlockInvalidRoutedIdempotent(t *testing.T) { }) // Call blockInvalidRouted directly multiple times - rule1, err := manager.blockInvalidRouted(ifaceMock) + rules1, err := manager.blockInvalidRouted(ifaceMock) require.NoError(t, err) - require.NotNil(t, rule1) + require.NotEmpty(t, rules1) - rule2, err := manager.blockInvalidRouted(ifaceMock) + rules2, err := manager.blockInvalidRouted(ifaceMock) require.NoError(t, err) - require.NotNil(t, rule2) + require.NotEmpty(t, rules2) - rule3, err := manager.blockInvalidRouted(ifaceMock) + rules3, err := manager.blockInvalidRouted(ifaceMock) require.NoError(t, err) - require.NotNil(t, rule3) + require.NotEmpty(t, rules3) - // All should return the same rule - assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule") - assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule") + // All calls should return the same v4 block rule (idempotent install). + assert.Equal(t, rules1[0].ID(), rules2[0].ID(), "Second call should return same v4 rule") + assert.Equal(t, rules2[0].ID(), rules3[0].ID(), "Third call should return same v4 rule") // Should have exactly 1 route rule manager.mutex.RLock() diff --git a/client/firewall/uspfilter/filter_test.go b/client/firewall/uspfilter/filter_test.go index 39e8efa2c..f19c4bb56 100644 --- a/client/firewall/uspfilter/filter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -31,12 +31,20 @@ var logger = log.NewFromLogrus(logrus.StandardLogger()) var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger() type IFaceMock struct { + NameFunc func() string SetFilterFunc func(device.PacketFilter) error AddressFunc func() wgaddr.Address GetWGDeviceFunc func() *wgdevice.Device GetDeviceFunc func() *device.FilteredDevice } +func (i *IFaceMock) Name() string { + if i.NameFunc == nil { + return "wgtest" + } + return i.NameFunc() +} + func (i *IFaceMock) GetWGDevice() *wgdevice.Device { if i.GetWGDeviceFunc == nil { return nil @@ -527,11 +535,16 @@ func TestProcessOutgoingHooks(t *testing.T) { d := &decoder{ decoded: []gopacket.LayerType{}, } - d.parser = gopacket.NewDecodingLayerParser( + d.parser4 = gopacket.NewDecodingLayerParser( layers.LayerTypeIPv4, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, ) - d.parser.IgnoreUnsupported = true + d.parser4.IgnoreUnsupported = true + d.parser6 = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv6, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser6.IgnoreUnsupported = true return d }, } @@ -630,11 +643,16 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { d := &decoder{ decoded: []gopacket.LayerType{}, } - d.parser = gopacket.NewDecodingLayerParser( + d.parser4 = gopacket.NewDecodingLayerParser( layers.LayerTypeIPv4, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, ) - d.parser.IgnoreUnsupported = true + d.parser4.IgnoreUnsupported = true + d.parser6 = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv6, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser6.IgnoreUnsupported = true return d }, } @@ -1040,8 +1058,8 @@ func TestMSSClamping(t *testing.T) { }() require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default") - expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize) - require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40") + require.Equal(t, uint16(1280-ipv4TCPHeaderMinSize), manager.mssClampValueIPv4, "IPv4 MSS clamp value should be MTU - 40") + require.Equal(t, uint16(1280-ipv6TCPHeaderMinSize), manager.mssClampValueIPv6, "IPv6 MSS clamp value should be MTU - 60") err = manager.UpdateLocalIPs() require.NoError(t, err) @@ -1059,7 +1077,7 @@ func TestMSSClamping(t *testing.T) { require.Len(t, d.tcp.Options, 1, "Should have MSS option") require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType)) actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) - require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40") + require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS should be clamped to MTU - 40") }) t.Run("SYN packet with low MSS unchanged", func(t *testing.T) { @@ -1083,7 +1101,7 @@ func TestMSSClamping(t *testing.T) { d := parsePacket(t, packet) require.Len(t, d.tcp.Options, 1, "Should have MSS option") actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) - require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped") + require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS in SYN-ACK should be clamped") }) t.Run("Non-SYN packet unchanged", func(t *testing.T) { @@ -1255,13 +1273,18 @@ func TestShouldForward(t *testing.T) { d := &decoder{ decoded: []gopacket.LayerType{}, } - d.parser = gopacket.NewDecodingLayerParser( + d.parser4 = gopacket.NewDecodingLayerParser( layers.LayerTypeIPv4, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, ) - d.parser.IgnoreUnsupported = true + d.parser4.IgnoreUnsupported = true + d.parser6 = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv6, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser6.IgnoreUnsupported = true - err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded) + err = d.decodePacket(buf.Bytes()) require.NoError(t, err) return d @@ -1321,6 +1344,44 @@ func TestShouldForward(t *testing.T) { }, } + // Add IPv6 to the interface and test dual-stack cases + wgIPv6 := netip.MustParseAddr("fd00::1") + otherIPv6 := netip.MustParseAddr("fd00::2") + ifaceMock.AddressFunc = func() wgaddr.Address { + return wgaddr.Address{ + IP: wgIP, + Network: netip.PrefixFrom(wgIP, 24), + IPv6: wgIPv6, + IPv6Net: netip.PrefixFrom(wgIPv6, 64), + } + } + + // Re-create manager to pick up the new address with IPv6 + require.NoError(t, manager.Close(nil)) + manager, err = Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) + require.NoError(t, err) + + v6Cases := []struct { + name string + dstIP netip.Addr + expected bool + description string + }{ + {"v6 traffic to other address", otherIPv6, true, "should forward v6 traffic not destined to our v6 address"}, + {"v6 traffic to our v6 IP", wgIPv6, false, "should not forward traffic destined to our v6 address"}, + {"v4 traffic to other with v6 configured", otherIP, true, "should forward v4 traffic when v6 configured"}, + {"v4 traffic to our v4 IP with v6 configured", wgIP, false, "should not forward traffic to our v4 address"}, + } + for _, tt := range v6Cases { + t.Run(tt.name, func(t *testing.T) { + manager.localForwarding = true + manager.netstack = false + decoder := createTCPDecoder(8080) + result := manager.shouldForward(decoder, tt.dstIP) + require.Equal(t, tt.expected, result, tt.description) + }) + } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Configure manager diff --git a/client/firewall/uspfilter/forwarder/endpoint.go b/client/firewall/uspfilter/forwarder/endpoint.go index 692a24140..fab776f2a 100644 --- a/client/firewall/uspfilter/forwarder/endpoint.go +++ b/client/firewall/uspfilter/forwarder/endpoint.go @@ -1,7 +1,8 @@ package forwarder import ( - "fmt" + "net" + "strconv" "sync/atomic" wgdevice "golang.zx2c4.com/wireguard/device" @@ -12,12 +13,19 @@ import ( nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" ) +// PacketCapture captures raw packets for debugging. Implementations must be +// safe for concurrent use and must not block. +type PacketCapture interface { + Offer(data []byte, outbound bool) +} + // endpoint implements stack.LinkEndpoint and handles integration with the wireguard device type endpoint struct { logger *nblog.Logger dispatcher stack.NetworkDispatcher device *wgdevice.Device mtu atomic.Uint32 + capture atomic.Pointer[PacketCapture] } func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { @@ -47,20 +55,31 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { var written int for _, pkt := range pkts.AsSlice() { - netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice()) - data := stack.PayloadSince(pkt.NetworkHeader()) if data == nil { continue } - // Send the packet through WireGuard - address := netHeader.DestinationAddress() - err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice()) - if err != nil { + raw := pkt.NetworkHeader().View().AsSlice() + if len(raw) == 0 { + continue + } + var address tcpip.Address + if raw[0]>>4 == 6 { + address = header.IPv6(raw).DestinationAddress() + } else { + address = header.IPv4(raw).DestinationAddress() + } + + pktBytes := data.AsSlice() + if err := e.device.CreateOutboundPacket(pktBytes, address.AsSlice()); err != nil { e.logger.Error1("CreateOutboundPacket: %v", err) continue } + + if pc := e.capture.Load(); pc != nil { + (*pc).Offer(pktBytes, true) + } written++ } @@ -103,5 +122,7 @@ type epID stack.TransportEndpointID func (i epID) String() string { // src and remote is swapped - return fmt.Sprintf("%s:%d → %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort) + return net.JoinHostPort(i.RemoteAddress.String(), strconv.Itoa(int(i.RemotePort))) + + " → " + + net.JoinHostPort(i.LocalAddress.String(), strconv.Itoa(int(i.LocalPort))) } diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index d17c3cd5c..6291eb285 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -14,6 +14,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -36,25 +37,31 @@ type Forwarder struct { logger *nblog.Logger flowLogger nftypes.FlowLogger // ruleIdMap is used to store the rule ID for a given connection - ruleIdMap sync.Map - stack *stack.Stack - endpoint *endpoint - udpForwarder *udpForwarder - ctx context.Context - cancel context.CancelFunc - ip tcpip.Address - netstack bool - hasRawICMPAccess bool - pingSemaphore chan struct{} + ruleIdMap sync.Map + stack *stack.Stack + endpoint *endpoint + udpForwarder *udpForwarder + ctx context.Context + cancel context.CancelFunc + ip tcpip.Address + ipv6 tcpip.Address + netstack bool + hasRawICMPAccess bool + hasRawICMPv6Access bool + pingSemaphore chan struct{} } func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, + ipv6.NewProtocol, + }, TransportProtocols: []stack.TransportProtocolFactory{ tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, + icmp.NewProtocol6, }, HandleLocal: false, }) @@ -73,7 +80,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow protoAddr := tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()), + Address: tcpip.AddrFrom4(iface.Address().IP.As4()), PrefixLen: iface.Address().Network.Bits(), }, } @@ -82,6 +89,19 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow return nil, fmt.Errorf("failed to add protocol address: %s", err) } + if v6 := iface.Address().IPv6; v6.IsValid() { + v6Addr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.AddrFrom16(v6.As16()), + PrefixLen: iface.Address().IPv6Net.Bits(), + }, + } + if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil { + return nil, fmt.Errorf("add IPv6 protocol address: %s", err) + } + } + defaultSubnet, err := tcpip.NewSubnet( tcpip.AddrFrom4([4]byte{0, 0, 0, 0}), tcpip.MaskFromBytes([]byte{0, 0, 0, 0}), @@ -90,6 +110,14 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow return nil, fmt.Errorf("creating default subnet: %w", err) } + defaultSubnetV6, err := tcpip.NewSubnet( + tcpip.AddrFrom16([16]byte{}), + tcpip.MaskFromBytes(make([]byte, 16)), + ) + if err != nil { + return nil, fmt.Errorf("creating default v6 subnet: %w", err) + } + if err := s.SetPromiscuousMode(nicID, true); err != nil { return nil, fmt.Errorf("set promiscuous mode: %s", err) } @@ -98,10 +126,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow } s.SetRouteTable([]tcpip.Route{ - { - Destination: defaultSubnet, - NIC: nicID, - }, + {Destination: defaultSubnet, NIC: nicID}, + {Destination: defaultSubnetV6, NIC: nicID}, }) ctx, cancel := context.WithCancel(context.Background()) @@ -114,7 +140,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow ctx: ctx, cancel: cancel, netstack: netstack, - ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()), + ip: tcpip.AddrFrom4(iface.Address().IP.As4()), + ipv6: addrFromNetipAddr(iface.Address().IPv6), pingSemaphore: make(chan struct{}, 3), } @@ -131,7 +158,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow udpForwarder := udp.NewForwarder(s, f.handleUDP) s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) - s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP) + // ICMP is handled directly in InjectIncomingPacket, bypassing gVisor's + // network layer. This avoids duplicate echo replies (v4) and the v6 + // auto-reply bug where gVisor responds at the network layer before + // our transport handler fires. f.checkICMPCapability() @@ -139,9 +169,41 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow return f, nil } +// SetCapture sets or clears the packet capture on the forwarder endpoint. +// This captures outbound packets that bypass the FilteredDevice (netstack forwarding). +func (f *Forwarder) SetCapture(pc PacketCapture) { + if pc == nil { + f.endpoint.capture.Store(nil) + return + } + f.endpoint.capture.Store(&pc) +} + func (f *Forwarder) InjectIncomingPacket(payload []byte) error { - if len(payload) < header.IPv4MinimumSize { - return fmt.Errorf("packet too small: %d bytes", len(payload)) + if len(payload) == 0 { + return fmt.Errorf("empty packet") + } + + var protoNum tcpip.NetworkProtocolNumber + switch payload[0] >> 4 { + case 4: + if len(payload) < header.IPv4MinimumSize { + return fmt.Errorf("IPv4 packet too small: %d bytes", len(payload)) + } + if f.handleICMPDirect(payload) { + return nil + } + protoNum = ipv4.ProtocolNumber + case 6: + if len(payload) < header.IPv6MinimumSize { + return fmt.Errorf("IPv6 packet too small: %d bytes", len(payload)) + } + if f.handleICMPDirect(payload) { + return nil + } + protoNum = ipv6.ProtocolNumber + default: + return fmt.Errorf("unknown IP version: %d", payload[0]>>4) } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -150,11 +212,160 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error { defer pkt.DecRef() if f.endpoint.dispatcher != nil { - f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt) + f.endpoint.dispatcher.DeliverNetworkPacket(protoNum, pkt) } return nil } +// handleICMPDirect intercepts ICMP packets from raw IP payloads before they +// enter gVisor. It synthesizes the TransportEndpointID and PacketBuffer that +// the existing handlers expect, then dispatches to handleICMP/handleICMPv6. +// This bypasses gVisor's network layer which causes duplicate v4 echo replies +// and auto-replies to all v6 echo requests in promiscuous mode. +// +// Unlike gVisor's network layer, this does not validate ICMP checksums or +// reassemble IP fragments. Fragmented ICMP packets fall through to gVisor. +func parseICMPv4(payload []byte) (ipHdrLen, icmpLen int, src, dst tcpip.Address, ok bool) { + if len(payload) < header.IPv4MinimumSize { + return 0, 0, src, dst, false + } + ip := header.IPv4(payload) + if ip.Protocol() != uint8(header.ICMPv4ProtocolNumber) { + return 0, 0, src, dst, false + } + if ip.FragmentOffset() != 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 { + return 0, 0, src, dst, false + } + ipHdrLen = int(ip.HeaderLength()) + totalLen := int(ip.TotalLength()) + if ipHdrLen < header.IPv4MinimumSize || ipHdrLen > totalLen || totalLen > len(payload) { + return 0, 0, src, dst, false + } + icmpLen = totalLen - ipHdrLen + if icmpLen < header.ICMPv4MinimumSize { + return 0, 0, src, dst, false + } + return ipHdrLen, icmpLen, ip.SourceAddress(), ip.DestinationAddress(), true +} + +func parseICMPv6(payload []byte) (ipHdrLen, icmpLen int, src, dst tcpip.Address, ok bool) { + if len(payload) < header.IPv6MinimumSize { + return 0, 0, src, dst, false + } + ip := header.IPv6(payload) + declaredLen := int(ip.PayloadLength()) + hdrEnd := header.IPv6MinimumSize + declaredLen + if hdrEnd > len(payload) { + return 0, 0, src, dst, false + } + icmpStart, ok := skipIPv6ExtensionsToICMPv6(payload, ip.NextHeader(), hdrEnd) + if !ok { + return 0, 0, src, dst, false + } + icmpLen = hdrEnd - icmpStart + if icmpLen < header.ICMPv6MinimumSize { + return 0, 0, src, dst, false + } + return icmpStart, icmpLen, ip.SourceAddress(), ip.DestinationAddress(), true +} + +// skipIPv6ExtensionsToICMPv6 walks the IPv6 extension-header chain starting +// after the fixed header. It advances past Hop-by-Hop, Routing, and +// Destination Options headers (which share the NextHeader+ExtLen+6+ExtLen*8 +// layout) and returns the offset of the ICMPv6 payload. Fragment, ESP, AH, +// and unknown identifiers are reported as not handleable so the caller can +// defer to gVisor. +func skipIPv6ExtensionsToICMPv6(payload []byte, next uint8, hdrEnd int) (int, bool) { + off := header.IPv6MinimumSize + for { + if next == uint8(header.ICMPv6ProtocolNumber) { + return off, true + } + if !isWalkableIPv6ExtHdr(next) { + return 0, false + } + newOff, newNext, ok := advanceIPv6ExtHdr(payload, off, hdrEnd) + if !ok { + return 0, false + } + off = newOff + next = newNext + } +} + +func isWalkableIPv6ExtHdr(id uint8) bool { + switch id { + case uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), + uint8(header.IPv6RoutingExtHdrIdentifier), + uint8(header.IPv6DestinationOptionsExtHdrIdentifier): + return true + } + return false +} + +func advanceIPv6ExtHdr(payload []byte, off, hdrEnd int) (int, uint8, bool) { + if off+8 > hdrEnd { + return 0, 0, false + } + extLen := (int(payload[off+1]) + 1) * 8 + if off+extLen > hdrEnd { + return 0, 0, false + } + return off + extLen, payload[off], true +} + +func (f *Forwarder) handleICMPDirect(payload []byte) bool { + if len(payload) == 0 { + return false + } + var ( + ipHdrLen int + icmpLen int + srcAddr tcpip.Address + dstAddr tcpip.Address + ok bool + ) + version := payload[0] >> 4 + switch version { + case 4: + ipHdrLen, icmpLen, srcAddr, dstAddr, ok = parseICMPv4(payload) + case 6: + ipHdrLen, icmpLen, srcAddr, dstAddr, ok = parseICMPv6(payload) + } + if !ok { + return false + } + + // Let gVisor handle ICMP destined for our own addresses natively. + // Its network-layer auto-reply is correct and efficient for local traffic. + if f.ip.Equal(dstAddr) || f.ipv6.Equal(dstAddr) { + return false + } + + id := stack.TransportEndpointID{ + LocalAddress: dstAddr, + RemoteAddress: srcAddr, + } + + // Trim the buffer to the IP-declared length so gVisor doesn't see padding. + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(payload[:ipHdrLen+icmpLen]), + }) + defer pkt.DecRef() + + if _, ok := pkt.NetworkHeader().Consume(ipHdrLen); !ok { + return false + } + if _, ok := pkt.TransportHeader().Consume(icmpLen); !ok { + return false + } + + if version == 6 { + return f.handleICMPv6(id, pkt) + } + return f.handleICMP(id, pkt) +} + // Stop gracefully shuts down the forwarder func (f *Forwarder) Stop() { f.cancel() @@ -167,11 +378,14 @@ func (f *Forwarder) Stop() { f.stack.Wait() } -func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP { +func (f *Forwarder) determineDialAddr(addr tcpip.Address) netip.Addr { if f.netstack && f.ip.Equal(addr) { - return net.IPv4(127, 0, 0, 1) + return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } - return addr.AsSlice() + if f.netstack && f.ipv6.Equal(addr) { + return netip.IPv6Loopback() + } + return addrToNetipAddr(addr) } func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) { @@ -205,23 +419,50 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe } } +// addrFromNetipAddr converts a netip.Addr to a gvisor tcpip.Address without allocating. +func addrFromNetipAddr(addr netip.Addr) tcpip.Address { + if !addr.IsValid() { + return tcpip.Address{} + } + if addr.Is4() { + return tcpip.AddrFrom4(addr.As4()) + } + return tcpip.AddrFrom16(addr.As16()) +} + +// addrToNetipAddr converts a gvisor tcpip.Address to netip.Addr without allocating. +func addrToNetipAddr(addr tcpip.Address) netip.Addr { + switch addr.Len() { + case 4: + return netip.AddrFrom4(addr.As4()) + case 16: + return netip.AddrFrom16(addr.As16()) + default: + return netip.Addr{} + } +} + // checkICMPCapability tests whether we have raw ICMP socket access at startup. func (f *Forwarder) checkICMPCapability() { + f.hasRawICMPAccess = probeRawICMP("ip4:icmp", "0.0.0.0", f.logger) + f.hasRawICMPv6Access = probeRawICMP("ip6:ipv6-icmp", "::", f.logger) +} + +func probeRawICMP(network, addr string, logger *nblog.Logger) bool { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() lc := net.ListenConfig{} - conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") + conn, err := lc.ListenPacket(ctx, network, addr) if err != nil { - f.hasRawICMPAccess = false - f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback") - return + logger.Debug1("forwarder: no raw %s socket access, will use ping binary fallback", network) + return false } if err := conn.Close(); err != nil { - f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err) + logger.Debug2("forwarder: failed to close %s capability test socket: %v", network, err) } - f.hasRawICMPAccess = true - f.logger.Debug("forwarder: Raw ICMP socket access available") + logger.Debug1("forwarder: raw %s socket access available", network) + return true } diff --git a/client/firewall/uspfilter/forwarder/forwarder_test.go b/client/firewall/uspfilter/forwarder/forwarder_test.go new file mode 100644 index 000000000..ad74e8493 --- /dev/null +++ b/client/firewall/uspfilter/forwarder/forwarder_test.go @@ -0,0 +1,162 @@ +package forwarder + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +const echoRequestSize = 8 + +func makeIPv6(t *testing.T, src, dst netip.Addr, nextHdr uint8, payload []byte) []byte { + t.Helper() + buf := make([]byte, header.IPv6MinimumSize+len(payload)) + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(payload)), + TransportProtocol: 0, // overwritten below to allow any value + HopLimit: 64, + SrcAddr: tcpipAddrFromNetip(src), + DstAddr: tcpipAddrFromNetip(dst), + }) + buf[6] = nextHdr + copy(buf[header.IPv6MinimumSize:], payload) + return buf +} + +func tcpipAddrFromNetip(a netip.Addr) tcpip.Address { + b := a.As16() + return tcpip.AddrFrom16(b) +} + +func echoRequest() []byte { + icmp := make([]byte, echoRequestSize) + icmp[0] = uint8(header.ICMPv6EchoRequest) + return icmp +} + +// extHdr builds a generic IPv6 extension header (HBH/Routing/DestOpts) of the +// given total octet length (must be multiple of 8, >= 8) with the given next +// header. +func extHdr(t *testing.T, next uint8, totalLen int) []byte { + t.Helper() + require.GreaterOrEqual(t, totalLen, 8) + require.Equal(t, 0, totalLen%8) + buf := make([]byte, totalLen) + buf[0] = next + buf[1] = uint8(totalLen/8 - 1) + return buf +} + +func TestParseICMPv6_NoExtensions(t *testing.T) { + src := netip.MustParseAddr("fd00::1") + dst := netip.MustParseAddr("fd00::2") + pkt := makeIPv6(t, src, dst, uint8(header.ICMPv6ProtocolNumber), echoRequest()) + + off, icmpLen, _, _, ok := parseICMPv6(pkt) + require.True(t, ok) + assert.Equal(t, header.IPv6MinimumSize, off) + assert.Equal(t, echoRequestSize, icmpLen) +} + +func TestParseICMPv6_SingleExtension(t *testing.T) { + src := netip.MustParseAddr("fd00::1") + dst := netip.MustParseAddr("fd00::2") + hbh := extHdr(t, uint8(header.ICMPv6ProtocolNumber), 8) + payload := append([]byte{}, hbh...) + payload = append(payload, echoRequest()...) + pkt := makeIPv6(t, src, dst, uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), payload) + + off, icmpLen, _, _, ok := parseICMPv6(pkt) + require.True(t, ok) + assert.Equal(t, header.IPv6MinimumSize+8, off) + assert.Equal(t, echoRequestSize, icmpLen) +} + +func TestParseICMPv6_ChainedExtensions(t *testing.T) { + src := netip.MustParseAddr("fd00::1") + dst := netip.MustParseAddr("fd00::2") + dest := extHdr(t, uint8(header.ICMPv6ProtocolNumber), 16) + rt := extHdr(t, uint8(header.IPv6DestinationOptionsExtHdrIdentifier), 8) + hbh := extHdr(t, uint8(header.IPv6RoutingExtHdrIdentifier), 8) + payload := append(append(append([]byte{}, hbh...), rt...), dest...) + payload = append(payload, echoRequest()...) + pkt := makeIPv6(t, src, dst, uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), payload) + + off, icmpLen, _, _, ok := parseICMPv6(pkt) + require.True(t, ok) + assert.Equal(t, header.IPv6MinimumSize+8+8+16, off) + assert.Equal(t, echoRequestSize, icmpLen) +} + +func TestParseICMPv6_FragmentDefersToGVisor(t *testing.T) { + src := netip.MustParseAddr("fd00::1") + dst := netip.MustParseAddr("fd00::2") + pkt := makeIPv6(t, src, dst, uint8(header.IPv6FragmentExtHdrIdentifier), make([]byte, 8)) + + _, _, _, _, ok := parseICMPv6(pkt) + assert.False(t, ok) +} + +func TestParseICMPv6_TruncatedExtension(t *testing.T) { + src := netip.MustParseAddr("fd00::1") + dst := netip.MustParseAddr("fd00::2") + // Extension claims 16 bytes but only 8 remain after the IP header. + hbh := []byte{uint8(header.ICMPv6ProtocolNumber), 1, 0, 0, 0, 0, 0, 0} + pkt := makeIPv6(t, src, dst, uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), hbh) + + _, _, _, _, ok := parseICMPv6(pkt) + assert.False(t, ok) +} + +func TestParseICMPv6_TruncatedICMPPayload(t *testing.T) { + src := netip.MustParseAddr("fd00::1") + dst := netip.MustParseAddr("fd00::2") + // PayloadLength claims 8 bytes of ICMPv6 but the buffer only holds 4. + pkt := makeIPv6(t, src, dst, uint8(header.ICMPv6ProtocolNumber), make([]byte, 8)) + pkt = pkt[:header.IPv6MinimumSize+4] + + _, _, _, _, ok := parseICMPv6(pkt) + assert.False(t, ok) +} + +func TestParseICMPv4_RejectsShortIHL(t *testing.T) { + pkt := make([]byte, 28) + pkt[0] = 0x44 // version 4, IHL 4 (16 bytes - below minimum) + pkt[9] = uint8(header.ICMPv4ProtocolNumber) + header.IPv4(pkt).SetTotalLength(28) + + _, _, _, _, ok := parseICMPv4(pkt) + assert.False(t, ok) +} + +func TestParseICMPv4_RejectsTotalLenOverBuffer(t *testing.T) { + pkt := make([]byte, header.IPv4MinimumSize+header.ICMPv4MinimumSize) + ip := header.IPv4(pkt) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(len(pkt) + 16), + Protocol: uint8(header.ICMPv4ProtocolNumber), + TTL: 64, + }) + + _, _, _, _, ok := parseICMPv4(pkt) + assert.False(t, ok) +} + +func TestParseICMPv4_RejectsFragment(t *testing.T) { + pkt := make([]byte, header.IPv4MinimumSize+header.ICMPv4MinimumSize) + ip := header.IPv4(pkt) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(len(pkt)), + Protocol: uint8(header.ICMPv4ProtocolNumber), + TTL: 64, + Flags: header.IPv4FlagMoreFragments, + }) + + _, _, _, _, ok := parseICMPv4(pkt) + assert.False(t, ok) +} diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go index 5aa280d43..d6d4e705e 100644 --- a/client/firewall/uspfilter/forwarder/icmp.go +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -36,7 +36,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBu } icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice() - conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond) + conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), false, 100*time.Millisecond) if err != nil { f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err) return true @@ -59,7 +59,7 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI defer func() { <-f.pingSemaphore }() if f.hasRawICMPAccess { - f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes) + f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes, false) } else { f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes) } @@ -73,18 +73,23 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI // forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection. // The caller is responsible for closing the returned connection. -func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) { +func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, v6 bool, timeout time.Duration) (net.PacketConn, error) { ctx, cancel := context.WithTimeout(f.ctx, timeout) defer cancel() + network, listenAddr := "ip4:icmp", "0.0.0.0" + if v6 { + network, listenAddr = "ip6:ipv6-icmp", "::" + } + lc := net.ListenConfig{} - conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") + conn, err := lc.ListenPacket(ctx, network, listenAddr) if err != nil { return nil, fmt.Errorf("create ICMP socket: %w", err) } dstIP := f.determineDialAddr(id.LocalAddress) - dst := &net.IPAddr{IP: dstIP} + dst := &net.IPAddr{IP: dstIP.AsSlice()} if _, err = conn.WriteTo(payload, dst); err != nil { if closeErr := conn.Close(); closeErr != nil { @@ -101,11 +106,11 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by return conn, nil } -// handleICMPViaSocket handles ICMP echo requests using raw sockets. -func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) { +// handleICMPViaSocket handles ICMP echo requests using raw sockets for both v4 and v6. +func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int, v6 bool) { sendTime := time.Now() - conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second) + conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, v6, 5*time.Second) if err != nil { f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err) return @@ -116,18 +121,22 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp } }() - txBytes := f.handleEchoResponse(conn, id) + txBytes := f.handleEchoResponse(conn, id, v6) rtt := time.Since(sendTime).Round(10 * time.Microsecond) if f.logger.Enabled(nblog.LevelTrace) { - f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)", - epID(id), icmpType, icmpCode, rtt) + proto := "ICMP" + if v6 { + proto = "ICMPv6" + } + f.logger.Trace5("forwarder: Forwarded %s echo reply %v type %v code %v (rtt=%v, raw socket)", + proto, epID(id), icmpType, icmpCode, rtt) } f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) } -func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) int { +func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID, v6 bool) int { if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err) return 0 @@ -142,6 +151,19 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn return 0 } + if v6 { + // Recompute checksum: the raw socket response has a checksum computed + // over the real endpoint addresses, but we inject with overlay addresses. + icmpHdr := header.ICMPv6(response[:n]) + icmpHdr.SetChecksum(0) + icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpHdr, + Src: id.LocalAddress, + Dst: id.RemoteAddress, + })) + return f.injectICMPv6Reply(id, response[:n]) + } + return f.injectICMPReply(id, response[:n]) } @@ -155,19 +177,23 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T txPackets = 1 } - srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) - dstIp := netip.AddrFrom4(id.LocalAddress.As4()) + srcIp := addrToNetipAddr(id.RemoteAddress) + dstIp := addrToNetipAddr(id.LocalAddress) + + proto := nftypes.ICMP + if srcIp.Is6() { + proto = nftypes.ICMPv6 + } fields := nftypes.EventFields{ FlowID: flowID, Type: typ, Direction: nftypes.Ingress, - Protocol: nftypes.ICMP, - // TODO: handle ipv6 - SourceIP: srcIp, - DestIP: dstIp, - ICMPType: icmpType, - ICMPCode: icmpCode, + Protocol: proto, + SourceIP: srcIp, + DestIP: dstIp, + ICMPType: icmpType, + ICMPCode: icmpCode, RxBytes: rxBytes, TxBytes: txBytes, @@ -218,26 +244,164 @@ func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpoi f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) } +// handleICMPv6 handles ICMPv6 packets from the network stack. +func (f *Forwarder) handleICMPv6(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + icmpHdr := header.ICMPv6(pkt.TransportHeader().View().AsSlice()) + + flowID := uuid.New() + f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0) + + if icmpHdr.Type() == header.ICMPv6EchoRequest { + return f.handleICMPv6Echo(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code())) + } + + // For non-echo types (Destination Unreachable, Packet Too Big, etc), forward without waiting + if !f.hasRawICMPv6Access { + f.logger.Debug2("forwarder: Cannot handle ICMPv6 type %v without raw socket access for %v", icmpHdr.Type(), epID(id)) + return false + } + + icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice() + conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), true, 100*time.Millisecond) + if err != nil { + f.logger.Error2("forwarder: Failed to forward ICMPv6 packet for %v: %v", epID(id), err) + return true + } + if err := conn.Close(); err != nil { + f.logger.Debug1("forwarder: Failed to close ICMPv6 socket: %v", err) + } + + return true +} + +// handleICMPv6Echo handles ICMPv6 echo requests via raw socket or ping binary fallback. +func (f *Forwarder) handleICMPv6Echo(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool { + select { + case f.pingSemaphore <- struct{}{}: + icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice() + rxBytes := pkt.Size() + + go func() { + defer func() { <-f.pingSemaphore }() + + if f.hasRawICMPv6Access { + f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes, true) + } else { + f.handleICMPv6ViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes) + } + }() + default: + f.logger.Debug3("forwarder: ICMPv6 rate limit exceeded for %v type %v code %v", epID(id), icmpType, icmpCode) + } + return true +} + +// handleICMPv6ViaPing uses the system ping6 binary for ICMPv6 echo. +func (f *Forwarder) handleICMPv6ViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) { + ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) + defer cancel() + + dstIP := f.determineDialAddr(id.LocalAddress) + cmd := buildPingCommand(ctx, dstIP, 5*time.Second) + + pingStart := time.Now() + if err := cmd.Run(); err != nil { + f.logger.Warn4("forwarder: Ping6 failed for %v type %v code %v: %v", epID(id), icmpType, icmpCode, err) + return + } + rtt := time.Since(pingStart).Round(10 * time.Microsecond) + + f.logger.Trace3("forwarder: Forwarded ICMPv6 echo request %v type %v code %v", + epID(id), icmpType, icmpCode) + + txBytes := f.synthesizeICMPv6EchoReply(id, icmpData) + + f.logger.Trace4("forwarder: Forwarded ICMPv6 echo reply %v type %v code %v (rtt=%v, ping binary)", + epID(id), icmpType, icmpCode, rtt) + + f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) +} + +// synthesizeICMPv6EchoReply creates an ICMPv6 echo reply and injects it back. +func (f *Forwarder) synthesizeICMPv6EchoReply(id stack.TransportEndpointID, icmpData []byte) int { + replyICMP := make([]byte, len(icmpData)) + copy(replyICMP, icmpData) + + replyHdr := header.ICMPv6(replyICMP) + replyHdr.SetType(header.ICMPv6EchoReply) + replyHdr.SetChecksum(0) + // ICMPv6Checksum computes the pseudo-header internally from Src/Dst. + // Header contains the full ICMP message, so PayloadCsum/PayloadLen are zero. + replyHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: replyHdr, + Src: id.LocalAddress, + Dst: id.RemoteAddress, + })) + + return f.injectICMPv6Reply(id, replyICMP) +} + +// injectICMPv6Reply wraps an ICMPv6 payload in an IPv6 header and sends to the peer. +func (f *Forwarder) injectICMPv6Reply(id stack.TransportEndpointID, icmpPayload []byte) int { + ipHdr := make([]byte, header.IPv6MinimumSize) + ip := header.IPv6(ipHdr) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(icmpPayload)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 64, + SrcAddr: id.LocalAddress, + DstAddr: id.RemoteAddress, + }) + + fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload)) + fullPacket = append(fullPacket, ipHdr...) + fullPacket = append(fullPacket, icmpPayload...) + + if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil { + f.logger.Error1("forwarder: Failed to send ICMPv6 reply to peer: %v", err) + return 0 + } + + return len(fullPacket) +} + +const ( + pingBin = "ping" + ping6Bin = "ping6" +) + // buildPingCommand creates a platform-specific ping command. -func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd { +// Most platforms auto-detect IPv6 from raw addresses. macOS/iOS/OpenBSD require ping6. +func buildPingCommand(ctx context.Context, target netip.Addr, timeout time.Duration) *exec.Cmd { timeoutSec := int(timeout.Seconds()) if timeoutSec < 1 { timeoutSec = 1 } + isV6 := target.Is6() + timeoutStr := fmt.Sprintf("%d", timeoutSec) + switch runtime.GOOS { case "linux", "android": - return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String()) + return exec.CommandContext(ctx, pingBin, "-c", "1", "-W", timeoutStr, "-q", target.String()) case "darwin", "ios": - return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String()) + bin := pingBin + if isV6 { + bin = ping6Bin + } + return exec.CommandContext(ctx, bin, "-c", "1", "-t", timeoutStr, "-q", target.String()) case "freebsd": - return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String()) + return exec.CommandContext(ctx, pingBin, "-c", "1", "-t", timeoutStr, target.String()) case "openbsd", "netbsd": - return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String()) + bin := pingBin + if isV6 { + bin = ping6Bin + } + return exec.CommandContext(ctx, bin, "-c", "1", "-w", timeoutStr, target.String()) case "windows": - return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String()) + return exec.CommandContext(ctx, pingBin, "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String()) default: - return exec.CommandContext(ctx, "ping", "-c", "1", target.String()) + return exec.CommandContext(ctx, pingBin, "-c", "1", target.String()) } } @@ -279,5 +443,9 @@ func (f *Forwarder) injectICMPReply(id stack.TransportEndpointID, icmpPayload [] return 0 } + if pc := f.endpoint.capture.Load(); pc != nil { + (*pc).Offer(fullPacket, true) + } + return len(fullPacket) } diff --git a/client/firewall/uspfilter/forwarder/tcp.go b/client/firewall/uspfilter/forwarder/tcp.go index 8e95522fe..c65ebcde0 100644 --- a/client/firewall/uspfilter/forwarder/tcp.go +++ b/client/firewall/uspfilter/forwarder/tcp.go @@ -1,9 +1,8 @@ package forwarder import ( - "fmt" "net" - "net/netip" + "strconv" "github.com/google/uuid" @@ -32,7 +31,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { } }() - dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) + dialAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort))) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr) if err != nil { @@ -94,15 +93,14 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn } func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) { - srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) - dstIp := netip.AddrFrom4(id.LocalAddress.As4()) + srcIp := addrToNetipAddr(id.RemoteAddress) + dstIp := addrToNetipAddr(id.LocalAddress) fields := nftypes.EventFields{ - FlowID: flowID, - Type: typ, - Direction: nftypes.Ingress, - Protocol: nftypes.TCP, - // TODO: handle ipv6 + FlowID: flowID, + Type: typ, + Direction: nftypes.Ingress, + Protocol: nftypes.TCP, SourceIP: srcIp, DestIP: dstIp, SourcePort: id.RemotePort, diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index 778a40842..d840ef06b 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -6,7 +6,7 @@ import ( "fmt" "io" "net" - "net/netip" + "strconv" "sync" "sync/atomic" "time" @@ -162,7 +162,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool { } }() - dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) + dstAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort))) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) if err != nil { f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err) @@ -284,15 +284,14 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack // sendUDPEvent stores flow events for UDP connections func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) { - srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) - dstIp := netip.AddrFrom4(id.LocalAddress.As4()) + srcIp := addrToNetipAddr(id.RemoteAddress) + dstIp := addrToNetipAddr(id.LocalAddress) fields := nftypes.EventFields{ - FlowID: flowID, - Type: typ, - Direction: nftypes.Ingress, - Protocol: nftypes.UDP, - // TODO: handle ipv6 + FlowID: flowID, + Type: typ, + Direction: nftypes.Ingress, + Protocol: nftypes.UDP, SourceIP: srcIp, DestIP: dstIp, SourcePort: id.RemotePort, diff --git a/client/firewall/uspfilter/hooks_filter.go b/client/firewall/uspfilter/hooks_filter.go index 8d3cc0f5c..f3adf5f8b 100644 --- a/client/firewall/uspfilter/hooks_filter.go +++ b/client/firewall/uspfilter/hooks_filter.go @@ -13,7 +13,6 @@ const ( ipv4HeaderMinLen = 20 ipv4ProtoOffset = 9 ipv4FlagsOffset = 6 - ipv4DstOffset = 16 ipProtoUDP = 17 ipProtoTCP = 6 ipv4FragOffMask = 0x1fff diff --git a/client/firewall/uspfilter/localip.go b/client/firewall/uspfilter/localip.go index f63fe3e45..b35be56c6 100644 --- a/client/firewall/uspfilter/localip.go +++ b/client/firewall/uspfilter/localip.go @@ -4,89 +4,32 @@ import ( "fmt" "net" "net/netip" - "sync" + "sync/atomic" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/firewall/uspfilter/common" ) -type localIPManager struct { - mu sync.RWMutex - - // fixed-size high array for upper byte of a IPv4 address - ipv4Bitmap [256]*ipv4LowBitmap +// localIPSnapshot is an immutable snapshot of local IP addresses, swapped +// atomically so reads are lock-free. +type localIPSnapshot struct { + ips map[netip.Addr]struct{} } -// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address -type ipv4LowBitmap struct { - bitmap [8192]uint32 +type localIPManager struct { + snapshot atomic.Pointer[localIPSnapshot] } func newLocalIPManager() *localIPManager { - return &localIPManager{} + m := &localIPManager{} + m.snapshot.Store(&localIPSnapshot{ + ips: make(map[netip.Addr]struct{}), + }) + return m } -func (m *localIPManager) setBitmapBit(ip net.IP) { - ipv4 := ip.To4() - if ipv4 == nil { - return - } - high := uint16(ipv4[0]) - low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3]) - - index := low / 32 - bit := low % 32 - - if m.ipv4Bitmap[high] == nil { - m.ipv4Bitmap[high] = &ipv4LowBitmap{} - } - - m.ipv4Bitmap[high].bitmap[index] |= 1 << bit -} - -func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) { - if !ip.Is4() { - return - } - ipv4 := ip.AsSlice() - - high := uint16(ipv4[0]) - low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3]) - - if bitmap[high] == nil { - bitmap[high] = &ipv4LowBitmap{} - } - - index := low / 32 - bit := low % 32 - bitmap[high].bitmap[index] |= 1 << bit - - if _, exists := ipv4Set[ip]; !exists { - ipv4Set[ip] = struct{}{} - *ipv4Addresses = append(*ipv4Addresses, ip) - } -} - -func (m *localIPManager) checkBitmapBit(ip []byte) bool { - high := uint16(ip[0]) - low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3]) - - if m.ipv4Bitmap[high] == nil { - return false - } - - index := low / 32 - bit := low % 32 - return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0 -} - -func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error { - m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses) - return nil -} - -func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) { +func processInterface(iface net.Interface, ips map[netip.Addr]struct{}, addresses *[]netip.Addr) { addrs, err := iface.Addrs() if err != nil { log.Debugf("get addresses for interface %s failed: %v", iface.Name, err) @@ -104,18 +47,19 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv continue } - addr, ok := netip.AddrFromSlice(ip) + parsed, ok := netip.AddrFromSlice(ip) if !ok { log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name) continue } - if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil { - log.Debugf("process IP failed: %v", err) - } + parsed = parsed.Unmap() + ips[parsed] = struct{}{} + *addresses = append(*addresses, parsed) } } +// UpdateLocalIPs rebuilds the local IP snapshot and swaps it in atomically. func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) { defer func() { if r := recover(); r != nil { @@ -123,20 +67,20 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) { } }() - var newIPv4Bitmap [256]*ipv4LowBitmap - ipv4Set := make(map[netip.Addr]struct{}) - var ipv4Addresses []netip.Addr + ips := make(map[netip.Addr]struct{}) + var addresses []netip.Addr - // 127.0.0.0/8 - newIPv4Bitmap[127] = &ipv4LowBitmap{} - for i := 0; i < 8192; i++ { - // #nosec G602 -- bitmap is defined as [8192]uint32, loop range is correct - newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF - } + // loopback + ips[netip.AddrFrom4([4]byte{127, 0, 0, 1})] = struct{}{} + ips[netip.IPv6Loopback()] = struct{}{} if iface != nil { - if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil { - return err + ip := iface.Address().IP + ips[ip] = struct{}{} + addresses = append(addresses, ip) + if v6 := iface.Address().IPv6; v6.IsValid() { + ips[v6] = struct{}{} + addresses = append(addresses, v6) } } @@ -147,25 +91,24 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) { // TODO: filter out down interfaces (net.FlagUp). Also handle the reverse // case where an interface comes up between refreshes. for _, intf := range interfaces { - m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses) + processInterface(intf, ips, &addresses) } } - m.mu.Lock() - m.ipv4Bitmap = newIPv4Bitmap - m.mu.Unlock() + m.snapshot.Store(&localIPSnapshot{ips: ips}) - log.Debugf("Local IPv4 addresses: %v", ipv4Addresses) + log.Debugf("Local IP addresses: %v", addresses) return nil } +// IsLocalIP checks if the given IP is a local address. Lock-free on the read path. func (m *localIPManager) IsLocalIP(ip netip.Addr) bool { - if !ip.Is4() { - return false + s := m.snapshot.Load() + + if ip.Is4() && ip.As4()[0] == 127 { + return true } - m.mu.RLock() - defer m.mu.RUnlock() - - return m.checkBitmapBit(ip.AsSlice()) + _, found := s.ips[ip] + return found } diff --git a/client/firewall/uspfilter/localip_bench_test.go b/client/firewall/uspfilter/localip_bench_test.go new file mode 100644 index 000000000..14e12bd08 --- /dev/null +++ b/client/firewall/uspfilter/localip_bench_test.go @@ -0,0 +1,72 @@ +package uspfilter + +import ( + "net/netip" + "testing" + + "github.com/netbirdio/netbird/client/iface/wgaddr" +) + +func setupManager(b *testing.B) *localIPManager { + b.Helper() + m := newLocalIPManager() + mock := &IFaceMock{ + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + IPv6: netip.MustParseAddr("fd00::1"), + IPv6Net: netip.MustParsePrefix("fd00::/64"), + } + }, + } + if err := m.UpdateLocalIPs(mock); err != nil { + b.Fatalf("UpdateLocalIPs: %v", err) + } + return m +} + +func BenchmarkIsLocalIP_v4_hit(b *testing.B) { + m := setupManager(b) + ip := netip.MustParseAddr("100.64.0.1") + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.IsLocalIP(ip) + } +} + +func BenchmarkIsLocalIP_v4_miss(b *testing.B) { + m := setupManager(b) + ip := netip.MustParseAddr("8.8.8.8") + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.IsLocalIP(ip) + } +} + +func BenchmarkIsLocalIP_v6_hit(b *testing.B) { + m := setupManager(b) + ip := netip.MustParseAddr("fd00::1") + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.IsLocalIP(ip) + } +} + +func BenchmarkIsLocalIP_v6_miss(b *testing.B) { + m := setupManager(b) + ip := netip.MustParseAddr("2001:db8::1") + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.IsLocalIP(ip) + } +} + +func BenchmarkIsLocalIP_loopback(b *testing.B) { + m := setupManager(b) + ip := netip.MustParseAddr("127.0.0.1") + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.IsLocalIP(ip) + } +} diff --git a/client/firewall/uspfilter/localip_test.go b/client/firewall/uspfilter/localip_test.go index 6653947fa..0dc524c41 100644 --- a/client/firewall/uspfilter/localip_test.go +++ b/client/firewall/uspfilter/localip_test.go @@ -72,14 +72,45 @@ func TestLocalIPManager(t *testing.T) { expected: false, }, { - name: "IPv6 address", + name: "IPv6 address matches", setupAddr: wgaddr.Address{ - IP: netip.MustParseAddr("fe80::1"), + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + IPv6: netip.MustParseAddr("fd00::1"), + IPv6Net: netip.MustParsePrefix("fd00::/64"), + }, + testIP: netip.MustParseAddr("fd00::1"), + expected: true, + }, + { + name: "IPv6 address does not match", + setupAddr: wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + IPv6: netip.MustParseAddr("fd00::1"), + IPv6Net: netip.MustParsePrefix("fd00::/64"), + }, + testIP: netip.MustParseAddr("fd00::99"), + expected: false, + }, + { + name: "No aliasing between similar IPs", + setupAddr: wgaddr.Address{ + IP: netip.MustParseAddr("192.168.1.1"), Network: netip.MustParsePrefix("192.168.1.0/24"), }, - testIP: netip.MustParseAddr("fe80::1"), + testIP: netip.MustParseAddr("192.168.0.17"), expected: false, }, + { + name: "IPv6 loopback", + setupAddr: wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + }, + testIP: netip.MustParseAddr("::1"), + expected: true, + }, } for _, tt := range tests { @@ -171,90 +202,3 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) { }) } } - -// MapImplementation is a version using map[string]struct{} -type MapImplementation struct { - localIPs map[string]struct{} -} - -func BenchmarkIPChecks(b *testing.B) { - interfaces := make([]net.IP, 16) - for i := range interfaces { - interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i)) - } - - // Setup bitmap - bitmapManager := newLocalIPManager() - for _, ip := range interfaces[:8] { // Add half of IPs - bitmapManager.setBitmapBit(ip) - } - - // Setup map version - mapManager := &MapImplementation{ - localIPs: make(map[string]struct{}), - } - for _, ip := range interfaces[:8] { - mapManager.localIPs[ip.String()] = struct{}{} - } - - b.Run("Bitmap_Hit", func(b *testing.B) { - ip := interfaces[4] - b.ResetTimer() - for i := 0; i < b.N; i++ { - bitmapManager.checkBitmapBit(ip) - } - }) - - b.Run("Bitmap_Miss", func(b *testing.B) { - ip := interfaces[12] - b.ResetTimer() - for i := 0; i < b.N; i++ { - bitmapManager.checkBitmapBit(ip) - } - }) - - b.Run("Map_Hit", func(b *testing.B) { - ip := interfaces[4] - b.ResetTimer() - for i := 0; i < b.N; i++ { - // nolint:gosimple - _ = mapManager.localIPs[ip.String()] - } - }) - - b.Run("Map_Miss", func(b *testing.B) { - ip := interfaces[12] - b.ResetTimer() - for i := 0; i < b.N; i++ { - // nolint:gosimple - _ = mapManager.localIPs[ip.String()] - } - }) -} - -func BenchmarkWGPosition(b *testing.B) { - wgIP := net.ParseIP("10.10.0.1") - - // Create two managers - one checks WG IP first, other checks it last - b.Run("WG_First", func(b *testing.B) { - bm := newLocalIPManager() - bm.setBitmapBit(wgIP) - b.ResetTimer() - for i := 0; i < b.N; i++ { - bm.checkBitmapBit(wgIP) - } - }) - - b.Run("WG_Last", func(b *testing.B) { - bm := newLocalIPManager() - // Fill with other IPs first - for i := 0; i < 15; i++ { - bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i))) - } - bm.setBitmapBit(wgIP) // Add WG IP last - b.ResetTimer() - for i := 0; i < b.N; i++ { - bm.checkBitmapBit(wgIP) - } - }) -} diff --git a/client/firewall/uspfilter/nat.go b/client/firewall/uspfilter/nat.go index c24b18daa..5d51c1538 100644 --- a/client/firewall/uspfilter/nat.go +++ b/client/firewall/uspfilter/nat.go @@ -14,8 +14,6 @@ import ( nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" ) -var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT") - var ( errInvalidIPHeaderLength = errors.New("invalid IP header length") ) @@ -26,10 +24,33 @@ const ( destinationPortOffset = 2 // IP address offsets in IPv4 header - sourceIPOffset = 12 - destinationIPOffset = 16 + ipv4SrcOffset = 12 + ipv4DstOffset = 16 + + // IP address offsets in IPv6 header + ipv6SrcOffset = 8 + ipv6DstOffset = 24 + + // IPv6 fixed header length + ipv6HeaderLen = 40 ) +// ipHeaderLen returns the IP header length based on the decoded layer type. +func ipHeaderLen(d *decoder) (int, error) { + switch d.decoded[0] { + case layers.LayerTypeIPv4: + n := int(d.ip4.IHL) * 4 + if n < 20 { + return 0, errInvalidIPHeaderLength + } + return n, nil + case layers.LayerTypeIPv6: + return ipv6HeaderLen, nil + default: + return 0, fmt.Errorf("unknown IP layer: %v", d.decoded[0]) + } +} + // ipv4Checksum calculates IPv4 header checksum. func ipv4Checksum(header []byte) uint16 { if len(header) < 20 { @@ -235,14 +256,13 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { return false } - dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) - + _, dstIP := extractPacketIPs(packetData, d) translatedIP, exists := m.getDNATTranslation(dstIP) if !exists { return false } - if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil { + if err := m.rewritePacketIP(packetData, d, translatedIP, false); err != nil { if m.logger.Enabled(nblog.LevelError) { m.logger.Error1("failed to rewrite packet destination: %v", err) } @@ -261,14 +281,13 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { return false } - srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) - + srcIP, _ := extractPacketIPs(packetData, d) originalIP, exists := m.findReverseDNATMapping(srcIP) if !exists { return false } - if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil { + if err := m.rewritePacketIP(packetData, d, originalIP, true); err != nil { if m.logger.Enabled(nblog.LevelError) { m.logger.Error1("failed to rewrite packet source: %v", err) } @@ -281,38 +300,96 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { return true } -// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums. -func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error { +// extractPacketIPs extracts src and dst IP addresses directly from raw packet bytes. +func extractPacketIPs(packetData []byte, d *decoder) (src, dst netip.Addr) { + switch d.decoded[0] { + case layers.LayerTypeIPv4: + src = netip.AddrFrom4([4]byte{packetData[ipv4SrcOffset], packetData[ipv4SrcOffset+1], packetData[ipv4SrcOffset+2], packetData[ipv4SrcOffset+3]}) + dst = netip.AddrFrom4([4]byte{packetData[ipv4DstOffset], packetData[ipv4DstOffset+1], packetData[ipv4DstOffset+2], packetData[ipv4DstOffset+3]}) + case layers.LayerTypeIPv6: + src = netip.AddrFrom16([16]byte(packetData[ipv6SrcOffset : ipv6SrcOffset+16])) + dst = netip.AddrFrom16([16]byte(packetData[ipv6DstOffset : ipv6DstOffset+16])) + } + return src, dst +} + +// rewritePacketIP replaces a source (isSource=true) or destination IP address in the packet and updates checksums. +func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, isSource bool) error { + hdrLen, err := ipHeaderLen(d) + if err != nil { + return err + } + + switch d.decoded[0] { + case layers.LayerTypeIPv4: + return m.rewriteIPv4(packetData, d, newIP, hdrLen, isSource) + case layers.LayerTypeIPv6: + return m.rewriteIPv6(packetData, d, newIP, hdrLen, isSource) + default: + return fmt.Errorf("unknown IP layer: %v", d.decoded[0]) + } +} + +func (m *Manager) rewriteIPv4(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error { if !newIP.Is4() { - return ErrIPv4Only + return fmt.Errorf("cannot write IPv6 address into IPv4 packet") + } + + offset := ipv4DstOffset + if isSource { + offset = ipv4SrcOffset } var oldIP [4]byte - copy(oldIP[:], packetData[ipOffset:ipOffset+4]) + copy(oldIP[:], packetData[offset:offset+4]) newIPBytes := newIP.As4() + copy(packetData[offset:offset+4], newIPBytes[:]) - copy(packetData[ipOffset:ipOffset+4], newIPBytes[:]) - - ipHeaderLen := int(d.ip4.IHL) * 4 - if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return errInvalidIPHeaderLength - } - + // Recalculate IPv4 header checksum binary.BigEndian.PutUint16(packetData[10:12], 0) - ipChecksum := ipv4Checksum(packetData[:ipHeaderLen]) - binary.BigEndian.PutUint16(packetData[10:12], ipChecksum) + binary.BigEndian.PutUint16(packetData[10:12], ipv4Checksum(packetData[:hdrLen])) + // Update transport checksums incrementally if len(d.decoded) > 1 { switch d.decoded[1] { case layers.LayerTypeTCP: - m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) + m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:]) case layers.LayerTypeUDP: - m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) + m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:]) case layers.LayerTypeICMPv4: - m.updateICMPChecksum(packetData, ipHeaderLen) + m.updateICMPChecksum(packetData, hdrLen) } } + return nil +} +func (m *Manager) rewriteIPv6(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error { + if !newIP.Is6() { + return fmt.Errorf("cannot write IPv4 address into IPv6 packet") + } + + offset := ipv6DstOffset + if isSource { + offset = ipv6SrcOffset + } + + var oldIP [16]byte + copy(oldIP[:], packetData[offset:offset+16]) + newIPBytes := newIP.As16() + copy(packetData[offset:offset+16], newIPBytes[:]) + + // IPv6 has no header checksum, only update transport checksums + if len(d.decoded) > 1 { + switch d.decoded[1] { + case layers.LayerTypeTCP: + m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:]) + case layers.LayerTypeUDP: + m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:]) + case layers.LayerTypeICMPv6: + // ICMPv6 checksum includes pseudo-header with addresses, use incremental update + m.updateICMPv6Checksum(packetData, hdrLen, oldIP[:], newIPBytes[:]) + } + } return nil } @@ -360,6 +437,20 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { binary.BigEndian.PutUint16(icmpData[2:4], checksum) } +// updateICMPv6Checksum updates ICMPv6 checksum after address change. +// ICMPv6 uses a pseudo-header (like TCP/UDP), so incremental update applies. +func (m *Manager) updateICMPv6Checksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { + icmpStart := ipHeaderLen + if len(packetData) < icmpStart+4 { + return + } + + checksumOffset := icmpStart + 2 + oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2]) + newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP) + binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) +} + // incrementalUpdate performs incremental checksum update per RFC 1624. func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { sum := uint32(^oldChecksum) @@ -412,14 +503,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { } // addPortRedirection adds a port redirection rule. -func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error { +func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, originalPort, translatedPort uint16) error { m.portDNATMutex.Lock() defer m.portDNATMutex.Unlock() rule := portDNATRule{ protocol: protocol, - origPort: sourcePort, - targetPort: targetPort, + origPort: originalPort, + targetPort: translatedPort, targetIP: targetIP, } @@ -431,7 +522,7 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. // TODO: also delegate to nativeFirewall when available for kernel WG mode -func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { var layerType gopacket.LayerType switch protocol { case firewall.ProtocolTCP: @@ -442,16 +533,16 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco return fmt.Errorf("unsupported protocol: %s", protocol) } - return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort) + return m.addPortRedirection(localAddr, layerType, originalPort, translatedPort) } // removePortRedirection removes a port redirection rule. -func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error { +func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, originalPort, translatedPort uint16) error { m.portDNATMutex.Lock() defer m.portDNATMutex.Unlock() m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool { - return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0 + return rule.protocol == protocol && rule.origPort == originalPort && rule.targetPort == translatedPort && rule.targetIP.Compare(targetIP) == 0 }) if len(m.portDNATRules) == 0 { @@ -462,7 +553,7 @@ func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.L } // RemoveInboundDNAT removes an inbound DNAT rule. -func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { var layerType gopacket.LayerType switch protocol { case firewall.ProtocolTCP: @@ -473,23 +564,23 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot return fmt.Errorf("unsupported protocol: %s", protocol) } - return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort) + return m.removePortRedirection(localAddr, layerType, originalPort, translatedPort) } // AddOutputDNAT delegates to the native firewall if available. -func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { if m.nativeFirewall == nil { return fmt.Errorf("output DNAT not supported without native firewall") } - return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort) + return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort) } // RemoveOutputDNAT delegates to the native firewall if available. -func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { if m.nativeFirewall == nil { return nil } - return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort) + return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort) } // translateInboundPortDNAT applies port-specific DNAT translation to inbound packets. @@ -543,12 +634,12 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti // rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum. func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { - ipHeaderLen := int(d.ip4.IHL) * 4 - if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return errInvalidIPHeaderLength + hdrLen, err := ipHeaderLen(d) + if err != nil { + return err } - tcpStart := ipHeaderLen + tcpStart := hdrLen if len(packetData) < tcpStart+4 { return fmt.Errorf("packet too short for TCP header") } @@ -574,12 +665,12 @@ func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, // rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum. func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { - ipHeaderLen := int(d.ip4.IHL) * 4 - if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return errInvalidIPHeaderLength + hdrLen, err := ipHeaderLen(d) + if err != nil { + return err } - udpStart := ipHeaderLen + udpStart := hdrLen if len(packetData) < udpStart+8 { return fmt.Errorf("packet too short for UDP header") } diff --git a/client/firewall/uspfilter/nat_bench_test.go b/client/firewall/uspfilter/nat_bench_test.go index d2599e577..1e15c8c0c 100644 --- a/client/firewall/uspfilter/nat_bench_test.go +++ b/client/firewall/uspfilter/nat_bench_test.go @@ -342,12 +342,17 @@ func BenchmarkDNATMemoryAllocations(b *testing.B) { // Parse the packet fresh each time to get a clean decoder d := &decoder{decoded: []gopacket.LayerType{}} - d.parser = gopacket.NewDecodingLayerParser( + d.parser4 = gopacket.NewDecodingLayerParser( layers.LayerTypeIPv4, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, ) - d.parser.IgnoreUnsupported = true - err = d.parser.DecodeLayers(testPacket, &d.decoded) + d.parser4.IgnoreUnsupported = true + d.parser6 = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv6, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser6.IgnoreUnsupported = true + err = d.decodePacket(testPacket) assert.NoError(b, err) manager.translateOutboundDNAT(testPacket, d) @@ -371,12 +376,17 @@ func BenchmarkDirectIPExtraction(b *testing.B) { b.Run("decoder_extraction", func(b *testing.B) { // Create decoder once for comparison d := &decoder{decoded: []gopacket.LayerType{}} - d.parser = gopacket.NewDecodingLayerParser( + d.parser4 = gopacket.NewDecodingLayerParser( layers.LayerTypeIPv4, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, ) - d.parser.IgnoreUnsupported = true - err := d.parser.DecodeLayers(packet, &d.decoded) + d.parser4.IgnoreUnsupported = true + d.parser6 = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv6, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser6.IgnoreUnsupported = true + err := d.decodePacket(packet) assert.NoError(b, err) for i := 0; i < b.N; i++ { diff --git a/client/firewall/uspfilter/nat_test.go b/client/firewall/uspfilter/nat_test.go index 50743d006..4598c3901 100644 --- a/client/firewall/uspfilter/nat_test.go +++ b/client/firewall/uspfilter/nat_test.go @@ -86,13 +86,18 @@ func parsePacket(t testing.TB, packetData []byte) *decoder { d := &decoder{ decoded: []gopacket.LayerType{}, } - d.parser = gopacket.NewDecodingLayerParser( + d.parser4 = gopacket.NewDecodingLayerParser( layers.LayerTypeIPv4, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, ) - d.parser.IgnoreUnsupported = true + d.parser4.IgnoreUnsupported = true + d.parser6 = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv6, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser6.IgnoreUnsupported = true - err := d.parser.DecodeLayers(packetData, &d.decoded) + err := d.decodePacket(packetData) require.NoError(t, err) return d } diff --git a/client/firewall/uspfilter/tracer.go b/client/firewall/uspfilter/tracer.go index 69c2519bf..696489e95 100644 --- a/client/firewall/uspfilter/tracer.go +++ b/client/firewall/uspfilter/tracer.go @@ -2,7 +2,9 @@ package uspfilter import ( "fmt" + "net" "net/netip" + "strconv" "time" "github.com/google/gopacket" @@ -112,10 +114,13 @@ func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string, } func (p *PacketBuilder) Build() ([]byte, error) { - ip := p.buildIPLayer() - pktLayers := []gopacket.SerializableLayer{ip} + ipLayer, err := p.buildIPLayer() + if err != nil { + return nil, err + } + pktLayers := []gopacket.SerializableLayer{ipLayer} - transportLayer, err := p.buildTransportLayer(ip) + transportLayer, err := p.buildTransportLayer(ipLayer) if err != nil { return nil, err } @@ -129,30 +134,43 @@ func (p *PacketBuilder) Build() ([]byte, error) { return serializePacket(pktLayers) } -func (p *PacketBuilder) buildIPLayer() *layers.IPv4 { +func (p *PacketBuilder) buildIPLayer() (gopacket.SerializableLayer, error) { + if p.SrcIP.Is4() != p.DstIP.Is4() { + return nil, fmt.Errorf("mixed address families: src=%s dst=%s", p.SrcIP, p.DstIP) + } + proto := getIPProtocolNumber(p.Protocol, p.SrcIP.Is6()) + if p.SrcIP.Is6() { + return &layers.IPv6{ + Version: 6, + HopLimit: 64, + NextHeader: proto, + SrcIP: p.SrcIP.AsSlice(), + DstIP: p.DstIP.AsSlice(), + }, nil + } return &layers.IPv4{ Version: 4, TTL: 64, - Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)), + Protocol: proto, SrcIP: p.SrcIP.AsSlice(), DstIP: p.DstIP.AsSlice(), - } + }, nil } -func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) { +func (p *PacketBuilder) buildTransportLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) { switch p.Protocol { case "tcp": - return p.buildTCPLayer(ip) + return p.buildTCPLayer(ipLayer) case "udp": - return p.buildUDPLayer(ip) + return p.buildUDPLayer(ipLayer) case "icmp": - return p.buildICMPLayer() + return p.buildICMPLayer(ipLayer) default: return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol) } } -func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) { +func (p *PacketBuilder) buildTCPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) { tcp := &layers.TCP{ SrcPort: layers.TCPPort(p.SrcPort), DstPort: layers.TCPPort(p.DstPort), @@ -164,24 +182,44 @@ func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableL PSH: p.TCPState != nil && p.TCPState.PSH, URG: p.TCPState != nil && p.TCPState.URG, } - if err := tcp.SetNetworkLayerForChecksum(ip); err != nil { - return nil, fmt.Errorf("set network layer for TCP checksum: %w", err) + if nl, ok := ipLayer.(gopacket.NetworkLayer); ok { + if err := tcp.SetNetworkLayerForChecksum(nl); err != nil { + return nil, fmt.Errorf("set network layer for TCP checksum: %w", err) + } } return []gopacket.SerializableLayer{tcp}, nil } -func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) { +func (p *PacketBuilder) buildUDPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) { udp := &layers.UDP{ SrcPort: layers.UDPPort(p.SrcPort), DstPort: layers.UDPPort(p.DstPort), } - if err := udp.SetNetworkLayerForChecksum(ip); err != nil { - return nil, fmt.Errorf("set network layer for UDP checksum: %w", err) + if nl, ok := ipLayer.(gopacket.NetworkLayer); ok { + if err := udp.SetNetworkLayerForChecksum(nl); err != nil { + return nil, fmt.Errorf("set network layer for UDP checksum: %w", err) + } } return []gopacket.SerializableLayer{udp}, nil } -func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) { +func (p *PacketBuilder) buildICMPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) { + if p.SrcIP.Is6() || p.DstIP.Is6() { + icmp := &layers.ICMPv6{ + TypeCode: layers.CreateICMPv6TypeCode(p.ICMPType, p.ICMPCode), + } + if nl, ok := ipLayer.(gopacket.NetworkLayer); ok { + _ = icmp.SetNetworkLayerForChecksum(nl) + } + if p.ICMPType == layers.ICMPv6TypeEchoRequest || p.ICMPType == layers.ICMPv6TypeEchoReply { + echo := &layers.ICMPv6Echo{ + Identifier: 1, + SeqNumber: 1, + } + return []gopacket.SerializableLayer{icmp, echo}, nil + } + return []gopacket.SerializableLayer{icmp}, nil + } icmp := &layers.ICMPv4{ TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode), } @@ -204,14 +242,17 @@ func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) { return buf.Bytes(), nil } -func getIPProtocolNumber(protocol fw.Protocol) int { +func getIPProtocolNumber(protocol fw.Protocol, isV6 bool) layers.IPProtocol { switch protocol { case fw.ProtocolTCP: - return int(layers.IPProtocolTCP) + return layers.IPProtocolTCP case fw.ProtocolUDP: - return int(layers.IPProtocolUDP) + return layers.IPProtocolUDP case fw.ProtocolICMP: - return int(layers.IPProtocolICMPv4) + if isV6 { + return layers.IPProtocolICMPv6 + } + return layers.IPProtocolICMPv4 default: return 0 } @@ -234,7 +275,7 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa trace := &PacketTrace{Direction: direction} // Initial packet decoding - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false) return trace } @@ -256,6 +297,8 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa trace.DestinationPort = uint16(d.udp.DstPort) case layers.LayerTypeICMPv4: trace.Protocol = "ICMP" + case layers.LayerTypeICMPv6: + trace.Protocol = "ICMPv6" } trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d", @@ -319,6 +362,13 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string { flags&conntrack.TCPFin != 0) case layers.LayerTypeICMPv4: msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq) + case layers.LayerTypeICMPv6: + var id, seq uint16 + if len(d.icmp6.Payload) >= 4 { + id = uint16(d.icmp6.Payload[0])<<8 | uint16(d.icmp6.Payload[1]) + seq = uint16(d.icmp6.Payload[2])<<8 | uint16(d.icmp6.Payload[3]) + } + msg += fmt.Sprintf(" (ICMPv6 ID=%d, Seq=%d)", id, seq) } return msg } @@ -395,7 +445,7 @@ func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP n trace.AddResult(StageRouteACL, msg, allowed) if allowed && m.forwarder.Load() != nil { - m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true) + m.addForwardingResult(trace, "proxy-remote", net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort))), true) } trace.AddResult(StageCompleted, msgProcessingCompleted, allowed) @@ -415,7 +465,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr d := m.decoders.Get().(*decoder) defer m.decoders.Put(d) - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { trace.AddResult(StageCompleted, "Packet dropped - decode error", false) return trace } @@ -434,7 +484,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool { portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d) if portDNATApplied { - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false) return true } @@ -444,7 +494,7 @@ func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *de nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d) if nat1to1Applied { - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false) return true } @@ -509,7 +559,7 @@ func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d * return false } - srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + srcIP, _ := extractPacketIPs(packetData, d) translated := m.translateInboundReverse(packetData, d) if translated { @@ -539,7 +589,7 @@ func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d return false } - dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + _, dstIP := extractPacketIPs(packetData, d) translated := m.translateOutboundDNAT(packetData, d) if translated { diff --git a/client/iface/bind/ice_bind_test.go b/client/iface/bind/ice_bind_test.go index 1fdd955c9..f49e68508 100644 --- a/client/iface/bind/ice_bind_test.go +++ b/client/iface/bind/ice_bind_test.go @@ -239,8 +239,12 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) { ipv6Count++ } - assert.Equal(t, packetsPerFamily, ipv4Count) - assert.Equal(t, packetsPerFamily, ipv6Count) + // Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The + // routing-correctness checks above are the real assertions; the counts + // are a sanity bound to catch a totally silent path. + minDelivered := packetsPerFamily * 80 / 100 + assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold") + assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold") } func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) { diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index e3a96590c..9b070aab8 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -119,7 +119,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, if err != nil { return fmt.Errorf("failed to parse endpoint address: %w", err) } - addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port)) + addrPort := netip.AddrPortFrom(addr.Unmap(), uint16(endpoint.Port)) c.activityRecorder.UpsertAddress(peerKey, addrPort) } return nil diff --git a/client/iface/device/adapter.go b/client/iface/device/adapter.go index 6ebc05390..e3caaf930 100644 --- a/client/iface/device/adapter.go +++ b/client/iface/device/adapter.go @@ -2,7 +2,7 @@ package device // TunAdapter is an interface for create tun device from external service type TunAdapter interface { - ConfigureInterface(address string, mtu int, dns string, searchDomains string, routes string) (int, error) + ConfigureInterface(address string, addressV6 string, mtu int, dns string, searchDomains string, routes string) (int, error) UpdateAddr(address string) error ProtectSocket(fd int32) bool } diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index 198343fbd..cbe88c10c 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -63,7 +63,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string searchDomainsToString = "" } - fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), int(t.mtu), dns, searchDomainsToString, routesString) + fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.address.IPv6String(), int(t.mtu), dns, searchDomainsToString, routesString) if err != nil { log.Errorf("failed to create Android interface: %s", err) return nil, err diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index acd5f6f11..ac8f8a51b 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -131,23 +131,32 @@ func (t *TunDevice) Device() *device.Device { // assignAddr Adds IP address to the tunnel interface and network route based on the range provided func (t *TunDevice) assignAddr() error { - cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()) - if out, err := cmd.CombinedOutput(); err != nil { - log.Errorf("adding address command '%v' failed with output: %s", cmd.String(), out) - return err + if out, err := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()).CombinedOutput(); err != nil { + return fmt.Errorf("add v4 address: %s: %w", string(out), err) } - // dummy ipv6 so routing works - cmd = exec.Command("ifconfig", t.name, "inet6", "fe80::/64") - if out, err := cmd.CombinedOutput(); err != nil { - log.Debugf("adding address command '%v' failed with output: %s", cmd.String(), out) + // Assign a dummy link-local so macOS enables IPv6 on the tun device. + // When a real overlay v6 is present, use that instead. + v6Addr := "fe80::/64" + if t.address.HasIPv6() { + v6Addr = t.address.IPv6String() + } + if out, err := exec.Command("ifconfig", t.name, "inet6", v6Addr).CombinedOutput(); err != nil { + log.Warnf("failed to assign IPv6 address %s, continuing v4-only: %s: %v", v6Addr, string(out), err) + t.address.ClearIPv6() } - routeCmd := exec.Command("route", "add", "-net", t.address.Network.String(), "-interface", t.name) - if out, err := routeCmd.CombinedOutput(); err != nil { - log.Errorf("adding route command '%v' failed with output: %s", routeCmd.String(), out) - return err + if out, err := exec.Command("route", "add", "-net", t.address.Network.String(), "-interface", t.name).CombinedOutput(); err != nil { + return fmt.Errorf("add route %s via %s: %s: %w", t.address.Network, t.name, string(out), err) } + + if t.address.HasIPv6() { + if out, err := exec.Command("route", "add", "-inet6", "-net", t.address.IPv6Net.String(), "-interface", t.name).CombinedOutput(); err != nil { + log.Warnf("failed to add route %s via %s, continuing v4-only: %s: %v", t.address.IPv6Net, t.name, string(out), err) + t.address.ClearIPv6() + } + } + return nil } diff --git a/client/iface/device/device_filter.go b/client/iface/device/device_filter.go index 4357d1916..fc1c65efa 100644 --- a/client/iface/device/device_filter.go +++ b/client/iface/device/device_filter.go @@ -3,6 +3,7 @@ package device import ( "net/netip" "sync" + "sync/atomic" "golang.zx2c4.com/wireguard/tun" ) @@ -28,11 +29,20 @@ type PacketFilter interface { SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) } +// PacketCapture captures raw packets for debugging. Implementations must be +// safe for concurrent use and must not block. +type PacketCapture interface { + // Offer submits a packet for capture. outbound is true for packets + // leaving the host (Read path), false for packets arriving (Write path). + Offer(data []byte, outbound bool) +} + // FilteredDevice to override Read or Write of packets type FilteredDevice struct { tun.Device filter PacketFilter + capture atomic.Pointer[PacketCapture] mutex sync.RWMutex closeOnce sync.Once } @@ -63,20 +73,25 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er if n, err = d.Device.Read(bufs, sizes, offset); err != nil { return 0, err } + d.mutex.RLock() filter := d.filter d.mutex.RUnlock() - if filter == nil { - return + if filter != nil { + for i := 0; i < n; i++ { + if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) { + bufs = append(bufs[:i], bufs[i+1:]...) + sizes = append(sizes[:i], sizes[i+1:]...) + n-- + i-- + } + } } - for i := 0; i < n; i++ { - if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) { - bufs = append(bufs[:i], bufs[i+1:]...) - sizes = append(sizes[:i], sizes[i+1:]...) - n-- - i-- + if pc := d.capture.Load(); pc != nil { + for i := 0; i < n; i++ { + (*pc).Offer(bufs[i][offset:offset+sizes[i]], true) } } @@ -85,6 +100,13 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er // Write wraps write method with filtering feature func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) { + // Capture before filtering so dropped packets are still visible in captures. + if pc := d.capture.Load(); pc != nil { + for _, buf := range bufs { + (*pc).Offer(buf[offset:], false) + } + } + d.mutex.RLock() filter := d.filter d.mutex.RUnlock() @@ -96,9 +118,10 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) { filteredBufs := make([][]byte, 0, len(bufs)) dropped := 0 for _, buf := range bufs { - if !filter.FilterInbound(buf[offset:], len(buf)) { - filteredBufs = append(filteredBufs, buf) + if filter.FilterInbound(buf[offset:], len(buf)) { dropped++ + } else { + filteredBufs = append(filteredBufs, buf) } } @@ -113,3 +136,14 @@ func (d *FilteredDevice) SetFilter(filter PacketFilter) { d.filter = filter d.mutex.Unlock() } + +// SetCapture sets or clears the packet capture sink. Pass nil to disable. +// Uses atomic store so the hot path (Read/Write) is a single pointer load +// with no locking overhead when capture is off. +func (d *FilteredDevice) SetCapture(pc PacketCapture) { + if pc == nil { + d.capture.Store(nil) + return + } + d.capture.Store(&pc) +} diff --git a/client/iface/device/device_filter_test.go b/client/iface/device/device_filter_test.go index eef783542..8fb16ca8d 100644 --- a/client/iface/device/device_filter_test.go +++ b/client/iface/device/device_filter_test.go @@ -158,7 +158,7 @@ func TestDeviceWrapperRead(t *testing.T) { t.Errorf("unexpected error: %v", err) return } - if n != 0 { + if n != 1 { t.Errorf("expected n=1, got %d", n) return } diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go index aa77cee45..8368c8dce 100644 --- a/client/iface/device/device_ios.go +++ b/client/iface/device/device_ios.go @@ -151,8 +151,11 @@ func (t *TunDevice) MTU() uint16 { return t.mtu } -func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error { - // todo implement +// UpdateAddr updates the device address. On iOS the tunnel is managed by the +// NetworkExtension, so we only store the new value. The extension picks up the +// change on the next tunnel reconfiguration. +func (t *TunDevice) UpdateAddr(addr wgaddr.Address) error { + t.address = addr return nil } diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go index 2a836f846..25c4148a6 100644 --- a/client/iface/device/device_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -173,7 +173,7 @@ func (t *TunKernelDevice) FilteredDevice() *FilteredDevice { // assignAddr Adds IP address to the tunnel interface func (t *TunKernelDevice) assignAddr() error { - return t.link.assignAddr(t.address) + return t.link.assignAddr(&t.address) } func (t *TunKernelDevice) GetNet() *netstack.Net { diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index 1a92b148f..b3bce3925 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -3,6 +3,7 @@ package device import ( "errors" "fmt" + "net/netip" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/conn" @@ -63,8 +64,12 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) { return nil, fmt.Errorf("last ip: %w", err) } - log.Debugf("netstack using address: %s", t.address.IP) - t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, int(t.mtu)) + addresses := []netip.Addr{t.address.IP} + if t.address.HasIPv6() { + addresses = append(addresses, t.address.IPv6) + } + log.Debugf("netstack using addresses: %v", addresses) + t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, addresses, dnsAddr, int(t.mtu)) log.Debugf("netstack using dns address: %s", dnsAddr) tunIface, net, err := t.nsTun.Create() if err != nil { diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index 24654fc03..04c265c49 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -16,7 +16,7 @@ import ( "github.com/netbirdio/netbird/client/iface/wgaddr" ) -type USPDevice struct { +type TunDevice struct { name string address wgaddr.Address port int @@ -30,10 +30,10 @@ type USPDevice struct { configurer WGConfigurer } -func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *USPDevice { +func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *TunDevice { log.Infof("using userspace bind mode") - return &USPDevice{ + return &TunDevice{ name: name, address: address, port: port, @@ -43,7 +43,7 @@ func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu } } -func (t *USPDevice) Create() (WGConfigurer, error) { +func (t *TunDevice) Create() (WGConfigurer, error) { log.Info("create tun interface") tunIface, err := tun.CreateTUN(t.name, int(t.mtu)) if err != nil { @@ -75,7 +75,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { if t.device == nil { return nil, fmt.Errorf("device is not ready yet") } @@ -95,12 +95,12 @@ func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *USPDevice) UpdateAddr(address wgaddr.Address) error { +func (t *TunDevice) UpdateAddr(address wgaddr.Address) error { t.address = address return t.assignAddr() } -func (t *USPDevice) Close() error { +func (t *TunDevice) Close() error { if t.configurer != nil { t.configurer.Close() } @@ -115,39 +115,39 @@ func (t *USPDevice) Close() error { return nil } -func (t *USPDevice) WgAddress() wgaddr.Address { +func (t *TunDevice) WgAddress() wgaddr.Address { return t.address } -func (t *USPDevice) MTU() uint16 { +func (t *TunDevice) MTU() uint16 { return t.mtu } -func (t *USPDevice) DeviceName() string { +func (t *TunDevice) DeviceName() string { return t.name } -func (t *USPDevice) FilteredDevice() *FilteredDevice { +func (t *TunDevice) FilteredDevice() *FilteredDevice { return t.filteredDevice } // Device returns the wireguard device -func (t *USPDevice) Device() *device.Device { +func (t *TunDevice) Device() *device.Device { return t.device } // assignAddr Adds IP address to the tunnel interface -func (t *USPDevice) assignAddr() error { +func (t *TunDevice) assignAddr() error { link := newWGLink(t.name) - return link.assignAddr(t.address) + return link.assignAddr(&t.address) } -func (t *USPDevice) GetNet() *netstack.Net { +func (t *TunDevice) GetNet() *netstack.Net { return nil } // GetICEBind returns the ICEBind instance -func (t *USPDevice) GetICEBind() EndpointManager { +func (t *TunDevice) GetICEBind() EndpointManager { return t.iceBind } diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go index 96350df8a..f52392fa2 100644 --- a/client/iface/device/device_windows.go +++ b/client/iface/device/device_windows.go @@ -87,7 +87,21 @@ func (t *TunDevice) Create() (WGConfigurer, error) { err = nbiface.Set() if err != nil { t.device.Close() - return nil, fmt.Errorf("got error when getting setting the interface mtu: %s", err) + return nil, fmt.Errorf("set IPv4 interface MTU: %s", err) + } + + if t.address.HasIPv6() { + nbiface6, err := luid.IPInterface(windows.AF_INET6) + if err != nil { + log.Warnf("failed to get IPv6 interface for MTU, continuing v4-only: %v", err) + t.address.ClearIPv6() + } else { + nbiface6.NLMTU = uint32(t.mtu) + if err := nbiface6.Set(); err != nil { + log.Warnf("failed to set IPv6 interface MTU, continuing v4-only: %v", err) + t.address.ClearIPv6() + } + } } err = t.assignAddr() if err != nil { @@ -178,8 +192,21 @@ func (t *TunDevice) GetInterfaceGUIDString() (string, error) { // assignAddr Adds IP address to the tunnel interface and network route based on the range provided func (t *TunDevice) assignAddr() error { luid := winipcfg.LUID(t.nativeTunDevice.LUID()) - log.Debugf("adding address %s to interface: %s", t.address.IP, t.name) - return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())}) + + v4Prefix := t.address.Prefix() + if t.address.HasIPv6() { + v6Prefix := t.address.IPv6Prefix() + log.Debugf("adding addresses %s, %s to interface: %s", v4Prefix, v6Prefix, t.name) + if err := luid.SetIPAddresses([]netip.Prefix{v4Prefix, v6Prefix}); err != nil { + log.Warnf("failed to assign dual-stack addresses, retrying v4-only: %v", err) + t.address.ClearIPv6() + return luid.SetIPAddresses([]netip.Prefix{v4Prefix}) + } + return nil + } + + log.Debugf("adding address %s to interface: %s", v4Prefix, t.name) + return luid.SetIPAddresses([]netip.Prefix{v4Prefix}) } func (t *TunDevice) GetNet() *netstack.Net { diff --git a/client/iface/device/kernel_module.go b/client/iface/device/kernel_module.go deleted file mode 100644 index 1bdd6f7c6..000000000 --- a/client/iface/device/kernel_module.go +++ /dev/null @@ -1,8 +0,0 @@ -//go:build (!linux && !freebsd) || android - -package device - -// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only) -func WireGuardModuleIsLoaded() bool { - return false -} diff --git a/client/iface/device/kernel_module_freebsd.go b/client/iface/device/kernel_module_freebsd.go deleted file mode 100644 index dd6c8b408..000000000 --- a/client/iface/device/kernel_module_freebsd.go +++ /dev/null @@ -1,18 +0,0 @@ -package device - -// WireGuardModuleIsLoaded check if kernel support wireguard -func WireGuardModuleIsLoaded() bool { - // Despite the fact FreeBSD natively support Wireguard (https://github.com/WireGuard/wireguard-freebsd) - // we are currently do not use it, since it is required to add wireguard kernel support to - // - https://github.com/netbirdio/netbird/tree/main/sharedsock - // - https://github.com/mdlayher/socket - // TODO: implement kernel space - return false -} - -// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it -func ModuleTunIsLoaded() bool { - // Assume tun supported by freebsd kernel by default - // TODO: implement check for module loaded in kernel or build-it - return true -} diff --git a/client/iface/device/kernel_module_nonlinux.go b/client/iface/device/kernel_module_nonlinux.go new file mode 100644 index 000000000..58d97080b --- /dev/null +++ b/client/iface/device/kernel_module_nonlinux.go @@ -0,0 +1,13 @@ +//go:build !linux || android + +package device + +// WireGuardModuleIsLoaded reports whether the kernel WireGuard module is available. +func WireGuardModuleIsLoaded() bool { + return false +} + +// ModuleTunIsLoaded reports whether the tun device is available. +func ModuleTunIsLoaded() bool { + return true +} diff --git a/client/iface/device/wg_link_freebsd.go b/client/iface/device/wg_link_freebsd.go index 1b06e0e15..87df89183 100644 --- a/client/iface/device/wg_link_freebsd.go +++ b/client/iface/device/wg_link_freebsd.go @@ -2,6 +2,7 @@ package device import ( "fmt" + "os/exec" log "github.com/sirupsen/logrus" @@ -57,32 +58,32 @@ func (l *wgLink) up() error { return nil } -func (l *wgLink) assignAddr(address wgaddr.Address) error { +func (l *wgLink) assignAddr(address *wgaddr.Address) error { link, err := freebsd.LinkByName(l.name) if err != nil { return fmt.Errorf("link by name: %w", err) } - ip := address.IP.String() - - // Convert prefix length to hex netmask prefixLen := address.Network.Bits() - if !address.IP.Is4() { - return fmt.Errorf("IPv6 not supported for interface assignment") - } - maskBits := uint32(0xffffffff) << (32 - prefixLen) mask := fmt.Sprintf("0x%08x", maskBits) - log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name) + log.Infof("assign addr %s mask %s to %s interface", address.IP, mask, l.name) - err = link.AssignAddr(ip, mask) - if err != nil { + if err := link.AssignAddr(address.IP.String(), mask); err != nil { return fmt.Errorf("assign addr: %w", err) } - err = link.Up() - if err != nil { + if address.HasIPv6() { + log.Infof("assign IPv6 addr %s to %s interface", address.IPv6String(), l.name) + cmd := exec.Command("ifconfig", l.name, "inet6", address.IPv6String()) + if out, err := cmd.CombinedOutput(); err != nil { + log.Warnf("failed to assign IPv6 address %s to %s, continuing v4-only: %s: %v", address.IPv6String(), l.name, string(out), err) + address.ClearIPv6() + } + } + + if err := link.Up(); err != nil { return fmt.Errorf("up: %w", err) } diff --git a/client/iface/device/wg_link_linux.go b/client/iface/device/wg_link_linux.go index d941cd022..6a02cb356 100644 --- a/client/iface/device/wg_link_linux.go +++ b/client/iface/device/wg_link_linux.go @@ -4,6 +4,8 @@ package device import ( "fmt" + "net" + "net/netip" "os" log "github.com/sirupsen/logrus" @@ -92,7 +94,7 @@ func (l *wgLink) up() error { return nil } -func (l *wgLink) assignAddr(address wgaddr.Address) error { +func (l *wgLink) assignAddr(address *wgaddr.Address) error { //delete existing addresses list, err := netlink.AddrList(l, 0) if err != nil { @@ -110,20 +112,16 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error { } name := l.attrs.Name - addrStr := address.String() - log.Debugf("adding address %s to interface: %s", addrStr, name) - - addr, err := netlink.ParseAddr(addrStr) - if err != nil { - return fmt.Errorf("parse addr: %w", err) + if err := l.addAddr(name, address.Prefix()); err != nil { + return err } - err = netlink.AddrAdd(l, addr) - if os.IsExist(err) { - log.Infof("interface %s already has the address: %s", name, addrStr) - } else if err != nil { - return fmt.Errorf("add addr: %w", err) + if address.HasIPv6() { + if err := l.addAddr(name, address.IPv6Prefix()); err != nil { + log.Warnf("failed to assign IPv6 address %s to %s, continuing v4-only: %v", address.IPv6Prefix(), name, err) + address.ClearIPv6() + } } // On linux, the link must be brought up @@ -133,3 +131,22 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error { return nil } + +func (l *wgLink) addAddr(ifaceName string, prefix netip.Prefix) error { + log.Debugf("adding address %s to interface: %s", prefix, ifaceName) + + addr := &netlink.Addr{ + IPNet: &net.IPNet{ + IP: prefix.Addr().AsSlice(), + Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), + }, + } + + if err := netlink.AddrAdd(l, addr); os.IsExist(err) { + log.Infof("interface %s already has the address: %s", ifaceName, prefix) + } else if err != nil { + return fmt.Errorf("add addr %s: %w", prefix, err) + } + + return nil +} diff --git a/client/iface/iface.go b/client/iface/iface.go index 655dd1682..78c5080e7 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -57,7 +57,7 @@ type wgProxyFactory interface { type WGIFaceOpts struct { IFaceName string - Address string + Address wgaddr.Address WGPort int WGPrivKey string MTU uint16 @@ -141,16 +141,11 @@ func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) { } // UpdateAddr updates address of the interface -func (w *WGIface) UpdateAddr(newAddr string) error { +func (w *WGIface) UpdateAddr(newAddr wgaddr.Address) error { w.mu.Lock() defer w.mu.Unlock() - addr, err := wgaddr.ParseWGAddress(newAddr) - if err != nil { - return err - } - - return w.tun.UpdateAddr(addr) + return w.tun.UpdateAddr(newAddr) } // UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist diff --git a/client/iface/iface_new_windows.go b/client/iface/iface_new.go similarity index 50% rename from client/iface/iface_new_windows.go rename to client/iface/iface_new.go index dfd9028e7..28f350e3f 100644 --- a/client/iface/iface_new_windows.go +++ b/client/iface/iface_new.go @@ -1,33 +1,28 @@ +//go:build !linux && !ios && !android && !js + package iface import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/netstack" - wgaddr "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" ) // NewWGIFace Creates a new WireGuard interface instance func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { - wgAddress, err := wgaddr.ParseWGAddress(opts.Address) - if err != nil { - return nil, err - } - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU) var tun WGTunDevice if netstack.IsEnabled() { - tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) + tun = device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) } else { - tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) + tun = device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) } - wgIFace := &WGIface{ + return &WGIface{ userspaceBind: true, tun: tun, wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), - } - return wgIFace, nil - + }, nil } diff --git a/client/iface/iface_new_android.go b/client/iface/iface_new_android.go index 3b68f63f2..e28dcc0de 100644 --- a/client/iface/iface_new_android.go +++ b/client/iface/iface_new_android.go @@ -4,23 +4,17 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/netstack" - "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" ) // NewWGIFace Creates a new WireGuard interface instance func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { - wgAddress, err := wgaddr.ParseWGAddress(opts.Address) - if err != nil { - return nil, err - } - - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU) if netstack.IsEnabled() { wgIFace := &WGIface{ userspaceBind: true, - tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()), + tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()), wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil @@ -28,7 +22,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ userspaceBind: true, - tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS), + tun: device.NewTunDevice(opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS), wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil diff --git a/client/iface/iface_new_darwin.go b/client/iface/iface_new_darwin.go deleted file mode 100644 index 9f21ec950..000000000 --- a/client/iface/iface_new_darwin.go +++ /dev/null @@ -1,35 +0,0 @@ -//go:build !ios - -package iface - -import ( - "github.com/netbirdio/netbird/client/iface/bind" - "github.com/netbirdio/netbird/client/iface/device" - "github.com/netbirdio/netbird/client/iface/netstack" - "github.com/netbirdio/netbird/client/iface/wgaddr" - "github.com/netbirdio/netbird/client/iface/wgproxy" -) - -// NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { - wgAddress, err := wgaddr.ParseWGAddress(opts.Address) - if err != nil { - return nil, err - } - - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) - - var tun WGTunDevice - if netstack.IsEnabled() { - tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) - } else { - tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) - } - - wgIFace := &WGIface{ - userspaceBind: true, - tun: tun, - wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), - } - return wgIFace, nil -} diff --git a/client/iface/iface_new_freebsd.go b/client/iface/iface_new_freebsd.go deleted file mode 100644 index a342bd579..000000000 --- a/client/iface/iface_new_freebsd.go +++ /dev/null @@ -1,41 +0,0 @@ -//go:build freebsd - -package iface - -import ( - "fmt" - - "github.com/netbirdio/netbird/client/iface/bind" - "github.com/netbirdio/netbird/client/iface/device" - "github.com/netbirdio/netbird/client/iface/netstack" - "github.com/netbirdio/netbird/client/iface/wgaddr" - "github.com/netbirdio/netbird/client/iface/wgproxy" -) - -// NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { - wgAddress, err := wgaddr.ParseWGAddress(opts.Address) - if err != nil { - return nil, err - } - - wgIFace := &WGIface{} - - if netstack.IsEnabled() { - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) - wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) - wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) - return wgIFace, nil - } - - if device.ModuleTunIsLoaded() { - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) - wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) - wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) - return wgIFace, nil - } - - return nil, fmt.Errorf("couldn't check or load tun module") -} diff --git a/client/iface/iface_new_ios.go b/client/iface/iface_new_ios.go index 5d6a32e39..41e0022b2 100644 --- a/client/iface/iface_new_ios.go +++ b/client/iface/iface_new_ios.go @@ -5,21 +5,15 @@ package iface import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" - "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" ) // NewWGIFace Creates a new WireGuard interface instance func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { - wgAddress, err := wgaddr.ParseWGAddress(opts.Address) - if err != nil { - return nil, err - } - - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU) wgIFace := &WGIface{ - tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd), + tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd), userspaceBind: true, wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } diff --git a/client/iface/iface_new_js.go b/client/iface/iface_new_js.go index ad913ab04..9f7a3ba62 100644 --- a/client/iface/iface_new_js.go +++ b/client/iface/iface_new_js.go @@ -4,21 +4,15 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/netstack" - "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" ) // NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode) func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { - wgAddress, err := wgaddr.ParseWGAddress(opts.Address) - if err != nil { - return nil, err - } - relayBind := bind.NewRelayBindJS() wgIface := &WGIface{ - tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()), + tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()), userspaceBind: true, wgProxyFactory: wgproxy.NewUSPFactory(relayBind, opts.MTU), } diff --git a/client/iface/iface_new_linux.go b/client/iface/iface_new_linux.go index d84035403..65ce67e88 100644 --- a/client/iface/iface_new_linux.go +++ b/client/iface/iface_new_linux.go @@ -3,44 +3,40 @@ package iface import ( - "fmt" + "errors" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/netstack" - "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" ) // NewWGIFace Creates a new WireGuard interface instance func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { - wgAddress, err := wgaddr.ParseWGAddress(opts.Address) - if err != nil { - return nil, err - } - - wgIFace := &WGIface{} - if netstack.IsEnabled() { - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) - wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) - wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) - return wgIFace, nil + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU) + return &WGIface{ + tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()), + userspaceBind: true, + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), + }, nil } if device.WireGuardModuleIsLoaded() { - wgIFace.tun = device.NewKernelDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet) - wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort, opts.MTU) - return wgIFace, nil - } - if device.ModuleTunIsLoaded() { - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) - wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) - wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) - return wgIFace, nil + return &WGIface{ + tun: device.NewKernelDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet), + wgProxyFactory: wgproxy.NewKernelFactory(opts.WGPort, opts.MTU), + }, nil } - return nil, fmt.Errorf("couldn't check or load tun module") + if device.ModuleTunIsLoaded() { + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU) + return &WGIface{ + tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind), + userspaceBind: true, + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), + }, nil + } + + return nil, errors.New("tun module not available") } diff --git a/client/iface/iface_test.go b/client/iface/iface_test.go index 6bbfeaa63..dbeb69bc6 100644 --- a/client/iface/iface_test.go +++ b/client/iface/iface_test.go @@ -16,6 +16,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/stdnet" ) @@ -48,7 +49,7 @@ func TestWGIface_UpdateAddr(t *testing.T) { opts := WGIFaceOpts{ IFaceName: ifaceName, - Address: addr, + Address: wgaddr.MustParseWGAddress(addr), WGPort: wgPort, WGPrivKey: key, MTU: DefaultMTU, @@ -84,7 +85,7 @@ func TestWGIface_UpdateAddr(t *testing.T) { //update WireGuard address addr = "100.64.0.2/8" - err = iface.UpdateAddr(addr) + err = iface.UpdateAddr(wgaddr.MustParseWGAddress(addr)) if err != nil { t.Fatal(err) } @@ -130,7 +131,7 @@ func Test_CreateInterface(t *testing.T) { } opts := WGIFaceOpts{ IFaceName: ifaceName, - Address: wgIP, + Address: wgaddr.MustParseWGAddress(wgIP), WGPort: 33100, WGPrivKey: key, MTU: DefaultMTU, @@ -174,7 +175,7 @@ func Test_Close(t *testing.T) { opts := WGIFaceOpts{ IFaceName: ifaceName, - Address: wgIP, + Address: wgaddr.MustParseWGAddress(wgIP), WGPort: wgPort, WGPrivKey: key, MTU: DefaultMTU, @@ -219,7 +220,7 @@ func TestRecreation(t *testing.T) { opts := WGIFaceOpts{ IFaceName: ifaceName, - Address: wgIP, + Address: wgaddr.MustParseWGAddress(wgIP), WGPort: wgPort, WGPrivKey: key, MTU: DefaultMTU, @@ -291,7 +292,7 @@ func Test_ConfigureInterface(t *testing.T) { } opts := WGIFaceOpts{ IFaceName: ifaceName, - Address: wgIP, + Address: wgaddr.MustParseWGAddress(wgIP), WGPort: wgPort, WGPrivKey: key, MTU: DefaultMTU, @@ -347,7 +348,7 @@ func Test_UpdatePeer(t *testing.T) { opts := WGIFaceOpts{ IFaceName: ifaceName, - Address: wgIP, + Address: wgaddr.MustParseWGAddress(wgIP), WGPort: 33100, WGPrivKey: key, MTU: DefaultMTU, @@ -417,7 +418,7 @@ func Test_RemovePeer(t *testing.T) { opts := WGIFaceOpts{ IFaceName: ifaceName, - Address: wgIP, + Address: wgaddr.MustParseWGAddress(wgIP), WGPort: 33100, WGPrivKey: key, MTU: DefaultMTU, @@ -482,7 +483,7 @@ func Test_ConnectPeers(t *testing.T) { optsPeer1 := WGIFaceOpts{ IFaceName: peer1ifaceName, - Address: peer1wgIP.String(), + Address: wgaddr.MustParseWGAddress(peer1wgIP.String()), WGPort: peer1wgPort, WGPrivKey: peer1Key.String(), MTU: DefaultMTU, @@ -522,7 +523,7 @@ func Test_ConnectPeers(t *testing.T) { optsPeer2 := WGIFaceOpts{ IFaceName: peer2ifaceName, - Address: peer2wgIP.String(), + Address: wgaddr.MustParseWGAddress(peer2wgIP.String()), WGPort: peer2wgPort, WGPrivKey: peer2Key.String(), MTU: DefaultMTU, diff --git a/client/iface/netstack/tun.go b/client/iface/netstack/tun.go index 346ae29ec..8c7526bbb 100644 --- a/client/iface/netstack/tun.go +++ b/client/iface/netstack/tun.go @@ -13,7 +13,7 @@ import ( const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY" type NetStackTun struct { //nolint:revive - address netip.Addr + addresses []netip.Addr dnsAddress netip.Addr mtu int listenAddress string @@ -22,9 +22,9 @@ type NetStackTun struct { //nolint:revive tundev tun.Device } -func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun { +func NewNetStackTun(listenAddress string, addresses []netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun { return &NetStackTun{ - address: address, + addresses: addresses, dnsAddress: dnsAddress, mtu: mtu, listenAddress: listenAddress, @@ -33,7 +33,7 @@ func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.A func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) { nsTunDev, tunNet, err := netstack.CreateNetTUN( - []netip.Addr{t.address}, + t.addresses, []netip.Addr{t.dnsAddress}, t.mtu) if err != nil { diff --git a/client/iface/wgaddr/address.go b/client/iface/wgaddr/address.go index 078f8be95..43d1ec9aa 100644 --- a/client/iface/wgaddr/address.go +++ b/client/iface/wgaddr/address.go @@ -3,12 +3,18 @@ package wgaddr import ( "fmt" "net/netip" + + "github.com/netbirdio/netbird/shared/netiputil" ) // Address WireGuard parsed address type Address struct { IP netip.Addr Network netip.Prefix + + // IPv6 overlay address, if assigned. + IPv6 netip.Addr + IPv6Net netip.Prefix } // ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address @@ -23,6 +29,60 @@ func ParseWGAddress(address string) (Address, error) { }, nil } -func (addr Address) String() string { - return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits()) +// HasIPv6 reports whether a v6 overlay address is assigned. +func (addr Address) HasIPv6() bool { + return addr.IPv6.IsValid() +} + +func (addr Address) String() string { + return addr.Prefix().String() +} + +// IPv6String returns the v6 address in CIDR notation, or empty string if none. +func (addr Address) IPv6String() string { + if !addr.HasIPv6() { + return "" + } + return addr.IPv6Prefix().String() +} + +// Prefix returns the v4 host address with its network prefix length (e.g. 100.64.0.1/16). +func (addr Address) Prefix() netip.Prefix { + return netip.PrefixFrom(addr.IP, addr.Network.Bits()) +} + +// IPv6Prefix returns the v6 host address with its network prefix length, or a zero prefix if none. +func (addr Address) IPv6Prefix() netip.Prefix { + if !addr.HasIPv6() { + return netip.Prefix{} + } + return netip.PrefixFrom(addr.IPv6, addr.IPv6Net.Bits()) +} + +// SetIPv6FromCompact decodes a compact prefix (5 or 17 bytes) and sets the IPv6 fields. +// Returns an error if the bytes are invalid. A nil or empty input is a no-op. +// +//nolint:recvcheck +func (addr *Address) SetIPv6FromCompact(raw []byte) error { + if len(raw) == 0 { + return nil + } + prefix, err := netiputil.DecodePrefix(raw) + if err != nil { + return fmt.Errorf("decode v6 overlay address: %w", err) + } + if !prefix.Addr().Is6() { + return fmt.Errorf("expected IPv6 address, got %s", prefix.Addr()) + } + addr.IPv6 = prefix.Addr() + addr.IPv6Net = prefix.Masked() + return nil +} + +// ClearIPv6 removes the IPv6 overlay address, leaving only v4. +// +//nolint:recvcheck +func (addr *Address) ClearIPv6() { + addr.IPv6 = netip.Addr{} + addr.IPv6Net = netip.Prefix{} } diff --git a/client/iface/wgaddr/address_test_helpers.go b/client/iface/wgaddr/address_test_helpers.go new file mode 100644 index 000000000..87403e789 --- /dev/null +++ b/client/iface/wgaddr/address_test_helpers.go @@ -0,0 +1,10 @@ +package wgaddr + +// MustParseWGAddress parses and returns a WG Address, panicking on error. +func MustParseWGAddress(address string) Address { + a, err := ParseWGAddress(address) + if err != nil { + panic(err) + } + return a +} diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index 9ac3ea6df..be6f3806e 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -6,7 +6,7 @@ import ( "fmt" "net" "net/netip" - "strings" + "sync" log "github.com/sirupsen/logrus" @@ -196,18 +196,25 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { } } -// fakeAddress returns a fake address that is used to as an identifier for the peer. -// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address. +// fakeAddress returns a fake address that is used as an identifier for the peer. +// The fake address is in the format of 127.1.x.x where x.x is derived from the +// last two bytes of the peer address (works for both IPv4 and IPv6). func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) { - octets := strings.Split(peerAddress.IP.String(), ".") - if len(octets) != 4 { - return nil, fmt.Errorf("invalid IP format") + if peerAddress == nil { + return nil, fmt.Errorf("nil peer address") + } + if peerAddress.Port < 0 || peerAddress.Port > 65535 { + return nil, fmt.Errorf("invalid UDP port: %d", peerAddress.Port) } - fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])) - if err != nil { - return nil, fmt.Errorf("parse new IP: %w", err) + addr, ok := netip.AddrFromSlice(peerAddress.IP) + if !ok { + return nil, fmt.Errorf("invalid IP format") } + addr = addr.Unmap() + + raw := addr.As16() + fakeIP := netip.AddrFrom4([4]byte{127, 1, raw[14], raw[15]}) netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port)) return &netipAddr, nil diff --git a/client/installer.nsis b/client/installer.nsis index 96d60a785..63bff1c5b 100644 --- a/client/installer.nsis +++ b/client/installer.nsis @@ -201,7 +201,18 @@ Pop $0 Function .onInit StrCpy $INSTDIR "${INSTALL_DIR}" +; Default autostart to enabled so silent installs (/S) match the interactive default +StrCpy $AutostartEnabled "1" + +; Pre-0.70.1 installers ran without SetRegView, so their uninstall keys live +; in the 32-bit view. Fall back to it so upgrades still find them. +SetRegView 64 ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString" +${If} $R0 == "" + SetRegView 32 + ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString" + SetRegView 64 +${EndIf} ${If} $R0 != "" # if silent install jump to uninstall step IfSilent uninstall @@ -214,6 +225,10 @@ ${If} $R0 != "" ${EndIf} FunctionEnd + +Function un.onInit +SetRegView 64 +FunctionEnd ###################################################################### Section -MainProgram ${INSTALL_TYPE} @@ -228,6 +243,7 @@ Section -MainProgram !else File /r "..\\dist\\netbird_windows_amd64\\" !endif + File "..\\client\\ui\\assets\\netbird.png" SectionEnd ###################################################################### @@ -247,9 +263,11 @@ WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}" ; Create autostart registry entry based on checkbox DetailPrint "Autostart enabled: $AutostartEnabled" ${If} $AutostartEnabled == "1" - WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe" + WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"' DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe" ${Else} + DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" + ; Legacy: pre-HKLM installs wrote to HKCU; clean that up too. DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" DetailPrint "Autostart not enabled by user" ${EndIf} @@ -283,6 +301,8 @@ ExecWait `taskkill /im ${UI_APP_EXE}.exe /f` ; Remove autostart registry entry DetailPrint "Removing autostart registry entry if exists..." +DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" +; Legacy: pre-HKLM installs wrote to HKCU; clean that up too. DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" ; Handle data deletion based on checkbox @@ -321,6 +341,7 @@ DetailPrint "Removing registry keys..." DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}" DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}" DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}" +DeleteRegKey HKCU "Software\Classes\AppUserModelId\${APP_NAME}" DetailPrint "Removing application directory from PATH..." EnVar::SetHKLM diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index dd6f9479a..c54a3e897 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -5,7 +5,6 @@ import ( "encoding/hex" "errors" "fmt" - "net" "net/netip" "strconv" "sync" @@ -19,6 +18,7 @@ import ( "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/shared/management/domain" mgmProto "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/netiputil" ) var ErrSourceRangesEmpty = errors.New("sources range is empty") @@ -105,6 +105,10 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { newRulePairs := make(map[id.RuleID][]firewall.Rule) ipsetByRuleSelectors := make(map[string]string) + // TODO: deny rules should be fatal: if a deny rule fails to apply, we must + // roll back all allow rules to avoid a fail-open where allowed traffic bypasses + // the missing deny. Currently we accumulate errors and continue. + var merr *multierror.Error for _, r := range rules { // if this rule is member of rule selection with more than DefaultIPsCountForSet // it's IP address can be used in the ipset for firewall manager which supports it @@ -117,9 +121,8 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { } pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName) if err != nil { - log.Errorf("failed to apply firewall rule: %+v, %v", r, err) - d.rollBack(newRulePairs) - break + merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err)) + continue } if len(rulePair) > 0 { d.peerRulesPairs[pairID] = rulePair @@ -127,6 +130,10 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { } } + if merr != nil { + log.Errorf("failed to apply %d peer ACL rule(s): %v", merr.Len(), nberrors.FormatErrorOrNil(merr)) + } + for pairID, rules := range d.peerRulesPairs { if _, ok := newRulePairs[pairID]; !ok { for _, rule := range rules { @@ -216,9 +223,9 @@ func (d *DefaultManager) protoRuleToFirewallRule( r *mgmProto.FirewallRule, ipsetName string, ) (id.RuleID, []firewall.Rule, error) { - ip := net.ParseIP(r.PeerIP) - if ip == nil { - return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule") + ip, err := extractRuleIP(r) + if err != nil { + return "", nil, err } protocol, err := convertToFirewallProtocol(r.Protocol) @@ -289,13 +296,13 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool { func (d *DefaultManager) addInRules( id []byte, - ip net.IP, + ip netip.Addr, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, ipsetName string, ) ([]firewall.Rule, error) { - rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, nil, port, action, ipsetName) + rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, nil, port, action, ipsetName) if err != nil { return nil, fmt.Errorf("add firewall rule: %w", err) } @@ -305,7 +312,7 @@ func (d *DefaultManager) addInRules( func (d *DefaultManager) addOutRules( id []byte, - ip net.IP, + ip netip.Addr, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, @@ -315,7 +322,7 @@ func (d *DefaultManager) addOutRules( return nil, nil } - rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, port, nil, action, ipsetName) + rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, port, nil, action, ipsetName) if err != nil { return nil, fmt.Errorf("add firewall rule: %w", err) } @@ -323,9 +330,9 @@ func (d *DefaultManager) addOutRules( return rule, nil } -// getPeerRuleID() returns unique ID for the rule based on its parameters. +// getPeerRuleID returns unique ID for the rule based on its parameters. func (d *DefaultManager) getPeerRuleID( - ip net.IP, + ip netip.Addr, proto firewall.Protocol, direction int, port *firewall.Port, @@ -344,15 +351,25 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo) } -func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) { - log.Debugf("rollback ACL to previous state") - for _, rules := range newRulePairs { - for _, rule := range rules { - if err := d.firewall.DeletePeerRule(rule); err != nil { - log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.ID(), err) - } + +// extractRuleIP extracts the peer IP from a firewall rule. +// If sourcePrefixes is populated (new management), decode the first entry and use its address. +// Otherwise fall back to the deprecated PeerIP string field (old management). +func extractRuleIP(r *mgmProto.FirewallRule) (netip.Addr, error) { + if len(r.SourcePrefixes) > 0 { + addr, err := netiputil.DecodeAddr(r.SourcePrefixes[0]) + if err != nil { + return netip.Addr{}, fmt.Errorf("decode source prefix: %w", err) } + return addr.Unmap(), nil } + + //nolint:staticcheck // PeerIP used for backward compatibility with old management + addr, err := netip.ParseAddr(r.PeerIP) + if err != nil { + return netip.Addr{}, fmt.Errorf("invalid IP address, skipping firewall rule") + } + return addr.Unmap(), nil } func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) { diff --git a/client/internal/auth/auth.go b/client/internal/auth/auth.go index bdfd07430..afc8ee77f 100644 --- a/client/internal/auth/auth.go +++ b/client/internal/auth/auth.go @@ -321,6 +321,7 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) { a.config.DisableFirewall, a.config.BlockLANAccess, a.config.BlockInbound, + a.config.DisableIPv6, a.config.LazyConnectionEnabled, a.config.EnableSSHRoot, a.config.EnableSSHSFTP, diff --git a/client/internal/connect.go b/client/internal/connect.go index ac498f719..8c0e9b1ba 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -14,10 +14,13 @@ import ( "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/netstack" @@ -333,6 +336,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan c.statusRecorder.MarkSignalConnected() relayURLs, token := parseRelayInfo(loginResp) + if override, ok := peer.OverrideRelayURLs(); ok { + log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override) + relayURLs = override + } peerConfig := loginResp.GetPeerConfig() engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath) @@ -532,9 +539,20 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf if config.NetworkMonitor != nil { nm = *config.NetworkMonitor } + wgAddr, err := wgaddr.ParseWGAddress(peerConfig.Address) + if err != nil { + return nil, fmt.Errorf("parse overlay address %q: %w", peerConfig.Address, err) + } + + if !config.DisableIPv6 { + if err := wgAddr.SetIPv6FromCompact(peerConfig.GetAddressV6()); err != nil { + log.Warn(err) + } + } + engineConf := &EngineConfig{ WgIfaceName: config.WgIface, - WgAddr: peerConfig.Address, + WgAddr: wgAddr, IFaceBlackList: config.IFaceBlackList, DisableIPv6Discovery: config.DisableIPv6Discovery, WgPrivateKey: key, @@ -559,6 +577,7 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf DisableFirewall: config.DisableFirewall, BlockLANAccess: config.BlockLANAccess, BlockInbound: config.BlockInbound, + DisableIPv6: config.DisableIPv6, LazyConnectionEnabled: config.LazyConnectionEnabled, @@ -633,6 +652,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config.DisableFirewall, config.BlockLANAccess, config.BlockInbound, + config.DisableIPv6, config.LazyConnectionEnabled, config.EnableSSHRoot, config.EnableSSHSFTP, diff --git a/client/internal/connect_android_default.go b/client/internal/connect_android_default.go index 190341c4a..b05e91fec 100644 --- a/client/internal/connect_android_default.go +++ b/client/internal/connect_android_default.go @@ -40,6 +40,10 @@ func (noopNetworkChangeListener) SetInterfaceIP(string) { // network stack, not by OS-level interface configuration. } +func (noopNetworkChangeListener) SetInterfaceIPv6(string) { + // No-op: same as SetInterfaceIP, IPv6 overlay is managed by userspace stack. +} + // noopDnsReadyListener is a stub for embed.Client on Android. // DNS readiness notifications are not needed in netstack/embed mode // since system DNS is disabled and DNS resolution happens externally. diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index bddb9a69e..9c50f02b3 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -21,6 +21,7 @@ import ( "time" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/protobuf/encoding/protojson" "github.com/netbirdio/netbird/client/anonymize" @@ -30,6 +31,7 @@ import ( "github.com/netbirdio/netbird/client/internal/updater/installer" nbstatus "github.com/netbirdio/netbird/client/status" mgmProto "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/netiputil" ) const readmeContent = `Netbird debug bundle @@ -61,6 +63,7 @@ allocs.prof: Allocations profiling information. threadcreate.prof: Thread creation profiling information. cpu.prof: CPU profiling information. stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation. +capture.pcap: Packet capture in pcap format. Only present when capture was running during bundle collection. Omitted from anonymized bundles because it contains raw decrypted packet data. Anonymization Process @@ -234,6 +237,7 @@ type BundleGenerator struct { logPath string tempDir string cpuProfile []byte + capturePath string refreshStatus func() // Optional callback to refresh status before bundle generation clientMetrics MetricsExporter @@ -257,7 +261,8 @@ type GeneratorDependencies struct { LogPath string TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used. CPUProfile []byte - RefreshStatus func() // Optional callback to refresh status before bundle generation + CapturePath string + RefreshStatus func() ClientMetrics MetricsExporter } @@ -277,6 +282,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen logPath: deps.LogPath, tempDir: deps.TempDir, cpuProfile: deps.CPUProfile, + capturePath: deps.CapturePath, refreshStatus: deps.RefreshStatus, clientMetrics: deps.ClientMetrics, @@ -346,6 +352,10 @@ func (g *BundleGenerator) createArchive() error { log.Errorf("failed to add CPU profile to debug bundle: %v", err) } + if err := g.addCaptureFile(); err != nil { + log.Errorf("failed to add capture file to debug bundle: %v", err) + } + if err := g.addStackTrace(); err != nil { log.Errorf("failed to add stack trace to debug bundle: %v", err) } @@ -575,6 +585,9 @@ func isSensitiveEnvVar(key string) bool { func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) { configContent.WriteString("NetBird Client Configuration:\n\n") + if key, err := wgtypes.ParseKey(g.internalConfig.PrivateKey); err == nil { + configContent.WriteString(fmt.Sprintf("PublicKey: %s\n", key.PublicKey().String())) + } configContent.WriteString(fmt.Sprintf("WgIface: %s\n", g.internalConfig.WgIface)) configContent.WriteString(fmt.Sprintf("WgPort: %d\n", g.internalConfig.WgPort)) if g.internalConfig.NetworkMonitor != nil { @@ -599,6 +612,12 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) if g.internalConfig.EnableSSHRemotePortForwarding != nil { configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding)) } + if g.internalConfig.DisableSSHAuth != nil { + configContent.WriteString(fmt.Sprintf("DisableSSHAuth: %v\n", *g.internalConfig.DisableSSHAuth)) + } + if g.internalConfig.SSHJWTCacheTTL != nil { + configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL)) + } configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes)) configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes)) @@ -606,6 +625,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall)) configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess)) configContent.WriteString(fmt.Sprintf("BlockInbound: %v\n", g.internalConfig.BlockInbound)) + configContent.WriteString(fmt.Sprintf("DisableIPv6: %v\n", g.internalConfig.DisableIPv6)) if g.internalConfig.DisableNotifications != nil { configContent.WriteString(fmt.Sprintf("DisableNotifications: %v\n", *g.internalConfig.DisableNotifications)) @@ -625,6 +645,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) } configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled)) + configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU)) } func (g *BundleGenerator) addProf() (err error) { @@ -669,6 +690,29 @@ func (g *BundleGenerator) addCPUProfile() error { return nil } +func (g *BundleGenerator) addCaptureFile() error { + if g.capturePath == "" { + return nil + } + + if g.anonymize { + log.Info("skipping capture file in anonymized bundle (contains raw packet data)") + return nil + } + + f, err := os.Open(g.capturePath) + if err != nil { + return fmt.Errorf("open capture file: %w", err) + } + defer f.Close() + + if err := g.addFileToZip(f, "capture.pcap"); err != nil { + return fmt.Errorf("add capture file to zip: %w", err) + } + + return nil +} + func (g *BundleGenerator) addStackTrace() error { buf := make([]byte, 5242880) // 5 MB buffer n := runtime.Stack(buf, true) @@ -1252,6 +1296,21 @@ func anonymizePeerConfig(config *mgmProto.PeerConfig, anonymizer *anonymize.Anon config.Address = anonymizer.AnonymizeIP(addr).String() } + if len(config.GetAddressV6()) > 0 { + v6Prefix, err := netiputil.DecodePrefix(config.GetAddressV6()) + if err != nil { + config.AddressV6 = nil + } else { + anonV6 := anonymizer.AnonymizeIP(v6Prefix.Addr()) + b, err := netiputil.EncodePrefix(netip.PrefixFrom(anonV6, v6Prefix.Bits())) + if err != nil { + config.AddressV6 = nil + } else { + config.AddressV6 = b + } + } + } + anonymizeSSHConfig(config.SshConfig) config.Dns = anonymizer.AnonymizeString(config.Dns) @@ -1354,8 +1413,20 @@ func anonymizeFirewallRule(rule *mgmProto.FirewallRule, anonymizer *anonymize.An return } + //nolint:staticcheck // PeerIP used for backward compatibility if addr, err := netip.ParseAddr(rule.PeerIP); err == nil { - rule.PeerIP = anonymizer.AnonymizeIP(addr).String() + rule.PeerIP = anonymizer.AnonymizeIP(addr).String() //nolint:staticcheck + } + + for i, raw := range rule.GetSourcePrefixes() { + p, err := netiputil.DecodePrefix(raw) + if err != nil { + continue + } + anonAddr := anonymizer.AnonymizeIP(p.Addr()) + if b, err := netiputil.EncodePrefix(netip.PrefixFrom(anonAddr, p.Bits())); err == nil { + rule.SourcePrefixes[i] = b + } } } diff --git a/client/internal/debug/debug_test.go b/client/internal/debug/debug_test.go index 6b5bb911c..39b972244 100644 --- a/client/internal/debug/debug_test.go +++ b/client/internal/debug/debug_test.go @@ -5,19 +5,33 @@ import ( "bytes" "encoding/json" "net" + "net/netip" + "net/url" "os" "path/filepath" + "reflect" "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/configs" + "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/shared/management/domain" mgmProto "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/netiputil" ) +func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte { + t.Helper() + b, err := netiputil.EncodePrefix(p) + require.NoError(t, err) + return b +} + func TestAnonymizeStateFile(t *testing.T) { testState := map[string]json.RawMessage{ "null_state": json.RawMessage("null"), @@ -168,7 +182,7 @@ func TestAnonymizeStateFile(t *testing.T) { assert.Equal(t, "100.64.0.1", state["protected_ip"]) // Protected IP unchanged assert.Equal(t, "8.8.8.8", state["well_known_ip"]) // Well-known IP unchanged assert.NotEqual(t, "2001:db8::1", state["ipv6_addr"]) - assert.Equal(t, "fd00::1", state["private_ipv6"]) // Private IPv6 unchanged + assert.NotEqual(t, "fd00::1", state["private_ipv6"]) // ULA IPv6 anonymized (global ID is a fingerprint) assert.NotEqual(t, "test.example.com", state["domain"]) assert.True(t, strings.HasSuffix(state["domain"].(string), ".domain")) assert.Equal(t, "device.netbird.cloud", state["netbird_domain"]) // Netbird domain unchanged @@ -272,11 +286,13 @@ func mustMarshal(v any) json.RawMessage { } func TestAnonymizeNetworkMap(t *testing.T) { + origV6Prefix := netip.MustParsePrefix("2001:db8:abcd::5/64") networkMap := &mgmProto.NetworkMap{ PeerConfig: &mgmProto.PeerConfig{ - Address: "203.0.113.5", - Dns: "1.2.3.4", - Fqdn: "peer1.corp.example.com", + Address: "203.0.113.5", + AddressV6: mustEncodePrefix(t, origV6Prefix), + Dns: "1.2.3.4", + Fqdn: "peer1.corp.example.com", SshConfig: &mgmProto.SSHConfig{ SshPubKey: []byte("ssh-rsa AAAAB3NzaC1..."), }, @@ -350,6 +366,12 @@ func TestAnonymizeNetworkMap(t *testing.T) { require.NotEqual(t, "peer1.corp.example.com", peerCfg.Fqdn) require.True(t, strings.HasSuffix(peerCfg.Fqdn, ".domain")) + // Verify AddressV6 is anonymized but preserves prefix length + anonV6Prefix, err := netiputil.DecodePrefix(peerCfg.AddressV6) + require.NoError(t, err) + assert.Equal(t, origV6Prefix.Bits(), anonV6Prefix.Bits(), "prefix length must be preserved") + assert.NotEqual(t, origV6Prefix.Addr(), anonV6Prefix.Addr(), "IPv6 address must be anonymized") + // Verify SSH key is replaced require.Equal(t, []byte("ssh-placeholder-key"), peerCfg.SshConfig.SshPubKey) @@ -471,8 +493,8 @@ func TestSanitizeServiceEnvVars(t *testing.T) { anonymize: false, input: map[string]any{ jsonKeyServiceEnv: map[string]any{ - "HOME": "/root", - "PATH": "/usr/bin", + "HOME": "/root", + "PATH": "/usr/bin", "NB_LOG_LEVEL": "debug", }, }, @@ -489,9 +511,9 @@ func TestSanitizeServiceEnvVars(t *testing.T) { anonymize: false, input: map[string]any{ jsonKeyServiceEnv: map[string]any{ - "NB_SETUP_KEY": "abc123", - "NB_API_TOKEN": "tok_xyz", - "NB_LOG_LEVEL": "info", + "NB_SETUP_KEY": "abc123", + "NB_API_TOKEN": "tok_xyz", + "NB_LOG_LEVEL": "info", }, }, check: func(t *testing.T, params map[string]any) { @@ -655,8 +677,6 @@ func isInCGNATRange(ip net.IP) bool { } func TestAnonymizeFirewallRules(t *testing.T) { - // TODO: Add ipv6 - // Example iptables-save output iptablesSave := `# Generated by iptables-save v1.8.7 on Thu Dec 19 10:00:00 2024 *filter @@ -692,17 +712,31 @@ Chain FORWARD (policy ACCEPT 0 packets, 0 bytes) Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes) pkts bytes target prot opt in out source destination` - // Example nftables output + // Example ip6tables-save output + ip6tablesSave := `# Generated by ip6tables-save v1.8.7 on Thu Dec 19 10:00:00 2024 +*filter +:INPUT ACCEPT [0:0] +:FORWARD ACCEPT [0:0] +:OUTPUT ACCEPT [0:0] +-A INPUT -s fd00:1234::1/128 -j ACCEPT +-A INPUT -s 2607:f8b0:4005::1/128 -j DROP +-A FORWARD -s 2001:db8::/32 -d 2607:f8b0:4005::200e/128 -j ACCEPT +COMMIT` + + // Example nftables output with IPv6 nftablesRules := `table inet filter { chain input { type filter hook input priority filter; policy accept; ip saddr 192.168.1.1 accept ip saddr 44.192.140.1 drop + ip6 saddr 2607:f8b0:4005::1 drop + ip6 saddr fd00:1234::1 accept } chain forward { type filter hook forward priority filter; policy accept; ip saddr 10.0.0.0/8 drop ip saddr 44.192.140.0/24 ip daddr 52.84.12.34/24 accept + ip6 saddr 2001:db8::/32 ip6 daddr 2607:f8b0:4005::200e/128 accept } }` @@ -765,4 +799,159 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes) assert.Contains(t, anonNftables, "table inet filter {") assert.Contains(t, anonNftables, "chain input {") assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;") + + // IPv6 public addresses in nftables should be anonymized + assert.NotContains(t, anonNftables, "2607:f8b0:4005::1") + assert.NotContains(t, anonNftables, "2607:f8b0:4005::200e") + assert.NotContains(t, anonNftables, "2001:db8::") + assert.Contains(t, anonNftables, "2001:db8:ffff::") // Default anonymous v6 range + + // ULA addresses in nftables should be anonymized (global ID is a fingerprint) + assert.NotContains(t, anonNftables, "fd00:1234::1") + + // IPv6 nftables structure preserved + assert.Contains(t, anonNftables, "ip6 saddr") + assert.Contains(t, anonNftables, "ip6 daddr") + + // Test ip6tables-save anonymization + anonIp6tablesSave := anonymizer.AnonymizeString(ip6tablesSave) + + // ULA IPv6 should be anonymized (global ID is a fingerprint) + assert.NotContains(t, anonIp6tablesSave, "fd00:1234::1/128") + + // Public IPv6 addresses should be anonymized + assert.NotContains(t, anonIp6tablesSave, "2607:f8b0:4005::1") + assert.NotContains(t, anonIp6tablesSave, "2607:f8b0:4005::200e") + assert.NotContains(t, anonIp6tablesSave, "2001:db8::") + assert.Contains(t, anonIp6tablesSave, "2001:db8:ffff::") // Default anonymous v6 range + + // Structure should be preserved + assert.Contains(t, anonIp6tablesSave, "*filter") + assert.Contains(t, anonIp6tablesSave, "COMMIT") + assert.Contains(t, anonIp6tablesSave, "-j DROP") + assert.Contains(t, anonIp6tablesSave, "-j ACCEPT") +} + +// TestAddConfig_AllFieldsCovered uses reflection to ensure every field in +// profilemanager.Config is either rendered in the debug bundle or explicitly +// excluded. When a new field is added to Config, this test fails until the +// developer either dumps it in addConfig/addCommonConfigFields or adds it to +// the excluded set with a justification. +func TestAddConfig_AllFieldsCovered(t *testing.T) { + excluded := map[string]string{ + "PrivateKey": "sensitive: WireGuard private key", + "PreSharedKey": "sensitive: WireGuard pre-shared key", + "SSHKey": "sensitive: SSH private key", + "ClientCertKeyPair": "non-config: parsed cert pair, not serialized", + } + + mURL, _ := url.Parse("https://api.example.com:443") + aURL, _ := url.Parse("https://admin.example.com:443") + bTrue := true + iVal := 42 + cfg := &profilemanager.Config{ + PrivateKey: "priv", + PreSharedKey: "psk", + ManagementURL: mURL, + AdminURL: aURL, + WgIface: "wt0", + WgPort: 51820, + NetworkMonitor: &bTrue, + IFaceBlackList: []string{"eth0"}, + DisableIPv6Discovery: true, + RosenpassEnabled: true, + RosenpassPermissive: true, + ServerSSHAllowed: &bTrue, + EnableSSHRoot: &bTrue, + EnableSSHSFTP: &bTrue, + EnableSSHLocalPortForwarding: &bTrue, + EnableSSHRemotePortForwarding: &bTrue, + DisableSSHAuth: &bTrue, + SSHJWTCacheTTL: &iVal, + DisableClientRoutes: true, + DisableServerRoutes: true, + DisableDNS: true, + DisableFirewall: true, + BlockLANAccess: true, + BlockInbound: true, + DisableNotifications: &bTrue, + DNSLabels: domain.List{}, + SSHKey: "sshkey", + NATExternalIPs: []string{"1.2.3.4"}, + CustomDNSAddress: "1.1.1.1:53", + DisableAutoConnect: true, + DNSRouteInterval: 5 * time.Second, + ClientCertPath: "/tmp/cert", + ClientCertKeyPath: "/tmp/key", + LazyConnectionEnabled: true, + MTU: 1280, + } + + for _, anonymize := range []bool{false, true} { + t.Run("anonymize="+map[bool]string{true: "true", false: "false"}[anonymize], func(t *testing.T) { + g := &BundleGenerator{ + anonymizer: newAnonymizerForTest(), + internalConfig: cfg, + anonymize: anonymize, + } + + var sb strings.Builder + g.addCommonConfigFields(&sb) + rendered := sb.String() + renderAddConfigSpecific(g) + + val := reflect.ValueOf(cfg).Elem() + typ := val.Type() + var missing []string + for i := 0; i < typ.NumField(); i++ { + name := typ.Field(i).Name + if _, ok := excluded[name]; ok { + continue + } + if !strings.Contains(rendered, name+":") { + missing = append(missing, name) + } + } + if len(missing) > 0 { + t.Fatalf("Config field(s) not present in debug bundle output: %v\n"+ + "Either render the field in addCommonConfigFields/addConfig, "+ + "or add it to the excluded map with a justification.", missing) + } + }) + } +} + +// renderAddConfigSpecific renders the fields handled by the anonymize/non-anonymize +// branches in addConfig (ManagementURL, AdminURL, NATExternalIPs, CustomDNSAddress). +// addCommonConfigFields covers the rest. Keeping this in the test mirrors the +// production shape without needing to write an actual zip. +func renderAddConfigSpecific(g *BundleGenerator) string { + var sb strings.Builder + if g.anonymize { + if g.internalConfig.ManagementURL != nil { + sb.WriteString("ManagementURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.ManagementURL.String()) + "\n") + } + if g.internalConfig.AdminURL != nil { + sb.WriteString("AdminURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.AdminURL.String()) + "\n") + } + sb.WriteString("NATExternalIPs: x\n") + if g.internalConfig.CustomDNSAddress != "" { + sb.WriteString("CustomDNSAddress: " + g.anonymizer.AnonymizeString(g.internalConfig.CustomDNSAddress) + "\n") + } + } else { + if g.internalConfig.ManagementURL != nil { + sb.WriteString("ManagementURL: " + g.internalConfig.ManagementURL.String() + "\n") + } + if g.internalConfig.AdminURL != nil { + sb.WriteString("AdminURL: " + g.internalConfig.AdminURL.String() + "\n") + } + sb.WriteString("NATExternalIPs: x\n") + if g.internalConfig.CustomDNSAddress != "" { + sb.WriteString("CustomDNSAddress: " + g.internalConfig.CustomDNSAddress + "\n") + } + } + return sb.String() +} + +func newAnonymizerForTest() *anonymize.Anonymizer { + return anonymize.NewAnonymizer(anonymize.DefaultAddresses()) } diff --git a/client/internal/debug/upload_test.go b/client/internal/debug/upload_test.go index e833c196d..f224b8d3f 100644 --- a/client/internal/debug/upload_test.go +++ b/client/internal/debug/upload_test.go @@ -3,10 +3,12 @@ package debug import ( "context" "errors" + "net" "net/http" "os" "path/filepath" "testing" + "time" "github.com/stretchr/testify/require" @@ -19,8 +21,10 @@ func TestUpload(t *testing.T) { t.Skip("Skipping upload test on docker ci") } testDir := t.TempDir() - testURL := "http://localhost:8080" + addr := reserveLoopbackPort(t) + testURL := "http://" + addr t.Setenv("SERVER_URL", testURL) + t.Setenv("SERVER_ADDRESS", addr) t.Setenv("STORE_DIR", testDir) srv := server.NewServer() go func() { @@ -33,6 +37,7 @@ func TestUpload(t *testing.T) { t.Errorf("Failed to stop server: %v", err) } }) + waitForServer(t, addr) file := filepath.Join(t.TempDir(), "tmpfile") fileContent := []byte("test file content") @@ -47,3 +52,30 @@ func TestUpload(t *testing.T) { require.NoError(t, err) require.Equal(t, fileContent, createdFileContent) } + +// reserveLoopbackPort binds an ephemeral port on loopback to learn a free +// address, then releases it so the server under test can rebind. The close/ +// rebind window is racy in theory; on loopback with a kernel-assigned port +// it's essentially never contended in practice. +func reserveLoopbackPort(t *testing.T) string { + t.Helper() + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + addr := l.Addr().String() + require.NoError(t, l.Close()) + return addr +} + +func waitForServer(t *testing.T, addr string) { + t.Helper() + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond) + if err == nil { + _ = c.Close() + return + } + time.Sleep(20 * time.Millisecond) + } + t.Fatalf("server did not start listening on %s in time", addr) +} diff --git a/client/internal/dns.go b/client/internal/dns.go index f5040ee49..a6604810f 100644 --- a/client/internal/dns.go +++ b/client/internal/dns.go @@ -12,52 +12,83 @@ import ( nbdns "github.com/netbirdio/netbird/dns" ) -func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) { - ip, err := netip.ParseAddr(aRecord.RData) +func createPTRRecord(record nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) { + ip, err := netip.ParseAddr(record.RData) if err != nil { - log.Warnf("failed to parse IP address %s: %v", aRecord.RData, err) + log.Warnf("failed to parse IP address %s: %v", record.RData, err) return nbdns.SimpleRecord{}, false } + ip = ip.Unmap() if !prefix.Contains(ip) { return nbdns.SimpleRecord{}, false } - ipOctets := strings.Split(ip.String(), ".") - slices.Reverse(ipOctets) - rdnsName := dns.Fqdn(strings.Join(ipOctets, ".") + ".in-addr.arpa") + var rdnsName string + if ip.Is4() { + octets := strings.Split(ip.String(), ".") + slices.Reverse(octets) + rdnsName = dns.Fqdn(strings.Join(octets, ".") + ".in-addr.arpa") + } else { + // Expand to full 32 nibbles in reverse order (LSB first) per RFC 3596. + raw := ip.As16() + nibbles := make([]string, 32) + for i := 0; i < 16; i++ { + nibbles[31-i*2] = fmt.Sprintf("%x", raw[i]>>4) + nibbles[31-i*2-1] = fmt.Sprintf("%x", raw[i]&0x0f) + } + rdnsName = dns.Fqdn(strings.Join(nibbles, ".") + ".ip6.arpa") + } return nbdns.SimpleRecord{ Name: rdnsName, Type: int(dns.TypePTR), - Class: aRecord.Class, - TTL: aRecord.TTL, - RData: dns.Fqdn(aRecord.Name), + Class: record.Class, + TTL: record.TTL, + RData: dns.Fqdn(record.Name), }, true } -// generateReverseZoneName creates the reverse DNS zone name for a given network +// generateReverseZoneName creates the reverse DNS zone name for a given network. +// For IPv4 it produces an in-addr.arpa name, for IPv6 an ip6.arpa name. func generateReverseZoneName(network netip.Prefix) (string, error) { - networkIP := network.Masked().Addr() + networkIP := network.Masked().Addr().Unmap() + bits := network.Bits() - if !networkIP.Is4() { - return "", fmt.Errorf("reverse DNS is only supported for IPv4 networks, got: %s", networkIP) + if networkIP.Is4() { + // Round up to nearest byte. + octetsToUse := (bits + 7) / 8 + + octets := strings.Split(networkIP.String(), ".") + if octetsToUse > len(octets) { + return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", bits) + } + + reverseOctets := make([]string, octetsToUse) + for i := 0; i < octetsToUse; i++ { + reverseOctets[octetsToUse-1-i] = octets[i] + } + + return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil } - // round up to nearest byte - octetsToUse := (network.Bits() + 7) / 8 + // IPv6: round up to nearest nibble (4-bit boundary). + nibblesToUse := (bits + 3) / 4 - octets := strings.Split(networkIP.String(), ".") - if octetsToUse > len(octets) { - return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", network.Bits()) + raw := networkIP.As16() + allNibbles := make([]string, 32) + for i := 0; i < 16; i++ { + allNibbles[i*2] = fmt.Sprintf("%x", raw[i]>>4) + allNibbles[i*2+1] = fmt.Sprintf("%x", raw[i]&0x0f) } - reverseOctets := make([]string, octetsToUse) - for i := 0; i < octetsToUse; i++ { - reverseOctets[octetsToUse-1-i] = octets[i] + // Take the first nibblesToUse nibbles (network portion), reverse them. + used := make([]string, nibblesToUse) + for i := 0; i < nibblesToUse; i++ { + used[nibblesToUse-1-i] = allNibbles[i] } - return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil + return dns.Fqdn(strings.Join(used, ".") + ".ip6.arpa"), nil } // zoneExists checks if a zone with the given name already exists in the configuration @@ -71,7 +102,7 @@ func zoneExists(config *nbdns.Config, zoneName string) bool { return false } -// collectPTRRecords gathers all PTR records for the given network from A records +// collectPTRRecords gathers all PTR records for the given network from A and AAAA records. func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.SimpleRecord { var records []nbdns.SimpleRecord @@ -80,7 +111,7 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple continue } for _, record := range zone.Records { - if record.Type != int(dns.TypeA) { + if record.Type != int(dns.TypeA) && record.Type != int(dns.TypeAAAA) { continue } diff --git a/client/internal/dns/file_parser_unix.go b/client/internal/dns/file_parser_unix.go index 8dacb4e51..50ba74c0c 100644 --- a/client/internal/dns/file_parser_unix.go +++ b/client/internal/dns/file_parser_unix.go @@ -13,6 +13,7 @@ import ( const ( defaultResolvConfPath = "/etc/resolv.conf" + nsswitchConfPath = "/etc/nsswitch.conf" ) type resolvConf struct { diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 6fbdedc59..57e7722d4 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -1,7 +1,10 @@ package dns import ( + "context" "fmt" + "math" + "net" "slices" "strconv" "strings" @@ -192,6 +195,12 @@ func (c *HandlerChain) logHandlers() { } func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + c.dispatch(w, r, math.MaxInt) +} + +// dispatch routes a DNS request through the chain, skipping handlers with +// priority > maxPriority. Shared by ServeDNS and ResolveInternal. +func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) { if len(r.Question) == 0 { return } @@ -216,6 +225,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // Try handlers in priority order for _, entry := range handlers { + if entry.Priority > maxPriority { + continue + } if !c.isHandlerMatch(qname, entry) { continue } @@ -273,6 +285,55 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q cw.response.Len(), meta, time.Since(startTime)) } +// ResolveInternal runs an in-process DNS query against the chain, skipping any +// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt +// cache refresher) that must bypass themselves to avoid loops. Honors ctx +// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own +// (bounded by the invoked handler's internal timeout). +func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) { + if len(r.Question) == 0 { + return nil, fmt.Errorf("empty question") + } + + base := &internalResponseWriter{} + done := make(chan struct{}) + go func() { + c.dispatch(base, r, maxPriority) + close(done) + }() + + select { + case <-done: + case <-ctx.Done(): + // Prefer a completed response if dispatch finished concurrently with cancellation. + select { + case <-done: + default: + return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err()) + } + } + + if base.response == nil || base.response.Rcode == dns.RcodeRefused { + return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d", + strings.ToLower(r.Question[0].Name), maxPriority) + } + return base.response, nil +} + +// HasRootHandlerAtOrBelow reports whether any "." handler is registered at +// priority ≤ maxPriority. +func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool { + c.mu.RLock() + defer c.mu.RUnlock() + + for _, h := range c.handlers { + if h.Pattern == "." && h.Priority <= maxPriority { + return true + } + } + return false +} + func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool { switch { case entry.Pattern == ".": @@ -291,3 +352,36 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool { } } } + +// internalResponseWriter captures a dns.Msg for in-process chain queries. +type internalResponseWriter struct { + response *dns.Msg +} + +func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil } +func (w *internalResponseWriter) LocalAddr() net.Addr { return nil } +func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil } + +// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg +// still surface their answer to ResolveInternal. +func (w *internalResponseWriter) Write(p []byte) (int, error) { + msg := new(dns.Msg) + if err := msg.Unpack(p); err != nil { + return 0, err + } + w.response = msg + return len(p), nil +} + +func (w *internalResponseWriter) Close() error { return nil } +func (w *internalResponseWriter) TsigStatus() error { return nil } + +// TsigTimersOnly is part of dns.ResponseWriter. +func (w *internalResponseWriter) TsigTimersOnly(bool) { + // no-op: in-process queries carry no TSIG state. +} + +// Hijack is part of dns.ResponseWriter. +func (w *internalResponseWriter) Hijack() { + // no-op: in-process queries have no underlying connection to hand off. +} diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index fa9525069..034a760dc 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -1,11 +1,15 @@ package dns_test import ( + "context" + "net" "testing" + "time" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns/test" @@ -1042,3 +1046,163 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) { }) } } + +// answeringHandler writes a fixed A record to ack the query. Used to verify +// which handler ResolveInternal dispatches to. +type answeringHandler struct { + name string + ip string +} + +func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + resp := &dns.Msg{} + resp.SetReply(r) + resp.Answer = []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP(h.ip).To4(), + }} + _ = w.WriteMsg(resp) +} + +func (h *answeringHandler) String() string { return h.name } + +func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) { + chain := nbdns.NewHandlerChain() + + high := &answeringHandler{name: "high", ip: "10.0.0.1"} + low := &answeringHandler{name: "low", ip: "10.0.0.2"} + + chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache) + chain.AddHandler("example.com.", low, nbdns.PriorityUpstream) + + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, 1, len(resp.Answer)) + a, ok := resp.Answer[0].(*dns.A) + assert.True(t, ok) + assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream") +} + +func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) { + chain := nbdns.NewHandlerChain() + high := &answeringHandler{name: "high", ip: "10.0.0.1"} + chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache) + + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + _, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream) + assert.Error(t, err, "no handler at or below maxPriority should error") +} + +// rawWriteHandler packs a response and calls ResponseWriter.Write directly +// (instead of WriteMsg), exercising the internalResponseWriter.Write path. +type rawWriteHandler struct { + ip string +} + +func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + resp := &dns.Msg{} + resp.SetReply(r) + resp.Answer = []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP(h.ip).To4(), + }} + packed, err := resp.Pack() + if err != nil { + return + } + _, _ = w.Write(packed) +} + +func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) { + chain := nbdns.NewHandlerChain() + chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream) + + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream) + assert.NoError(t, err) + require.NotNil(t, resp) + require.Len(t, resp.Answer, 1) + a, ok := resp.Answer[0].(*dns.A) + require.True(t, ok) + assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer") +} + +func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) { + chain := nbdns.NewHandlerChain() + _, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream) + assert.Error(t, err) +} + +// hangingHandler blocks indefinitely until closed, simulating a wedged upstream. +type hangingHandler struct { + block chan struct{} +} + +func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + <-h.block + resp := &dns.Msg{} + resp.SetReply(r) + _ = w.WriteMsg(resp) +} + +func (h *hangingHandler) String() string { return "hangingHandler" } + +func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) { + chain := nbdns.NewHandlerChain() + h := &hangingHandler{block: make(chan struct{})} + defer close(h.block) + + chain.AddHandler("example.com.", h, nbdns.PriorityUpstream) + + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + start := time.Now() + _, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream) + elapsed := time.Since(start) + + assert.Error(t, err) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline") +} + +func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) { + chain := nbdns.NewHandlerChain() + h := &answeringHandler{name: "h", ip: "10.0.0.1"} + + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain") + + chain.AddHandler("example.com.", h, nbdns.PriorityUpstream) + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count") + + chain.AddHandler(".", h, nbdns.PriorityMgmtCache) + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded") + + chain.AddHandler(".", h, nbdns.PriorityDefault) + assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included") + + chain.RemoveHandler(".", nbdns.PriorityDefault) + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream)) + + // Primary nsgroup case: root handler lands at PriorityUpstream. + chain.AddHandler(".", h, nbdns.PriorityUpstream) + assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included") + chain.RemoveHandler(".", nbdns.PriorityUpstream) + + // Fallback case: original /etc/resolv.conf entries land at PriorityFallback. + chain.AddHandler(".", h, nbdns.PriorityFallback) + assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included") + chain.RemoveHandler(".", nbdns.PriorityFallback) + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream)) +} diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index b3908f163..0f4eb6bf8 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -298,6 +298,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() { ip = ip.Unmap() serverAddresses = append(serverAddresses, ip) + // Prefer the first IPv4 server as ServerIP since our DNS listener is IPv4. if !dnsSettings.ServerIP.IsValid() && ip.Is4() { dnsSettings.ServerIP = ip } diff --git a/client/internal/dns/host_unix.go b/client/internal/dns/host_unix.go index 422fed4e5..d7301d725 100644 --- a/client/internal/dns/host_unix.go +++ b/client/internal/dns/host_unix.go @@ -46,12 +46,12 @@ type restoreHostManager interface { } func newHostManager(wgInterface string) (hostManager, error) { - osManager, err := getOSDNSManagerType() + osManager, reason, err := getOSDNSManagerType() if err != nil { return nil, fmt.Errorf("get os dns manager type: %w", err) } - log.Infof("System DNS manager discovered: %s", osManager) + log.Infof("System DNS manager discovered: %s (%s)", osManager, reason) mgr, err := newHostManagerFromType(wgInterface, osManager) // need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value if err != nil { @@ -74,17 +74,49 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor } } -func getOSDNSManagerType() (osManagerType, error) { +func getOSDNSManagerType() (osManagerType, string, error) { + resolved := isSystemdResolvedRunning() + nss := isLibnssResolveUsed() + stub := checkStub() + + // Prefer systemd-resolved whenever it owns libc resolution, regardless of + // who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups + // that go through nss-resolve, and in foreign mode they can loop back + // through resolved as an upstream. + if resolved && (nss || stub) { + return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil + } + + mgr, reason, rejected, err := scanResolvConfHeader() + if err != nil { + return 0, "", err + } + if reason != "" { + return mgr, reason, nil + } + + fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub) + if len(rejected) > 0 { + fallback += "; rejected: " + strings.Join(rejected, ", ") + } + return fileManager, fallback, nil +} + +// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the +// matching manager. If reason is empty the caller should pick file mode and +// use rejected for diagnostics. +func scanResolvConfHeader() (osManagerType, string, []string, error) { file, err := os.Open(defaultResolvConfPath) if err != nil { - return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err) + return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err) } defer func() { - if err := file.Close(); err != nil { - log.Errorf("close file %s: %s", defaultResolvConfPath, err) + if cerr := file.Close(); cerr != nil { + log.Errorf("close file %s: %s", defaultResolvConfPath, cerr) } }() + var rejected []string scanner := bufio.NewScanner(file) for scanner.Scan() { text := scanner.Text() @@ -92,41 +124,48 @@ func getOSDNSManagerType() (osManagerType, error) { continue } if text[0] != '#' { - return fileManager, nil + break } - if strings.Contains(text, fileGeneratedResolvConfContentHeader) { - return netbirdManager, nil - } - if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() { - return networkManager, nil - } - if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() { - if checkStub() { - return systemdManager, nil - } else { - return fileManager, nil - } - } - if strings.Contains(text, "resolvconf") { - if isSystemdResolveConfMode() { - return systemdManager, nil - } - - return resolvConfManager, nil + if mgr, reason, rej := matchResolvConfHeader(text); reason != "" { + return mgr, reason, nil, nil + } else if rej != "" { + rejected = append(rejected, rej) } } if err := scanner.Err(); err != nil && err != io.EOF { - return 0, fmt.Errorf("scan: %w", err) + return 0, "", nil, fmt.Errorf("scan: %w", err) } - - return fileManager, nil + return 0, "", rejected, nil } -// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager. +// matchResolvConfHeader inspects a single comment line. Returns either a +// definitive (manager, reason) or a non-empty rejected diagnostic. +func matchResolvConfHeader(text string) (osManagerType, string, string) { + if strings.Contains(text, fileGeneratedResolvConfContentHeader) { + return netbirdManager, "netbird-managed resolv.conf header detected", "" + } + if strings.Contains(text, "NetworkManager") { + if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() { + return networkManager, "NetworkManager header + supported version on dbus", "" + } + return 0, "", "NetworkManager header (no dbus or unsupported version)" + } + if strings.Contains(text, "resolvconf") { + if isSystemdResolveConfMode() { + return systemdManager, "resolvconf header in systemd-resolved compatibility mode", "" + } + return resolvConfManager, "resolvconf header detected", "" + } + return 0, "", "" +} + +// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed +// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping +// into file mode while resolved is active. func checkStub() bool { rConf, err := parseDefaultResolvConf() if err != nil { - log.Warnf("failed to parse resolv conf: %s", err) + log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err) return true } @@ -139,3 +178,36 @@ func checkStub() bool { return false } + +// isLibnssResolveUsed reports whether nss-resolve is listed before dns on +// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are +// delegated to systemd-resolved regardless of /etc/resolv.conf. +func isLibnssResolveUsed() bool { + bs, err := os.ReadFile(nsswitchConfPath) + if err != nil { + log.Debugf("read %s: %v", nsswitchConfPath, err) + return false + } + return parseNsswitchResolveAhead(bs) +} + +func parseNsswitchResolveAhead(data []byte) bool { + for _, line := range strings.Split(string(data), "\n") { + if i := strings.IndexByte(line, '#'); i >= 0 { + line = line[:i] + } + fields := strings.Fields(line) + if len(fields) < 2 || fields[0] != "hosts:" { + continue + } + for _, module := range fields[1:] { + switch module { + case "dns": + return false + case "resolve": + return true + } + } + } + return false +} diff --git a/client/internal/dns/host_unix_test.go b/client/internal/dns/host_unix_test.go new file mode 100644 index 000000000..e936281d3 --- /dev/null +++ b/client/internal/dns/host_unix_test.go @@ -0,0 +1,76 @@ +//go:build (linux && !android) || freebsd + +package dns + +import "testing" + +func TestParseNsswitchResolveAhead(t *testing.T) { + tests := []struct { + name string + in string + want bool + }{ + { + name: "resolve before dns with action token", + in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n", + want: true, + }, + { + name: "dns before resolve", + in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n", + want: false, + }, + { + name: "debian default with only dns", + in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n", + want: false, + }, + { + name: "neither resolve nor dns", + in: "hosts: files myhostname\n", + want: false, + }, + { + name: "no hosts line", + in: "passwd: files systemd\ngroup: files systemd\n", + want: false, + }, + { + name: "empty", + in: "", + want: false, + }, + { + name: "comments and blank lines ignored", + in: "# comment\n\n# another\nhosts: resolve dns\n", + want: true, + }, + { + name: "trailing inline comment", + in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n", + want: true, + }, + { + name: "hosts token must be the first field", + in: " hosts: resolve dns\n", + want: true, + }, + { + name: "other db line mentioning resolve is ignored", + in: "networks: resolve\nhosts: dns\n", + want: false, + }, + { + name: "only resolve, no dns", + in: "hosts: files resolve\n", + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want { + t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go index a67a23945..e9d310f00 100644 --- a/client/internal/dns/local/local.go +++ b/client/internal/dns/local/local.go @@ -13,7 +13,6 @@ import ( "github.com/miekg/dns" log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/client/internal/dns/types" @@ -67,9 +66,9 @@ func (d *Resolver) Stop() { d.mu.Lock() defer d.mu.Unlock() - maps.Clear(d.records) - maps.Clear(d.domains) - maps.Clear(d.zones) + clear(d.records) + clear(d.domains) + clear(d.zones) } // ID returns the unique handler ID @@ -444,9 +443,9 @@ func (d *Resolver) Update(customZones []nbdns.CustomZone) { d.mu.Lock() defer d.mu.Unlock() - maps.Clear(d.records) - maps.Clear(d.domains) - maps.Clear(d.zones) + clear(d.records) + clear(d.domains) + clear(d.zones) for _, zone := range customZones { zoneDomain := domain.Domain(strings.ToLower(dns.Fqdn(zone.Domain))) diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go index 314af51d9..988e427fb 100644 --- a/client/internal/dns/mgmt/mgmt.go +++ b/client/internal/dns/mgmt/mgmt.go @@ -2,40 +2,83 @@ package mgmt import ( "context" + "errors" "fmt" "net" - "net/netip" "net/url" + "os" + "slices" "strings" "sync" + "sync/atomic" "time" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "golang.org/x/sync/singleflight" dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" + "github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/shared/management/domain" ) -const dnsTimeout = 5 * time.Second +const ( + dnsTimeout = 5 * time.Second + defaultTTL = 300 * time.Second + refreshBackoff = 30 * time.Second -// Resolver caches critical NetBird infrastructure domains + // envMgmtCacheTTL overrides defaultTTL for integration/dev testing. + envMgmtCacheTTL = "NB_MGMT_CACHE_TTL" +) + +// ChainResolver lets the cache refresh stale entries through the DNS handler +// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the +// system resolver. +type ChainResolver interface { + ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) + HasRootHandlerAtOrBelow(maxPriority int) bool +} + +// cachedRecord holds DNS records plus timestamps used for TTL refresh. +// records and cachedAt are set at construction and treated as immutable; +// lastFailedRefresh and consecFailures are mutable and must be accessed under +// Resolver.mutex. +type cachedRecord struct { + records []dns.RR + cachedAt time.Time + lastFailedRefresh time.Time + consecFailures int +} + +// Resolver caches critical NetBird infrastructure domains. +// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex. type Resolver struct { - records map[dns.Question][]dns.RR + records map[dns.Question]*cachedRecord mgmtDomain *domain.Domain serverDomains *dnsconfig.ServerDomains mutex sync.RWMutex -} -type ipsResponse struct { - ips []netip.Addr - err error + chain ChainResolver + chainMaxPriority int + refreshGroup singleflight.Group + + // refreshing tracks questions whose refresh is running via the OS + // fallback path. A ServeDNS hit for a question in this map indicates + // the OS resolver routed the recursive query back to us (loop). Only + // the OS path arms this so chain-path refreshes don't produce false + // positives. The atomic bool is CAS-flipped once per refresh to + // throttle the warning log. + refreshing map[dns.Question]*atomic.Bool + + cacheTTL time.Duration } // NewResolver creates a new management domains cache resolver. func NewResolver() *Resolver { return &Resolver{ - records: make(map[dns.Question][]dns.RR), + records: make(map[dns.Question]*cachedRecord), + refreshing: make(map[dns.Question]*atomic.Bool), + cacheTTL: resolveCacheTTL(), } } @@ -44,7 +87,19 @@ func (m *Resolver) String() string { return "MgmtCacheResolver" } -// ServeDNS implements dns.Handler interface. +// SetChainResolver wires the handler chain used to refresh stale cache entries. +// maxPriority caps which handlers may answer refresh queries (typically +// PriorityUpstream, so upstream/default/fallback handlers are consulted and +// mgmt/route/local handlers are skipped). +func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) { + m.mutex.Lock() + m.chain = chain + m.chainMaxPriority = maxPriority + m.mutex.Unlock() +} + +// ServeDNS serves cached A/AAAA records. Stale entries are returned +// immediately and refreshed asynchronously (stale-while-revalidate). func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) == 0 { m.continueToNext(w, r) @@ -60,7 +115,14 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } m.mutex.RLock() - records, found := m.records[question] + cached, found := m.records[question] + inflight := m.refreshing[question] + var shouldRefresh bool + if found { + stale := time.Since(cached.cachedAt) > m.cacheTTL + inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff + shouldRefresh = stale && !inBackoff + } m.mutex.RUnlock() if !found { @@ -68,12 +130,23 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } + if inflight != nil && inflight.CompareAndSwap(false, true) { + log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)", + question.Name) + } + + // Skip scheduling a refresh goroutine if one is already inflight for + // this question; singleflight would dedup anyway but skipping avoids + // a parked goroutine per stale hit under bursty load. + if shouldRefresh && inflight == nil { + m.scheduleRefresh(question, cached) + } + resp := &dns.Msg{} resp.SetReply(r) resp.Authoritative = false resp.RecursionAvailable = true - - resp.Answer = append(resp.Answer, records...) + resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt)) log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name) @@ -98,101 +171,260 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) { } } -// AddDomain manually adds a domain to cache by resolving it. +// AddDomain resolves a domain and stores its A/AAAA records in the cache. +// A family that resolves NODATA (nil err, zero records) evicts any stale +// entry for that qtype. func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString())) ctx, cancel := context.WithTimeout(ctx, dnsTimeout) defer cancel() - ips, err := lookupIPWithExtraTimeout(ctx, d) - if err != nil { - return err + aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName) + + if errA != nil && errAAAA != nil { + return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA)) } - var aRecords, aaaaRecords []dns.RR - for _, ip := range ips { - if ip.Is4() { - rr := &dns.A{ - Hdr: dns.RR_Header{ - Name: dnsName, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 300, - }, - A: ip.AsSlice(), - } - aRecords = append(aRecords, rr) - } else if ip.Is6() { - rr := &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: dnsName, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: 300, - }, - AAAA: ip.AsSlice(), - } - aaaaRecords = append(aaaaRecords, rr) + if len(aRecords) == 0 && len(aaaaRecords) == 0 { + if err := errors.Join(errA, errAAAA); err != nil { + return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err) } + return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString()) } + now := time.Now() m.mutex.Lock() + defer m.mutex.Unlock() - if len(aRecords) > 0 { - aQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - } - m.records[aQuestion] = aRecords - } + m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now) + m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now) - if len(aaaaRecords) > 0 { - aaaaQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeAAAA, - Qclass: dns.ClassINET, - } - m.records[aaaaQuestion] = aaaaRecords - } - - m.mutex.Unlock() - - log.Debugf("added domain=%s with %d A records and %d AAAA records", + log.Debugf("added/updated domain=%s with %d A records and %d AAAA records", d.SafeString(), len(aRecords), len(aaaaRecords)) return nil } -func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) { - log.Infof("looking up IP for mgmt domain=%s", d.SafeString()) - defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString()) - resultChan := make(chan *ipsResponse, 1) +// applyFamilyRecords writes records, evicts on NODATA, leaves the cache +// untouched on error. Caller holds m.mutex. +func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) { + q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET} + switch { + case len(records) > 0: + m.records[q] = &cachedRecord{records: records, cachedAt: now} + case err == nil: + delete(m.records, q) + } +} - go func() { - ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString()) - resultChan <- &ipsResponse{ - err: err, - ips: ips, +// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per +// unique in-flight key; bursty stale hits share its channel. expected is the +// cachedRecord pointer observed by the caller; the refresh only mutates the +// cache if that pointer is still the one stored, so a stale in-flight refresh +// can't clobber a newer entry written by AddDomain or a competing refresh. +func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) { + key := question.Name + "|" + dns.TypeToString[question.Qtype] + _ = m.refreshGroup.DoChan(key, func() (any, error) { + return nil, m.refreshQuestion(question, expected) + }) +} + +// refreshQuestion replaces the cached records on success, or marks the entry +// failed (arming the backoff) on failure. While this runs, ServeDNS can detect +// a resolver loop by spotting a query for this same question arriving on us. +// expected pins the cache entry observed at schedule time; mutations only apply +// if m.records[question] still points at it. +func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error { + ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) + defer cancel() + + d, err := domain.FromString(strings.TrimSuffix(question.Name, ".")) + if err != nil { + m.markRefreshFailed(question, expected) + return fmt.Errorf("parse domain: %w", err) + } + + records, err := m.lookupRecords(ctx, d, question) + if err != nil { + fails := m.markRefreshFailed(question, expected) + logf := log.Warnf + if fails == 0 || fails > 1 { + logf = log.Debugf } - }() - - var resp *ipsResponse - - select { - case <-time.After(dnsTimeout + time.Millisecond*500): - log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString()) - return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString()) - case <-ctx.Done(): - return nil, ctx.Err() - case resp = <-resultChan: + logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)", + d.SafeString(), dns.TypeToString[question.Qtype], err, fails) + return err } - if resp.err != nil { - return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err) + // NOERROR/NODATA: family gone upstream, evict so we stop serving stale. + if len(records) == 0 { + m.mutex.Lock() + if m.records[question] == expected { + delete(m.records, question) + m.mutex.Unlock() + log.Infof("removed mgmt cache domain=%s type=%s: no records returned", + d.SafeString(), dns.TypeToString[question.Qtype]) + return nil + } + m.mutex.Unlock() + log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh", + d.SafeString(), dns.TypeToString[question.Qtype]) + return nil } - return resp.ips, nil + + now := time.Now() + m.mutex.Lock() + if m.records[question] != expected { + m.mutex.Unlock() + log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh", + d.SafeString(), dns.TypeToString[question.Qtype]) + return nil + } + m.records[question] = &cachedRecord{records: records, cachedAt: now} + m.mutex.Unlock() + + log.Infof("refreshed mgmt cache domain=%s type=%s", + d.SafeString(), dns.TypeToString[question.Qtype]) + return nil +} + +func (m *Resolver) markRefreshing(question dns.Question) { + m.mutex.Lock() + m.refreshing[question] = &atomic.Bool{} + m.mutex.Unlock() +} + +func (m *Resolver) clearRefreshing(question dns.Question) { + m.mutex.Lock() + delete(m.refreshing, question) + m.mutex.Unlock() +} + +// markRefreshFailed arms the backoff and returns the new consecutive-failure +// count so callers can downgrade subsequent failure logs to debug. +func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int { + m.mutex.Lock() + defer m.mutex.Unlock() + c, ok := m.records[question] + if !ok || c != expected { + return 0 + } + c.lastFailedRefresh = time.Now() + c.consecFailures++ + return c.consecFailures +} + +// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let +// callers tell records, NODATA (nil err, no records), and failure apart. +func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) { + m.mutex.RLock() + chain := m.chain + maxPriority := m.chainMaxPriority + m.mutex.RUnlock() + + if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) { + aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA) + aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA) + return + } + + // TODO: drop once every supported OS registers a fallback resolver. Safe + // today: no root handler at priority ≤ PriorityUpstream means NetBird is + // not the system resolver, so net.DefaultResolver will not loop back. + aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA) + aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA) + return +} + +// lookupRecords resolves a single record type via chain or OS. The OS branch +// arms the loop detector for the duration of its call so that ServeDNS can +// spot the OS resolver routing the recursive query back to us. +func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) { + m.mutex.RLock() + chain := m.chain + maxPriority := m.chainMaxPriority + m.mutex.RUnlock() + + if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) { + return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype) + } + + // TODO: drop once every supported OS registers a fallback resolver. + m.markRefreshing(q) + defer m.clearRefreshing(q) + + return m.osLookup(ctx, d, q.Name, q.Qtype) +} + +// lookupViaChain resolves via the handler chain and rewrites each RR to use +// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache +// target-owned records or upstream TTLs. NODATA returns (nil, nil). +func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) { + msg := &dns.Msg{} + msg.SetQuestion(dnsName, qtype) + msg.RecursionDesired = true + + resp, err := chain.ResolveInternal(ctx, msg, maxPriority) + if err != nil { + return nil, fmt.Errorf("chain resolve: %w", err) + } + if resp == nil { + return nil, fmt.Errorf("chain resolve returned nil response") + } + if resp.Rcode != dns.RcodeSuccess { + return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode]) + } + + ttl := uint32(m.cacheTTL.Seconds()) + owners := cnameOwners(dnsName, resp.Answer) + var filtered []dns.RR + for _, rr := range resp.Answer { + h := rr.Header() + if h.Class != dns.ClassINET || h.Rrtype != qtype { + continue + } + if !owners[strings.ToLower(dns.Fqdn(h.Name))] { + continue + } + if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil { + filtered = append(filtered, cp) + } + } + return filtered, nil +} + +// osLookup resolves a single family via net.DefaultResolver using resutil, +// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA +// returns (nil, nil). +func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) { + network := resutil.NetworkForQtype(qtype) + if network == "" { + return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype]) + } + + log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype]) + defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype]) + + result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype) + if result.Rcode == dns.RcodeSuccess { + return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil + } + + if result.Err != nil { + return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err) + } + return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode]) +} + +// responseTTL returns the remaining cache lifetime in seconds (rounded up), +// so downstream resolvers don't cache an answer for longer than we will. +func (m *Resolver) responseTTL(cachedAt time.Time) uint32 { + remaining := m.cacheTTL - time.Since(cachedAt) + if remaining <= 0 { + return 0 + } + return uint32((remaining + time.Second - 1) / time.Second) } // PopulateFromConfig extracts and caches domains from the client configuration. @@ -224,19 +456,12 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error { m.mutex.Lock() defer m.mutex.Unlock() - aQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - } - delete(m.records, aQuestion) - - aaaaQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeAAAA, - Qclass: dns.ClassINET, - } - delete(m.records, aaaaQuestion) + qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET} + qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET} + delete(m.records, qA) + delete(m.records, qAAAA) + delete(m.refreshing, qA) + delete(m.refreshing, qAAAA) log.Debugf("removed domain=%s from cache", d.SafeString()) return nil @@ -394,3 +619,73 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve return domains } + +// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non +// A/AAAA records return nil. +func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR { + switch r := rr.(type) { + case *dns.A: + cp := *r + cp.Hdr.Name = owner + cp.Hdr.Ttl = ttl + cp.A = slices.Clone(r.A) + return &cp + case *dns.AAAA: + cp := *r + cp.Hdr.Name = owner + cp.Hdr.Ttl = ttl + cp.AAAA = slices.Clone(r.AAAA) + return &cp + } + return nil +} + +// cloneRecordsWithTTL clones A/AAAA records preserving their owner and +// stamping ttl so the response shares no memory with the cached slice. +func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR { + out := make([]dns.RR, 0, len(records)) + for _, rr := range records { + if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil { + out = append(out, cp) + } + } + return out +} + +// cnameOwners returns dnsName plus every target reachable by following CNAMEs +// in answer, iterating until fixed point so out-of-order chains resolve. +func cnameOwners(dnsName string, answer []dns.RR) map[string]bool { + owners := map[string]bool{dnsName: true} + for { + added := false + for _, rr := range answer { + cname, ok := rr.(*dns.CNAME) + if !ok { + continue + } + name := strings.ToLower(dns.Fqdn(cname.Hdr.Name)) + if !owners[name] { + continue + } + target := strings.ToLower(dns.Fqdn(cname.Target)) + if !owners[target] { + owners[target] = true + added = true + } + } + if !added { + return owners + } + } +} + +// resolveCacheTTL reads the cache TTL override env var; invalid or empty +// values fall back to defaultTTL. Called once per Resolver from NewResolver. +func resolveCacheTTL() time.Duration { + if v := os.Getenv(envMgmtCacheTTL); v != "" { + if d, err := time.ParseDuration(v); err == nil && d > 0 { + return d + } + } + return defaultTTL +} diff --git a/client/internal/dns/mgmt/mgmt_refresh_test.go b/client/internal/dns/mgmt/mgmt_refresh_test.go new file mode 100644 index 000000000..9faa5a0b8 --- /dev/null +++ b/client/internal/dns/mgmt/mgmt_refresh_test.go @@ -0,0 +1,408 @@ +package mgmt + +import ( + "context" + "errors" + "net" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/dns/test" + "github.com/netbirdio/netbird/shared/management/domain" +) + +type fakeChain struct { + mu sync.Mutex + calls map[string]int + answers map[string][]dns.RR + err error + hasRoot bool + onLookup func() +} + +func newFakeChain() *fakeChain { + return &fakeChain{ + calls: map[string]int{}, + answers: map[string][]dns.RR{}, + hasRoot: true, + } +} + +func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool { + f.mu.Lock() + defer f.mu.Unlock() + return f.hasRoot +} + +func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) { + f.mu.Lock() + q := msg.Question[0] + key := q.Name + "|" + dns.TypeToString[q.Qtype] + f.calls[key]++ + answers := f.answers[key] + err := f.err + onLookup := f.onLookup + f.mu.Unlock() + + if onLookup != nil { + onLookup() + } + if err != nil { + return nil, err + } + resp := &dns.Msg{} + resp.SetReply(msg) + resp.Answer = answers + return resp, nil +} + +func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) { + f.mu.Lock() + defer f.mu.Unlock() + key := name + "|" + dns.TypeToString[qtype] + hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60} + switch qtype { + case dns.TypeA: + f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}} + case dns.TypeAAAA: + f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}} + } +} + +func (f *fakeChain) callCount(name string, qtype uint16) int { + f.mu.Lock() + defer f.mu.Unlock() + return f.calls[name+"|"+dns.TypeToString[qtype]] +} + +// waitFor polls the predicate until it returns true or the deadline passes. +func waitFor(t *testing.T, d time.Duration, fn func() bool) { + t.Helper() + deadline := time.Now().Add(d) + for time.Now().Before(deadline) { + if fn() { + return + } + time.Sleep(5 * time.Millisecond) + } + t.Fatalf("condition not met within %s", d) +} + +func queryA(t *testing.T, r *Resolver, name string) *dns.Msg { + t.Helper() + msg := new(dns.Msg) + msg.SetQuestion(name, dns.TypeA) + w := &test.MockResponseWriter{} + r.ServeDNS(w, msg) + return w.GetLastResponse() +} + +func firstA(t *testing.T, resp *dns.Msg) string { + t.Helper() + require.NotNil(t, resp) + require.Greater(t, len(resp.Answer), 0, "expected at least one answer") + a, ok := resp.Answer[0].(*dns.A) + require.True(t, ok, "expected A record") + return a.A.String() +} + +func TestResolver_CacheTTLGatesRefresh(t *testing.T) { + // Same cached entry age, different cacheTTL values: the shorter TTL must + // trigger a background refresh, the longer one must not. Proves that the + // per-Resolver cacheTTL field actually drives the stale decision. + cachedAt := time.Now().Add(-100 * time.Millisecond) + + newRec := func() *cachedRecord { + return &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: cachedAt, + } + } + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + + t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) { + r := NewResolver() + r.cacheTTL = 10 * time.Millisecond + chain := newFakeChain() + chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + r.records[q] = newRec() + + resp := queryA(t, r, q.Name) + assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs") + + waitFor(t, time.Second, func() bool { + return chain.callCount(q.Name, dns.TypeA) >= 1 + }) + }) + + t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) { + r := NewResolver() + r.cacheTTL = time.Hour + chain := newFakeChain() + chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + r.records[q] = newRec() + + resp := queryA(t, r, q.Name) + assert.Equal(t, "10.0.0.1", firstA(t, resp)) + + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh") + }) +} + +func TestResolver_ServeFresh_NoRefresh(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + + r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now(), // fresh + } + + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.1", firstA(t, resp)) + + time.Sleep(20 * time.Millisecond) + assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh") +} + +func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now().Add(-2 * defaultTTL), // stale + } + + // First query: serves stale immediately. + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs") + + waitFor(t, time.Second, func() bool { + return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1 + }) + + // Next query should now return the refreshed IP. + waitFor(t, time.Second, func() bool { + resp := queryA(t, r, "mgmt.example.com.") + return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2" + }) +} + +func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + + var inflight atomic.Int32 + var maxInflight atomic.Int32 + chain.onLookup = func() { + cur := inflight.Add(1) + defer inflight.Add(-1) + for { + prev := maxInflight.Load() + if cur <= prev || maxInflight.CompareAndSwap(prev, cur) { + break + } + } + time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide + } + + r.SetChainResolver(chain, 50) + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now().Add(-2 * defaultTTL), + } + + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + queryA(t, r, "mgmt.example.com.") + }() + } + wg.Wait() + + waitFor(t, 2*time.Second, func() bool { + return inflight.Load() == 0 + }) + + calls := chain.callCount("mgmt.example.com.", dns.TypeA) + assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls) + assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently") +} + +func TestResolver_RefreshFailureArmsBackoff(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.err = errors.New("boom") + r.SetChainResolver(chain, 50) + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now().Add(-2 * defaultTTL), + } + + // First stale hit triggers a refresh attempt that fails. + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails") + + waitFor(t, time.Second, func() bool { + return chain.callCount("mgmt.example.com.", dns.TypeA) == 1 + }) + waitFor(t, time.Second, func() bool { + r.mutex.RLock() + defer r.mutex.RUnlock() + c, ok := r.records[q] + return ok && !c.lastFailedRefresh.IsZero() + }) + + // Subsequent stale hits within backoff window should not schedule more refreshes. + for i := 0; i < 10; i++ { + queryA(t, r, "mgmt.example.com.") + } + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes") +} + +func TestResolver_NoRootHandler_SkipsChain(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.hasRoot = false + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + + // With hasRoot=false the chain must not be consulted. Use a short + // deadline so the OS fallback returns quickly without waiting on a + // real network call in CI. + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + _, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.") + + assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), + "chain must not be used when no root handler is registered at the bound priority") +} + +func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) { + // ServeDNS being invoked for a question while a refresh for that question + // is inflight indicates a resolver loop (OS resolver sent the recursive + // query back to us). The inflightRefresh.loopLoggedOnce flag must be set. + r := NewResolver() + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now(), + } + + // Simulate an inflight refresh. + r.markRefreshing(q) + defer r.clearRefreshing(q) + + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries") + + r.mutex.RLock() + inflight := r.refreshing[q] + r.mutex.RUnlock() + require.NotNil(t, inflight) + assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed") +} + +func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) { + r := NewResolver() + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now(), + } + + r.markRefreshing(q) + defer r.clearRefreshing(q) + + // Multiple ServeDNS calls during the same refresh must not re-set the flag + // (CompareAndSwap from false -> true returns true only on the first call). + for range 5 { + queryA(t, r, "mgmt.example.com.") + } + + r.mutex.RLock() + inflight := r.refreshing[q] + r.mutex.RUnlock() + assert.True(t, inflight.Load()) +} + +func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) { + r := NewResolver() + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now(), + } + + queryA(t, r, "mgmt.example.com.") + + r.mutex.RLock() + _, ok := r.refreshing[q] + r.mutex.RUnlock() + assert.False(t, ok, "no refresh inflight means no loop tracking") +} + +func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2") + r.SetChainResolver(chain, 50) + + require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com"))) + + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.2", firstA(t, resp)) + assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA)) + assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA)) +} diff --git a/client/internal/dns/mgmt/mgmt_test.go b/client/internal/dns/mgmt/mgmt_test.go index 9e8a746f3..276cbba0a 100644 --- a/client/internal/dns/mgmt/mgmt_test.go +++ b/client/internal/dns/mgmt/mgmt_test.go @@ -6,6 +6,7 @@ import ( "net/url" "strings" "testing" + "time" "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -23,6 +24,60 @@ func TestResolver_NewResolver(t *testing.T) { assert.False(t, resolver.MatchSubdomains()) } +func TestResolveCacheTTL(t *testing.T) { + tests := []struct { + name string + value string + want time.Duration + }{ + {"unset falls back to default", "", defaultTTL}, + {"valid duration", "45s", 45 * time.Second}, + {"valid minutes", "2m", 2 * time.Minute}, + {"malformed falls back to default", "not-a-duration", defaultTTL}, + {"zero falls back to default", "0s", defaultTTL}, + {"negative falls back to default", "-5s", defaultTTL}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Setenv(envMgmtCacheTTL, tc.value) + got := resolveCacheTTL() + assert.Equal(t, tc.want, got, "parsed TTL should match") + }) + } +} + +func TestNewResolver_CacheTTLFromEnv(t *testing.T) { + t.Setenv(envMgmtCacheTTL, "7s") + r := NewResolver() + assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env") +} + +func TestResolver_ResponseTTL(t *testing.T) { + now := time.Now() + tests := []struct { + name string + cacheTTL time.Duration + cachedAt time.Time + wantMin uint32 + wantMax uint32 + }{ + {"fresh entry returns full TTL", 60 * time.Second, now, 59, 60}, + {"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31}, + {"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0}, + {"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + r := &Resolver{cacheTTL: tc.cacheTTL} + got := r.responseTTL(tc.cachedAt) + assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin") + assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax") + }) + } +} + func TestResolver_ExtractDomainFromURL(t *testing.T) { tests := []struct { name string diff --git a/client/internal/dns/network_manager_unix.go b/client/internal/dns/network_manager_unix.go index e4ccc8cbd..66d82dcd7 100644 --- a/client/internal/dns/network_manager_unix.go +++ b/client/internal/dns/network_manager_unix.go @@ -110,8 +110,25 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st connSettings.cleanDeprecatedSettings() - convDNSIP := binary.LittleEndian.Uint32(config.ServerIP.AsSlice()) - connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP}) + ipKey := networkManagerDbusIPv4Key + staleKey := networkManagerDbusIPv6Key + if config.ServerIP.Is6() { + ipKey = networkManagerDbusIPv6Key + staleKey = networkManagerDbusIPv4Key + raw := config.ServerIP.As16() + connSettings[ipKey][networkManagerDbusDNSKey] = dbus.MakeVariant([][]byte{raw[:]}) + } else { + convDNSIP := binary.LittleEndian.Uint32(config.ServerIP.AsSlice()) + connSettings[ipKey][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP}) + } + + // Clear stale DNS settings from the opposite address family to avoid + // leftover entries if the server IP family changed. + if staleSettings, ok := connSettings[staleKey]; ok { + delete(staleSettings, networkManagerDbusDNSKey) + delete(staleSettings, networkManagerDbusDNSPriorityKey) + delete(staleSettings, networkManagerDbusDNSSearchKey) + } var ( searchDomains []string matchDomains []string @@ -146,8 +163,8 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st n.routingAll = false } - connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority) - connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList) + connSettings[ipKey][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority) + connSettings[ipKey][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList) state := &ShutdownState{ ManagerType: networkManager, diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index f7865047b..6fe2e21b6 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -212,6 +212,7 @@ func newDefaultServer( ctx, stop := context.WithCancel(ctx) mgmtCacheResolver := mgmt.NewResolver() + mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream) defaultServer := &DefaultServer{ ctx: ctx, @@ -409,7 +410,7 @@ func (s *DefaultServer) Stop() { log.Errorf("failed to disable DNS: %v", err) } - maps.Clear(s.extraDomains) + clear(s.extraDomains) } func (s *DefaultServer) disableDNS() (retErr error) { diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index f77f6e898..1026a29fc 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -347,7 +347,7 @@ func TestUpdateDNSServer(t *testing.T) { opts := iface.WGIFaceOpts{ IFaceName: fmt.Sprintf("utun230%d", n), - Address: fmt.Sprintf("100.66.100.%d/32", n+1), + Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)), WGPort: 33100, WGPrivKey: privKey.String(), MTU: iface.DefaultMTU, @@ -448,7 +448,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { privKey, _ := wgtypes.GeneratePrivateKey() opts := iface.WGIFaceOpts{ IFaceName: "utun2301", - Address: "100.66.100.1/32", + Address: wgaddr.MustParseWGAddress("100.66.100.1/32"), WGPort: 33100, WGPrivKey: privKey.String(), MTU: iface.DefaultMTU, @@ -929,7 +929,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { opts := iface.WGIFaceOpts{ IFaceName: "utun2301", - Address: "100.66.100.2/24", + Address: wgaddr.MustParseWGAddress("100.66.100.2/24"), WGPort: 33100, WGPrivKey: privKey.String(), MTU: iface.DefaultMTU, diff --git a/client/internal/dns/service.go b/client/internal/dns/service.go index 1c6ce7849..04bcd5985 100644 --- a/client/internal/dns/service.go +++ b/client/internal/dns/service.go @@ -16,8 +16,8 @@ const ( // This is used when the DNS server cannot bind port 53 directly // and needs firewall rules to redirect traffic. type Firewall interface { - AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error - RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error + AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error + RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error } type service interface { diff --git a/client/internal/dns/service_listener.go b/client/internal/dns/service_listener.go index 4e09f1b7f..9c0e52af8 100644 --- a/client/internal/dns/service_listener.go +++ b/client/internal/dns/service_listener.go @@ -188,11 +188,10 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr { return s.listenIP } - -// evalListenAddress figure out the listen address for the DNS server -// first check the 53 port availability on WG interface or lo, if not success -// pick a random port on WG interface for eBPF, if not success -// check the 5053 port availability on WG interface or lo without eBPF usage, +// evalListenAddress figures out the listen address for the DNS server. +// IPv4-only: all peers have a v4 overlay address, and DNS config points to v4. +// First checks port 53 on WG interface or lo, then tries eBPF on a random port, +// then falls back to port 5053. func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) { if s.customAddr != nil { return s.customAddr.Addr(), s.customAddr.Port(), nil @@ -278,7 +277,7 @@ func (s *serviceViaListener) tryToUseeBPF() (ebpfMgr.Manager, uint16, bool) { } ebpfSrv := ebpf.GetEbpfManagerInstance() - err = ebpfSrv.LoadDNSFwd(s.wgInterface.Address().IP.String(), int(port)) + err = ebpfSrv.LoadDNSFwd(s.wgInterface.Address().IP, int(port)) if err != nil { log.Warnf("failed to load DNS forwarder eBPF program, error: %s", err) return nil, 0, false diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index d9854c033..573dff540 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -90,8 +90,12 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool { } func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { + family := int32(unix.AF_INET) + if config.ServerIP.Is6() { + family = unix.AF_INET6 + } defaultLinkInput := systemdDbusDNSInput{ - Family: unix.AF_INET, + Family: family, Address: config.ServerIP.AsSlice(), } if err := s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}); err != nil { diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 746b73ca7..a26536f6e 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -21,6 +21,7 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/peer" @@ -29,6 +30,12 @@ import ( var currentMTU uint16 = iface.DefaultMTU +// privateClientIface is the subset of the WireGuard interface needed by GetClientPrivate. +type privateClientIface interface { + Name() string + Address() wgaddr.Address +} + func SetCurrentMTU(mtu uint16) { currentMTU = mtu } diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index ee1ca42fe..988adb7d2 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -86,7 +86,7 @@ func (u *upstreamResolver) isLocalResolver(upstream string) bool { return false } -func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { +func GetClientPrivate(_ privateClientIface, _ netip.Addr, dialTimeout time.Duration) (*dns.Client, error) { return &dns.Client{ Timeout: dialTimeout, Net: "udp", diff --git a/client/internal/dns/upstream_general.go b/client/internal/dns/upstream_general.go index 1143b6c51..910c3779e 100644 --- a/client/internal/dns/upstream_general.go +++ b/client/internal/dns/upstream_general.go @@ -52,7 +52,7 @@ func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns return ExchangeWithFallback(ctx, client, r, upstream) } -func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { +func GetClientPrivate(_ privateClientIface, _ netip.Addr, dialTimeout time.Duration) (*dns.Client, error) { return &dns.Client{ Timeout: dialTimeout, Net: "udp", diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 02c11173b..0e04742a0 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -19,9 +19,7 @@ import ( type upstreamResolverIOS struct { *upstreamResolverBase - lIP netip.Addr - lNet netip.Prefix - interfaceName string + wgIface WGIface } func newUpstreamResolver( @@ -35,9 +33,7 @@ func newUpstreamResolver( ios := &upstreamResolverIOS{ upstreamResolverBase: upstreamResolverBase, - lIP: wgIface.Address().IP, - lNet: wgIface.Address().Network, - interfaceName: wgIface.Name(), + wgIface: wgIface, } ios.upstreamClient = ios @@ -65,11 +61,13 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * } else { upstreamIP = upstreamIP.Unmap() } - needsPrivate := u.lNet.Contains(upstreamIP) || + addr := u.wgIface.Address() + needsPrivate := addr.Network.Contains(upstreamIP) || + addr.IPv6Net.Contains(upstreamIP) || (u.routeMatch != nil && u.routeMatch(upstreamIP)) if needsPrivate { log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream) - client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout) + client, err = GetClientPrivate(u.wgIface, upstreamIP, timeout) if err != nil { return nil, 0, fmt.Errorf("create private client: %s", err) } @@ -79,25 +77,33 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * return ExchangeWithFallback(nil, client, r, upstream) } -// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface -// This method is needed for iOS -func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { - index, err := getInterfaceIndex(interfaceName) +// GetClientPrivate returns a new DNS client bound to the local IP of the Netbird interface. +// It selects the v6 bind address when the upstream is IPv6 and the interface has one, otherwise v4. +func GetClientPrivate(iface privateClientIface, upstreamIP netip.Addr, dialTimeout time.Duration) (*dns.Client, error) { + index, err := getInterfaceIndex(iface.Name()) if err != nil { - log.Debugf("unable to get interface index for %s: %s", interfaceName, err) + log.Debugf("unable to get interface index for %s: %s", iface.Name(), err) return nil, err } + addr := iface.Address() + bindIP := addr.IP + if upstreamIP.Is6() && addr.HasIPv6() { + bindIP = addr.IPv6 + } + + proto, opt := unix.IPPROTO_IP, unix.IP_BOUND_IF + if bindIP.Is6() { + proto, opt = unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF + } + dialer := &net.Dialer{ - LocalAddr: &net.UDPAddr{ - IP: ip.AsSlice(), - Port: 0, // Let the OS pick a free port - }, - Timeout: dialTimeout, + LocalAddr: net.UDPAddrFromAddrPort(netip.AddrPortFrom(bindIP, 0)), + Timeout: dialTimeout, Control: func(network, address string, c syscall.RawConn) error { var operr error fn := func(s uintptr) { - operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, index) + operr = unix.SetsockoptInt(int(s), proto, opt, index) } if err := c.Control(fn); err != nil { diff --git a/client/internal/dns_test.go b/client/internal/dns_test.go new file mode 100644 index 000000000..e15cc8fb7 --- /dev/null +++ b/client/internal/dns_test.go @@ -0,0 +1,138 @@ +package internal + +import ( + "net/netip" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbdns "github.com/netbirdio/netbird/dns" +) + +func TestCreatePTRRecord_IPv4(t *testing.T) { + record := nbdns.SimpleRecord{ + Name: "peer1.netbird.cloud.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "100.64.0.5", + } + prefix := netip.MustParsePrefix("100.64.0.0/16") + + ptr, ok := createPTRRecord(record, prefix) + require.True(t, ok) + assert.Equal(t, "5.0.64.100.in-addr.arpa.", ptr.Name) + assert.Equal(t, int(dns.TypePTR), ptr.Type) + assert.Equal(t, "peer1.netbird.cloud.", ptr.RData) +} + +func TestCreatePTRRecord_IPv6(t *testing.T) { + record := nbdns.SimpleRecord{ + Name: "peer1.netbird.cloud.", + Type: int(dns.TypeAAAA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "fd00:1234:5678::1", + } + prefix := netip.MustParsePrefix("fd00:1234:5678::/48") + + ptr, ok := createPTRRecord(record, prefix) + require.True(t, ok) + assert.Equal(t, "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.7.6.5.4.3.2.1.0.0.d.f.ip6.arpa.", ptr.Name) + assert.Equal(t, int(dns.TypePTR), ptr.Type) + assert.Equal(t, "peer1.netbird.cloud.", ptr.RData) +} + +func TestCreatePTRRecord_OutOfRange(t *testing.T) { + record := nbdns.SimpleRecord{ + Name: "peer1.netbird.cloud.", + Type: int(dns.TypeA), + RData: "10.0.0.1", + } + prefix := netip.MustParsePrefix("100.64.0.0/16") + + _, ok := createPTRRecord(record, prefix) + assert.False(t, ok) +} + +func TestGenerateReverseZoneName_IPv4(t *testing.T) { + tests := []struct { + prefix string + expected string + }{ + {"100.64.0.0/16", "64.100.in-addr.arpa."}, + {"10.0.0.0/8", "10.in-addr.arpa."}, + {"192.168.1.0/24", "1.168.192.in-addr.arpa."}, + } + + for _, tt := range tests { + t.Run(tt.prefix, func(t *testing.T) { + zone, err := generateReverseZoneName(netip.MustParsePrefix(tt.prefix)) + require.NoError(t, err) + assert.Equal(t, tt.expected, zone) + }) + } +} + +func TestGenerateReverseZoneName_IPv6(t *testing.T) { + tests := []struct { + prefix string + expected string + }{ + {"fd00:1234:5678::/48", "8.7.6.5.4.3.2.1.0.0.d.f.ip6.arpa."}, + {"fd00::/16", "0.0.d.f.ip6.arpa."}, + {"fd12:3456:789a:bcde::/64", "e.d.c.b.a.9.8.7.6.5.4.3.2.1.d.f.ip6.arpa."}, + } + + for _, tt := range tests { + t.Run(tt.prefix, func(t *testing.T) { + zone, err := generateReverseZoneName(netip.MustParsePrefix(tt.prefix)) + require.NoError(t, err) + assert.Equal(t, tt.expected, zone) + }) + } +} + +func TestCollectPTRRecords_BothFamilies(t *testing.T) { + config := &nbdns.Config{ + CustomZones: []nbdns.CustomZone{ + { + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud.", Type: int(dns.TypeA), RData: "100.64.0.1"}, + {Name: "peer1.netbird.cloud.", Type: int(dns.TypeAAAA), RData: "fd00::1"}, + {Name: "peer2.netbird.cloud.", Type: int(dns.TypeA), RData: "100.64.0.2"}, + }, + }, + }, + } + + v4Records := collectPTRRecords(config, netip.MustParsePrefix("100.64.0.0/16")) + assert.Len(t, v4Records, 2, "should collect 2 A record PTRs for the v4 prefix") + + v6Records := collectPTRRecords(config, netip.MustParsePrefix("fd00::/64")) + assert.Len(t, v6Records, 1, "should collect 1 AAAA record PTR for the v6 prefix") +} + +func TestAddReverseZone_IPv6(t *testing.T) { + config := &nbdns.Config{ + CustomZones: []nbdns.CustomZone{ + { + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud.", Type: int(dns.TypeAAAA), RData: "fd00:1234:5678::1"}, + }, + }, + }, + } + + addReverseZone(config, netip.MustParsePrefix("fd00:1234:5678::/48")) + + require.Len(t, config.CustomZones, 2) + reverseZone := config.CustomZones[1] + assert.Equal(t, "8.7.6.5.4.3.2.1.0.0.d.f.ip6.arpa.", reverseZone.Domain) + assert.Len(t, reverseZone.Records, 1) + assert.Equal(t, int(dns.TypePTR), reverseZone.Records[0].Type) +} diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index 58b88d9ef..c4c16cd3f 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -80,6 +80,7 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } + // IPv4-only: peers reach the forwarder via its v4 overlay address. localAddr := m.wgIface.Address().IP if localAddr.IsValid() && m.firewall != nil { diff --git a/client/internal/ebpf/ebpf/dns_fwd_linux.go b/client/internal/ebpf/ebpf/dns_fwd_linux.go index 93797da76..1e7774573 100644 --- a/client/internal/ebpf/ebpf/dns_fwd_linux.go +++ b/client/internal/ebpf/ebpf/dns_fwd_linux.go @@ -2,7 +2,8 @@ package ebpf import ( "encoding/binary" - "net" + "fmt" + "net/netip" log "github.com/sirupsen/logrus" ) @@ -12,7 +13,7 @@ const ( mapKeyDNSPort uint32 = 1 ) -func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error { +func (tf *GeneralManager) LoadDNSFwd(ip netip.Addr, dnsPort int) error { log.Debugf("load eBPF DNS forwarder, watching addr: %s:53, redirect to port: %d", ip, dnsPort) tf.lock.Lock() defer tf.lock.Unlock() @@ -22,7 +23,11 @@ func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error { return err } - err = tf.bpfObjs.NbMapDnsIp.Put(mapKeyDNSIP, ip2int(ip)) + if !ip.Is4() { + return fmt.Errorf("eBPF DNS forwarder only supports IPv4, got %s", ip) + } + ip4 := ip.As4() + err = tf.bpfObjs.NbMapDnsIp.Put(mapKeyDNSIP, binary.BigEndian.Uint32(ip4[:])) if err != nil { return err } @@ -45,7 +50,3 @@ func (tf *GeneralManager) FreeDNSFwd() error { return tf.unsetFeatureFlag(featureFlagDnsForwarder) } -func ip2int(ipString string) uint32 { - ip := net.ParseIP(ipString) - return binary.BigEndian.Uint32(ip.To4()) -} diff --git a/client/internal/ebpf/manager/manager.go b/client/internal/ebpf/manager/manager.go index af10142d5..25a767090 100644 --- a/client/internal/ebpf/manager/manager.go +++ b/client/internal/ebpf/manager/manager.go @@ -1,8 +1,10 @@ package manager +import "net/netip" + // Manager is used to load multiple eBPF programs. E.g., current DNS programs and WireGuard proxy type Manager interface { - LoadDNSFwd(ip string, dnsPort int) error + LoadDNSFwd(ip netip.Addr, dnsPort int) error FreeDNSFwd() error LoadWgProxy(proxyPort, wgPort int) error FreeWGProxy() error diff --git a/client/internal/engine.go b/client/internal/engine.go index b49e02c6d..66fe6056b 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -26,11 +26,14 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/firewall" + "github.com/netbirdio/netbird/client/firewall/firewalld" firewallManager "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/udpmux" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/debug" "github.com/netbirdio/netbird/client/internal/dns" @@ -62,11 +65,13 @@ import ( mgm "github.com/netbirdio/netbird/shared/management/client" "github.com/netbirdio/netbird/shared/management/domain" mgmProto "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/netiputil" auth "github.com/netbirdio/netbird/shared/relay/auth/hmac" relayClient "github.com/netbirdio/netbird/shared/relay/client" signal "github.com/netbirdio/netbird/shared/signal/client" sProto "github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/util/capture" ) // PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer. @@ -85,8 +90,9 @@ type EngineConfig struct { WgPort int WgIfaceName string - // WgAddr is a Wireguard local address (Netbird Network IP) - WgAddr string + // WgAddr is the Wireguard local address (Netbird Network IP). + // Contains both v4 and optional v6 overlay addresses. + WgAddr wgaddr.Address // WgPrivateKey is a Wireguard private key of our peer (it MUST never leave the machine) WgPrivateKey wgtypes.Key @@ -131,6 +137,7 @@ type EngineConfig struct { DisableFirewall bool BlockLANAccess bool BlockInbound bool + DisableIPv6 bool LazyConnectionEnabled bool @@ -217,6 +224,8 @@ type Engine struct { portForwardManager *portforward.Manager srWatcher *guard.SRWatcher + afpacketCapture *capture.AFPacketCapture + // Sync response persistence (protected by syncRespMux) syncRespMux sync.RWMutex persistSyncResponse bool @@ -570,7 +579,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) e.connMgr.Start(e.ctx) e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg) - e.srWatcher.Start() + e.srWatcher.Start(peer.IsForceRelayed()) e.receiveSignalEvents() e.receiveManagementEvents() @@ -604,6 +613,8 @@ func (e *Engine) createFirewall() error { return nil } + firewalld.SetParentContext(e.ctx) + var err error e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU) if err != nil { @@ -637,7 +648,7 @@ func (e *Engine) initFirewall() error { rosenpassPort := e.rpManager.GetAddress().Port port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}} - // this rule is static and will be torn down on engine down by the firewall manager + // IPv4-only: rosenpass peers connect via AllowedIps[0] which is always v4. if _, err := e.firewall.AddPeerFiltering( nil, net.IP{0, 0, 0, 0}, @@ -689,10 +700,15 @@ func (e *Engine) blockLanAccess() { log.Infof("blocking route LAN access for networks: %v", toBlock) v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0) + v6 := netip.PrefixFrom(netip.IPv6Unspecified(), 0) for _, network := range toBlock { + source := v4 + if network.Addr().Is6() { + source = v6 + } if _, err := e.firewall.AddRouteFiltering( nil, - []netip.Prefix{v4}, + []netip.Prefix{source}, firewallManager.Network{Prefix: network}, firewallManager.ProtocolALL, nil, @@ -730,7 +746,7 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { if !ok { continue } - if !compareNetIPLists(allowedIPs, p.GetAllowedIps()) { + if !compareNetIPLists(allowedIPs, e.filterAllowedIPs(p.GetAllowedIps())) { modified = append(modified, p) continue } @@ -941,7 +957,12 @@ func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error { return fmt.Errorf("update relay token: %w", err) } - e.relayManager.UpdateServerURLs(update.Urls) + urls := update.Urls + if override, ok := peer.OverrideRelayURLs(); ok { + log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override) + urls = override + } + e.relayManager.UpdateServerURLs(urls) // Just in case the agent started with an MGM server where the relay was disabled but was later enabled. // We can ignore all errors because the guard will manage the reconnection retries. @@ -1004,6 +1025,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error { e.config.DisableFirewall, e.config.BlockLANAccess, e.config.BlockInbound, + e.config.DisableIPv6, e.config.LazyConnectionEnabled, e.config.EnableSSHRoot, e.config.EnableSSHSFTP, @@ -1031,6 +1053,13 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { return ErrResetConnection } + if !e.config.DisableIPv6 && e.hasIPv6Changed(conf) { + log.Infof("peer IPv6 address changed, restarting client") + _ = CtxGetState(e.ctx).Wrap(ErrResetConnection) + e.clientCancel() + return ErrResetConnection + } + if conf.GetSshConfig() != nil { if err := e.updateSSH(conf.GetSshConfig()); err != nil { log.Warnf("failed handling SSH server setup: %v", err) @@ -1039,6 +1068,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { state := e.statusRecorder.GetLocalPeerState() state.IP = e.wgInterface.Address().String() + state.IPv6 = e.wgInterface.Address().IPv6String() state.PubKey = e.config.WgPrivateKey.PublicKey().String() state.KernelInterface = !e.wgInterface.IsUserspaceBind() state.FQDN = conf.GetFqdn() @@ -1047,6 +1077,28 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { return nil } + +// hasIPv6Changed reports whether the IPv6 overlay address in the peer config +// differs from the configured address (added, removed, or changed). +// Compares against e.config.WgAddr (not the interface address, which may have +// been cleared by ClearIPv6 if OS assignment failed). +func (e *Engine) hasIPv6Changed(conf *mgmProto.PeerConfig) bool { + current := e.config.WgAddr + raw := conf.GetAddressV6() + + if len(raw) == 0 { + return current.HasIPv6() + } + + prefix, err := netiputil.DecodePrefix(raw) + if err != nil { + log.Errorf("decode v6 overlay address: %v", err) + return false + } + + return !current.HasIPv6() || current.IPv6 != prefix.Addr() || current.IPv6Net != prefix.Masked() +} + func (e *Engine) receiveJobEvents() { e.jobExecutorWG.Add(1) go func() { @@ -1145,6 +1197,7 @@ func (e *Engine) receiveManagementEvents() { e.config.DisableFirewall, e.config.BlockLANAccess, e.config.BlockInbound, + e.config.DisableIPv6, e.config.LazyConnectionEnabled, e.config.EnableSSHRoot, e.config.EnableSSHSFTP, @@ -1244,7 +1297,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { protoDNSConfig = &mgmProto.DNSConfig{} } - dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network) + dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address()) if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil { log.Errorf("failed to update dns server, err: %v", err) @@ -1399,7 +1452,9 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE return entries } -func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config { +func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, addr wgaddr.Address) nbdns.Config { + network := addr.Network + networkV6 := addr.IPv6Net //nolint forwarderPort := uint16(protoDNSConfig.GetForwarderPort()) if forwarderPort == 0 { @@ -1456,6 +1511,9 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns if len(dnsUpdate.CustomZones) > 0 { addReverseZone(&dnsUpdate, network) + if networkV6.IsValid() { + addReverseZone(&dnsUpdate, networkV6) + } } return dnsUpdate @@ -1465,8 +1523,10 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) { replacement := make([]peer.State, len(offlinePeers)) for i, offlinePeer := range offlinePeers { log.Debugf("added offline peer %s", offlinePeer.Fqdn) + v4, v6 := overlayAddrsFromAllowedIPs(offlinePeer.GetAllowedIps(), e.wgInterface.Address().IPv6Net) replacement[i] = peer.State{ - IP: strings.Join(offlinePeer.GetAllowedIps(), ","), + IP: addrToString(v4), + IPv6: addrToString(v6), PubKey: offlinePeer.GetWgPubKey(), FQDN: offlinePeer.GetFqdn(), ConnStatus: peer.StatusIdle, @@ -1477,6 +1537,37 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) { e.statusRecorder.ReplaceOfflinePeers(replacement) } +// overlayAddrsFromAllowedIPs extracts the peer's v4 and v6 overlay addresses +// from AllowedIPs strings. Only host routes (/32, /128) are considered; v6 must +// fall within ourV6Net to distinguish overlay addresses from routed prefixes. +func overlayAddrsFromAllowedIPs(allowedIPs []string, ourV6Net netip.Prefix) (v4, v6 netip.Addr) { + for _, cidr := range allowedIPs { + prefix, err := netip.ParsePrefix(cidr) + if err != nil { + log.Warnf("failed to parse AllowedIP %q: %v", cidr, err) + continue + } + addr := prefix.Addr().Unmap() + switch { + case addr.Is4() && prefix.Bits() == 32 && !v4.IsValid(): + v4 = addr + case addr.Is6() && prefix.Bits() == 128 && ourV6Net.Contains(addr) && !v6.IsValid(): + v6 = addr + } + if v4.IsValid() && v6.IsValid() { + break + } + } + return +} + +func addrToString(addr netip.Addr) string { + if !addr.IsValid() { + return "" + } + return addr.String() +} + // addNewPeers adds peers that were not know before but arrived from the Management service with the update func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { for _, p := range peersUpdate { @@ -1502,15 +1593,23 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { log.Errorf("failed to parse allowedIPS: %v", err) return err } + if allowedNetIP.Addr().Is6() && !e.wgInterface.Address().HasIPv6() { + continue + } peerIPs = append(peerIPs, allowedNetIP) } + if len(peerIPs) == 0 { + return fmt.Errorf("peer %s has no usable AllowedIPs", peerKey) + } + conn, err := e.createPeerConn(peerKey, peerIPs, peerConfig.AgentVersion) if err != nil { return fmt.Errorf("create peer connection: %w", err) } - err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn, peerIPs[0].Addr().String()) + peerV4, peerV6 := overlayAddrsFromAllowedIPs(peerConfig.GetAllowedIps(), e.wgInterface.Address().IPv6Net) + err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn, addrToString(peerV4), addrToString(peerV6)) if err != nil { log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err) } @@ -1695,6 +1794,11 @@ func (e *Engine) parseNATExternalIPMappings() []string { } func (e *Engine) close() { + if e.afpacketCapture != nil { + e.afpacketCapture.Stop() + e.afpacketCapture = nil + } + log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) if e.wgInterface != nil { @@ -1740,6 +1844,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err e.config.DisableFirewall, e.config.BlockLANAccess, e.config.BlockInbound, + e.config.DisableIPv6, e.config.LazyConnectionEnabled, e.config.EnableSSHRoot, e.config.EnableSSHSFTP, @@ -1753,7 +1858,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err return nil, nil, false, err } routes := toRoutes(netMap.GetRoutes()) - dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network) + dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address()) dnsFeatureFlag := toDNSFeatureFlag(netMap) return routes, &dnsCfg, dnsFeatureFlag, nil } @@ -1795,7 +1900,10 @@ func (e *Engine) wgInterfaceCreate() (err error) { case "android": err = e.wgInterface.CreateOnAndroid(e.routeManager.InitialRouteRange(), e.dnsServer.DnsIP().String(), e.dnsServer.SearchDomains()) case "ios": - e.mobileDep.NetworkChangeListener.SetInterfaceIP(e.config.WgAddr) + e.mobileDep.NetworkChangeListener.SetInterfaceIP(e.config.WgAddr.String()) + if e.config.WgAddr.HasIPv6() { + e.mobileDep.NetworkChangeListener.SetInterfaceIPv6(e.config.WgAddr.IPv6String()) + } err = e.wgInterface.Create() default: err = e.wgInterface.Create() @@ -2072,6 +2180,14 @@ func (e *Engine) GetWgAddr() netip.Addr { return e.wgInterface.Address().IP } +// GetWgV6Addr returns the IPv6 overlay address of the WireGuard interface. +func (e *Engine) GetWgV6Addr() netip.Addr { + if e.wgInterface == nil { + return netip.Addr{} + } + return e.wgInterface.Address().IPv6 +} + func (e *Engine) RenewTun(fd int) error { e.syncMsgMux.Lock() wgInterface := e.wgInterface @@ -2160,6 +2276,62 @@ func (e *Engine) Address() (netip.Addr, error) { return e.wgInterface.Address().IP, nil } +// SetCapture sets or clears packet capture on the WireGuard device. +// On userspace WireGuard, it taps the FilteredDevice directly. +// On kernel WireGuard (Linux), it falls back to AF_PACKET raw socket capture. +// Pass nil to disable capture. +func (e *Engine) SetCapture(pc device.PacketCapture) error { + e.syncMsgMux.Lock() + defer e.syncMsgMux.Unlock() + + intf := e.wgInterface + if intf == nil { + return errors.New("wireguard interface not initialized") + } + + if e.afpacketCapture != nil { + e.afpacketCapture.Stop() + e.afpacketCapture = nil + } + + dev := intf.GetDevice() + if dev != nil { + dev.SetCapture(pc) + e.setForwarderCapture(pc) + return nil + } + + // Kernel mode: no FilteredDevice. Use AF_PACKET on Linux. + if pc == nil { + return nil + } + sess, ok := pc.(*capture.Session) + if !ok { + return errors.New("filtered device not available and AF_PACKET requires *capture.Session") + } + + afc := capture.NewAFPacketCapture(intf.Name(), sess) + if err := afc.Start(); err != nil { + return fmt.Errorf("start AF_PACKET capture on %s: %w", intf.Name(), err) + } + e.afpacketCapture = afc + return nil +} + +// setForwarderCapture propagates capture to the USP filter's forwarder endpoint. +// This captures outbound response packets that bypass the FilteredDevice in netstack mode. +func (e *Engine) setForwarderCapture(pc device.PacketCapture) { + if e.firewall == nil { + return + } + type forwarderCapturer interface { + SetPacketCapture(pc forwarder.PacketCapture) + } + if fc, ok := e.firewall.(forwarderCapturer); ok { + fc.SetPacketCapture(pc) + } +} + func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) { if e.firewall == nil { log.Warn("firewall is disabled, not updating forwarding rules") @@ -2297,8 +2469,7 @@ func getInterfacePrefixes() ([]netip.Prefix, error) { prefix := netip.PrefixFrom(addr.Unmap(), ones).Masked() ip := prefix.Addr() - // TODO: add IPv6 - if !ip.Is4() || ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + if ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { continue } @@ -2309,6 +2480,24 @@ func getInterfacePrefixes() ([]netip.Prefix, error) { return prefixes, nberrors.FormatErrorOrNil(merr) } +// filterAllowedIPs strips IPv6 entries when the local interface has no v6 address. +// This covers both the explicit --disable-ipv6 flag (v6 never assigned) and the +// case where OS v6 assignment failed (ClearIPv6). Without this, WireGuard would +// accept v6 traffic that the native firewall cannot filter. +func (e *Engine) filterAllowedIPs(ips []string) []string { + if e.wgInterface.Address().HasIPv6() { + return ips + } + filtered := make([]string, 0, len(ips)) + for _, s := range ips { + p, err := netip.ParsePrefix(s) + if err != nil || !p.Addr().Is6() { + filtered = append(filtered, s) + } + } + return filtered +} + // compareNetIPLists compares a list of netip.Prefix with a list of strings. // return true if both lists are equal, false otherwise. func compareNetIPLists(list1 []netip.Prefix, list2 []string) bool { @@ -2381,6 +2570,8 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) { } } + relayIP := decodeRelayIP(msg.GetBody().GetRelayServerIP()) + offerAnswer := peer.OfferAnswer{ IceCredentials: peer.IceCredentials{ UFrag: remoteCred.UFrag, @@ -2391,7 +2582,23 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) { RosenpassPubKey: rosenpassPubKey, RosenpassAddr: rosenpassAddr, RelaySrvAddress: msg.GetBody().GetRelayServerAddress(), + RelaySrvIP: relayIP, SessionID: sessionID, } return &offerAnswer, nil } + +// decodeRelayIP decodes the proto relayServerIP bytes (4 or 16) into a +// netip.Addr. Returns the zero value for empty input and logs a warning +// for malformed payloads. +func decodeRelayIP(b []byte) netip.Addr { + if len(b) == 0 { + return netip.Addr{} + } + ip, ok := netip.AddrFromSlice(b) + if !ok { + log.Warnf("invalid relayServerIP in signal message (%d bytes), ignoring", len(b)) + return netip.Addr{} + } + return ip.Unmap() +} diff --git a/client/internal/engine_ssh.go b/client/internal/engine_ssh.go index 1419bc262..53d2c1122 100644 --- a/client/internal/engine_ssh.go +++ b/client/internal/engine_ssh.go @@ -41,6 +41,14 @@ func (e *Engine) setupSSHPortRedirection() error { } log.Infof("SSH port redirection enabled: %s:22 -> %s:22022", localAddr, localAddr) + if v6 := e.wgInterface.Address().IPv6; v6.IsValid() { + if err := e.firewall.AddInboundDNAT(v6, firewallManager.ProtocolTCP, 22, 22022); err != nil { + log.Warnf("failed to add IPv6 SSH port redirection: %v", err) + } else { + log.Infof("SSH port redirection enabled: [%s]:22 -> [%s]:22022", v6, v6) + } + } + return nil } @@ -137,12 +145,13 @@ func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) [] continue } - peerIP := e.extractPeerIP(peerConfig) + peerV4, peerV6 := overlayAddrsFromAllowedIPs(peerConfig.GetAllowedIps(), e.wgInterface.Address().IPv6Net) hostname := e.extractHostname(peerConfig) peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{ Hostname: hostname, - IP: peerIP, + IP: peerV4, + IPv6: peerV6, FQDN: peerConfig.GetFqdn(), }) } @@ -150,18 +159,6 @@ func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) [] return peerInfo } -// extractPeerIP extracts IP address from peer's allowed IPs -func (e *Engine) extractPeerIP(peerConfig *mgmProto.RemotePeerConfig) string { - if len(peerConfig.GetAllowedIps()) == 0 { - return "" - } - - if prefix, err := netip.ParsePrefix(peerConfig.GetAllowedIps()[0]); err == nil { - return prefix.Addr().String() - } - return "" -} - // extractHostname extracts short hostname from FQDN func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string { fqdn := peerConfig.GetFqdn() @@ -208,7 +205,7 @@ func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) { fullStatus := statusRecorder.GetFullStatus() for _, peerState := range fullStatus.Peers { - if peerState.IP == peerAddress || peerState.FQDN == peerAddress { + if peerState.IP == peerAddress || peerState.FQDN == peerAddress || peerState.IPv6 == peerAddress { if len(peerState.SSHHostKey) > 0 { return peerState.SSHHostKey, true } @@ -262,6 +259,13 @@ func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error { return fmt.Errorf("start SSH server: %w", err) } + if v6 := wgAddr.IPv6; v6.IsValid() { + v6Addr := netip.AddrPortFrom(v6, sshserver.InternalSSHPort) + if err := server.AddListener(e.ctx, v6Addr); err != nil { + log.Warnf("failed to add IPv6 SSH listener: %v", err) + } + } + e.sshServer = server if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { @@ -330,6 +334,12 @@ func (e *Engine) cleanupSSHPortRedirection() error { } log.Debugf("SSH port redirection removed: %s:22 -> %s:22022", localAddr, localAddr) + if v6 := e.wgInterface.Address().IPv6; v6.IsValid() { + if err := e.firewall.RemoveInboundDNAT(v6, firewallManager.ProtocolTCP, 22, 22022); err != nil { + log.Debugf("failed to remove IPv6 SSH port redirection: %v", err) + } + } + return nil } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 9fa4e51b2..834a49a09 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -67,6 +67,7 @@ import ( mgmt "github.com/netbirdio/netbird/shared/management/client" mgmtProto "github.com/netbirdio/netbird/shared/management/proto" relayClient "github.com/netbirdio/netbird/shared/relay/client" + "github.com/netbirdio/netbird/shared/netiputil" signal "github.com/netbirdio/netbird/shared/signal/client" "github.com/netbirdio/netbird/shared/signal/proto" signalServer "github.com/netbirdio/netbird/signal/server" @@ -95,7 +96,7 @@ type MockWGIface struct { AddressFunc func() wgaddr.Address ToInterfaceFunc func() *net.Interface UpFunc func() (*udpmux.UniversalUDPMuxDefault, error) - UpdateAddrFunc func(newAddr string) error + UpdateAddrFunc func(newAddr wgaddr.Address) error UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error RemovePeerFunc func(peerKey string) error AddAllowedIPFunc func(peerKey string, allowedIP netip.Prefix) error @@ -157,7 +158,7 @@ func (m *MockWGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) { return m.UpFunc() } -func (m *MockWGIface) UpdateAddr(newAddr string) error { +func (m *MockWGIface) UpdateAddr(newAddr wgaddr.Address) error { return m.UpdateAddrFunc(newAddr) } @@ -254,7 +255,7 @@ func TestEngine_SSH(t *testing.T) { ctx, cancel, &EngineConfig{ WgIfaceName: "utun101", - WgAddr: "100.64.0.1/24", + WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"), WgPrivateKey: key, WgPort: 33100, ServerSSHAllowed: true, @@ -431,7 +432,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) engine := NewEngine(ctx, cancel, &EngineConfig{ WgIfaceName: "utun102", - WgAddr: "100.64.0.1/24", + WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"), WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, @@ -655,7 +656,7 @@ func TestEngine_Sync(t *testing.T) { relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) engine := NewEngine(ctx, cancel, &EngineConfig{ WgIfaceName: "utun103", - WgAddr: "100.64.0.1/24", + WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"), WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, @@ -825,7 +826,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) engine := NewEngine(ctx, cancel, &EngineConfig{ WgIfaceName: wgIfaceName, - WgAddr: wgAddr, + WgAddr: wgaddr.MustParseWGAddress(wgAddr), WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, @@ -843,7 +844,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { opts := iface.WGIFaceOpts{ IFaceName: wgIfaceName, - Address: wgAddr, + Address: wgaddr.MustParseWGAddress(wgAddr), WGPort: engine.config.WgPort, WGPrivKey: key.String(), MTU: iface.DefaultMTU, @@ -1032,7 +1033,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) engine := NewEngine(ctx, cancel, &EngineConfig{ WgIfaceName: wgIfaceName, - WgAddr: wgAddr, + WgAddr: wgaddr.MustParseWGAddress(wgAddr), WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, @@ -1050,7 +1051,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { } opts := iface.WGIFaceOpts{ IFaceName: wgIfaceName, - Address: wgAddr, + Address: wgaddr.MustParseWGAddress(wgAddr), WGPort: 33100, WGPrivKey: key.String(), MTU: iface.DefaultMTU, @@ -1555,7 +1556,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin wgPort := 33100 + i conf := &EngineConfig{ WgIfaceName: ifaceName, - WgAddr: resp.PeerConfig.Address, + WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address), WgPrivateKey: key, WgPort: wgPort, MTU: iface.DefaultMTU, @@ -1671,7 +1672,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri if err != nil { return nil, "", err } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil) if err != nil { return nil, "", err } @@ -1705,3 +1706,224 @@ func getPeers(e *Engine) int { return len(e.peerStore.PeersPubKey()) } + +func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte { + t.Helper() + b, err := netiputil.EncodePrefix(p) + require.NoError(t, err) + return b +} + +func TestEngine_hasIPv6Changed(t *testing.T) { + v4Only := wgaddr.MustParseWGAddress("100.64.0.1/16") + + v4v6 := wgaddr.MustParseWGAddress("100.64.0.1/16") + v4v6.IPv6 = netip.MustParseAddr("fd00::1") + v4v6.IPv6Net = netip.MustParsePrefix("fd00::1/64").Masked() + + tests := []struct { + name string + current wgaddr.Address + confV6 []byte + expected bool + }{ + { + name: "no v6 before, no v6 now", + current: v4Only, + confV6: nil, + expected: false, + }, + { + name: "no v6 before, v6 added", + current: v4Only, + confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/64")), + expected: true, + }, + { + name: "had v6, now removed", + current: v4v6, + confV6: nil, + expected: true, + }, + { + name: "had v6, same v6", + current: v4v6, + confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/64")), + expected: false, + }, + { + name: "had v6, different v6", + current: v4v6, + confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::2/64")), + expected: true, + }, + { + name: "same v6 addr, different prefix length", + current: v4v6, + confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/80")), + expected: true, + }, + { + name: "decode error keeps status quo", + current: v4Only, + confV6: []byte{1, 2, 3}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + engine := &Engine{ + config: &EngineConfig{WgAddr: tt.current}, + } + conf := &mgmtProto.PeerConfig{ + AddressV6: tt.confV6, + } + assert.Equal(t, tt.expected, engine.hasIPv6Changed(conf)) + }) + } +} + +func TestFilterAllowedIPs(t *testing.T) { + v4v6Addr := wgaddr.MustParseWGAddress("100.64.0.1/16") + v4v6Addr.IPv6 = netip.MustParseAddr("fd00::1") + v4v6Addr.IPv6Net = netip.MustParsePrefix("fd00::1/64").Masked() + + v4OnlyAddr := wgaddr.MustParseWGAddress("100.64.0.1/16") + + tests := []struct { + name string + addr wgaddr.Address + input []string + expected []string + }{ + { + name: "interface has v6, keep all", + addr: v4v6Addr, + input: []string{"100.64.0.1/32", "fd00::1/128"}, + expected: []string{"100.64.0.1/32", "fd00::1/128"}, + }, + { + name: "no v6, strip v6", + addr: v4OnlyAddr, + input: []string{"100.64.0.1/32", "fd00::1/128"}, + expected: []string{"100.64.0.1/32"}, + }, + { + name: "no v6, only v4", + addr: v4OnlyAddr, + input: []string{"100.64.0.1/32", "10.0.0.0/8"}, + expected: []string{"100.64.0.1/32", "10.0.0.0/8"}, + }, + { + name: "no v6, only v6 input", + addr: v4OnlyAddr, + input: []string{"fd00::1/128", "::/0"}, + expected: []string{}, + }, + { + name: "no v6, invalid prefix preserved", + addr: v4OnlyAddr, + input: []string{"100.64.0.1/32", "garbage"}, + expected: []string{"100.64.0.1/32", "garbage"}, + }, + { + name: "no v6, empty input", + addr: v4OnlyAddr, + input: []string{}, + expected: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addr := tt.addr + engine := &Engine{ + config: &EngineConfig{}, + wgInterface: &MockWGIface{ + AddressFunc: func() wgaddr.Address { return addr }, + }, + } + result := engine.filterAllowedIPs(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestOverlayAddrsFromAllowedIPs(t *testing.T) { + ourV6Net := netip.MustParsePrefix("fd00:1234:5678:abcd::/64") + + tests := []struct { + name string + allowedIPs []string + ourV6Net netip.Prefix + wantV4 string + wantV6 string + }{ + { + name: "v4 only", + allowedIPs: []string{"100.64.0.1/32"}, + ourV6Net: ourV6Net, + wantV4: "100.64.0.1", + wantV6: "", + }, + { + name: "v4 and v6 overlay", + allowedIPs: []string{"100.64.0.1/32", "fd00:1234:5678:abcd::1/128"}, + ourV6Net: ourV6Net, + wantV4: "100.64.0.1", + wantV6: "fd00:1234:5678:abcd::1", + }, + { + name: "v4, routed v6, overlay v6", + allowedIPs: []string{"100.64.0.1/32", "2001:db8::1/128", "fd00:1234:5678:abcd::1/128"}, + ourV6Net: ourV6Net, + wantV4: "100.64.0.1", + wantV6: "fd00:1234:5678:abcd::1", + }, + { + name: "routed v6 /128 outside our subnet is ignored", + allowedIPs: []string{"100.64.0.1/32", "2001:db8::1/128"}, + ourV6Net: ourV6Net, + wantV4: "100.64.0.1", + wantV6: "", + }, + { + name: "routed v6 prefix is ignored", + allowedIPs: []string{"100.64.0.1/32", "fd00:1234:5678:abcd::/64"}, + ourV6Net: ourV6Net, + wantV4: "100.64.0.1", + wantV6: "", + }, + { + name: "no v6 subnet configured", + allowedIPs: []string{"100.64.0.1/32", "fd00:1234:5678:abcd::1/128"}, + ourV6Net: netip.Prefix{}, + wantV4: "100.64.0.1", + wantV6: "", + }, + { + name: "v4 /24 route is ignored", + allowedIPs: []string{"100.64.0.0/24", "100.64.0.1/32"}, + ourV6Net: ourV6Net, + wantV4: "100.64.0.1", + wantV6: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v4, v6 := overlayAddrsFromAllowedIPs(tt.allowedIPs, tt.ourV6Net) + if tt.wantV4 == "" { + assert.False(t, v4.IsValid(), "expected no v4") + } else { + assert.Equal(t, tt.wantV4, v4.String(), "v4") + } + if tt.wantV6 == "" { + assert.False(t, v6.IsValid(), "expected no v6") + } else { + assert.Equal(t, tt.wantV6, v6.String(), "v6") + } + }) + } +} diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index 39e9bacfa..2eeac1954 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -26,7 +26,7 @@ type wgIfaceBase interface { Address() wgaddr.Address ToInterface() *net.Interface Up() (*udpmux.UniversalUDPMuxDefault, error) - UpdateAddr(newAddr string) error + UpdateAddr(newAddr wgaddr.Address) error GetProxy() wgproxy.Proxy GetProxyPort() uint16 UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error diff --git a/client/internal/lazyconn/activity/listener_bind.go b/client/internal/lazyconn/activity/listener_bind.go index 792d04215..60b8baadb 100644 --- a/client/internal/lazyconn/activity/listener_bind.go +++ b/client/internal/lazyconn/activity/listener_bind.go @@ -57,6 +57,7 @@ func NewBindListener(wgIface WgInterface, bind device.EndpointManager, cfg lazyc // deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP. // Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y). // It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface. +// For IPv6-only peers, the last two bytes of the v6 address are used. func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) { if len(allowedIPs) == 0 { return netip.Addr{}, fmt.Errorf("no allowed IPs for peer") @@ -64,6 +65,7 @@ func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, e ourNetwork := wgIface.Address().Network + // Try v4 first (preferred: deterministic from overlay IP) var peerIP netip.Addr for _, allowedIP := range allowedIPs { ip := allowedIP.Addr() @@ -76,13 +78,24 @@ func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, e } } - if !peerIP.IsValid() { - return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs") + if peerIP.IsValid() { + octets := peerIP.As4() + return netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]}), nil } - octets := peerIP.As4() - fakeIP := netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]}) - return fakeIP, nil + // Fallback: use last two bytes of first v6 overlay IP + addr := wgIface.Address() + if addr.IPv6Net.IsValid() { + for _, allowedIP := range allowedIPs { + ip := allowedIP.Addr() + if ip.Is6() && addr.IPv6Net.Contains(ip) { + raw := ip.As16() + return netip.AddrFrom4([4]byte{127, 2, raw[14], raw[15]}), nil + } + } + } + + return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs") } func (d *BindListener) setupLazyConn() error { diff --git a/client/internal/lazyconn/activity/listener_bind_test.go b/client/internal/lazyconn/activity/listener_bind_test.go index f86dd3877..1baaae6be 100644 --- a/client/internal/lazyconn/activity/listener_bind_test.go +++ b/client/internal/lazyconn/activity/listener_bind_test.go @@ -3,7 +3,6 @@ package activity import ( "net" "net/netip" - "runtime" "testing" "time" @@ -18,10 +17,6 @@ import ( peerid "github.com/netbirdio/netbird/client/internal/peer/id" ) -func isBindListenerPlatform() bool { - return runtime.GOOS == "windows" || runtime.GOOS == "js" -} - // mockEndpointManager implements device.EndpointManager for testing type mockEndpointManager struct { endpoints map[netip.Addr]net.Conn @@ -181,10 +176,6 @@ func TestBindListener_Close(t *testing.T) { } func TestManager_BindMode(t *testing.T) { - if !isBindListenerPlatform() { - t.Skip("BindListener only used on Windows/JS platforms") - } - mockEndpointMgr := newMockEndpointManager() mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} @@ -226,10 +217,6 @@ func TestManager_BindMode(t *testing.T) { } func TestManager_BindMode_MultiplePeers(t *testing.T) { - if !isBindListenerPlatform() { - t.Skip("BindListener only used on Windows/JS platforms") - } - mockEndpointMgr := newMockEndpointManager() mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} diff --git a/client/internal/lazyconn/activity/manager.go b/client/internal/lazyconn/activity/manager.go index 1c11378c8..cccc0669f 100644 --- a/client/internal/lazyconn/activity/manager.go +++ b/client/internal/lazyconn/activity/manager.go @@ -4,14 +4,12 @@ import ( "errors" "net" "net/netip" - "runtime" "sync" "time" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/lazyconn" peerid "github.com/netbirdio/netbird/client/internal/peer/id" @@ -75,16 +73,6 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error) return NewUDPListener(m.wgIface, peerCfg) } - // BindListener is used on Windows, JS, and netstack platforms: - // - JS: Cannot listen to UDP sockets - // - Windows: IP_UNICAST_IF socket option forces packets out the interface the default - // gateway points to, preventing them from reaching the loopback interface. - // - Netstack: Allows multiple instances on the same host without port conflicts. - // BindListener bypasses these issues by passing data directly through the bind. - if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() { - return NewUDPListener(m.wgIface, peerCfg) - } - provider, ok := m.wgIface.(bindProvider) if !ok { return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider") diff --git a/client/internal/lazyconn/manager/manager.go b/client/internal/lazyconn/manager/manager.go index b6b3c6091..fc47bda39 100644 --- a/client/internal/lazyconn/manager/manager.go +++ b/client/internal/lazyconn/manager/manager.go @@ -6,7 +6,6 @@ import ( "time" log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/internal/lazyconn" "github.com/netbirdio/netbird/client/internal/lazyconn/activity" @@ -91,8 +90,8 @@ func (m *Manager) UpdateRouteHAMap(haMap route.HAMap) { m.routesMu.Lock() defer m.routesMu.Unlock() - maps.Clear(m.peerToHAGroups) - maps.Clear(m.haGroupToPeers) + clear(m.peerToHAGroups) + clear(m.haGroupToPeers) for haUniqueID, routes := range haMap { var peers []string diff --git a/client/internal/listener/network_change.go b/client/internal/listener/network_change.go index 08bf5fd52..e0aa43abe 100644 --- a/client/internal/listener/network_change.go +++ b/client/internal/listener/network_change.go @@ -5,4 +5,5 @@ type NetworkChangeListener interface { // OnNetworkChanged invoke when network settings has been changed OnNetworkChanged(string) SetInterfaceIP(string) + SetInterfaceIPv6(string) } diff --git a/client/internal/netflow/conntrack/conntrack.go b/client/internal/netflow/conntrack/conntrack.go index 2420b1fdf..6f1da5138 100644 --- a/client/internal/netflow/conntrack/conntrack.go +++ b/client/internal/netflow/conntrack/conntrack.go @@ -316,7 +316,7 @@ func (c *ConnTrack) handleEvent(event nfct.Event) { case nftypes.TCP, nftypes.UDP, nftypes.SCTP: srcPort = flow.TupleOrig.Proto.SourcePort dstPort = flow.TupleOrig.Proto.DestinationPort - case nftypes.ICMP: + case nftypes.ICMP, nftypes.ICMPv6: icmpType = flow.TupleOrig.Proto.ICMPType icmpCode = flow.TupleOrig.Proto.ICMPCode } @@ -359,8 +359,14 @@ func (c *ConnTrack) relevantFlow(mark uint32, srcIP, dstIP netip.Addr) bool { } // fallback if mark rules are not in place - wgnet := c.iface.Address().Network - return wgnet.Contains(srcIP) || wgnet.Contains(dstIP) + addr := c.iface.Address() + if addr.Network.Contains(srcIP) || addr.Network.Contains(dstIP) { + return true + } + if addr.IPv6Net.IsValid() { + return addr.IPv6Net.Contains(srcIP) || addr.IPv6Net.Contains(dstIP) + } + return false } // mapRxPackets maps packet counts to RX based on flow direction @@ -419,17 +425,16 @@ func (c *ConnTrack) inferDirection(mark uint32, srcIP, dstIP netip.Addr) nftypes } // fallback if marks are not set - wgaddr := c.iface.Address().IP - wgnetwork := c.iface.Address().Network + addr := c.iface.Address() switch { - case wgaddr == srcIP: + case addr.IP == srcIP || (addr.IPv6.IsValid() && addr.IPv6 == srcIP): return nftypes.Egress - case wgaddr == dstIP: + case addr.IP == dstIP || (addr.IPv6.IsValid() && addr.IPv6 == dstIP): return nftypes.Ingress - case wgnetwork.Contains(srcIP): + case addr.Network.Contains(srcIP) || (addr.IPv6Net.IsValid() && addr.IPv6Net.Contains(srcIP)): // netbird network -> resource network return nftypes.Ingress - case wgnetwork.Contains(dstIP): + case addr.Network.Contains(dstIP) || (addr.IPv6Net.IsValid() && addr.IPv6Net.Contains(dstIP)): // resource network -> netbird network return nftypes.Egress } diff --git a/client/internal/netflow/logger/logger.go b/client/internal/netflow/logger/logger.go index a033a2a7c..8f8e68784 100644 --- a/client/internal/netflow/logger/logger.go +++ b/client/internal/netflow/logger/logger.go @@ -24,15 +24,17 @@ type Logger struct { cancel context.CancelFunc statusRecorder *peer.Status wgIfaceNet netip.Prefix + wgIfaceNetV6 netip.Prefix dnsCollection atomic.Bool exitNodeCollection atomic.Bool Store types.Store } -func New(statusRecorder *peer.Status, wgIfaceIPNet netip.Prefix) *Logger { +func New(statusRecorder *peer.Status, wgIfaceIPNet, wgIfaceIPNetV6 netip.Prefix) *Logger { return &Logger{ statusRecorder: statusRecorder, wgIfaceNet: wgIfaceIPNet, + wgIfaceNetV6: wgIfaceIPNetV6, Store: store.NewMemoryStore(), } } @@ -88,11 +90,11 @@ func (l *Logger) startReceiver() { var isSrcExitNode bool var isDestExitNode bool - if !l.wgIfaceNet.Contains(event.SourceIP) { + if !l.isOverlayIP(event.SourceIP) { event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP) } - if !l.wgIfaceNet.Contains(event.DestIP) { + if !l.isOverlayIP(event.DestIP) { event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP) } @@ -136,6 +138,10 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) { l.exitNodeCollection.Store(exitNodeCollection) } +func (l *Logger) isOverlayIP(ip netip.Addr) bool { + return l.wgIfaceNet.Contains(ip) || (l.wgIfaceNetV6.IsValid() && l.wgIfaceNetV6.Contains(ip)) +} + func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool { // check dns collection if !l.dnsCollection.Load() && event.Protocol == types.UDP && diff --git a/client/internal/netflow/logger/logger_test.go b/client/internal/netflow/logger/logger_test.go index 1144544d8..ad2eedef2 100644 --- a/client/internal/netflow/logger/logger_test.go +++ b/client/internal/netflow/logger/logger_test.go @@ -12,7 +12,7 @@ import ( ) func TestStore(t *testing.T) { - logger := logger.New(nil, netip.Prefix{}) + logger := logger.New(nil, netip.Prefix{}, netip.Prefix{}) logger.Enable() event := types.EventFields{ diff --git a/client/internal/netflow/manager.go b/client/internal/netflow/manager.go index 7752c97b0..eff083dbf 100644 --- a/client/internal/netflow/manager.go +++ b/client/internal/netflow/manager.go @@ -35,11 +35,12 @@ type Manager struct { // NewManager creates a new netflow manager func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager { - var prefix netip.Prefix + var prefix, prefixV6 netip.Prefix if iface != nil { prefix = iface.Address().Network + prefixV6 = iface.Address().IPv6Net } - flowLogger := logger.New(statusRecorder, prefix) + flowLogger := logger.New(statusRecorder, prefix, prefixV6) var ct nftypes.ConnTracker if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() { @@ -269,7 +270,7 @@ func toProtoEvent(publicKey []byte, event *nftypes.Event) *proto.FlowEvent { }, } - if event.Protocol == nftypes.ICMP { + if event.Protocol == nftypes.ICMP || event.Protocol == nftypes.ICMPv6 { protoEvent.FlowFields.ConnectionInfo = &proto.FlowFields_IcmpInfo{ IcmpInfo: &proto.ICMPInfo{ IcmpType: uint32(event.ICMPType), diff --git a/client/internal/netflow/store/memory.go b/client/internal/netflow/store/memory.go index b695a0a12..a44505e96 100644 --- a/client/internal/netflow/store/memory.go +++ b/client/internal/netflow/store/memory.go @@ -3,8 +3,6 @@ package store import ( "sync" - "golang.org/x/exp/maps" - "github.com/google/uuid" "github.com/netbirdio/netbird/client/internal/netflow/types" @@ -30,7 +28,7 @@ func (m *Memory) StoreEvent(event *types.Event) { func (m *Memory) Close() { m.mux.Lock() defer m.mux.Unlock() - maps.Clear(m.events) + clear(m.events) } func (m *Memory) GetEvents() []*types.Event { diff --git a/client/internal/netflow/types/types.go b/client/internal/netflow/types/types.go index f76146ba3..3f7d0d0ad 100644 --- a/client/internal/netflow/types/types.go +++ b/client/internal/netflow/types/types.go @@ -19,6 +19,7 @@ const ( ICMP = Protocol(1) TCP = Protocol(6) UDP = Protocol(17) + ICMPv6 = Protocol(58) SCTP = Protocol(132) ) @@ -30,6 +31,8 @@ func (p Protocol) String() string { return "TCP" case 17: return "UDP" + case 58: + return "ICMPv6" case 132: return "SCTP" default: diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 8d1585b3f..1e416bfe7 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -185,17 +185,20 @@ func (conn *Conn) Open(engineCtx context.Context) error { conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager) - relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() - workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) - if err != nil { - return err + forceRelay := IsForceRelayed() + if !forceRelay { + relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() + workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) + if err != nil { + return err + } + conn.workerICE = workerICE } - conn.workerICE = workerICE conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages) conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer) - if !isForceRelayed() { + if !forceRelay { conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer) } @@ -251,7 +254,9 @@ func (conn *Conn) Close(signalToRemote bool) { conn.wgWatcherCancel() } conn.workerRelay.CloseConn() - conn.workerICE.Close() + if conn.workerICE != nil { + conn.workerICE.Close() + } if conn.wgProxyRelay != nil { err := conn.wgProxyRelay.CloseConn() @@ -294,7 +299,9 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) { // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) { conn.dumpState.RemoteCandidate() - conn.workerICE.OnRemoteCandidate(candidate, haRoutes) + if conn.workerICE != nil { + conn.workerICE.OnRemoteCandidate(candidate, haRoutes) + } } // SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established @@ -712,33 +719,35 @@ func (conn *Conn) evalStatus() ConnStatus { return StatusConnecting } -func (conn *Conn) isConnectedOnAllWay() (connected bool) { - // would be better to protect this with a mutex, but it could cause deadlock with Close function - +// isConnectedOnAllWay evaluates the overall connection status based on ICE and Relay transports. +// +// The result is a tri-state: +// - ConnStatusConnected: all available transports are up +// - ConnStatusPartiallyConnected: relay is up but ICE is still pending/reconnecting +// - ConnStatusDisconnected: no working transport +func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) { defer func() { - if !connected { + if status == guard.ConnStatusDisconnected { conn.logTraceConnState() } }() - // For JS platform: only relay connection is supported - if runtime.GOOS == "js" { - return conn.statusRelay.Get() == worker.StatusConnected + iceWorkerCreated := conn.workerICE != nil + + var iceInProgress bool + if iceWorkerCreated { + iceInProgress = conn.workerICE.InProgress() } - // For non-JS platforms: check ICE connection status - if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() { - return false - } - - // If relay is supported with peer, it must also be connected - if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { - if conn.statusRelay.Get() == worker.StatusDisconnected { - return false - } - } - - return true + return evalConnStatus(connStatusInputs{ + forceRelay: IsForceRelayed(), + peerUsesRelay: conn.workerRelay.IsRelayConnectionSupportedWithPeer(), + relayConnected: conn.statusRelay.Get() == worker.StatusConnected, + remoteSupportsICE: conn.handshaker.RemoteICESupported(), + iceWorkerCreated: iceWorkerCreated, + iceStatusConnecting: conn.statusICE.Get() != worker.StatusDisconnected, + iceInProgress: iceInProgress, + }) } func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) { @@ -926,3 +935,43 @@ func isController(config ConnConfig) bool { func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { return remoteRosenpassPubKey != nil } + +func evalConnStatus(in connStatusInputs) guard.ConnStatus { + // "Relay up and needed" — the peer uses relay and the transport is connected. + relayUsedAndUp := in.peerUsesRelay && in.relayConnected + + // Force-relay mode: ICE never runs. Relay is the only transport and must be up. + if in.forceRelay { + return boolToConnStatus(relayUsedAndUp) + } + + // Remote peer doesn't support ICE, or we haven't created the worker yet: + // relay is the only possible transport. + if !in.remoteSupportsICE || !in.iceWorkerCreated { + return boolToConnStatus(relayUsedAndUp) + } + + // ICE counts as "up" when the status is anything other than Disconnected, OR + // when a negotiation is currently in progress (so we don't spam offers while one is in flight). + iceUp := in.iceStatusConnecting || in.iceInProgress + + // Relay side is acceptable if the peer doesn't rely on relay, or relay is connected. + relayOK := !in.peerUsesRelay || in.relayConnected + + switch { + case iceUp && relayOK: + return guard.ConnStatusConnected + case relayUsedAndUp: + // Relay is up but ICE is down — partially connected. + return guard.ConnStatusPartiallyConnected + default: + return guard.ConnStatusDisconnected + } +} + +func boolToConnStatus(connected bool) guard.ConnStatus { + if connected { + return guard.ConnStatusConnected + } + return guard.ConnStatusDisconnected +} diff --git a/client/internal/peer/conn_status.go b/client/internal/peer/conn_status.go index 73acc5ef5..b43e245f3 100644 --- a/client/internal/peer/conn_status.go +++ b/client/internal/peer/conn_status.go @@ -13,6 +13,20 @@ const ( StatusConnected ) +// connStatusInputs is the primitive-valued snapshot of the state that drives the +// tri-state connection classification. Extracted so the decision logic can be unit-tested +// without constructing full Worker/Handshaker objects. +type connStatusInputs struct { + forceRelay bool // NB_FORCE_RELAY or JS/WASM + peerUsesRelay bool // remote peer advertises relay support AND local has relay + relayConnected bool // statusRelay reports Connected (independent of whether peer uses relay) + remoteSupportsICE bool // remote peer sent ICE credentials + iceWorkerCreated bool // local WorkerICE exists (false in force-relay mode) + iceStatusConnecting bool // statusICE is anything other than Disconnected + iceInProgress bool // a negotiation is currently in flight +} + + // ConnStatus describe the status of a peer's connection type ConnStatus int32 diff --git a/client/internal/peer/conn_status_eval_test.go b/client/internal/peer/conn_status_eval_test.go new file mode 100644 index 000000000..66393cafe --- /dev/null +++ b/client/internal/peer/conn_status_eval_test.go @@ -0,0 +1,201 @@ +package peer + +import ( + "testing" + + "github.com/netbirdio/netbird/client/internal/peer/guard" +) + +func TestEvalConnStatus_ForceRelay(t *testing.T) { + tests := []struct { + name string + in connStatusInputs + want guard.ConnStatus + }{ + { + name: "force relay, peer uses relay, relay up", + in: connStatusInputs{ + forceRelay: true, + peerUsesRelay: true, + relayConnected: true, + }, + want: guard.ConnStatusConnected, + }, + { + name: "force relay, peer uses relay, relay down", + in: connStatusInputs{ + forceRelay: true, + peerUsesRelay: true, + relayConnected: false, + }, + want: guard.ConnStatusDisconnected, + }, + { + name: "force relay, peer does NOT use relay - disconnected forever", + in: connStatusInputs{ + forceRelay: true, + peerUsesRelay: false, + relayConnected: true, + }, + want: guard.ConnStatusDisconnected, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := evalConnStatus(tc.in); got != tc.want { + t.Fatalf("evalConnStatus = %v, want %v", got, tc.want) + } + }) + } +} + +func TestEvalConnStatus_ICEUnavailable(t *testing.T) { + tests := []struct { + name string + in connStatusInputs + want guard.ConnStatus + }{ + { + name: "remote does not support ICE, peer uses relay, relay up", + in: connStatusInputs{ + peerUsesRelay: true, + relayConnected: true, + remoteSupportsICE: false, + iceWorkerCreated: true, + }, + want: guard.ConnStatusConnected, + }, + { + name: "remote does not support ICE, peer uses relay, relay down", + in: connStatusInputs{ + peerUsesRelay: true, + relayConnected: false, + remoteSupportsICE: false, + iceWorkerCreated: true, + }, + want: guard.ConnStatusDisconnected, + }, + { + name: "ICE worker not yet created, relay up", + in: connStatusInputs{ + peerUsesRelay: true, + relayConnected: true, + remoteSupportsICE: true, + iceWorkerCreated: false, + }, + want: guard.ConnStatusConnected, + }, + { + name: "remote does not support ICE, peer does not use relay", + in: connStatusInputs{ + peerUsesRelay: false, + relayConnected: false, + remoteSupportsICE: false, + iceWorkerCreated: true, + }, + want: guard.ConnStatusDisconnected, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := evalConnStatus(tc.in); got != tc.want { + t.Fatalf("evalConnStatus = %v, want %v", got, tc.want) + } + }) + } +} + +func TestEvalConnStatus_FullyAvailable(t *testing.T) { + base := connStatusInputs{ + remoteSupportsICE: true, + iceWorkerCreated: true, + } + + tests := []struct { + name string + mutator func(*connStatusInputs) + want guard.ConnStatus + }{ + { + name: "ICE connected, relay connected, peer uses relay", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = true + in.relayConnected = true + in.iceStatusConnecting = true + }, + want: guard.ConnStatusConnected, + }, + { + name: "ICE connected, peer does NOT use relay", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = false + in.relayConnected = false + in.iceStatusConnecting = true + }, + want: guard.ConnStatusConnected, + }, + { + name: "ICE InProgress only, peer does NOT use relay", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = false + in.iceStatusConnecting = false + in.iceInProgress = true + }, + want: guard.ConnStatusConnected, + }, + { + name: "ICE down, relay up, peer uses relay -> partial", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = true + in.relayConnected = true + in.iceStatusConnecting = false + in.iceInProgress = false + }, + want: guard.ConnStatusPartiallyConnected, + }, + { + name: "ICE down, peer does NOT use relay -> disconnected", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = false + in.relayConnected = false + in.iceStatusConnecting = false + in.iceInProgress = false + }, + want: guard.ConnStatusDisconnected, + }, + { + name: "ICE up, peer uses relay but relay down -> partial (relay required, ICE ignored)", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = true + in.relayConnected = false + in.iceStatusConnecting = true + }, + // relayOK = false (peer uses relay but it's down), iceUp = true + // first switch arm fails (relayOK false), relayUsedAndUp = false (relay down), + // falls into default: Disconnected. + want: guard.ConnStatusDisconnected, + }, + { + name: "ICE down, relay up but peer does not use relay -> disconnected", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = false + in.relayConnected = true // not actually used since peer doesn't rely on it + in.iceStatusConnecting = false + in.iceInProgress = false + }, + want: guard.ConnStatusDisconnected, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + in := base + tc.mutator(&in) + if got := evalConnStatus(in); got != tc.want { + t.Fatalf("evalConnStatus = %v, want %v (inputs: %+v)", got, tc.want, in) + } + }) + } +} diff --git a/client/internal/peer/env.go b/client/internal/peer/env.go index 7f500c410..ed6a3af53 100644 --- a/client/internal/peer/env.go +++ b/client/internal/peer/env.go @@ -7,12 +7,38 @@ import ( ) const ( - EnvKeyNBForceRelay = "NB_FORCE_RELAY" + EnvKeyNBForceRelay = "NB_FORCE_RELAY" + EnvKeyNBHomeRelayServers = "NB_HOME_RELAY_SERVERS" ) -func isForceRelayed() bool { +func IsForceRelayed() bool { if runtime.GOOS == "js" { return true } return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true") } + +// OverrideRelayURLs returns the relay server URL list set in +// NB_HOME_RELAY_SERVERS (comma-separated) and a boolean indicating whether +// the override is active. When the env var is unset, the boolean is false +// and the caller should keep the list received from the management server. +// Intended for lab/debug scenarios where a peer must pin to a specific home +// relay regardless of what management offers. +func OverrideRelayURLs() ([]string, bool) { + raw := os.Getenv(EnvKeyNBHomeRelayServers) + if raw == "" { + return nil, false + } + parts := strings.Split(raw, ",") + urls := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + urls = append(urls, p) + } + } + if len(urls) == 0 { + return nil, false + } + return urls, true +} diff --git a/client/internal/peer/guard/guard.go b/client/internal/peer/guard/guard.go index d93403730..2e5efbcc5 100644 --- a/client/internal/peer/guard/guard.go +++ b/client/internal/peer/guard/guard.go @@ -8,7 +8,19 @@ import ( log "github.com/sirupsen/logrus" ) -type isConnectedFunc func() bool +// ConnStatus represents the connection state as seen by the guard. +type ConnStatus int + +const ( + // ConnStatusDisconnected means neither ICE nor Relay is connected. + ConnStatusDisconnected ConnStatus = iota + // ConnStatusPartiallyConnected means Relay is connected but ICE is not. + ConnStatusPartiallyConnected + // ConnStatusConnected means all required connections are established. + ConnStatusConnected +) + +type connStatusFunc func() ConnStatus // Guard is responsible for the reconnection logic. // It will trigger to send an offer to the peer then has connection issues. @@ -20,14 +32,14 @@ type isConnectedFunc func() bool // - ICE candidate changes type Guard struct { log *log.Entry - isConnectedOnAllWay isConnectedFunc + isConnectedOnAllWay connStatusFunc timeout time.Duration srWatcher *SRWatcher relayedConnDisconnected chan struct{} iCEConnDisconnected chan struct{} } -func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { +func NewGuard(log *log.Entry, isConnectedFn connStatusFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { return &Guard{ log: log, isConnectedOnAllWay: isConnectedFn, @@ -57,8 +69,17 @@ func (g *Guard) SetICEConnDisconnected() { } } -// reconnectLoopWithRetry periodically check the connection status. -// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported +// reconnectLoopWithRetry periodically checks the connection status and sends offers to re-establish connectivity. +// +// Behavior depends on the connection state reported by isConnectedOnAllWay: +// - Connected: no action, the peer is fully reachable. +// - Disconnected (neither ICE nor Relay): retries aggressively with exponential backoff (800ms doubling +// up to timeout), never gives up. This ensures rapid recovery when the peer has no connectivity at all. +// - PartiallyConnected (Relay up, ICE not): retries up to 3 times with exponential backoff, then switches +// to one attempt per hour. This limits signaling traffic when relay already provides connectivity. +// +// External events (relay/ICE disconnect, signal/relay reconnect, candidate changes) reset the retry +// counter and backoff ticker, giving ICE a fresh chance after network conditions change. func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) { srReconnectedChan := g.srWatcher.NewListener() defer g.srWatcher.RemoveListener(srReconnectedChan) @@ -68,36 +89,47 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) { tickerChannel := ticker.C + iceState := &iceRetryState{log: g.log} + defer iceState.reset() + for { select { - case t := <-tickerChannel: - if t.IsZero() { - g.log.Infof("retry timed out, stop periodic offer sending") - // after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop - tickerChannel = make(<-chan time.Time) - continue + case <-tickerChannel: + switch g.isConnectedOnAllWay() { + case ConnStatusConnected: + // all good, nothing to do + case ConnStatusDisconnected: + callback() + case ConnStatusPartiallyConnected: + if iceState.shouldRetry() { + callback() + } else { + iceState.enterHourlyMode() + ticker.Stop() + tickerChannel = iceState.hourlyC() + } } - if !g.isConnectedOnAllWay() { - callback() - } case <-g.relayedConnDisconnected: g.log.Debugf("Relay connection changed, reset reconnection ticker") ticker.Stop() - ticker = g.prepareExponentTicker(ctx) + ticker = g.newReconnectTicker(ctx) tickerChannel = ticker.C + iceState.reset() case <-g.iCEConnDisconnected: g.log.Debugf("ICE connection changed, reset reconnection ticker") ticker.Stop() - ticker = g.prepareExponentTicker(ctx) + ticker = g.newReconnectTicker(ctx) tickerChannel = ticker.C + iceState.reset() case <-srReconnectedChan: g.log.Debugf("has network changes, reset reconnection ticker") ticker.Stop() - ticker = g.prepareExponentTicker(ctx) + ticker = g.newReconnectTicker(ctx) tickerChannel = ticker.C + iceState.reset() case <-ctx.Done(): g.log.Debugf("context is done, stop reconnect loop") @@ -120,7 +152,7 @@ func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker { return backoff.NewTicker(bo) } -func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker { +func (g *Guard) newReconnectTicker(ctx context.Context) *backoff.Ticker { bo := backoff.WithContext(&backoff.ExponentialBackOff{ InitialInterval: 800 * time.Millisecond, RandomizationFactor: 0.1, diff --git a/client/internal/peer/guard/ice_retry_state.go b/client/internal/peer/guard/ice_retry_state.go new file mode 100644 index 000000000..01dc1bf2d --- /dev/null +++ b/client/internal/peer/guard/ice_retry_state.go @@ -0,0 +1,61 @@ +package guard + +import ( + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + // maxICERetries is the maximum number of ICE offer attempts when relay is connected + maxICERetries = 3 + // iceRetryInterval is the periodic retry interval after ICE retries are exhausted + iceRetryInterval = 1 * time.Hour +) + +// iceRetryState tracks the limited ICE retry attempts when relay is already connected. +// After maxICERetries attempts it switches to a periodic hourly retry. +type iceRetryState struct { + log *log.Entry + retries int + hourly *time.Ticker +} + +func (s *iceRetryState) reset() { + s.retries = 0 + if s.hourly != nil { + s.hourly.Stop() + s.hourly = nil + } +} + +// shouldRetry reports whether the caller should send another ICE offer on this tick. +// Returns false when the per-cycle retry budget is exhausted and the caller must switch +// to the hourly ticker via enterHourlyMode + hourlyC. +func (s *iceRetryState) shouldRetry() bool { + if s.hourly != nil { + s.log.Debugf("hourly ICE retry attempt") + return true + } + + s.retries++ + if s.retries <= maxICERetries { + s.log.Debugf("ICE retry attempt %d/%d", s.retries, maxICERetries) + return true + } + + return false +} + +// enterHourlyMode starts the hourly retry ticker. Must be called after shouldRetry returns false. +func (s *iceRetryState) enterHourlyMode() { + s.log.Infof("ICE retries exhausted (%d/%d), switching to hourly retry", maxICERetries, maxICERetries) + s.hourly = time.NewTicker(iceRetryInterval) +} + +func (s *iceRetryState) hourlyC() <-chan time.Time { + if s.hourly == nil { + return nil + } + return s.hourly.C +} diff --git a/client/internal/peer/guard/ice_retry_state_test.go b/client/internal/peer/guard/ice_retry_state_test.go new file mode 100644 index 000000000..6a5b5a76f --- /dev/null +++ b/client/internal/peer/guard/ice_retry_state_test.go @@ -0,0 +1,103 @@ +package guard + +import ( + "testing" + + log "github.com/sirupsen/logrus" +) + +func newTestRetryState() *iceRetryState { + return &iceRetryState{log: log.NewEntry(log.StandardLogger())} +} + +func TestICERetryState_AllowsInitialBudget(t *testing.T) { + s := newTestRetryState() + + for i := 1; i <= maxICERetries; i++ { + if !s.shouldRetry() { + t.Fatalf("shouldRetry returned false on attempt %d, want true (budget = %d)", i, maxICERetries) + } + } +} + +func TestICERetryState_ExhaustsAfterBudget(t *testing.T) { + s := newTestRetryState() + + for i := 0; i < maxICERetries; i++ { + _ = s.shouldRetry() + } + + if s.shouldRetry() { + t.Fatalf("shouldRetry returned true after budget exhausted, want false") + } +} + +func TestICERetryState_HourlyCNilBeforeEnterHourlyMode(t *testing.T) { + s := newTestRetryState() + + if s.hourlyC() != nil { + t.Fatalf("hourlyC returned non-nil channel before enterHourlyMode") + } +} + +func TestICERetryState_EnterHourlyModeArmsTicker(t *testing.T) { + s := newTestRetryState() + for i := 0; i < maxICERetries+1; i++ { + _ = s.shouldRetry() + } + + s.enterHourlyMode() + defer s.reset() + + if s.hourlyC() == nil { + t.Fatalf("hourlyC returned nil after enterHourlyMode") + } +} + +func TestICERetryState_ShouldRetryTrueInHourlyMode(t *testing.T) { + s := newTestRetryState() + s.enterHourlyMode() + defer s.reset() + + if !s.shouldRetry() { + t.Fatalf("shouldRetry returned false in hourly mode, want true") + } + + // Subsequent calls also return true — we keep retrying on each hourly tick. + if !s.shouldRetry() { + t.Fatalf("second shouldRetry returned false in hourly mode, want true") + } +} + +func TestICERetryState_ResetRestoresBudget(t *testing.T) { + s := newTestRetryState() + for i := 0; i < maxICERetries+1; i++ { + _ = s.shouldRetry() + } + s.enterHourlyMode() + + s.reset() + + if s.hourlyC() != nil { + t.Fatalf("hourlyC returned non-nil channel after reset") + } + if s.retries != 0 { + t.Fatalf("retries = %d after reset, want 0", s.retries) + } + + for i := 1; i <= maxICERetries; i++ { + if !s.shouldRetry() { + t.Fatalf("shouldRetry returned false on attempt %d after reset, want true", i) + } + } +} + +func TestICERetryState_ResetIsIdempotent(t *testing.T) { + s := newTestRetryState() + s.reset() + s.reset() // second call must not panic or re-stop a nil ticker + + if s.hourlyC() != nil { + t.Fatalf("hourlyC non-nil after double reset") + } +} diff --git a/client/internal/peer/guard/sr_watcher.go b/client/internal/peer/guard/sr_watcher.go index 6f4f5ad4f..0befd7438 100644 --- a/client/internal/peer/guard/sr_watcher.go +++ b/client/internal/peer/guard/sr_watcher.go @@ -39,7 +39,7 @@ func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscove return srw } -func (w *SRWatcher) Start() { +func (w *SRWatcher) Start(disableICEMonitor bool) { w.mu.Lock() defer w.mu.Unlock() @@ -50,8 +50,10 @@ func (w *SRWatcher) Start() { ctx, cancel := context.WithCancel(context.Background()) w.cancelIceMonitor = cancel - iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod()) - go iceMonitor.Start(ctx, w.onICEChanged) + if !disableICEMonitor { + iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod()) + go iceMonitor.Start(ctx, w.onICEChanged) + } w.signalClient.SetOnReconnectedListener(w.onReconnected) w.relayManager.SetOnReconnectedListener(w.onReconnected) diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go index 9b50cecd1..1d44096b6 100644 --- a/client/internal/peer/handshaker.go +++ b/client/internal/peer/handshaker.go @@ -3,7 +3,9 @@ package peer import ( "context" "errors" + "net/netip" "sync" + "sync/atomic" log "github.com/sirupsen/logrus" @@ -39,10 +41,18 @@ type OfferAnswer struct { // relay server address RelaySrvAddress string + // RelaySrvIP is the IP the remote peer is connected to on its + // relay server. Used as a dial target if DNS for RelaySrvAddress + // fails. Zero value if the peer did not advertise an IP. + RelaySrvIP netip.Addr // SessionID is the unique identifier of the session, used to discard old messages SessionID *ICESessionID } +func (o *OfferAnswer) hasICECredentials() bool { + return o.IceCredentials.UFrag != "" && o.IceCredentials.Pwd != "" +} + type Handshaker struct { mu sync.Mutex log *log.Entry @@ -59,6 +69,10 @@ type Handshaker struct { relayListener *AsyncOfferListener iceListener func(remoteOfferAnswer *OfferAnswer) + // remoteICESupported tracks whether the remote peer includes ICE credentials in its offers/answers. + // When false, the local side skips ICE listener dispatch and suppresses ICE credentials in responses. + remoteICESupported atomic.Bool + // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection remoteOffersCh chan OfferAnswer // remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection @@ -66,7 +80,7 @@ type Handshaker struct { } func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker { - return &Handshaker{ + h := &Handshaker{ log: log, config: config, signaler: signaler, @@ -76,6 +90,13 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W remoteOffersCh: make(chan OfferAnswer), remoteAnswerCh: make(chan OfferAnswer), } + // assume remote supports ICE until we learn otherwise from received offers + h.remoteICESupported.Store(ice != nil) + return h +} + +func (h *Handshaker) RemoteICESupported() bool { + return h.remoteICESupported.Load() } func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) { @@ -90,18 +111,20 @@ func (h *Handshaker) Listen(ctx context.Context) { for { select { case remoteOfferAnswer := <-h.remoteOffersCh: - h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) + h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials()) // Record signaling received for reconnection attempts if h.metricsStages != nil { h.metricsStages.RecordSignalingReceived() } + h.updateRemoteICEState(&remoteOfferAnswer) + if h.relayListener != nil { h.relayListener.Notify(&remoteOfferAnswer) } - if h.iceListener != nil { + if h.iceListener != nil && h.RemoteICESupported() { h.iceListener(&remoteOfferAnswer) } @@ -110,18 +133,20 @@ func (h *Handshaker) Listen(ctx context.Context) { continue } case remoteOfferAnswer := <-h.remoteAnswerCh: - h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) + h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials()) // Record signaling received for reconnection attempts if h.metricsStages != nil { h.metricsStages.RecordSignalingReceived() } + h.updateRemoteICEState(&remoteOfferAnswer) + if h.relayListener != nil { h.relayListener.Notify(&remoteOfferAnswer) } - if h.iceListener != nil { + if h.iceListener != nil && h.RemoteICESupported() { h.iceListener(&remoteOfferAnswer) } case <-ctx.Done(): @@ -183,20 +208,39 @@ func (h *Handshaker) sendAnswer() error { } func (h *Handshaker) buildOfferAnswer() OfferAnswer { - uFrag, pwd := h.ice.GetLocalUserCredentials() - sid := h.ice.SessionID() answer := OfferAnswer{ - IceCredentials: IceCredentials{uFrag, pwd}, WgListenPort: h.config.LocalWgPort, Version: version.NetbirdVersion(), RosenpassPubKey: h.config.RosenpassConfig.PubKey, RosenpassAddr: h.config.RosenpassConfig.Addr, - SessionID: &sid, } - if addr, err := h.relay.RelayInstanceAddress(); err == nil { + if h.ice != nil && h.RemoteICESupported() { + uFrag, pwd := h.ice.GetLocalUserCredentials() + sid := h.ice.SessionID() + answer.IceCredentials = IceCredentials{uFrag, pwd} + answer.SessionID = &sid + } + + if addr, ip, err := h.relay.RelayInstanceAddress(); err == nil { answer.RelaySrvAddress = addr + answer.RelaySrvIP = ip } return answer } + +func (h *Handshaker) updateRemoteICEState(offer *OfferAnswer) { + hasICE := offer.hasICECredentials() + prev := h.remoteICESupported.Swap(hasICE) + if prev != hasICE { + if hasICE { + h.log.Infof("remote peer started sending ICE credentials") + } else { + h.log.Infof("remote peer stopped sending ICE credentials") + if h.ice != nil { + h.ice.Close() + } + } + } +} diff --git a/client/internal/peer/notifier_test.go b/client/internal/peer/notifier_test.go index bbdc00e13..0b7722b0c 100644 --- a/client/internal/peer/notifier_test.go +++ b/client/internal/peer/notifier_test.go @@ -8,6 +8,7 @@ import ( type mocListener struct { lastState int wg sync.WaitGroup + peersWg sync.WaitGroup peers int } @@ -33,6 +34,7 @@ func (l *mocListener) OnAddressChanged(host, addr string) { } func (l *mocListener) OnPeersListChanged(size int) { l.peers = size + l.peersWg.Done() } func (l *mocListener) setWaiter() { @@ -43,6 +45,14 @@ func (l *mocListener) wait() { l.wg.Wait() } +func (l *mocListener) setPeersWaiter() { + l.peersWg.Add(1) +} + +func (l *mocListener) waitPeers() { + l.peersWg.Wait() +} + func Test_notifier_serverState(t *testing.T) { type scenario struct { @@ -72,11 +82,13 @@ func Test_notifier_serverState(t *testing.T) { func Test_notifier_SetListener(t *testing.T) { listener := &mocListener{} listener.setWaiter() + listener.setPeersWaiter() n := newNotifier() n.lastNotification = stateConnecting n.setListener(listener) listener.wait() + listener.waitPeers() if listener.lastState != n.lastNotification { t.Errorf("invalid state: %d, expected: %d", listener.lastState, n.lastNotification) } @@ -85,9 +97,14 @@ func Test_notifier_SetListener(t *testing.T) { func Test_notifier_RemoveListener(t *testing.T) { listener := &mocListener{} listener.setWaiter() + listener.setPeersWaiter() n := newNotifier() n.lastNotification = stateConnecting n.setListener(listener) + // setListener replays cached state on a goroutine; wait for both the state + // and peers callbacks to finish so we don't race on listener.peers. + listener.wait() + listener.waitPeers() n.removeListener() n.peerListChanged(1) diff --git a/client/internal/peer/signaler.go b/client/internal/peer/signaler.go index b28906625..5e437d96b 100644 --- a/client/internal/peer/signaler.go +++ b/client/internal/peer/signaler.go @@ -46,23 +46,27 @@ func (s *Signaler) Ready() bool { // SignalOfferAnswer signals either an offer or an answer to remote peer func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error { - sessionIDBytes, err := offerAnswer.SessionID.Bytes() - if err != nil { - log.Warnf("failed to get session ID bytes: %v", err) + var sessionIDBytes []byte + if offerAnswer.SessionID != nil { + var err error + sessionIDBytes, err = offerAnswer.SessionID.Bytes() + if err != nil { + log.Warnf("failed to get session ID bytes: %v", err) + } } - msg, err := signal.MarshalCredential( - s.wgPrivateKey, - offerAnswer.WgListenPort, - remoteKey, - &signal.Credential{ + msg, err := signal.MarshalCredential(s.wgPrivateKey, remoteKey, signal.CredentialPayload{ + Type: bodyType, + WgListenPort: offerAnswer.WgListenPort, + Credential: &signal.Credential{ UFrag: offerAnswer.IceCredentials.UFrag, Pwd: offerAnswer.IceCredentials.Pwd, }, - bodyType, - offerAnswer.RosenpassPubKey, - offerAnswer.RosenpassAddr, - offerAnswer.RelaySrvAddress, - sessionIDBytes) + RosenpassPubKey: offerAnswer.RosenpassPubKey, + RosenpassAddr: offerAnswer.RosenpassAddr, + RelaySrvAddress: offerAnswer.RelaySrvAddress, + RelaySrvIP: offerAnswer.RelaySrvIP, + SessionID: sessionIDBytes, + }) if err != nil { return err } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index abedc208e..df746fa13 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -53,6 +53,7 @@ type RouterState struct { type State struct { Mux *sync.RWMutex IP string + IPv6 string PubKey string FQDN string ConnStatus ConnStatus @@ -106,6 +107,7 @@ func (s *State) GetRoutes() map[string]struct{} { // LocalPeerState contains the latest state of the local peer type LocalPeerState struct { IP string + IPv6 string PubKey string KernelInterface bool FQDN string @@ -259,7 +261,7 @@ func (d *Status) ReplaceOfflinePeers(replacement []State) { } // AddPeer adds peer to Daemon status map -func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string) error { +func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string) error { d.mux.Lock() defer d.mux.Unlock() @@ -270,6 +272,7 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string) error { d.peers[peerPubKey] = State{ PubKey: peerPubKey, IP: ip, + IPv6: ipv6, ConnStatus: StatusIdle, FQDN: fqdn, Mux: new(sync.RWMutex), @@ -320,10 +323,10 @@ func (d *Status) RemovePeer(peerPubKey string) error { // UpdatePeerState updates peer status func (d *Status) UpdatePeerState(receivedState State) error { d.mux.Lock() - defer d.mux.Unlock() peerState, ok := d.peers[receivedState.PubKey] if !ok { + d.mux.Unlock() return errors.New("peer doesn't exist") } @@ -343,23 +346,29 @@ func (d *Status) UpdatePeerState(receivedState State) error { d.peers[receivedState.PubKey] = peerState - if hasConnStatusChanged(oldState, receivedState.ConnStatus) { - d.notifyPeerListChanged() - } - + notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus) // when we close the connection we will not notify the router manager - if receivedState.ConnStatus == StatusIdle { - d.notifyPeerStateChangeListeners(receivedState.PubKey) + notifyRouter := receivedState.ConnStatus == StatusIdle + routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter) + numPeers := d.numOfPeers() + + d.mux.Unlock() + + if notifyList { + d.notifier.peerListChanged(numPeers) + } + if notifyRouter { + d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot) } return nil } func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.ResID) error { d.mux.Lock() - defer d.mux.Unlock() peerState, ok := d.peers[peer] if !ok { + d.mux.Unlock() return errors.New("peer doesn't exist") } @@ -371,17 +380,20 @@ func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.R d.routeIDLookup.AddRemoteRouteID(resourceId, pref) } + numPeers := d.numOfPeers() + d.mux.Unlock() + // todo: consider to make sense of this notification or not - d.notifyPeerListChanged() + d.notifier.peerListChanged(numPeers) return nil } func (d *Status) RemovePeerStateRoute(peer string, route string) error { d.mux.Lock() - defer d.mux.Unlock() peerState, ok := d.peers[peer] if !ok { + d.mux.Unlock() return errors.New("peer doesn't exist") } @@ -393,8 +405,11 @@ func (d *Status) RemovePeerStateRoute(peer string, route string) error { d.routeIDLookup.RemoveRemoteRouteID(pref) } + numPeers := d.numOfPeers() + d.mux.Unlock() + // todo: consider to make sense of this notification or not - d.notifyPeerListChanged() + d.notifier.peerListChanged(numPeers) return nil } @@ -410,10 +425,10 @@ func (d *Status) CheckRoutes(ip netip.Addr) ([]byte, bool) { func (d *Status) UpdatePeerICEState(receivedState State) error { d.mux.Lock() - defer d.mux.Unlock() peerState, ok := d.peers[receivedState.PubKey] if !ok { + d.mux.Unlock() return errors.New("peer doesn't exist") } @@ -431,22 +446,28 @@ func (d *Status) UpdatePeerICEState(receivedState State) error { d.peers[receivedState.PubKey] = peerState - if hasConnStatusChanged(oldState, receivedState.ConnStatus) { - d.notifyPeerListChanged() - } + notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus) + notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) + routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter) + numPeers := d.numOfPeers() - if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) { - d.notifyPeerStateChangeListeners(receivedState.PubKey) + d.mux.Unlock() + + if notifyList { + d.notifier.peerListChanged(numPeers) + } + if notifyRouter { + d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot) } return nil } func (d *Status) UpdatePeerRelayedState(receivedState State) error { d.mux.Lock() - defer d.mux.Unlock() peerState, ok := d.peers[receivedState.PubKey] if !ok { + d.mux.Unlock() return errors.New("peer doesn't exist") } @@ -461,22 +482,28 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error { d.peers[receivedState.PubKey] = peerState - if hasConnStatusChanged(oldState, receivedState.ConnStatus) { - d.notifyPeerListChanged() - } + notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus) + notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) + routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter) + numPeers := d.numOfPeers() - if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) { - d.notifyPeerStateChangeListeners(receivedState.PubKey) + d.mux.Unlock() + + if notifyList { + d.notifier.peerListChanged(numPeers) + } + if notifyRouter { + d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot) } return nil } func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error { d.mux.Lock() - defer d.mux.Unlock() peerState, ok := d.peers[receivedState.PubKey] if !ok { + d.mux.Unlock() return errors.New("peer doesn't exist") } @@ -490,22 +517,28 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error d.peers[receivedState.PubKey] = peerState - if hasConnStatusChanged(oldState, receivedState.ConnStatus) { - d.notifyPeerListChanged() - } + notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus) + notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) + routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter) + numPeers := d.numOfPeers() - if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) { - d.notifyPeerStateChangeListeners(receivedState.PubKey) + d.mux.Unlock() + + if notifyList { + d.notifier.peerListChanged(numPeers) + } + if notifyRouter { + d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot) } return nil } func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error { d.mux.Lock() - defer d.mux.Unlock() peerState, ok := d.peers[receivedState.PubKey] if !ok { + d.mux.Unlock() return errors.New("peer doesn't exist") } @@ -522,12 +555,18 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error { d.peers[receivedState.PubKey] = peerState - if hasConnStatusChanged(oldState, receivedState.ConnStatus) { - d.notifyPeerListChanged() - } + notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus) + notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) + routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter) + numPeers := d.numOfPeers() - if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) { - d.notifyPeerStateChangeListeners(receivedState.PubKey) + d.mux.Unlock() + + if notifyList { + d.notifier.peerListChanged(numPeers) + } + if notifyRouter { + d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot) } return nil } @@ -594,17 +633,33 @@ func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) erro // FinishPeerListModifications this event invoke the notification func (d *Status) FinishPeerListModifications() { d.mux.Lock() - defer d.mux.Unlock() if !d.peerListChangedForNotification { + d.mux.Unlock() return } d.peerListChangedForNotification = false - d.notifyPeerListChanged() + numPeers := d.numOfPeers() + // snapshot per-peer router state to deliver after the lock is released + type routerDispatch struct { + peerID string + snapshot map[string]RouterState + } + dispatches := make([]routerDispatch, 0, len(d.peers)) for key := range d.peers { - d.notifyPeerStateChangeListeners(key) + snapshot := d.snapshotRouterPeersLocked(key, true) + if snapshot != nil { + dispatches = append(dispatches, routerDispatch{peerID: key, snapshot: snapshot}) + } + } + + d.mux.Unlock() + + d.notifier.peerListChanged(numPeers) + for _, rd := range dispatches { + d.dispatchRouterPeers(rd.peerID, rd.snapshot) } } @@ -655,10 +710,15 @@ func (d *Status) GetLocalPeerState() LocalPeerState { // UpdateLocalPeerState updates local peer status func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) { d.mux.Lock() - defer d.mux.Unlock() - d.localPeer = localPeerState - d.notifyAddressChanged() + fqdn := d.localPeer.FQDN + ip := d.localPeer.IP + if d.localPeer.IPv6 != "" { + ip = ip + "\n" + d.localPeer.IPv6 + } + d.mux.Unlock() + + d.notifier.localAddressChanged(fqdn, ip) } // AddLocalPeerStateRoute adds a route to the local peer state @@ -721,30 +781,36 @@ func (d *Status) CleanLocalPeerStateRoutes() { // CleanLocalPeerState cleans local peer status func (d *Status) CleanLocalPeerState() { d.mux.Lock() - defer d.mux.Unlock() - d.localPeer = LocalPeerState{} - d.notifyAddressChanged() + fqdn := d.localPeer.FQDN + ip := d.localPeer.IP + d.mux.Unlock() + + d.notifier.localAddressChanged(fqdn, ip) } // MarkManagementDisconnected sets ManagementState to disconnected func (d *Status) MarkManagementDisconnected(err error) { d.mux.Lock() - defer d.mux.Unlock() - defer d.onConnectionChanged() - d.managementState = false d.managementError = err + mgm := d.managementState + sig := d.signalState + d.mux.Unlock() + + d.notifier.updateServerStates(mgm, sig) } // MarkManagementConnected sets ManagementState to connected func (d *Status) MarkManagementConnected() { d.mux.Lock() - defer d.mux.Unlock() - defer d.onConnectionChanged() - d.managementState = true d.managementError = nil + mgm := d.managementState + sig := d.signalState + d.mux.Unlock() + + d.notifier.updateServerStates(mgm, sig) } // UpdateSignalAddress update the address of the signal server @@ -778,21 +844,25 @@ func (d *Status) UpdateLazyConnection(enabled bool) { // MarkSignalDisconnected sets SignalState to disconnected func (d *Status) MarkSignalDisconnected(err error) { d.mux.Lock() - defer d.mux.Unlock() - defer d.onConnectionChanged() - d.signalState = false d.signalError = err + mgm := d.managementState + sig := d.signalState + d.mux.Unlock() + + d.notifier.updateServerStates(mgm, sig) } // MarkSignalConnected sets SignalState to connected func (d *Status) MarkSignalConnected() { d.mux.Lock() - defer d.mux.Unlock() - defer d.onConnectionChanged() - d.signalState = true d.signalError = nil + mgm := d.managementState + sig := d.signalState + d.mux.Unlock() + + d.notifier.updateServerStates(mgm, sig) } func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) { @@ -919,7 +989,7 @@ func (d *Status) GetRelayStates() []relay.ProbeResult { // if the server connection is not established then we will use the general address // in case of connection we will use the instance specific address - instanceAddr, err := d.relayMgr.RelayInstanceAddress() + instanceAddr, _, err := d.relayMgr.RelayInstanceAddress() if err != nil { // TODO add their status for _, r := range d.relayMgr.ServerURLs() { @@ -1012,18 +1082,17 @@ func (d *Status) RemoveConnectionListener() { d.notifier.removeListener() } -func (d *Status) onConnectionChanged() { - d.notifier.updateServerStates(d.managementState, d.signalState) -} - -// notifyPeerStateChangeListeners notifies route manager about the change in peer state -func (d *Status) notifyPeerStateChangeListeners(peerID string) { - subs, ok := d.changeNotify[peerID] - if !ok { - return +// snapshotRouterPeersLocked builds the RouterState map for a peer's subscribers. +// Caller MUST hold d.mux. Returns nil when there are no subscribers for peerID +// or when notify is false. The snapshot is consumed later by dispatchRouterPeers +// outside the lock so the channel send cannot stall any d.mux holder. +func (d *Status) snapshotRouterPeersLocked(peerID string, notify bool) map[string]RouterState { + if !notify { + return nil + } + if _, ok := d.changeNotify[peerID]; !ok { + return nil } - - // collect the relevant data for router peers routerPeers := make(map[string]RouterState, len(d.changeNotify)) for pid := range d.changeNotify { s, ok := d.peers[pid] @@ -1031,13 +1100,35 @@ func (d *Status) notifyPeerStateChangeListeners(peerID string) { log.Warnf("router peer not found in peers list: %s", pid) continue } - routerPeers[pid] = RouterState{ Status: s.ConnStatus, Relayed: s.Relayed, Latency: s.Latency, } } + return routerPeers +} + +// dispatchRouterPeers delivers a previously snapshotted router-state map to +// the peer's subscribers. Caller MUST NOT hold d.mux. The method takes a +// fresh, short read of d.changeNotify under the lock to grab subscriber +// channels, then sends outside the lock so a slow consumer cannot block other +// d.mux holders. The send itself stays blocking (only short-circuited by the +// subscriber's context) so peer state transitions are not silently dropped. +func (d *Status) dispatchRouterPeers(peerID string, routerPeers map[string]RouterState) { + if routerPeers == nil { + return + } + + d.mux.Lock() + subsMap, ok := d.changeNotify[peerID] + subs := make([]*StatusChangeSubscription, 0, len(subsMap)) + if ok { + for _, sub := range subsMap { + subs = append(subs, sub) + } + } + d.mux.Unlock() for _, sub := range subs { select { @@ -1047,14 +1138,6 @@ func (d *Status) notifyPeerStateChangeListeners(peerID string) { } } -func (d *Status) notifyPeerListChanged() { - d.notifier.peerListChanged(d.numOfPeers()) -} - -func (d *Status) notifyAddressChanged() { - d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP) -} - func (d *Status) numOfPeers() int { return len(d.peers) + len(d.offlinePeers) } @@ -1239,6 +1322,7 @@ func (fs FullStatus) ToProto() *proto.FullStatus { } pbFullStatus.LocalPeerState.IP = fs.LocalPeerState.IP + pbFullStatus.LocalPeerState.Ipv6 = fs.LocalPeerState.IPv6 pbFullStatus.LocalPeerState.PubKey = fs.LocalPeerState.PubKey pbFullStatus.LocalPeerState.KernelInterface = fs.LocalPeerState.KernelInterface pbFullStatus.LocalPeerState.Fqdn = fs.LocalPeerState.FQDN @@ -1254,6 +1338,7 @@ func (fs FullStatus) ToProto() *proto.FullStatus { pbPeerState := &proto.PeerState{ IP: peerState.IP, + Ipv6: peerState.IPv6, PubKey: peerState.PubKey, ConnStatus: peerState.ConnStatus.String(), ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate), diff --git a/client/internal/peer/status_test.go b/client/internal/peer/status_test.go index 272638750..9bafca55a 100644 --- a/client/internal/peer/status_test.go +++ b/client/internal/peer/status_test.go @@ -8,19 +8,20 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestAddPeer(t *testing.T) { key := "abc" ip := "100.108.254.1" status := NewRecorder("https://mgm") - err := status.AddPeer(key, "abc.netbird", ip) + err := status.AddPeer(key, "abc.netbird", ip, "") assert.NoError(t, err, "shouldn't return error") _, exists := status.peers[key] assert.True(t, exists, "value was found") - err = status.AddPeer(key, "abc.netbird", ip) + err = status.AddPeer(key, "abc.netbird", ip, "") assert.Error(t, err, "should return error on duplicate") } @@ -29,7 +30,7 @@ func TestGetPeer(t *testing.T) { key := "abc" ip := "100.108.254.1" status := NewRecorder("https://mgm") - err := status.AddPeer(key, "abc.netbird", ip) + err := status.AddPeer(key, "abc.netbird", ip, "") assert.NoError(t, err, "shouldn't return error") peerStatus, err := status.GetPeer(key) @@ -46,7 +47,7 @@ func TestUpdatePeerState(t *testing.T) { ip := "10.10.10.10" fqdn := "peer-a.netbird.local" status := NewRecorder("https://mgm") - _ = status.AddPeer(key, fqdn, ip) + require.NoError(t, status.AddPeer(key, fqdn, ip, "")) peerState := State{ PubKey: key, @@ -85,7 +86,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) { key := "abc" ip := "10.10.10.10" status := NewRecorder("https://mgm") - _ = status.AddPeer(key, "abc.netbird", ip) + _ = status.AddPeer(key, "abc.netbird", ip, "") sub := status.SubscribeToPeerStateChanges(context.Background(), key) assert.NotNil(t, sub, "channel shouldn't be nil") diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go index 06309fbaf..0402992c9 100644 --- a/client/internal/peer/worker_relay.go +++ b/client/internal/peer/worker_relay.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net" + "net/netip" "sync" "sync/atomic" @@ -53,15 +54,19 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { w.relaySupportedOnRemotePeer.Store(true) // the relayManager will return with error in case if the connection has lost with relay server - currentRelayAddress, err := w.relayManager.RelayInstanceAddress() + currentRelayAddress, _, err := w.relayManager.RelayInstanceAddress() if err != nil { w.log.Errorf("failed to handle new offer: %s", err) return } srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress) + var serverIP netip.Addr + if srv == remoteOfferAnswer.RelaySrvAddress { + serverIP = remoteOfferAnswer.RelaySrvIP + } - relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key) + relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key, serverIP) if err != nil { if errors.Is(err, relayClient.ErrConnAlreadyExists) { w.log.Debugf("handled offer by reusing existing relay connection") @@ -90,7 +95,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { }) } -func (w *WorkerRelay) RelayInstanceAddress() (string, error) { +func (w *WorkerRelay) RelayInstanceAddress() (string, netip.Addr, error) { return w.relayManager.RelayInstanceAddress() } diff --git a/client/internal/portforward/pcp/nat.go b/client/internal/portforward/pcp/nat.go index 1dc24274b..6491e7367 100644 --- a/client/internal/portforward/pcp/nat.go +++ b/client/internal/portforward/pcp/nat.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/netip" + "runtime" "sync" "time" @@ -177,7 +178,12 @@ func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) { return nil, nil, err } - _, gateway, localIP, err = router.Route(net.IPv4zero) + dst := net.IPv4zero + if runtime.GOOS == "linux" { + // go-netroute v0.4.0 rejects unspecified destinations client-side on Linux. + dst = net.IPv4(0, 0, 0, 1) + } + _, gateway, localIP, err = router.Route(dst) if err != nil { return nil, nil, err } @@ -196,7 +202,12 @@ func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) { return nil, nil, err } - _, gateway, localIP, err = router.Route(net.IPv6zero) + dst := net.IPv6zero + if runtime.GOOS == "linux" { + // ::2 + dst = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2} + } + _, gateway, localIP, err = router.Route(dst) if err != nil { return nil, nil, err } diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go index 20c615d57..cd5bc0680 100644 --- a/client/internal/profilemanager/config.go +++ b/client/internal/profilemanager/config.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "encoding/json" "fmt" + "net" "net/url" "os" "os/user" @@ -89,6 +90,7 @@ type ConfigInput struct { DisableFirewall *bool BlockLANAccess *bool BlockInbound *bool + DisableIPv6 *bool DisableNotifications *bool @@ -127,6 +129,7 @@ type Config struct { DisableFirewall bool BlockLANAccess bool BlockInbound bool + DisableIPv6 bool DisableNotifications *bool @@ -542,6 +545,12 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { updated = true } + if input.DisableIPv6 != nil && *input.DisableIPv6 != config.DisableIPv6 { + log.Infof("setting IPv6 overlay disabled=%v", *input.DisableIPv6) + config.DisableIPv6 = *input.DisableIPv6 + updated = true + } + if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications { if *input.DisableNotifications { log.Infof("disabling notifications") @@ -751,8 +760,7 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri return config, nil } - newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s:%d", - config.ManagementURL.Scheme, defaultManagementURL.Hostname(), 443)) + newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s", config.ManagementURL.Scheme, net.JoinHostPort(defaultManagementURL.Hostname(), "443"))) if err != nil { return nil, err } diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index 59be5b0a7..f00a8d93a 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net" + "strconv" "sync" "time" @@ -257,7 +258,7 @@ func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr stri } }() - turnServerAddr := fmt.Sprintf("%s:%d", uri.Host, uri.Port) + turnServerAddr := net.JoinHostPort(uri.Host, strconv.Itoa(uri.Port)) var conn net.PacketConn switch uri.Proto { diff --git a/client/internal/rosenpass/manager.go b/client/internal/rosenpass/manager.go index 1faa22dc5..11cda8dbc 100644 --- a/client/internal/rosenpass/manager.go +++ b/client/internal/rosenpass/manager.go @@ -75,7 +75,7 @@ func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuar if err != nil { return fmt.Errorf("failed to parse rosenpass address: %w", err) } - peerAddr := fmt.Sprintf("%s:%s", wireGuardIP, strPort) + peerAddr := net.JoinHostPort(wireGuardIP, strPort) if pcfg.Endpoint, err = net.ResolveUDPAddr("udp", peerAddr); err != nil { return fmt.Errorf("failed to resolve peer endpoint address: %w", err) } @@ -259,6 +259,9 @@ func findRandomAvailableUDPPort() (int, error) { } defer conn.Close() - splitAddress := strings.Split(conn.LocalAddr().String(), ":") - return strconv.Atoi(splitAddress[len(splitAddress)-1]) + _, portStr, err := net.SplitHostPort(conn.LocalAddr().String()) + if err != nil { + return 0, fmt.Errorf("parse local address %s: %w", conn.LocalAddr(), err) + } + return strconv.Atoi(portStr) } diff --git a/client/internal/rosenpass/manager_test.go b/client/internal/rosenpass/manager_test.go new file mode 100644 index 000000000..90bbdda59 --- /dev/null +++ b/client/internal/rosenpass/manager_test.go @@ -0,0 +1,14 @@ +package rosenpass + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFindRandomAvailableUDPPort(t *testing.T) { + port, err := findRandomAvailableUDPPort() + require.NoError(t, err) + require.Greater(t, port, 0) + require.LessOrEqual(t, port, 65535) +} diff --git a/client/internal/routemanager/client/client.go b/client/internal/routemanager/client/client.go index e6ef8b876..c691c54f8 100644 --- a/client/internal/routemanager/client/client.go +++ b/client/internal/routemanager/client/client.go @@ -3,9 +3,8 @@ package client import ( "context" "fmt" - "net" + "net/netip" "reflect" - "strconv" "time" log "github.com/sirupsen/logrus" @@ -566,7 +565,7 @@ func HandlerFromRoute(params common.HandlerParams) RouteHandler { return dnsinterceptor.New(params) case handlerTypeDynamic: dns := nbdns.NewServiceViaMemory(params.WgInterface) - dnsAddr := net.JoinHostPort(dns.RuntimeIP().String(), strconv.Itoa(dns.RuntimePort())) + dnsAddr := netip.AddrPortFrom(dns.RuntimeIP(), uint16(dns.RuntimePort())) return dynamic.NewRoute(params, dnsAddr) default: return static.NewRoute(params) diff --git a/client/internal/routemanager/client/client_bench_test.go b/client/internal/routemanager/client/client_bench_test.go index 591042ac5..668aec427 100644 --- a/client/internal/routemanager/client/client_bench_test.go +++ b/client/internal/routemanager/client/client_bench_test.go @@ -46,7 +46,7 @@ func generateBenchmarkData(tier benchmarkTier) (*peer.Status, map[route.ID]*rout fqdn := fmt.Sprintf("peer-%d.example.com", i) ip := fmt.Sprintf("10.0.%d.%d", i/256, i%256) - err := statusRecorder.AddPeer(peerKey, fqdn, ip) + err := statusRecorder.AddPeer(peerKey, fqdn, ip, "") if err != nil { panic(fmt.Sprintf("failed to add peer: %v", err)) } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 64f2a8789..e25cc2a5c 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -582,7 +582,7 @@ func (d *DnsInterceptor) queryUpstreamDNS(ctx context.Context, w dns.ResponseWri if nsNet != nil { reply, err = nbdns.ExchangeWithNetstack(ctx, nsNet, r, upstream) } else { - client, clientErr := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout) + client, clientErr := nbdns.GetClientPrivate(d.wgInterface, upstreamIP, dnsTimeout) if clientErr != nil { d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", clientErr)) return nil diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index 8d1398a7a..f0efd7b22 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -50,10 +50,10 @@ type Route struct { cancel context.CancelFunc statusRecorder *peer.Status wgInterface iface.WGIface - resolverAddr string + resolverAddr netip.AddrPort } -func NewRoute(params common.HandlerParams, resolverAddr string) *Route { +func NewRoute(params common.HandlerParams, resolverAddr netip.AddrPort) *Route { return &Route{ route: params.Route, routeRefCounter: params.RouteRefCounter, diff --git a/client/internal/routemanager/dynamic/route_ios.go b/client/internal/routemanager/dynamic/route_ios.go index 8fed1c8f9..1ae281d56 100644 --- a/client/internal/routemanager/dynamic/route_ios.go +++ b/client/internal/routemanager/dynamic/route_ios.go @@ -17,37 +17,47 @@ import ( const dialTimeout = 10 * time.Second func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) { - privateClient, err := nbdns.GetClientPrivate(r.wgInterface.Address().IP, r.wgInterface.Name(), dialTimeout) + privateClient, err := nbdns.GetClientPrivate(r.wgInterface, r.resolverAddr.Addr(), dialTimeout) if err != nil { return nil, fmt.Errorf("error while creating private client: %s", err) } - msg := new(dns.Msg) - msg.SetQuestion(dns.Fqdn(domain.PunycodeString()), dns.TypeA) - + fqdn := dns.Fqdn(domain.PunycodeString()) startTime := time.Now() - response, _, err := nbdns.ExchangeWithFallback(nil, privateClient, msg, r.resolverAddr) - if err != nil { - return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err) - } + var ips []net.IP + var queryErr error - if response.Rcode != dns.RcodeSuccess { - return nil, fmt.Errorf("dns response code: %s", dns.RcodeToString[response.Rcode]) - } + for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} { + msg := new(dns.Msg) + msg.SetQuestion(fqdn, qtype) - ips := make([]net.IP, 0) - - for _, answ := range response.Answer { - if aRecord, ok := answ.(*dns.A); ok { - ips = append(ips, aRecord.A) + response, _, err := nbdns.ExchangeWithFallback(nil, privateClient, msg, r.resolverAddr.String()) + if err != nil { + if queryErr == nil { + queryErr = fmt.Errorf("DNS query for %s (type %d) after %s: %w", domain.SafeString(), qtype, time.Since(startTime), err) + } + continue } - if aaaaRecord, ok := answ.(*dns.AAAA); ok { - ips = append(ips, aaaaRecord.AAAA) + + if response.Rcode != dns.RcodeSuccess { + continue + } + + for _, answ := range response.Answer { + if aRecord, ok := answ.(*dns.A); ok { + ips = append(ips, aRecord.A) + } + if aaaaRecord, ok := answ.(*dns.AAAA); ok { + ips = append(ips, aaaaRecord.AAAA) + } } } if len(ips) == 0 { + if queryErr != nil { + return nil, queryErr + } return nil, fmt.Errorf("no A or AAAA records found for %s", domain.SafeString()) } diff --git a/client/internal/routemanager/fakeip/fakeip.go b/client/internal/routemanager/fakeip/fakeip.go index 1592045d2..5be4ca12e 100644 --- a/client/internal/routemanager/fakeip/fakeip.go +++ b/client/internal/routemanager/fakeip/fakeip.go @@ -1,93 +1,145 @@ package fakeip import ( + "errors" "fmt" "net/netip" "sync" ) -// Manager manages allocation of fake IPs from the 240.0.0.0/8 block -type Manager struct { - mu sync.Mutex - nextIP netip.Addr // Next IP to allocate +var ( + // 240.0.0.1 - 240.255.255.254, block 240.0.0.0/8 (reserved, RFC 1112) + v4Base = netip.AddrFrom4([4]byte{240, 0, 0, 1}) + v4Max = netip.AddrFrom4([4]byte{240, 255, 255, 254}) + v4Block = netip.PrefixFrom(netip.AddrFrom4([4]byte{240, 0, 0, 0}), 8) + + // 0100::1 - 0100::ffff:ffff:ffff:fffe, block 0100::/64 (discard, RFC 6666) + v6Base = netip.AddrFrom16([16]byte{0x01, 0x00, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}) + v6Max = netip.AddrFrom16([16]byte{0x01, 0x00, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}) + v6Block = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x01, 0x00}), 64) +) + +// fakeIPPool holds the allocation state for a single address family. +type fakeIPPool struct { + nextIP netip.Addr + baseIP netip.Addr + maxIP netip.Addr + block netip.Prefix allocated map[netip.Addr]netip.Addr // real IP -> fake IP fakeToReal map[netip.Addr]netip.Addr // fake IP -> real IP - baseIP netip.Addr // First usable IP: 240.0.0.1 - maxIP netip.Addr // Last usable IP: 240.255.255.254 } -// NewManager creates a new fake IP manager using 240.0.0.0/8 block -func NewManager() *Manager { - baseIP := netip.AddrFrom4([4]byte{240, 0, 0, 1}) - maxIP := netip.AddrFrom4([4]byte{240, 255, 255, 254}) - - return &Manager{ - nextIP: baseIP, +func newPool(base, maxAddr netip.Addr, block netip.Prefix) *fakeIPPool { + return &fakeIPPool{ + nextIP: base, + baseIP: base, + maxIP: maxAddr, + block: block, allocated: make(map[netip.Addr]netip.Addr), fakeToReal: make(map[netip.Addr]netip.Addr), - baseIP: baseIP, - maxIP: maxIP, } } -// AllocateFakeIP allocates a fake IP for the given real IP -// Returns the fake IP, or existing fake IP if already allocated -func (m *Manager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) { - if !realIP.Is4() { - return netip.Addr{}, fmt.Errorf("only IPv4 addresses supported") - } - - m.mu.Lock() - defer m.mu.Unlock() - - if fakeIP, exists := m.allocated[realIP]; exists { +// allocate allocates a fake IP for the given real IP. +// Returns the existing fake IP if already allocated. +func (p *fakeIPPool) allocate(realIP netip.Addr) (netip.Addr, error) { + if fakeIP, exists := p.allocated[realIP]; exists { return fakeIP, nil } - startIP := m.nextIP + startIP := p.nextIP for { - currentIP := m.nextIP + currentIP := p.nextIP // Advance to next IP, wrapping at boundary - if m.nextIP.Compare(m.maxIP) >= 0 { - m.nextIP = m.baseIP + if p.nextIP.Compare(p.maxIP) >= 0 { + p.nextIP = p.baseIP } else { - m.nextIP = m.nextIP.Next() + p.nextIP = p.nextIP.Next() } - // Check if current IP is available - if _, inUse := m.fakeToReal[currentIP]; !inUse { - m.allocated[realIP] = currentIP - m.fakeToReal[currentIP] = realIP + if _, inUse := p.fakeToReal[currentIP]; !inUse { + p.allocated[realIP] = currentIP + p.fakeToReal[currentIP] = realIP return currentIP, nil } - // Prevent infinite loop if all IPs exhausted - if m.nextIP.Compare(startIP) == 0 { - return netip.Addr{}, fmt.Errorf("no more fake IPs available in 240.0.0.0/8 block") + if p.nextIP.Compare(startIP) == 0 { + return netip.Addr{}, fmt.Errorf("no more fake IPs available in %s block", p.block) } } } -// GetFakeIP returns the fake IP for a real IP if it exists +// Manager manages allocation of fake IPs for dynamic DNS routes. +// IPv4 uses 240.0.0.0/8 (reserved), IPv6 uses 0100::/64 (discard, RFC 6666). +type Manager struct { + mu sync.Mutex + v4 *fakeIPPool + v6 *fakeIPPool +} + +// NewManager creates a new fake IP manager. +func NewManager() *Manager { + return &Manager{ + v4: newPool(v4Base, v4Max, v4Block), + v6: newPool(v6Base, v6Max, v6Block), + } +} + +func (m *Manager) pool(ip netip.Addr) *fakeIPPool { + if ip.Is6() { + return m.v6 + } + return m.v4 +} + +// AllocateFakeIP allocates a fake IP for the given real IP. +func (m *Manager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) { + realIP = realIP.Unmap() + if !realIP.IsValid() { + return netip.Addr{}, errors.New("invalid IP address") + } + + m.mu.Lock() + defer m.mu.Unlock() + + return m.pool(realIP).allocate(realIP) +} + +// GetFakeIP returns the fake IP for a real IP if it exists. func (m *Manager) GetFakeIP(realIP netip.Addr) (netip.Addr, bool) { + realIP = realIP.Unmap() + if !realIP.IsValid() { + return netip.Addr{}, false + } + m.mu.Lock() defer m.mu.Unlock() - fakeIP, exists := m.allocated[realIP] - return fakeIP, exists + fakeIP, ok := m.pool(realIP).allocated[realIP] + return fakeIP, ok } -// GetRealIP returns the real IP for a fake IP if it exists, otherwise false +// GetRealIP returns the real IP for a fake IP if it exists. func (m *Manager) GetRealIP(fakeIP netip.Addr) (netip.Addr, bool) { + fakeIP = fakeIP.Unmap() + if !fakeIP.IsValid() { + return netip.Addr{}, false + } + m.mu.Lock() defer m.mu.Unlock() - realIP, exists := m.fakeToReal[fakeIP] - return realIP, exists + realIP, ok := m.pool(fakeIP).fakeToReal[fakeIP] + return realIP, ok } -// GetFakeIPBlock returns the fake IP block used by this manager +// GetFakeIPBlock returns the v4 fake IP block used by this manager. func (m *Manager) GetFakeIPBlock() netip.Prefix { - return netip.MustParsePrefix("240.0.0.0/8") + return m.v4.block +} + +// GetFakeIPv6Block returns the v6 fake IP block used by this manager. +func (m *Manager) GetFakeIPv6Block() netip.Prefix { + return m.v6.block } diff --git a/client/internal/routemanager/fakeip/fakeip_test.go b/client/internal/routemanager/fakeip/fakeip_test.go index ad3e4bd4e..f554f970d 100644 --- a/client/internal/routemanager/fakeip/fakeip_test.go +++ b/client/internal/routemanager/fakeip/fakeip_test.go @@ -9,16 +9,16 @@ import ( func TestNewManager(t *testing.T) { manager := NewManager() - if manager.baseIP.String() != "240.0.0.1" { - t.Errorf("Expected base IP 240.0.0.1, got %s", manager.baseIP.String()) + if manager.v4.baseIP.String() != "240.0.0.1" { + t.Errorf("Expected v4 base IP 240.0.0.1, got %s", manager.v4.baseIP.String()) } - if manager.maxIP.String() != "240.255.255.254" { - t.Errorf("Expected max IP 240.255.255.254, got %s", manager.maxIP.String()) + if manager.v4.maxIP.String() != "240.255.255.254" { + t.Errorf("Expected v4 max IP 240.255.255.254, got %s", manager.v4.maxIP.String()) } - if manager.nextIP.Compare(manager.baseIP) != 0 { - t.Errorf("Expected nextIP to start at baseIP") + if manager.v6.baseIP.String() != "100::1" { + t.Errorf("Expected v6 base IP 100::1, got %s", manager.v6.baseIP.String()) } } @@ -35,7 +35,6 @@ func TestAllocateFakeIP(t *testing.T) { t.Error("Fake IP should be IPv4") } - // Check it's in the correct range if fakeIP.As4()[0] != 240 { t.Errorf("Fake IP should be in 240.0.0.0/8 range, got %s", fakeIP.String()) } @@ -51,13 +50,31 @@ func TestAllocateFakeIP(t *testing.T) { } } -func TestAllocateFakeIPIPv6Rejection(t *testing.T) { +func TestAllocateFakeIPv6(t *testing.T) { manager := NewManager() - realIPv6 := netip.MustParseAddr("2001:db8::1") + realIP := netip.MustParseAddr("2001:db8::1") - _, err := manager.AllocateFakeIP(realIPv6) - if err == nil { - t.Error("Expected error for IPv6 address") + fakeIP, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Fatalf("Failed to allocate fake IPv6: %v", err) + } + + if !fakeIP.Is6() { + t.Error("Fake IP should be IPv6") + } + + if !netip.MustParsePrefix("100::/64").Contains(fakeIP) { + t.Errorf("Fake IP should be in 100::/64 range, got %s", fakeIP.String()) + } + + // Should return same fake IP for same real IP + fakeIP2, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Fatalf("Failed to get existing fake IPv6: %v", err) + } + + if fakeIP.Compare(fakeIP2) != 0 { + t.Errorf("Expected same fake IP, got %s and %s", fakeIP.String(), fakeIP2.String()) } } @@ -65,13 +82,11 @@ func TestGetFakeIP(t *testing.T) { manager := NewManager() realIP := netip.MustParseAddr("1.1.1.1") - // Should not exist initially _, exists := manager.GetFakeIP(realIP) if exists { t.Error("Fake IP should not exist before allocation") } - // Allocate and check expectedFakeIP, err := manager.AllocateFakeIP(realIP) if err != nil { t.Fatalf("Failed to allocate: %v", err) @@ -87,12 +102,30 @@ func TestGetFakeIP(t *testing.T) { } } +func TestGetRealIPv6(t *testing.T) { + manager := NewManager() + realIP := netip.MustParseAddr("2001:db8::1") + + fakeIP, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Fatalf("Failed to allocate: %v", err) + } + + gotReal, exists := manager.GetRealIP(fakeIP) + if !exists { + t.Error("Real IP should exist for allocated fake IP") + } + + if gotReal.Compare(realIP) != 0 { + t.Errorf("Expected real IP %s, got %s", realIP, gotReal) + } +} + func TestMultipleAllocations(t *testing.T) { manager := NewManager() allocations := make(map[netip.Addr]netip.Addr) - // Allocate multiple IPs for i := 1; i <= 100; i++ { realIP := netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)}) fakeIP, err := manager.AllocateFakeIP(realIP) @@ -100,7 +133,6 @@ func TestMultipleAllocations(t *testing.T) { t.Fatalf("Failed to allocate fake IP for %s: %v", realIP.String(), err) } - // Check for duplicates for _, existingFake := range allocations { if fakeIP.Compare(existingFake) == 0 { t.Errorf("Duplicate fake IP allocated: %s", fakeIP.String()) @@ -110,7 +142,6 @@ func TestMultipleAllocations(t *testing.T) { allocations[realIP] = fakeIP } - // Verify all allocations can be retrieved for realIP, expectedFake := range allocations { actualFake, exists := manager.GetFakeIP(realIP) if !exists { @@ -124,11 +155,13 @@ func TestMultipleAllocations(t *testing.T) { func TestGetFakeIPBlock(t *testing.T) { manager := NewManager() - block := manager.GetFakeIPBlock() - expected := "240.0.0.0/8" - if block.String() != expected { - t.Errorf("Expected %s, got %s", expected, block.String()) + if block := manager.GetFakeIPBlock(); block.String() != "240.0.0.0/8" { + t.Errorf("Expected 240.0.0.0/8, got %s", block.String()) + } + + if block := manager.GetFakeIPv6Block(); block.String() != "100::/64" { + t.Errorf("Expected 100::/64, got %s", block.String()) } } @@ -141,7 +174,6 @@ func TestConcurrentAccess(t *testing.T) { var wg sync.WaitGroup results := make(chan netip.Addr, numGoroutines*allocationsPerGoroutine) - // Concurrent allocations for i := 0; i < numGoroutines; i++ { wg.Add(1) go func(goroutineID int) { @@ -161,7 +193,6 @@ func TestConcurrentAccess(t *testing.T) { wg.Wait() close(results) - // Check for duplicates seen := make(map[netip.Addr]bool) count := 0 for fakeIP := range results { @@ -178,47 +209,61 @@ func TestConcurrentAccess(t *testing.T) { } func TestIPExhaustion(t *testing.T) { - // Create a manager with limited range for testing manager := &Manager{ - nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}), - allocated: make(map[netip.Addr]netip.Addr), - fakeToReal: make(map[netip.Addr]netip.Addr), - baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}), - maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 3}), // Only 3 IPs available + v4: newPool( + netip.AddrFrom4([4]byte{240, 0, 0, 1}), + netip.AddrFrom4([4]byte{240, 0, 0, 3}), + netip.MustParsePrefix("240.0.0.0/8"), + ), + v6: newPool( + netip.MustParseAddr("100::1"), + netip.MustParseAddr("100::3"), + netip.MustParsePrefix("100::/64"), + ), } - // Allocate all available IPs - realIPs := []netip.Addr{ - netip.MustParseAddr("1.0.0.1"), - netip.MustParseAddr("1.0.0.2"), - netip.MustParseAddr("1.0.0.3"), - } - - for _, realIP := range realIPs { - _, err := manager.AllocateFakeIP(realIP) + for _, realIP := range []string{"1.0.0.1", "1.0.0.2", "1.0.0.3"} { + _, err := manager.AllocateFakeIP(netip.MustParseAddr(realIP)) if err != nil { t.Fatalf("Failed to allocate fake IP: %v", err) } } - // Try to allocate one more - should fail _, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.4")) if err == nil { - t.Error("Expected exhaustion error") + t.Error("Expected v4 exhaustion error") + } + + // Same for v6 + for _, realIP := range []string{"2001:db8::1", "2001:db8::2", "2001:db8::3"} { + _, err := manager.AllocateFakeIP(netip.MustParseAddr(realIP)) + if err != nil { + t.Fatalf("Failed to allocate fake IPv6: %v", err) + } + } + + _, err = manager.AllocateFakeIP(netip.MustParseAddr("2001:db8::4")) + if err == nil { + t.Error("Expected v6 exhaustion error") } } func TestWrapAround(t *testing.T) { - // Create manager starting near the end of range manager := &Manager{ - nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}), - allocated: make(map[netip.Addr]netip.Addr), - fakeToReal: make(map[netip.Addr]netip.Addr), - baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}), - maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}), + v4: newPool( + netip.AddrFrom4([4]byte{240, 0, 0, 1}), + netip.AddrFrom4([4]byte{240, 0, 0, 254}), + netip.MustParsePrefix("240.0.0.0/8"), + ), + v6: newPool( + netip.MustParseAddr("100::1"), + netip.MustParseAddr("100::ffff:ffff:ffff:fffe"), + netip.MustParsePrefix("100::/64"), + ), } + // Start near the end + manager.v4.nextIP = netip.AddrFrom4([4]byte{240, 0, 0, 254}) - // Allocate the last IP fakeIP1, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.1")) if err != nil { t.Fatalf("Failed to allocate first IP: %v", err) @@ -228,7 +273,6 @@ func TestWrapAround(t *testing.T) { t.Errorf("Expected 240.0.0.254, got %s", fakeIP1.String()) } - // Next allocation should wrap around to the beginning fakeIP2, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.2")) if err != nil { t.Fatalf("Failed to allocate second IP: %v", err) @@ -238,3 +282,32 @@ func TestWrapAround(t *testing.T) { t.Errorf("Expected 240.0.0.1 after wrap, got %s", fakeIP2.String()) } } + +func TestMixedV4V6(t *testing.T) { + manager := NewManager() + + v4Fake, err := manager.AllocateFakeIP(netip.MustParseAddr("8.8.8.8")) + if err != nil { + t.Fatalf("Failed to allocate v4: %v", err) + } + + v6Fake, err := manager.AllocateFakeIP(netip.MustParseAddr("2001:db8::1")) + if err != nil { + t.Fatalf("Failed to allocate v6: %v", err) + } + + if !v4Fake.Is4() || !v6Fake.Is6() { + t.Errorf("Wrong families: v4=%s v6=%s", v4Fake, v6Fake) + } + + // Reverse lookups should work for both + gotV4, ok := manager.GetRealIP(v4Fake) + if !ok || gotV4.String() != "8.8.8.8" { + t.Errorf("v4 reverse lookup failed: got %s, ok=%v", gotV4, ok) + } + + gotV6, ok := manager.GetRealIP(v6Fake) + if !ok || gotV6.String() != "2001:db8::1" { + t.Errorf("v6 reverse lookup failed: got %s, ok=%v", gotV6, ok) + } +} diff --git a/client/internal/routemanager/ipfwdstate/ipfwdstate.go b/client/internal/routemanager/ipfwdstate/ipfwdstate.go index da81c18f9..2be1c2ae7 100644 --- a/client/internal/routemanager/ipfwdstate/ipfwdstate.go +++ b/client/internal/routemanager/ipfwdstate/ipfwdstate.go @@ -9,7 +9,11 @@ import ( ) // IPForwardingState is a struct that keeps track of the IP forwarding state. -// todo: read initial state of the IP forwarding from the system and reset the state based on it +// todo: read initial state of the IP forwarding from the system and reset the state based on it. +// todo: separate v4/v6 forwarding state, since the sysctls are independent +// (net.ipv4.ip_forward vs net.ipv6.conf.all.forwarding). Currently the nftables +// manager shares one instance between both routers, which works only because +// EnableIPForwarding enables both sysctls in a single call. type IPForwardingState struct { enabledCounter int } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 3923e153b..e5d9363ca 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -159,16 +159,24 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) { if config.DNSFeatureFlag { m.fakeIPManager = fakeip.NewManager() - id := uuid.NewString() + v4ID := uuid.NewString() fakeIPRoute := &route.Route{ - ID: route.ID(id), + ID: route.ID(v4ID), Network: m.fakeIPManager.GetFakeIPBlock(), - NetID: route.NetID(id), + NetID: route.NetID(v4ID), Peer: m.pubKey, NetworkType: route.IPv4Network, } - cr = append(cr, fakeIPRoute) - m.notifier.SetFakeIPRoute(fakeIPRoute) + v6ID := uuid.NewString() + fakeIPv6Route := &route.Route{ + ID: route.ID(v6ID), + Network: m.fakeIPManager.GetFakeIPv6Block(), + NetID: route.NetID(v6ID), + Peer: m.pubKey, + NetworkType: route.IPv6Network, + } + cr = append(cr, fakeIPRoute, fakeIPv6Route) + m.notifier.SetFakeIPRoutes([]*route.Route{fakeIPRoute, fakeIPv6Route}) } m.notifier.SetInitialClientRoutes(cr, routesForComparison) diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 3697545ae..926f06bc9 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/route" ) @@ -409,7 +410,7 @@ func TestManagerUpdateRoutes(t *testing.T) { } opts := iface.WGIFaceOpts{ IFaceName: fmt.Sprintf("utun43%d", n), - Address: "100.65.65.2/24", + Address: wgaddr.MustParseWGAddress("100.65.65.2/24"), WGPort: 33100, WGPrivKey: peerPrivateKey.String(), MTU: iface.DefaultMTU, diff --git a/client/internal/routemanager/notifier/notifier_android.go b/client/internal/routemanager/notifier/notifier_android.go index 55e0b7421..140a583f7 100644 --- a/client/internal/routemanager/notifier/notifier_android.go +++ b/client/internal/routemanager/notifier/notifier_android.go @@ -16,7 +16,7 @@ import ( type Notifier struct { initialRoutes []*route.Route currentRoutes []*route.Route - fakeIPRoute *route.Route + fakeIPRoutes []*route.Route listener listener.NetworkChangeListener listenerMux sync.Mutex @@ -38,9 +38,9 @@ func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesFo n.currentRoutes = filterStatic(routesForComparison) } -// SetFakeIPRoute stores the fake IP route to be included in every TUN rebuild. -func (n *Notifier) SetFakeIPRoute(r *route.Route) { - n.fakeIPRoute = r +// SetFakeIPRoutes stores the fake IP routes to be included in every TUN rebuild. +func (n *Notifier) SetFakeIPRoutes(routes []*route.Route) { + n.fakeIPRoutes = routes } func (n *Notifier) OnNewRoutes(idMap route.HAMap) { @@ -74,14 +74,12 @@ func (n *Notifier) notify() { } allRoutes := slices.Clone(n.currentRoutes) - if n.fakeIPRoute != nil { - allRoutes = append(allRoutes, n.fakeIPRoute) - } + allRoutes = append(allRoutes, n.fakeIPRoutes...) routeStrings := n.routesToStrings(allRoutes) sort.Strings(routeStrings) go func(l listener.NetworkChangeListener) { - l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, allRoutes), ",")) + l.OnNetworkChanged(strings.Join(routeStrings, ",")) }(n.listener) } @@ -119,14 +117,5 @@ func (n *Notifier) hasRouteDiff(a []*route.Route, b []*route.Route) bool { func (n *Notifier) GetInitialRouteRanges() []string { initialStrings := n.routesToStrings(n.initialRoutes) sort.Strings(initialStrings) - return n.addIPv6RangeIfNeeded(initialStrings, n.initialRoutes) -} - -func (n *Notifier) addIPv6RangeIfNeeded(inputRanges []string, routes []*route.Route) []string { - for _, r := range routes { - if r.Network.Addr().Is4() && r.Network.Bits() == 0 { - return append(slices.Clone(inputRanges), "::/0") - } - } - return inputRanges + return initialStrings } diff --git a/client/internal/routemanager/notifier/notifier_ios.go b/client/internal/routemanager/notifier/notifier_ios.go index 68c85067a..27a2a722d 100644 --- a/client/internal/routemanager/notifier/notifier_ios.go +++ b/client/internal/routemanager/notifier/notifier_ios.go @@ -34,7 +34,7 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) { // iOS doesn't care about initial routes } -func (n *Notifier) SetFakeIPRoute(*route.Route) { +func (n *Notifier) SetFakeIPRoutes([]*route.Route) { // Not used on iOS } @@ -65,19 +65,10 @@ func (n *Notifier) notify() { } go func(l listener.NetworkChangeListener) { - l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(n.currentPrefixes), ",")) + l.OnNetworkChanged(strings.Join(n.currentPrefixes, ",")) }(n.listener) } func (n *Notifier) GetInitialRouteRanges() []string { return nil } - -func (n *Notifier) addIPv6RangeIfNeeded(inputRanges []string) []string { - for _, r := range inputRanges { - if r == "0.0.0.0/0" { - return append(slices.Clone(inputRanges), "::/0") - } - } - return inputRanges -} diff --git a/client/internal/routemanager/notifier/notifier_other.go b/client/internal/routemanager/notifier/notifier_other.go index 97c815cf0..f57cadb0b 100644 --- a/client/internal/routemanager/notifier/notifier_other.go +++ b/client/internal/routemanager/notifier/notifier_other.go @@ -23,7 +23,7 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) { // Not used on non-mobile platforms } -func (n *Notifier) SetFakeIPRoute(*route.Route) { +func (n *Notifier) SetFakeIPRoutes([]*route.Route) { // Not used on non-mobile platforms } diff --git a/client/internal/routemanager/server/server.go b/client/internal/routemanager/server/server.go index e674c80cd..f569c0cac 100644 --- a/client/internal/routemanager/server/server.go +++ b/client/internal/routemanager/server/server.go @@ -21,6 +21,7 @@ type Router struct { firewall firewall.Manager wgInterface iface.WGIface statusRecorder *peer.Status + useNewDNSRoute bool } func NewRouter(ctx context.Context, wgInterface iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*Router, error) { @@ -37,6 +38,8 @@ func (r *Router) UpdateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRout r.mux.Lock() defer r.mux.Unlock() + prevUseNewDNSRoute := r.useNewDNSRoute + serverRoutesToRemove := make([]route.ID, 0) for routeID := range r.routes { @@ -48,7 +51,7 @@ func (r *Router) UpdateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRout for _, routeID := range serverRoutesToRemove { oldRoute := r.routes[routeID] - err := r.removeFromServerNetwork(oldRoute) + err := r.removeFromServerNetwork(oldRoute, prevUseNewDNSRoute) if err != nil { log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v", oldRoute.ID, oldRoute.Network, err) @@ -56,6 +59,8 @@ func (r *Router) UpdateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRout delete(r.routes, routeID) } + r.useNewDNSRoute = useNewDNSRoute + // If routing is to be disabled, do it after routes have been removed // If routing is to be enabled, do it before adding new routes; addToServerNetwork needs routing to be enabled if len(routesMap) > 0 { @@ -85,13 +90,13 @@ func (r *Router) UpdateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRout return nil } -func (r *Router) removeFromServerNetwork(route *route.Route) error { +func (r *Router) removeFromServerNetwork(route *route.Route, useNewDNSRoute bool) error { if r.ctx.Err() != nil { log.Infof("Not removing from server network because context is done") return r.ctx.Err() } - routerPair := routeToRouterPair(route, false) + routerPair := routeToRouterPair(route, useNewDNSRoute) if err := r.firewall.RemoveNatRule(routerPair); err != nil { return fmt.Errorf("remove routing rules: %w", err) } @@ -124,7 +129,7 @@ func (r *Router) CleanUp() { defer r.mux.Unlock() for _, route := range r.routes { - routerPair := routeToRouterPair(route, false) + routerPair := routeToRouterPair(route, r.useNewDNSRoute) if err := r.firewall.RemoveNatRule(routerPair); err != nil { log.Errorf("Failed to remove cleanup route: %v", err) } @@ -146,8 +151,7 @@ func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterP if useNewDNSRoute { destination.Set = firewall.NewDomainSet(route.Domains) } else { - // TODO: add ipv6 additionally - destination = getDefaultPrefix(destination.Prefix) + destination = getDefaultPrefix(route.Network) } } else { destination.Prefix = route.Network.Masked() @@ -158,6 +162,7 @@ func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterP Source: source, Destination: destination, Masquerade: route.Masquerade, + Dynamic: route.IsDynamic(), } } diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index c0ca21d22..165448b60 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -107,8 +107,16 @@ func (r *SysOps) validateRoute(prefix netip.Prefix) error { addr.IsInterfaceLocalMulticast(), addr.IsMulticast(), addr.IsUnspecified() && prefix.Bits() != 0, - r.wgInterface.Address().Network.Contains(addr): + r.isOwnAddress(addr): return vars.ErrRouteNotAllowed } return nil } + +func (r *SysOps) isOwnAddress(addr netip.Addr) bool { + if r.wgInterface == nil { + return false + } + wgAddr := r.wgInterface.Address() + return wgAddr.Network.Contains(addr) || (wgAddr.IPv6Net.IsValid() && wgAddr.IPv6Net.Contains(addr)) +} diff --git a/client/internal/routemanager/systemops/systemops_darwin.go b/client/internal/routemanager/systemops/systemops_darwin.go index d6875ff95..3fcac4c6a 100644 --- a/client/internal/routemanager/systemops/systemops_darwin.go +++ b/client/internal/routemanager/systemops/systemops_darwin.go @@ -89,8 +89,16 @@ func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) { return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec) } + reused := false if err := r.addScopedDefault(unspec, nexthop); err != nil { - return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err) + if !errors.Is(err, unix.EEXIST) { + return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err) + } + // macOS installs its own RTF_IFSCOPE defaults for primary service + // selection on multi-NIC setups, so a route on this ifindex can + // already exist before we try. Binding to it via IP[V6]_BOUND_IF + // still produces the scoped lookup we need. + reused = true } af := unix.AF_INET @@ -102,7 +110,11 @@ func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) { if nexthop.IP.IsValid() { via = nexthop.IP.String() } - log.Infof("installed scoped default route via %s on %s for %s", via, nexthop.Intf.Name, afOf(unspec)) + verb := "installed" + if reused { + verb = "reused existing" + } + log.Infof("%s scoped default route via %s on %s for %s", verb, via, nexthop.Intf.Name, afOf(unspec)) return true, nil } diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 4211eb057..2b96c14dc 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -221,30 +221,20 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er return err } - // TODO: remove once IPv6 is supported on the interface - if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil { - if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) + // When the interface has no v6, add v6 split-default as blackhole so + // unroutable v6 goes to WG (dropped, no AllowedIPs) instead of leaking + // to the system default route. When v6 is active, management sends ::/0 + // as a separate route that the dedicated handler adds. + // Soft-fail: v6 blackhole is best-effort, don't abort v4 routing on failure. + if !r.wgInterface.Address().HasIPv6() { + if err := r.addV6SplitDefault(nextHop); err != nil { + log.Warnf("failed to add v6 split-default blackhole: %s", err) } - return fmt.Errorf("add unreachable route split 2: %w", err) } return nil case vars.Defaultv6: - if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil { - if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return fmt.Errorf("add unreachable route split 2: %w", err) - } - - return nil + return r.addV6SplitDefault(nextHop) } return r.addToRouteTable(prefix, nextHop) @@ -265,30 +255,42 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) result = multierror.Append(result, err) } - // TODO: remove once IPv6 is supported on the interface - if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil { - result = multierror.Append(result, err) - } - if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil { - result = multierror.Append(result, err) + if !r.wgInterface.Address().HasIPv6() { + result = multierror.Append(result, r.removeV6SplitDefault(nextHop)) } return nberrors.FormatErrorOrNil(result) case vars.Defaultv6: - var result *multierror.Error - if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil { - result = multierror.Append(result, err) - } - if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil { - result = multierror.Append(result, err) - } - - return nberrors.FormatErrorOrNil(result) + return nberrors.FormatErrorOrNil(r.removeV6SplitDefault(nextHop)) default: return r.removeFromRouteTable(prefix, nextHop) } } +func (r *SysOps) addV6SplitDefault(nextHop Nexthop) error { + if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil { + return fmt.Errorf("add split 1: %w", err) + } + if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil { + if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil { + log.Warnf("Failed to rollback v6 split-default: %s", err2) + } + return fmt.Errorf("add split 2: %w", err) + } + return nil +} + +func (r *SysOps) removeV6SplitDefault(nextHop Nexthop) *multierror.Error { + var result *multierror.Error + if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil { + result = multierror.Append(result, err) + } + if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil { + result = multierror.Append(result, err) + } + return result +} + func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error { beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error { if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil { @@ -342,6 +344,22 @@ func GetNextHop(ip netip.Addr) (Nexthop, error) { if err != nil { return Nexthop{}, fmt.Errorf("new netroute: %w", err) } + + // go-netroute v0.4.0 rejects unspecified destinations on Linux with a hard + // client-side check. Substitute the lowest non-loopback address so the + // lookup falls through to the default route (::1 / 127.0.0.1 would match + // loopback, ::/0.0.0.0 are unspec). BSD/Windows pass the query straight to + // the kernel and need no substitution. + if runtime.GOOS == "linux" && ip.IsUnspecified() { + if ip.Is6() { + // ::2 + ip = netip.AddrFrom16([16]byte{15: 2}) + } else { + // 0.0.0.1 + ip = netip.AddrFrom4([4]byte{0, 0, 0, 1}) + } + } + intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) if err != nil { log.Debugf("Failed to get route for %s: %v", ip, err) diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 01916fbe3..5695c40c3 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -21,6 +21,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/routemanager/vars" nbnet "github.com/netbirdio/netbird/client/net" ) @@ -354,9 +355,13 @@ func TestAddRouteToNonVPNIntf(t *testing.T) { require.NoError(t, err, "Should be able to get IPv4 default route") t.Logf("Initial IPv4 next hop: %s", initialNextHopV4) + if testCase.prefix.Addr().Is6() && !testCase.expectError { + ensureIPv6DefaultRoute(t) + } + initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified()) if testCase.prefix.Addr().Is6() && - (errors.Is(err, vars.ErrRouteNotFound) || initialNextHopV6.Intf != nil && strings.HasPrefix(initialNextHopV6.Intf.Name, "utun")) { + initialNextHopV6.Intf != nil && strings.HasPrefix(initialNextHopV6.Intf.Name, "utun") { t.Skip("Skipping test as no ipv6 default route is available") } if err != nil && !errors.Is(err, vars.ErrRouteNotFound) { @@ -441,7 +446,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen opts := iface.WGIFaceOpts{ IFaceName: interfaceName, - Address: ipAddressCIDR, + Address: wgaddr.MustParseWGAddress(ipAddressCIDR), WGPrivKey: peerPrivateKey.String(), WGPort: listenPort, MTU: iface.DefaultMTU, diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index 39a9fd978..8c6b7d9a9 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -53,6 +53,8 @@ const ( // ipv4ForwardingPath is the path to the file containing the IP forwarding setting. ipv4ForwardingPath = "net.ipv4.ip_forward" + // ipv6ForwardingPath is the path to the file containing the IPv6 forwarding setting. + ipv6ForwardingPath = "net.ipv6.conf.all.forwarding" ) var ErrTableIDExists = errors.New("ID exists with different name") @@ -185,10 +187,11 @@ func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { // No need to check if routes exist as main table takes precedence over the VPN table via Rule 1 - // TODO remove this once we have ipv6 support - if prefix == vars.Defaultv4 { + // When the peer has no IPv6, blackhole v6 to prevent leaking. + // When IPv6 is enabled, management sends ::/0 as a separate route. + if prefix == vars.Defaultv4 && (r.wgInterface == nil || !r.wgInterface.Address().HasIPv6()) { if err := addUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil { - return fmt.Errorf("add blackhole: %w", err) + return fmt.Errorf("add v6 blackhole: %w", err) } } if err := addRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil { @@ -206,10 +209,9 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error return r.genericRemoveVPNRoute(prefix, intf) } - // TODO remove this once we have ipv6 support - if prefix == vars.Defaultv4 { + if prefix == vars.Defaultv4 && (r.wgInterface == nil || !r.wgInterface.Address().HasIPv6()) { if err := removeUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil { - return fmt.Errorf("remove unreachable route: %w", err) + log.Debugf("remove v6 blackhole: %v", err) } } if err := removeRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil { @@ -762,8 +764,13 @@ func flushRoutes(tableID, family int) error { } func EnableIPForwarding() error { - _, err := sysctl.Set(ipv4ForwardingPath, 1, false) - return err + if _, err := sysctl.Set(ipv4ForwardingPath, 1, false); err != nil { + return err + } + if _, err := sysctl.Set(ipv6ForwardingPath, 1, false); err != nil { + log.Warnf("failed to enable IPv6 forwarding: %v", err) + } + return nil } // entryExists checks if the specified ID or name already exists in the rt_tables file diff --git a/client/internal/routemanager/systemops/v6route_bsd_test.go b/client/internal/routemanager/systemops/v6route_bsd_test.go new file mode 100644 index 000000000..98ce29c6d --- /dev/null +++ b/client/internal/routemanager/systemops/v6route_bsd_test.go @@ -0,0 +1,30 @@ +//go:build darwin || dragonfly || freebsd || netbsd || openbsd + +package systemops + +import ( + "bytes" + "os/exec" + "testing" +) + +// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback +// interface so route lookups for global IPv6 prefixes resolve in environments +// without v6 connectivity. If a default already exists it is left alone. +func ensureIPv6DefaultRoute(t *testing.T) { + t.Helper() + + out, err := exec.Command("route", "-6", "add", "default", "-iface", "lo0").CombinedOutput() + if err != nil { + // Existing default; nothing to install or clean up. + if bytes.Contains(out, []byte("route already in table")) { + return + } + t.Skipf("install IPv6 fallback default route: %v: %s", err, out) + } + t.Cleanup(func() { + if out, err := exec.Command("route", "-6", "delete", "default").CombinedOutput(); err != nil { + t.Logf("delete IPv6 fallback default route: %v: %s", err, out) + } + }) +} diff --git a/client/internal/routemanager/systemops/v6route_linux_test.go b/client/internal/routemanager/systemops/v6route_linux_test.go new file mode 100644 index 000000000..0b17cefff --- /dev/null +++ b/client/internal/routemanager/systemops/v6route_linux_test.go @@ -0,0 +1,41 @@ +//go:build linux && !android + +package systemops + +import ( + "errors" + "net" + "syscall" + "testing" + + "github.com/stretchr/testify/require" + "github.com/vishvananda/netlink" +) + +// ensureIPv6DefaultRoute installs a low-preference IPv6 default route via the +// loopback interface so route lookups for global IPv6 prefixes resolve in +// environments without v6 connectivity. Any pre-existing default route wins +// because of its lower metric. +func ensureIPv6DefaultRoute(t *testing.T) { + t.Helper() + + lo, err := netlink.LinkByName("lo") + require.NoError(t, err, "find loopback interface") + + route := &netlink.Route{ + Dst: &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)}, + LinkIndex: lo.Attrs().Index, + Priority: 1 << 20, + } + if err := netlink.RouteAdd(route); err != nil { + if errors.Is(err, syscall.EEXIST) { + return + } + t.Skipf("install IPv6 fallback default route: %v", err) + } + t.Cleanup(func() { + if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) { + t.Logf("delete IPv6 fallback default route: %v", err) + } + }) +} diff --git a/client/internal/routemanager/systemops/v6route_windows_test.go b/client/internal/routemanager/systemops/v6route_windows_test.go new file mode 100644 index 000000000..f79277b87 --- /dev/null +++ b/client/internal/routemanager/systemops/v6route_windows_test.go @@ -0,0 +1,34 @@ +//go:build windows + +package systemops + +import ( + "bytes" + "os/exec" + "testing" +) + +const loopbackIfaceWindows = "Loopback Pseudo-Interface 1" + +// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback +// interface so route lookups for global IPv6 prefixes resolve in environments +// without v6 connectivity. If a default already exists it is left alone. +func ensureIPv6DefaultRoute(t *testing.T) { + t.Helper() + + script := `New-NetRoute -DestinationPrefix "::/0" -InterfaceAlias "` + loopbackIfaceWindows + `" -RouteMetric 9999 -PolicyStore ActiveStore -ErrorAction Stop` + out, err := exec.Command("powershell", "-Command", script).CombinedOutput() + if err != nil { + // Existing default; nothing to install or clean up. + if bytes.Contains(out, []byte("already exists")) { + return + } + t.Skipf("install IPv6 fallback default route: %v: %s", err, out) + } + t.Cleanup(func() { + script := `Remove-NetRoute -DestinationPrefix "::/0" -InterfaceAlias "` + loopbackIfaceWindows + `" -Confirm:$false -ErrorAction Stop` + if out, err := exec.Command("powershell", "-Command", script).CombinedOutput(); err != nil { + t.Logf("delete IPv6 fallback default route: %v: %s", err, out) + } + }) +} diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index 61c8bbc79..30afc013b 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -7,7 +7,6 @@ import ( "sync" "github.com/hashicorp/go-multierror" - "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/route" @@ -44,8 +43,8 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al if rs.selectedRoutes == nil { rs.selectedRoutes = map[route.NetID]struct{}{} } - maps.Clear(rs.deselectedRoutes) - maps.Clear(rs.selectedRoutes) + clear(rs.deselectedRoutes) + clear(rs.selectedRoutes) for _, r := range allRoutes { rs.deselectedRoutes[r] = struct{}{} } @@ -78,8 +77,8 @@ func (rs *RouteSelector) SelectAllRoutes() { if rs.selectedRoutes == nil { rs.selectedRoutes = map[route.NetID]struct{}{} } - maps.Clear(rs.deselectedRoutes) - maps.Clear(rs.selectedRoutes) + clear(rs.deselectedRoutes) + clear(rs.selectedRoutes) } // DeselectRoutes removes specific routes from the selection. @@ -116,8 +115,8 @@ func (rs *RouteSelector) DeselectAllRoutes() { if rs.selectedRoutes == nil { rs.selectedRoutes = map[route.NetID]struct{}{} } - maps.Clear(rs.deselectedRoutes) - maps.Clear(rs.selectedRoutes) + clear(rs.deselectedRoutes) + clear(rs.selectedRoutes) } // IsSelected checks if a specific route is selected. diff --git a/client/internal/sleep/detector_darwin.go b/client/internal/sleep/detector_darwin.go index 3d6747ed1..ef495bded 100644 --- a/client/internal/sleep/detector_darwin.go +++ b/client/internal/sleep/detector_darwin.go @@ -2,217 +2,358 @@ package sleep -/* -#cgo LDFLAGS: -framework IOKit -framework CoreFoundation -#include -#include -#include - -extern void sleepCallbackBridge(); -extern void poweredOnCallbackBridge(); -extern void suspendedCallbackBridge(); -extern void resumedCallbackBridge(); - - -// C global variables for IOKit state -static IONotificationPortRef g_notifyPortRef = NULL; -static io_object_t g_notifierObject = 0; -static io_object_t g_generalInterestNotifier = 0; -static io_connect_t g_rootPort = 0; -static CFRunLoopRef g_runLoop = NULL; - -static void sleepCallback(void* refCon, io_service_t service, natural_t messageType, void* messageArgument) { - switch (messageType) { - case kIOMessageSystemWillSleep: - sleepCallbackBridge(); - IOAllowPowerChange(g_rootPort, (long)messageArgument); - break; - case kIOMessageSystemHasPoweredOn: - poweredOnCallbackBridge(); - break; - case kIOMessageServiceIsSuspended: - suspendedCallbackBridge(); - break; - case kIOMessageServiceIsResumed: - resumedCallbackBridge(); - break; - default: - break; - } -} - -static void registerNotifications() { - g_rootPort = IORegisterForSystemPower( - NULL, - &g_notifyPortRef, - (IOServiceInterestCallback)sleepCallback, - &g_notifierObject - ); - - if (g_rootPort == 0) { - return; - } - - CFRunLoopAddSource(CFRunLoopGetCurrent(), - IONotificationPortGetRunLoopSource(g_notifyPortRef), - kCFRunLoopCommonModes); - - g_runLoop = CFRunLoopGetCurrent(); - CFRunLoopRun(); -} - -static void unregisterNotifications() { - CFRunLoopRemoveSource(g_runLoop, - IONotificationPortGetRunLoopSource(g_notifyPortRef), - kCFRunLoopCommonModes); - - IODeregisterForSystemPower(&g_notifierObject); - IOServiceClose(g_rootPort); - IONotificationPortDestroy(g_notifyPortRef); - CFRunLoopStop(g_runLoop); - - g_notifyPortRef = NULL; - g_notifierObject = 0; - g_rootPort = 0; - g_runLoop = NULL; -} - -*/ -import "C" - import ( - "context" "fmt" "runtime" "sync" "time" + "unsafe" + "github.com/ebitengine/purego" log "github.com/sirupsen/logrus" ) -var ( - serviceRegistry = make(map[*Detector]struct{}) - serviceRegistryMu sync.Mutex +// IOKit message types from IOKit/IOMessage.h. +const ( + kIOMessageCanSystemSleep uintptr = 0xe0000270 + kIOMessageSystemWillSleep uintptr = 0xe0000280 + kIOMessageSystemHasPoweredOn uintptr = 0xe0000300 ) -//export sleepCallbackBridge -func sleepCallbackBridge() { - log.Info("sleepCallbackBridge event triggered") +var ( + ioKit iokitFuncs + cf cfFuncs + cfCommonModes uintptr - serviceRegistryMu.Lock() - defer serviceRegistryMu.Unlock() + libInitOnce sync.Once + libInitErr error - for svc := range serviceRegistry { - svc.triggerCallback(EventTypeSleep) - } + // callbackThunk is the single C-callable trampoline registered with IOKit. + callbackThunk uintptr + + serviceRegistry = make(map[*Detector]struct{}) + serviceRegistryMu sync.Mutex + session *runLoopSession + + // lifecycleMu serializes Register/Deregister so a new registration can't + // start a second runloop while a previous teardown is still pending. + lifecycleMu sync.Mutex +) + +// iokitFuncs holds IOKit symbols resolved once at init. +type iokitFuncs struct { + IORegisterForSystemPower func(refcon uintptr, portRef *uintptr, callback uintptr, notifier *uintptr) uintptr + IODeregisterForSystemPower func(notifier *uintptr) int32 + IOAllowPowerChange func(kernelPort uintptr, notificationID uintptr) int32 + IOServiceClose func(connect uintptr) int32 + IONotificationPortGetRunLoopSource func(port uintptr) uintptr + IONotificationPortDestroy func(port uintptr) } -//export resumedCallbackBridge -func resumedCallbackBridge() { - log.Info("resumedCallbackBridge event triggered") +// cfFuncs holds CoreFoundation symbols resolved once at init. +type cfFuncs struct { + CFRunLoopGetCurrent func() uintptr + CFRunLoopRun func() + CFRunLoopStop func(rl uintptr) + CFRunLoopAddSource func(rl, source, mode uintptr) + CFRunLoopRemoveSource func(rl, source, mode uintptr) } -//export suspendedCallbackBridge -func suspendedCallbackBridge() { - log.Info("suspendedCallbackBridge event triggered") +// runLoopSession bundles the handles owned by one CFRunLoop lifetime. A nil +// session means no runloop is active and the next Register must start one. +type runLoopSession struct { + rl uintptr + port uintptr + notifier uintptr + rp uintptr } -//export poweredOnCallbackBridge -func poweredOnCallbackBridge() { - log.Info("poweredOnCallbackBridge event triggered") - serviceRegistryMu.Lock() - defer serviceRegistryMu.Unlock() - - for svc := range serviceRegistry { - svc.triggerCallback(EventTypeWakeUp) - } +// detectorSnapshot pins a detector's callback and done channel so dispatch +// runs with values valid at snapshot time, even if a concurrent +// Deregister/Register rewrites the detector's fields. +type detectorSnapshot struct { + detector *Detector + callback func(event EventType) + done <-chan struct{} } +// Detector delivers sleep and wake events to a registered callback. type Detector struct { callback func(event EventType) - ctx context.Context - cancel context.CancelFunc -} - -func NewDetector() (*Detector, error) { - return &Detector{}, nil + done chan struct{} } +// Register installs callback for power events. The first registration starts +// the CFRunLoop on a dedicated OS-locked thread and blocks until IOKit +// registration succeeds or fails; subsequent registrations just add to the +// dispatch set. func (d *Detector) Register(callback func(event EventType)) error { - serviceRegistryMu.Lock() - defer serviceRegistryMu.Unlock() + lifecycleMu.Lock() + defer lifecycleMu.Unlock() + serviceRegistryMu.Lock() if _, exists := serviceRegistry[d]; exists { + serviceRegistryMu.Unlock() return fmt.Errorf("detector service already registered") } - d.callback = callback + d.done = make(chan struct{}) + serviceRegistry[d] = struct{}{} + needSetup := session == nil + serviceRegistryMu.Unlock() - d.ctx, d.cancel = context.WithCancel(context.Background()) - - if len(serviceRegistry) > 0 { - serviceRegistry[d] = struct{}{} + if !needSetup { return nil } - serviceRegistry[d] = struct{}{} - - // CFRunLoop must run on a single fixed OS thread - go func() { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - C.registerNotifications() - }() + errCh := make(chan error, 1) + go runRunLoop(errCh) + if err := <-errCh; err != nil { + serviceRegistryMu.Lock() + delete(serviceRegistry, d) + close(d.done) + d.done = nil + serviceRegistryMu.Unlock() + return err + } log.Info("sleep detection service started on macOS") return nil } -// Deregister removes the detector. When the last detector is removed, IOKit registration is torn down -// and the runloop is stopped and cleaned up. +// Deregister removes the detector. When the last detector leaves, IOKit +// notifications are torn down and the runloop is stopped. func (d *Detector) Deregister() error { + lifecycleMu.Lock() + defer lifecycleMu.Unlock() + serviceRegistryMu.Lock() - defer serviceRegistryMu.Unlock() - _, exists := serviceRegistry[d] - if !exists { + if _, exists := serviceRegistry[d]; !exists { + serviceRegistryMu.Unlock() return nil } - - // cancel and remove this detector - d.cancel() + close(d.done) delete(serviceRegistry, d) - // If other Detectors still exist, leave IOKit running if len(serviceRegistry) > 0 { + serviceRegistryMu.Unlock() return nil } + sess := session + serviceRegistryMu.Unlock() log.Info("sleep detection service stopping (deregister)") - // Deregister IOKit notifications, stop runloop, and free resources - C.unregisterNotifications() + if sess == nil { + return nil + } + + if sess.rl != 0 && sess.port != 0 { + source := ioKit.IONotificationPortGetRunLoopSource(sess.port) + cf.CFRunLoopRemoveSource(sess.rl, source, cfCommonModes) + } + if sess.notifier != 0 { + n := sess.notifier + ioKit.IODeregisterForSystemPower(&n) + } + + // Clear session only after IODeregisterForSystemPower returns so any + // in-flight powerCallback can still look up session.rp to ack sleep. + serviceRegistryMu.Lock() + session = nil + serviceRegistryMu.Unlock() + + if sess.rp != 0 { + ioKit.IOServiceClose(sess.rp) + } + if sess.port != 0 { + ioKit.IONotificationPortDestroy(sess.port) + } + if sess.rl != 0 { + cf.CFRunLoopStop(sess.rl) + } return nil } -func (d *Detector) triggerCallback(event EventType) { - doneChan := make(chan struct{}) +func (d *Detector) triggerCallback(event EventType, cb func(event EventType), done <-chan struct{}) { + if cb == nil || done == nil { + return + } + select { + case <-done: + return + default: + } + + doneChan := make(chan struct{}) timeout := time.NewTimer(500 * time.Millisecond) defer timeout.Stop() - cb := d.callback - go func(callback func(event EventType)) { + go func() { + defer close(doneChan) + defer func() { + if r := recover(); r != nil { + log.Errorf("panic in sleep callback: %v", r) + } + }() log.Info("sleep detection event fired") - callback(event) - close(doneChan) - }(cb) + cb(event) + }() select { case <-doneChan: - case <-d.ctx.Done(): + case <-done: case <-timeout.C: - log.Warnf("sleep callback timed out") + log.Warn("sleep callback timed out") } } + +// NewDetector initializes IOKit/CoreFoundation bindings and returns a Detector. +func NewDetector() (*Detector, error) { + if err := initLibs(); err != nil { + return nil, err + } + return &Detector{}, nil +} + +func initLibs() error { + libInitOnce.Do(func() { + iokit, err := purego.Dlopen("/System/Library/Frameworks/IOKit.framework/IOKit", purego.RTLD_NOW|purego.RTLD_GLOBAL) + if err != nil { + libInitErr = fmt.Errorf("dlopen IOKit: %w", err) + return + } + cfLib, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL) + if err != nil { + libInitErr = fmt.Errorf("dlopen CoreFoundation: %w", err) + return + } + + purego.RegisterLibFunc(&ioKit.IORegisterForSystemPower, iokit, "IORegisterForSystemPower") + purego.RegisterLibFunc(&ioKit.IODeregisterForSystemPower, iokit, "IODeregisterForSystemPower") + purego.RegisterLibFunc(&ioKit.IOAllowPowerChange, iokit, "IOAllowPowerChange") + purego.RegisterLibFunc(&ioKit.IOServiceClose, iokit, "IOServiceClose") + purego.RegisterLibFunc(&ioKit.IONotificationPortGetRunLoopSource, iokit, "IONotificationPortGetRunLoopSource") + purego.RegisterLibFunc(&ioKit.IONotificationPortDestroy, iokit, "IONotificationPortDestroy") + + purego.RegisterLibFunc(&cf.CFRunLoopGetCurrent, cfLib, "CFRunLoopGetCurrent") + purego.RegisterLibFunc(&cf.CFRunLoopRun, cfLib, "CFRunLoopRun") + purego.RegisterLibFunc(&cf.CFRunLoopStop, cfLib, "CFRunLoopStop") + purego.RegisterLibFunc(&cf.CFRunLoopAddSource, cfLib, "CFRunLoopAddSource") + purego.RegisterLibFunc(&cf.CFRunLoopRemoveSource, cfLib, "CFRunLoopRemoveSource") + + modeAddr, err := purego.Dlsym(cfLib, "kCFRunLoopCommonModes") + if err != nil { + libInitErr = fmt.Errorf("dlsym kCFRunLoopCommonModes: %w", err) + return + } + // Launder the uintptr-to-pointer conversion through a Go variable so + // go vet's unsafeptr analyzer doesn't flag a system-library global. + cfCommonModes = **(**uintptr)(unsafe.Pointer(&modeAddr)) + + // NewCallback slots are a finite, non-reclaimable resource, so register + // a single thunk that dispatches to the current Detector set. + callbackThunk = purego.NewCallback(powerCallback) + }) + return libInitErr +} + +// powerCallback is the IOServiceInterestCallback trampoline, invoked on the +// runloop thread. A Go panic crossing the purego boundary has undefined +// behavior, so contain it here. +func powerCallback(refcon, service, messageType, messageArgument uintptr) uintptr { + defer func() { + if r := recover(); r != nil { + log.Errorf("panic in sleep powerCallback: %v", r) + } + }() + switch messageType { + case kIOMessageCanSystemSleep: + // Not acknowledging forces a 30s IOKit timeout before idle sleep. + allowPowerChange(messageArgument) + case kIOMessageSystemWillSleep: + dispatchEvent(EventTypeSleep) + allowPowerChange(messageArgument) + case kIOMessageSystemHasPoweredOn: + dispatchEvent(EventTypeWakeUp) + } + return 0 +} + +func allowPowerChange(messageArgument uintptr) { + serviceRegistryMu.Lock() + var port uintptr + if session != nil { + port = session.rp + } + serviceRegistryMu.Unlock() + if port != 0 { + ioKit.IOAllowPowerChange(port, messageArgument) + } +} + +func dispatchEvent(event EventType) { + serviceRegistryMu.Lock() + snaps := make([]detectorSnapshot, 0, len(serviceRegistry)) + for d := range serviceRegistry { + snaps = append(snaps, detectorSnapshot{ + detector: d, + callback: d.callback, + done: d.done, + }) + } + serviceRegistryMu.Unlock() + + for _, s := range snaps { + s.detector.triggerCallback(event, s.callback, s.done) + } +} + +// runRunLoop owns the OS-locked thread that CFRunLoop is pinned to. Setup +// result is reported on errCh so Register can surface failures synchronously. +func runRunLoop(errCh chan<- error) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + sess, err := setupSession() + if err == nil { + serviceRegistryMu.Lock() + session = sess + serviceRegistryMu.Unlock() + } + errCh <- err + if err != nil { + return + } + + defer func() { + if r := recover(); r != nil { + log.Errorf("panic in sleep runloop: %v", r) + } + }() + cf.CFRunLoopRun() +} + +// setupSession performs the IOKit registration on the current thread. Panics +// are converted to errors so runRunLoop never leaves errCh unsent. +func setupSession() (s *runLoopSession, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic during runloop setup: %v", r) + } + }() + + var portRef, notifier uintptr + rp := ioKit.IORegisterForSystemPower(0, &portRef, callbackThunk, ¬ifier) + if rp == 0 { + return nil, fmt.Errorf("IORegisterForSystemPower returned zero") + } + + rl := cf.CFRunLoopGetCurrent() + source := ioKit.IONotificationPortGetRunLoopSource(portRef) + cf.CFRunLoopAddSource(rl, source, cfCommonModes) + + return &runLoopSession{rl: rl, port: portRef, notifier: notifier, rp: rp}, nil +} diff --git a/client/internal/wg_iface_monitor.go b/client/internal/wg_iface_monitor.go index a870c1145..2a2fa2366 100644 --- a/client/internal/wg_iface_monitor.go +++ b/client/internal/wg_iface_monitor.go @@ -6,7 +6,6 @@ import ( "fmt" "net" "runtime" - "time" log "github.com/sirupsen/logrus" @@ -28,6 +27,10 @@ func NewWGIfaceMonitor() *WGIfaceMonitor { // Start begins monitoring the WireGuard interface. // It relies on the provided context cancellation to stop. +// +// On Linux the watcher is event-driven (RTNLGRP_LINK netlink subscription) +// to avoid the allocation churn of repeatedly dumping the kernel link +// table; on other platforms it falls back to a low-frequency poll. func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) { defer close(m.done) @@ -56,31 +59,7 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex) - ticker := time.NewTicker(2 * time.Second) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - log.Infof("Interface monitor: stopped for %s", ifaceName) - return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err()) - case <-ticker.C: - currentIndex, err := getInterfaceIndex(ifaceName) - if err != nil { - // Interface was deleted - log.Infof("Interface monitor: %s deleted", ifaceName) - return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err) - } - - // Check if interface index changed (interface was recreated) - if currentIndex != expectedIndex { - log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine", - ifaceName, expectedIndex, currentIndex) - return true, nil - } - } - } - + return watchInterface(ctx, ifaceName, expectedIndex) } // getInterfaceIndex returns the index of a network interface by name. diff --git a/client/internal/wg_iface_monitor_linux.go b/client/internal/wg_iface_monitor_linux.go new file mode 100644 index 000000000..2662b99d6 --- /dev/null +++ b/client/internal/wg_iface_monitor_linux.go @@ -0,0 +1,134 @@ +//go:build linux + +package internal + +import ( + "context" + "fmt" + "syscall" + + log "github.com/sirupsen/logrus" + "github.com/vishvananda/netlink" +) + +// watchInterface uses an RTNLGRP_LINK netlink subscription to detect +// deletion or recreation of the WireGuard interface. +// +// The previous implementation polled net.InterfaceByName every 2 s, which +// on Linux issues syscall.NetlinkRIB(RTM_GETLINK, ...) and dumps the +// entire kernel link table on every call. On hosts with many veth +// interfaces (containers, bridges) the resulting allocation churn was on +// the order of ~1 GB/day from this single ticker, which on small ARM +// hosts manifested as a slow RSS climb (see netbirdio/netbird#3678). +// +// The event-driven version below allocates only when the kernel actually +// publishes a link event for the tracked interface — typically zero +// allocations between events. +func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) { + done := make(chan struct{}) + defer close(done) + + // Buffer the channel to absorb event bursts (e.g. when many veth + // pairs are created/destroyed at once by container runtimes). + linkChan := make(chan netlink.LinkUpdate, 32) + if err := netlink.LinkSubscribe(linkChan, done); err != nil { + // Return shouldRestart=true so the engine recovers monitoring + // via triggerClientRestart instead of silently losing it for + // the rest of the process lifetime. + return true, fmt.Errorf("subscribe to link updates: %w", err) + } + + // Race window: the interface could have been deleted (or recreated) + // between the initial getInterfaceIndex() in Start and LinkSubscribe + // completing its handshake with the kernel. Re-check explicitly so we + // do not block forever waiting for an event that already fired. + if currentIndex, err := getInterfaceIndex(ifaceName); err != nil { + log.Infof("Interface monitor: %s deleted before subscription completed", ifaceName) + return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err) + } else if currentIndex != expectedIndex { + log.Infof("Interface monitor: %s recreated (index changed from %d to %d) before subscription completed", + ifaceName, expectedIndex, currentIndex) + return true, nil + } + + for { + select { + case <-ctx.Done(): + log.Infof("Interface monitor: stopped for %s", ifaceName) + return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err()) + + case update, ok := <-linkChan: + if !ok { + // The vishvananda/netlink subscription goroutine closes + // the channel on receive errors. Signal the engine to + // restart so monitoring is re-established instead of + // silently ending. + log.Warnf("Interface monitor: link subscription channel closed unexpectedly for %s", ifaceName) + return true, fmt.Errorf("link subscription channel closed unexpectedly") + } + if restart, err := inspectLinkEvent(update, ifaceName, expectedIndex); restart { + return true, err + } + } + } +} + +// inspectLinkEvent classifies a single netlink link update against the +// tracked WireGuard interface. It returns (true, err) when the engine +// should restart monitoring; (false, nil) means the event is unrelated +// and the caller should keep waiting. +// +// The error component, when non-nil, describes the kernel-side reason +// (deletion or rename); the recreation case returns (true, nil) since +// no error condition is reported. +func inspectLinkEvent(update netlink.LinkUpdate, ifaceName string, expectedIndex int) (bool, error) { + eventIndex := int(update.Index) + eventName := "" + if attrs := update.Attrs(); attrs != nil { + eventName = attrs.Name + } + + switch update.Header.Type { + case syscall.RTM_DELLINK: + return inspectDelLink(eventIndex, ifaceName, expectedIndex) + case syscall.RTM_NEWLINK: + return inspectNewLink(eventIndex, eventName, ifaceName, expectedIndex) + } + return false, nil +} + +// inspectDelLink reports a restart when an RTM_DELLINK arrives for the +// tracked interface index. +func inspectDelLink(eventIndex int, ifaceName string, expectedIndex int) (bool, error) { + if eventIndex != expectedIndex { + return false, nil + } + log.Infof("Interface monitor: %s deleted", ifaceName) + return true, fmt.Errorf("interface %s deleted", ifaceName) +} + +// inspectNewLink reports a restart when an RTM_NEWLINK either: +// +// 1. Introduces a link with our name at a different index (recreation +// after a delete), or +// +// 2. Reports a link still at our index but with a different name +// (in-place rename). The previous polling implementation caught +// this implicitly because net.InterfaceByName(ifaceName) would +// start failing; the event-driven version has to test it. +// +// Same name + same index is just a flag/state change on the existing +// interface and is ignored. +func inspectNewLink(eventIndex int, eventName, ifaceName string, expectedIndex int) (bool, error) { + if eventName == ifaceName && eventIndex != expectedIndex { + log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine", + ifaceName, expectedIndex, eventIndex) + return true, nil + } + if eventIndex == expectedIndex && eventName != "" && eventName != ifaceName { + log.Infof("Interface monitor: %s renamed to %s (index %d), restarting engine", + ifaceName, eventName, expectedIndex) + return true, fmt.Errorf("interface %s renamed to %s", ifaceName, eventName) + } + return false, nil +} diff --git a/client/internal/wg_iface_monitor_other.go b/client/internal/wg_iface_monitor_other.go new file mode 100644 index 000000000..afebbf4df --- /dev/null +++ b/client/internal/wg_iface_monitor_other.go @@ -0,0 +1,56 @@ +//go:build !linux + +package internal + +import ( + "context" + "fmt" + "time" + + log "github.com/sirupsen/logrus" +) + +// watchInterface polls net.InterfaceByName at a fixed interval to detect +// deletion or recreation of the WireGuard interface. +// +// This is the fallback used on non-Linux desktop and server platforms +// (darwin, windows, freebsd). It is also compiled on android and ios so +// the package builds on every supported GOOS, but it is never reached +// at runtime there because Start() in wg_iface_monitor.go exits early +// on mobile platforms. +// +// The Linux build (see wg_iface_monitor_linux.go) uses an event-driven +// RTNLGRP_LINK netlink subscription instead, because on Linux +// net.InterfaceByName issues syscall.NetlinkRIB(RTM_GETLINK, ...) which +// dumps the entire kernel link table on every call and produces +// significant allocation churn (netbirdio/netbird#3678). +// +// Windows is also reported in #3678 as affected by RSS climb. A future +// follow-up could implement an event-driven watcher there using +// NotifyIpInterfaceChange from iphlpapi. +func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) { + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Infof("Interface monitor: stopped for %s", ifaceName) + return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err()) + case <-ticker.C: + currentIndex, err := getInterfaceIndex(ifaceName) + if err != nil { + // Interface was deleted + log.Infof("Interface monitor: %s deleted", ifaceName) + return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err) + } + + // Check if interface index changed (interface was recreated) + if currentIndex != expectedIndex { + log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine", + ifaceName, expectedIndex, currentIndex) + return true, nil + } + } + } +} diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 043673904..33f5ab1b0 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -50,10 +50,11 @@ type CustomLogger interface { } type selectRoute struct { - NetID string - Network netip.Prefix - Domains domain.List - Selected bool + NetID string + Network netip.Prefix + Domains domain.List + Selected bool + extraNetworks []netip.Prefix } func init() { @@ -198,6 +199,7 @@ func (c *Client) GetStatusDetails() *StatusDetails { } pi := PeerInfo{ IP: p.IP, + IPv6: p.IPv6, FQDN: p.FQDN, LocalIceCandidateEndpoint: p.LocalIceCandidateEndpoint, RemoteIceCandidateEndpoint: p.RemoteIceCandidateEndpoint, @@ -216,7 +218,7 @@ func (c *Client) GetStatusDetails() *StatusDetails { } peerInfos[n] = pi } - return &StatusDetails{items: peerInfos, fqdn: fullStatus.LocalPeerState.FQDN, ip: fullStatus.LocalPeerState.IP} + return &StatusDetails{items: peerInfos, fqdn: fullStatus.LocalPeerState.FQDN, ip: fullStatus.LocalPeerState.IP, ipv6: fullStatus.LocalPeerState.IPv6} } // SetConnectionListener set the network connection listener @@ -366,72 +368,102 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) { } routeManager := engine.GetRouteManager() - routesMap := routeManager.GetClientRoutesWithNetID() if routeManager == nil { return nil, fmt.Errorf("could not get route manager") } + routesMap := routeManager.GetClientRoutesWithNetID() routeSelector := routeManager.GetRouteSelector() if routeSelector == nil { return nil, fmt.Errorf("could not get route selector") } + v6ExitMerged := route.V6ExitMergeSet(routesMap) + routes := buildSelectRoutes(routesMap, routeSelector.IsSelected, v6ExitMerged) + resolvedDomains := c.recorder.GetResolvedDomainsStates() + + return prepareRouteSelectionDetails(routes, resolvedDomains), nil +} + +func buildSelectRoutes(routesMap map[route.NetID][]*route.Route, isSelected func(route.NetID) bool, v6Merged map[route.NetID]struct{}) []*selectRoute { var routes []*selectRoute for id, rt := range routesMap { if len(rt) == 0 { continue } - route := &selectRoute{ + if _, ok := v6Merged[id]; ok { + continue + } + + r := &selectRoute{ NetID: string(id), Network: rt[0].Network, Domains: rt[0].Domains, - Selected: routeSelector.IsSelected(id), + Selected: isSelected(id), } - routes = append(routes, route) + + v6ID := route.NetID(string(id) + route.V6ExitSuffix) + if _, ok := v6Merged[v6ID]; ok { + r.extraNetworks = []netip.Prefix{routesMap[v6ID][0].Network} + } + + routes = append(routes, r) } sort.Slice(routes, func(i, j int) bool { - iPrefix := routes[i].Network.Bits() - jPrefix := routes[j].Network.Bits() - - if iPrefix == jPrefix { - iAddr := routes[i].Network.Addr() - jAddr := routes[j].Network.Addr() - if iAddr == jAddr { - return routes[i].NetID < routes[j].NetID - } - return iAddr.String() < jAddr.String() + iBits, jBits := routes[i].Network.Bits(), routes[j].Network.Bits() + if iBits != jBits { + return iBits < jBits } - return iPrefix < jPrefix + iAddr, jAddr := routes[i].Network.Addr(), routes[j].Network.Addr() + if iAddr != jAddr { + return iAddr.Less(jAddr) + } + return routes[i].NetID < routes[j].NetID }) - resolvedDomains := c.recorder.GetResolvedDomainsStates() - - return prepareRouteSelectionDetails(routes, resolvedDomains), nil - + return routes } func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails { var routeSelection []RoutesSelectionInfo for _, r := range routes { - domainList := make([]DomainInfo, 0) + // resolvedDomains is keyed by the resolved domain (e.g. api.ipify.org), + // not the configured pattern (e.g. *.ipify.org). Group entries whose + // ParentDomain belongs to this route, mirroring the daemon logic in + // client/server/network.go. + domainList := make([]DomainInfo, 0, len(r.Domains)) + domainIndex := make(map[domain.Domain]int, len(r.Domains)) for _, d := range r.Domains { - domainResp := DomainInfo{ - Domain: d.SafeString(), - } - - if info, exists := resolvedDomains[d]; exists { - var ipStrings []string - for _, prefix := range info.Prefixes { - ipStrings = append(ipStrings, prefix.Addr().String()) - } - domainResp.ResolvedIPs = strings.Join(ipStrings, ", ") - } - domainList = append(domainList, domainResp) + domainIndex[d] = len(domainList) + domainList = append(domainList, DomainInfo{Domain: d.SafeString()}) } + + for _, info := range resolvedDomains { + idx, ok := domainIndex[info.ParentDomain] + if !ok { + continue + } + for _, prefix := range info.Prefixes { + domainList[idx].AddResolvedIP(prefix.Addr().String()) + } + } + domainDetails := DomainDetails{items: domainList} + + // For dynamic (DNS) routes, expose the joined domain pattern as the + // Network value so it matches the peer.routes entries on the Swift + // side (mirroring the Android bridge in client/android/client.go). + netStr := r.Network.String() + if len(r.Domains) > 0 { + netStr = r.Domains.SafeString() + } + for _, extra := range r.extraNetworks { + netStr += ", " + extra.String() + } + routeSelection = append(routeSelection, RoutesSelectionInfo{ ID: r.NetID, - Network: r.Network.String(), + Network: netStr, Domains: &domainDetails, Selected: r.Selected, }) @@ -459,7 +491,9 @@ func (c *Client) SelectRoute(id string) error { } else { log.Debugf("select route with id: %s", id) routes := toNetIDs([]string{id}) - if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil { + routesMap := routeManager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) + if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routesMap)); err != nil { log.Debugf("error when selecting routes: %s", err) return fmt.Errorf("select routes: %w", err) } @@ -486,7 +520,9 @@ func (c *Client) DeselectRoute(id string) error { } else { log.Debugf("deselect route with id: %s", id) routes := toNetIDs([]string{id}) - if err := routeSelector.DeselectRoutes(routes, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil { + routesMap := routeManager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) + if err := routeSelector.DeselectRoutes(routes, maps.Keys(routesMap)); err != nil { log.Debugf("error when deselecting routes: %s", err) return fmt.Errorf("deselect routes: %w", err) } diff --git a/client/ios/NetBirdSDK/peer_notifier.go b/client/ios/NetBirdSDK/peer_notifier.go index 9b00568be..025cd94cd 100644 --- a/client/ios/NetBirdSDK/peer_notifier.go +++ b/client/ios/NetBirdSDK/peer_notifier.go @@ -5,6 +5,7 @@ package NetBirdSDK // PeerInfo describe information about the peers. It designed for the UI usage type PeerInfo struct { IP string + IPv6 string FQDN string LocalIceCandidateEndpoint string RemoteIceCandidateEndpoint string @@ -23,6 +24,11 @@ type PeerInfo struct { Routes RoutesDetails } +// GetIPv6 returns the IPv6 address of the peer +func (p PeerInfo) GetIPv6() string { + return p.IPv6 +} + // GetRoutes return with RouteDetails func (p PeerInfo) GetRouteDetails() *RoutesDetails { return &p.Routes @@ -57,6 +63,7 @@ type StatusDetails struct { items []PeerInfo fqdn string ip string + ipv6 string } // Add new PeerInfo to the collection @@ -100,3 +107,8 @@ func (array StatusDetails) GetFQDN() string { func (array StatusDetails) GetIP() string { return array.ip } + +// GetIPv6 return with the IPv6 of the local peer +func (array StatusDetails) GetIPv6() string { + return array.ipv6 +} diff --git a/client/ios/NetBirdSDK/preferences.go b/client/ios/NetBirdSDK/preferences.go index c26a6decd..ed49ccddb 100644 --- a/client/ios/NetBirdSDK/preferences.go +++ b/client/ios/NetBirdSDK/preferences.go @@ -110,6 +110,24 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) { return cfg.RosenpassPermissive, err } +// GetDisableIPv6 reads disable IPv6 setting from config file +func (p *Preferences) GetDisableIPv6() (bool, error) { + if p.configInput.DisableIPv6 != nil { + return *p.configInput.DisableIPv6, nil + } + + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + return cfg.DisableIPv6, err +} + +// SetDisableIPv6 stores the given value and waits for commit +func (p *Preferences) SetDisableIPv6(disable bool) { + p.configInput.DisableIPv6 = &disable +} + // Commit write out the changes into config file func (p *Preferences) Commit() error { // Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename) diff --git a/client/ios/NetBirdSDK/routes.go b/client/ios/NetBirdSDK/routes.go index 7b84d6e1c..025313bfa 100644 --- a/client/ios/NetBirdSDK/routes.go +++ b/client/ios/NetBirdSDK/routes.go @@ -34,7 +34,34 @@ type DomainDetails struct { type DomainInfo struct { Domain string - ResolvedIPs string + resolvedIPs ResolvedIPs +} + +func (d *DomainInfo) AddResolvedIP(ipAddress string) { + d.resolvedIPs.Add(ipAddress) +} + +func (d *DomainInfo) GetResolvedIPs() *ResolvedIPs { + return &d.resolvedIPs +} + +type ResolvedIPs struct { + items []string +} + +func (r *ResolvedIPs) Add(ipAddress string) { + r.items = append(r.items, ipAddress) +} + +func (r *ResolvedIPs) Get(i int) string { + if i < 0 || i >= len(r.items) { + return "" + } + return r.items[i] +} + +func (r *ResolvedIPs) Size() int { + return len(r.items) } // Add new PeerInfo to the collection diff --git a/client/netbird.wxs b/client/netbird.wxs index 03221dd91..6f18b63b5 100644 --- a/client/netbird.wxs +++ b/client/netbird.wxs @@ -13,15 +13,25 @@ + + + - - + + + + + + + + + @@ -46,8 +56,30 @@ + + + + + + + + + + + + + + + + + + + diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 6506307d3..2c054c99a 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -143,56 +143,6 @@ func (ExposeProtocol) EnumDescriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{1} } -// avoid collision with loglevel enum -type OSLifecycleRequest_CycleType int32 - -const ( - OSLifecycleRequest_UNKNOWN OSLifecycleRequest_CycleType = 0 - OSLifecycleRequest_SLEEP OSLifecycleRequest_CycleType = 1 - OSLifecycleRequest_WAKEUP OSLifecycleRequest_CycleType = 2 -) - -// Enum value maps for OSLifecycleRequest_CycleType. -var ( - OSLifecycleRequest_CycleType_name = map[int32]string{ - 0: "UNKNOWN", - 1: "SLEEP", - 2: "WAKEUP", - } - OSLifecycleRequest_CycleType_value = map[string]int32{ - "UNKNOWN": 0, - "SLEEP": 1, - "WAKEUP": 2, - } -) - -func (x OSLifecycleRequest_CycleType) Enum() *OSLifecycleRequest_CycleType { - p := new(OSLifecycleRequest_CycleType) - *p = x - return p -} - -func (x OSLifecycleRequest_CycleType) String() string { - return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) -} - -func (OSLifecycleRequest_CycleType) Descriptor() protoreflect.EnumDescriptor { - return file_daemon_proto_enumTypes[2].Descriptor() -} - -func (OSLifecycleRequest_CycleType) Type() protoreflect.EnumType { - return &file_daemon_proto_enumTypes[2] -} - -func (x OSLifecycleRequest_CycleType) Number() protoreflect.EnumNumber { - return protoreflect.EnumNumber(x) -} - -// Deprecated: Use OSLifecycleRequest_CycleType.Descriptor instead. -func (OSLifecycleRequest_CycleType) EnumDescriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{1, 0} -} - type SystemEvent_Severity int32 const ( @@ -229,11 +179,11 @@ func (x SystemEvent_Severity) String() string { } func (SystemEvent_Severity) Descriptor() protoreflect.EnumDescriptor { - return file_daemon_proto_enumTypes[3].Descriptor() + return file_daemon_proto_enumTypes[2].Descriptor() } func (SystemEvent_Severity) Type() protoreflect.EnumType { - return &file_daemon_proto_enumTypes[3] + return &file_daemon_proto_enumTypes[2] } func (x SystemEvent_Severity) Number() protoreflect.EnumNumber { @@ -242,7 +192,7 @@ func (x SystemEvent_Severity) Number() protoreflect.EnumNumber { // Deprecated: Use SystemEvent_Severity.Descriptor instead. func (SystemEvent_Severity) EnumDescriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{53, 0} + return file_daemon_proto_rawDescGZIP(), []int{51, 0} } type SystemEvent_Category int32 @@ -284,11 +234,11 @@ func (x SystemEvent_Category) String() string { } func (SystemEvent_Category) Descriptor() protoreflect.EnumDescriptor { - return file_daemon_proto_enumTypes[4].Descriptor() + return file_daemon_proto_enumTypes[3].Descriptor() } func (SystemEvent_Category) Type() protoreflect.EnumType { - return &file_daemon_proto_enumTypes[4] + return &file_daemon_proto_enumTypes[3] } func (x SystemEvent_Category) Number() protoreflect.EnumNumber { @@ -297,7 +247,7 @@ func (x SystemEvent_Category) Number() protoreflect.EnumNumber { // Deprecated: Use SystemEvent_Category.Descriptor instead. func (SystemEvent_Category) EnumDescriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{53, 1} + return file_daemon_proto_rawDescGZIP(), []int{51, 1} } type EmptyRequest struct { @@ -336,86 +286,6 @@ func (*EmptyRequest) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{0} } -type OSLifecycleRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - Type OSLifecycleRequest_CycleType `protobuf:"varint,1,opt,name=type,proto3,enum=daemon.OSLifecycleRequest_CycleType" json:"type,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *OSLifecycleRequest) Reset() { - *x = OSLifecycleRequest{} - mi := &file_daemon_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *OSLifecycleRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*OSLifecycleRequest) ProtoMessage() {} - -func (x *OSLifecycleRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[1] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use OSLifecycleRequest.ProtoReflect.Descriptor instead. -func (*OSLifecycleRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{1} -} - -func (x *OSLifecycleRequest) GetType() OSLifecycleRequest_CycleType { - if x != nil { - return x.Type - } - return OSLifecycleRequest_UNKNOWN -} - -type OSLifecycleResponse struct { - state protoimpl.MessageState `protogen:"open.v1"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *OSLifecycleResponse) Reset() { - *x = OSLifecycleResponse{} - mi := &file_daemon_proto_msgTypes[2] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *OSLifecycleResponse) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*OSLifecycleResponse) ProtoMessage() {} - -func (x *OSLifecycleResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[2] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use OSLifecycleResponse.ProtoReflect.Descriptor instead. -func (*OSLifecycleResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{2} -} - type LoginRequest struct { state protoimpl.MessageState `protogen:"open.v1"` // setupKey netbird setup key. @@ -472,13 +342,14 @@ type LoginRequest struct { EnableSSHRemotePortForwarding *bool `protobuf:"varint,37,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"` DisableSSHAuth *bool `protobuf:"varint,38,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"` SshJWTCacheTTL *int32 `protobuf:"varint,39,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"` + DisableIpv6 *bool `protobuf:"varint,40,opt,name=disable_ipv6,json=disableIpv6,proto3,oneof" json:"disable_ipv6,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *LoginRequest) Reset() { *x = LoginRequest{} - mi := &file_daemon_proto_msgTypes[3] + mi := &file_daemon_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -490,7 +361,7 @@ func (x *LoginRequest) String() string { func (*LoginRequest) ProtoMessage() {} func (x *LoginRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[3] + mi := &file_daemon_proto_msgTypes[1] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -503,7 +374,7 @@ func (x *LoginRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use LoginRequest.ProtoReflect.Descriptor instead. func (*LoginRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{3} + return file_daemon_proto_rawDescGZIP(), []int{1} } func (x *LoginRequest) GetSetupKey() string { @@ -780,6 +651,13 @@ func (x *LoginRequest) GetSshJWTCacheTTL() int32 { return 0 } +func (x *LoginRequest) GetDisableIpv6() bool { + if x != nil && x.DisableIpv6 != nil { + return *x.DisableIpv6 + } + return false +} + type LoginResponse struct { state protoimpl.MessageState `protogen:"open.v1"` NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"` @@ -792,7 +670,7 @@ type LoginResponse struct { func (x *LoginResponse) Reset() { *x = LoginResponse{} - mi := &file_daemon_proto_msgTypes[4] + mi := &file_daemon_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -804,7 +682,7 @@ func (x *LoginResponse) String() string { func (*LoginResponse) ProtoMessage() {} func (x *LoginResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[4] + mi := &file_daemon_proto_msgTypes[2] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -817,7 +695,7 @@ func (x *LoginResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use LoginResponse.ProtoReflect.Descriptor instead. func (*LoginResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{4} + return file_daemon_proto_rawDescGZIP(), []int{2} } func (x *LoginResponse) GetNeedsSSOLogin() bool { @@ -858,7 +736,7 @@ type WaitSSOLoginRequest struct { func (x *WaitSSOLoginRequest) Reset() { *x = WaitSSOLoginRequest{} - mi := &file_daemon_proto_msgTypes[5] + mi := &file_daemon_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -870,7 +748,7 @@ func (x *WaitSSOLoginRequest) String() string { func (*WaitSSOLoginRequest) ProtoMessage() {} func (x *WaitSSOLoginRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[5] + mi := &file_daemon_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -883,7 +761,7 @@ func (x *WaitSSOLoginRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitSSOLoginRequest.ProtoReflect.Descriptor instead. func (*WaitSSOLoginRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{5} + return file_daemon_proto_rawDescGZIP(), []int{3} } func (x *WaitSSOLoginRequest) GetUserCode() string { @@ -909,7 +787,7 @@ type WaitSSOLoginResponse struct { func (x *WaitSSOLoginResponse) Reset() { *x = WaitSSOLoginResponse{} - mi := &file_daemon_proto_msgTypes[6] + mi := &file_daemon_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -921,7 +799,7 @@ func (x *WaitSSOLoginResponse) String() string { func (*WaitSSOLoginResponse) ProtoMessage() {} func (x *WaitSSOLoginResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[6] + mi := &file_daemon_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -934,7 +812,7 @@ func (x *WaitSSOLoginResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitSSOLoginResponse.ProtoReflect.Descriptor instead. func (*WaitSSOLoginResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{6} + return file_daemon_proto_rawDescGZIP(), []int{4} } func (x *WaitSSOLoginResponse) GetEmail() string { @@ -954,7 +832,7 @@ type UpRequest struct { func (x *UpRequest) Reset() { *x = UpRequest{} - mi := &file_daemon_proto_msgTypes[7] + mi := &file_daemon_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -966,7 +844,7 @@ func (x *UpRequest) String() string { func (*UpRequest) ProtoMessage() {} func (x *UpRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[7] + mi := &file_daemon_proto_msgTypes[5] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -979,7 +857,7 @@ func (x *UpRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use UpRequest.ProtoReflect.Descriptor instead. func (*UpRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{7} + return file_daemon_proto_rawDescGZIP(), []int{5} } func (x *UpRequest) GetProfileName() string { @@ -1004,7 +882,7 @@ type UpResponse struct { func (x *UpResponse) Reset() { *x = UpResponse{} - mi := &file_daemon_proto_msgTypes[8] + mi := &file_daemon_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1016,7 +894,7 @@ func (x *UpResponse) String() string { func (*UpResponse) ProtoMessage() {} func (x *UpResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[8] + mi := &file_daemon_proto_msgTypes[6] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1029,7 +907,7 @@ func (x *UpResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use UpResponse.ProtoReflect.Descriptor instead. func (*UpResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{8} + return file_daemon_proto_rawDescGZIP(), []int{6} } type StatusRequest struct { @@ -1044,7 +922,7 @@ type StatusRequest struct { func (x *StatusRequest) Reset() { *x = StatusRequest{} - mi := &file_daemon_proto_msgTypes[9] + mi := &file_daemon_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1056,7 +934,7 @@ func (x *StatusRequest) String() string { func (*StatusRequest) ProtoMessage() {} func (x *StatusRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[9] + mi := &file_daemon_proto_msgTypes[7] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1069,7 +947,7 @@ func (x *StatusRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use StatusRequest.ProtoReflect.Descriptor instead. func (*StatusRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{9} + return file_daemon_proto_rawDescGZIP(), []int{7} } func (x *StatusRequest) GetGetFullPeerStatus() bool { @@ -1106,7 +984,7 @@ type StatusResponse struct { func (x *StatusResponse) Reset() { *x = StatusResponse{} - mi := &file_daemon_proto_msgTypes[10] + mi := &file_daemon_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1118,7 +996,7 @@ func (x *StatusResponse) String() string { func (*StatusResponse) ProtoMessage() {} func (x *StatusResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[10] + mi := &file_daemon_proto_msgTypes[8] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1131,7 +1009,7 @@ func (x *StatusResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use StatusResponse.ProtoReflect.Descriptor instead. func (*StatusResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{10} + return file_daemon_proto_rawDescGZIP(), []int{8} } func (x *StatusResponse) GetStatus() string { @@ -1163,7 +1041,7 @@ type DownRequest struct { func (x *DownRequest) Reset() { *x = DownRequest{} - mi := &file_daemon_proto_msgTypes[11] + mi := &file_daemon_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1175,7 +1053,7 @@ func (x *DownRequest) String() string { func (*DownRequest) ProtoMessage() {} func (x *DownRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[11] + mi := &file_daemon_proto_msgTypes[9] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1188,7 +1066,7 @@ func (x *DownRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DownRequest.ProtoReflect.Descriptor instead. func (*DownRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{11} + return file_daemon_proto_rawDescGZIP(), []int{9} } type DownResponse struct { @@ -1199,7 +1077,7 @@ type DownResponse struct { func (x *DownResponse) Reset() { *x = DownResponse{} - mi := &file_daemon_proto_msgTypes[12] + mi := &file_daemon_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1211,7 +1089,7 @@ func (x *DownResponse) String() string { func (*DownResponse) ProtoMessage() {} func (x *DownResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[12] + mi := &file_daemon_proto_msgTypes[10] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1224,7 +1102,7 @@ func (x *DownResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DownResponse.ProtoReflect.Descriptor instead. func (*DownResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{12} + return file_daemon_proto_rawDescGZIP(), []int{10} } type GetConfigRequest struct { @@ -1237,7 +1115,7 @@ type GetConfigRequest struct { func (x *GetConfigRequest) Reset() { *x = GetConfigRequest{} - mi := &file_daemon_proto_msgTypes[13] + mi := &file_daemon_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1249,7 +1127,7 @@ func (x *GetConfigRequest) String() string { func (*GetConfigRequest) ProtoMessage() {} func (x *GetConfigRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[13] + mi := &file_daemon_proto_msgTypes[11] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1262,7 +1140,7 @@ func (x *GetConfigRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetConfigRequest.ProtoReflect.Descriptor instead. func (*GetConfigRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{13} + return file_daemon_proto_rawDescGZIP(), []int{11} } func (x *GetConfigRequest) GetProfileName() string { @@ -1312,13 +1190,14 @@ type GetConfigResponse struct { EnableSSHRemotePortForwarding bool `protobuf:"varint,23,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"` DisableSSHAuth bool `protobuf:"varint,25,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"` SshJWTCacheTTL int32 `protobuf:"varint,26,opt,name=sshJWTCacheTTL,proto3" json:"sshJWTCacheTTL,omitempty"` + DisableIpv6 bool `protobuf:"varint,27,opt,name=disable_ipv6,json=disableIpv6,proto3" json:"disable_ipv6,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *GetConfigResponse) Reset() { *x = GetConfigResponse{} - mi := &file_daemon_proto_msgTypes[14] + mi := &file_daemon_proto_msgTypes[12] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1330,7 +1209,7 @@ func (x *GetConfigResponse) String() string { func (*GetConfigResponse) ProtoMessage() {} func (x *GetConfigResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[14] + mi := &file_daemon_proto_msgTypes[12] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1343,7 +1222,7 @@ func (x *GetConfigResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetConfigResponse.ProtoReflect.Descriptor instead. func (*GetConfigResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{14} + return file_daemon_proto_rawDescGZIP(), []int{12} } func (x *GetConfigResponse) GetManagementUrl() string { @@ -1528,6 +1407,13 @@ func (x *GetConfigResponse) GetSshJWTCacheTTL() int32 { return 0 } +func (x *GetConfigResponse) GetDisableIpv6() bool { + if x != nil { + return x.DisableIpv6 + } + return false +} + // PeerState contains the latest state of a peer type PeerState struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1549,13 +1435,14 @@ type PeerState struct { Latency *durationpb.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"` RelayAddress string `protobuf:"bytes,18,opt,name=relayAddress,proto3" json:"relayAddress,omitempty"` SshHostKey []byte `protobuf:"bytes,19,opt,name=sshHostKey,proto3" json:"sshHostKey,omitempty"` + Ipv6 string `protobuf:"bytes,20,opt,name=ipv6,proto3" json:"ipv6,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *PeerState) Reset() { *x = PeerState{} - mi := &file_daemon_proto_msgTypes[15] + mi := &file_daemon_proto_msgTypes[13] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1567,7 +1454,7 @@ func (x *PeerState) String() string { func (*PeerState) ProtoMessage() {} func (x *PeerState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[15] + mi := &file_daemon_proto_msgTypes[13] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1580,7 +1467,7 @@ func (x *PeerState) ProtoReflect() protoreflect.Message { // Deprecated: Use PeerState.ProtoReflect.Descriptor instead. func (*PeerState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{15} + return file_daemon_proto_rawDescGZIP(), []int{13} } func (x *PeerState) GetIP() string { @@ -1709,6 +1596,13 @@ func (x *PeerState) GetSshHostKey() []byte { return nil } +func (x *PeerState) GetIpv6() string { + if x != nil { + return x.Ipv6 + } + return "" +} + // LocalPeerState contains the latest state of the local peer type LocalPeerState struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1719,13 +1613,14 @@ type LocalPeerState struct { RosenpassEnabled bool `protobuf:"varint,5,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` RosenpassPermissive bool `protobuf:"varint,6,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"` Networks []string `protobuf:"bytes,7,rep,name=networks,proto3" json:"networks,omitempty"` + Ipv6 string `protobuf:"bytes,8,opt,name=ipv6,proto3" json:"ipv6,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *LocalPeerState) Reset() { *x = LocalPeerState{} - mi := &file_daemon_proto_msgTypes[16] + mi := &file_daemon_proto_msgTypes[14] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1737,7 +1632,7 @@ func (x *LocalPeerState) String() string { func (*LocalPeerState) ProtoMessage() {} func (x *LocalPeerState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[16] + mi := &file_daemon_proto_msgTypes[14] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1750,7 +1645,7 @@ func (x *LocalPeerState) ProtoReflect() protoreflect.Message { // Deprecated: Use LocalPeerState.ProtoReflect.Descriptor instead. func (*LocalPeerState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{16} + return file_daemon_proto_rawDescGZIP(), []int{14} } func (x *LocalPeerState) GetIP() string { @@ -1802,6 +1697,13 @@ func (x *LocalPeerState) GetNetworks() []string { return nil } +func (x *LocalPeerState) GetIpv6() string { + if x != nil { + return x.Ipv6 + } + return "" +} + // SignalState contains the latest state of a signal connection type SignalState struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1814,7 +1716,7 @@ type SignalState struct { func (x *SignalState) Reset() { *x = SignalState{} - mi := &file_daemon_proto_msgTypes[17] + mi := &file_daemon_proto_msgTypes[15] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1826,7 +1728,7 @@ func (x *SignalState) String() string { func (*SignalState) ProtoMessage() {} func (x *SignalState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[17] + mi := &file_daemon_proto_msgTypes[15] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1839,7 +1741,7 @@ func (x *SignalState) ProtoReflect() protoreflect.Message { // Deprecated: Use SignalState.ProtoReflect.Descriptor instead. func (*SignalState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{17} + return file_daemon_proto_rawDescGZIP(), []int{15} } func (x *SignalState) GetURL() string { @@ -1875,7 +1777,7 @@ type ManagementState struct { func (x *ManagementState) Reset() { *x = ManagementState{} - mi := &file_daemon_proto_msgTypes[18] + mi := &file_daemon_proto_msgTypes[16] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1887,7 +1789,7 @@ func (x *ManagementState) String() string { func (*ManagementState) ProtoMessage() {} func (x *ManagementState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[18] + mi := &file_daemon_proto_msgTypes[16] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1900,7 +1802,7 @@ func (x *ManagementState) ProtoReflect() protoreflect.Message { // Deprecated: Use ManagementState.ProtoReflect.Descriptor instead. func (*ManagementState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{18} + return file_daemon_proto_rawDescGZIP(), []int{16} } func (x *ManagementState) GetURL() string { @@ -1936,7 +1838,7 @@ type RelayState struct { func (x *RelayState) Reset() { *x = RelayState{} - mi := &file_daemon_proto_msgTypes[19] + mi := &file_daemon_proto_msgTypes[17] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1948,7 +1850,7 @@ func (x *RelayState) String() string { func (*RelayState) ProtoMessage() {} func (x *RelayState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[19] + mi := &file_daemon_proto_msgTypes[17] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1961,7 +1863,7 @@ func (x *RelayState) ProtoReflect() protoreflect.Message { // Deprecated: Use RelayState.ProtoReflect.Descriptor instead. func (*RelayState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{19} + return file_daemon_proto_rawDescGZIP(), []int{17} } func (x *RelayState) GetURI() string { @@ -1997,7 +1899,7 @@ type NSGroupState struct { func (x *NSGroupState) Reset() { *x = NSGroupState{} - mi := &file_daemon_proto_msgTypes[20] + mi := &file_daemon_proto_msgTypes[18] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2009,7 +1911,7 @@ func (x *NSGroupState) String() string { func (*NSGroupState) ProtoMessage() {} func (x *NSGroupState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[20] + mi := &file_daemon_proto_msgTypes[18] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2022,7 +1924,7 @@ func (x *NSGroupState) ProtoReflect() protoreflect.Message { // Deprecated: Use NSGroupState.ProtoReflect.Descriptor instead. func (*NSGroupState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{20} + return file_daemon_proto_rawDescGZIP(), []int{18} } func (x *NSGroupState) GetServers() []string { @@ -2067,7 +1969,7 @@ type SSHSessionInfo struct { func (x *SSHSessionInfo) Reset() { *x = SSHSessionInfo{} - mi := &file_daemon_proto_msgTypes[21] + mi := &file_daemon_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2079,7 +1981,7 @@ func (x *SSHSessionInfo) String() string { func (*SSHSessionInfo) ProtoMessage() {} func (x *SSHSessionInfo) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[21] + mi := &file_daemon_proto_msgTypes[19] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2092,7 +1994,7 @@ func (x *SSHSessionInfo) ProtoReflect() protoreflect.Message { // Deprecated: Use SSHSessionInfo.ProtoReflect.Descriptor instead. func (*SSHSessionInfo) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{21} + return file_daemon_proto_rawDescGZIP(), []int{19} } func (x *SSHSessionInfo) GetUsername() string { @@ -2141,7 +2043,7 @@ type SSHServerState struct { func (x *SSHServerState) Reset() { *x = SSHServerState{} - mi := &file_daemon_proto_msgTypes[22] + mi := &file_daemon_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2153,7 +2055,7 @@ func (x *SSHServerState) String() string { func (*SSHServerState) ProtoMessage() {} func (x *SSHServerState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[22] + mi := &file_daemon_proto_msgTypes[20] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2166,7 +2068,7 @@ func (x *SSHServerState) ProtoReflect() protoreflect.Message { // Deprecated: Use SSHServerState.ProtoReflect.Descriptor instead. func (*SSHServerState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{22} + return file_daemon_proto_rawDescGZIP(), []int{20} } func (x *SSHServerState) GetEnabled() bool { @@ -2202,7 +2104,7 @@ type FullStatus struct { func (x *FullStatus) Reset() { *x = FullStatus{} - mi := &file_daemon_proto_msgTypes[23] + mi := &file_daemon_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2214,7 +2116,7 @@ func (x *FullStatus) String() string { func (*FullStatus) ProtoMessage() {} func (x *FullStatus) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[23] + mi := &file_daemon_proto_msgTypes[21] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2227,7 +2129,7 @@ func (x *FullStatus) ProtoReflect() protoreflect.Message { // Deprecated: Use FullStatus.ProtoReflect.Descriptor instead. func (*FullStatus) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{23} + return file_daemon_proto_rawDescGZIP(), []int{21} } func (x *FullStatus) GetManagementState() *ManagementState { @@ -2309,7 +2211,7 @@ type ListNetworksRequest struct { func (x *ListNetworksRequest) Reset() { *x = ListNetworksRequest{} - mi := &file_daemon_proto_msgTypes[24] + mi := &file_daemon_proto_msgTypes[22] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2321,7 +2223,7 @@ func (x *ListNetworksRequest) String() string { func (*ListNetworksRequest) ProtoMessage() {} func (x *ListNetworksRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[24] + mi := &file_daemon_proto_msgTypes[22] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2334,7 +2236,7 @@ func (x *ListNetworksRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListNetworksRequest.ProtoReflect.Descriptor instead. func (*ListNetworksRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{24} + return file_daemon_proto_rawDescGZIP(), []int{22} } type ListNetworksResponse struct { @@ -2346,7 +2248,7 @@ type ListNetworksResponse struct { func (x *ListNetworksResponse) Reset() { *x = ListNetworksResponse{} - mi := &file_daemon_proto_msgTypes[25] + mi := &file_daemon_proto_msgTypes[23] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2358,7 +2260,7 @@ func (x *ListNetworksResponse) String() string { func (*ListNetworksResponse) ProtoMessage() {} func (x *ListNetworksResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[25] + mi := &file_daemon_proto_msgTypes[23] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2371,7 +2273,7 @@ func (x *ListNetworksResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListNetworksResponse.ProtoReflect.Descriptor instead. func (*ListNetworksResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{25} + return file_daemon_proto_rawDescGZIP(), []int{23} } func (x *ListNetworksResponse) GetRoutes() []*Network { @@ -2392,7 +2294,7 @@ type SelectNetworksRequest struct { func (x *SelectNetworksRequest) Reset() { *x = SelectNetworksRequest{} - mi := &file_daemon_proto_msgTypes[26] + mi := &file_daemon_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2404,7 +2306,7 @@ func (x *SelectNetworksRequest) String() string { func (*SelectNetworksRequest) ProtoMessage() {} func (x *SelectNetworksRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[26] + mi := &file_daemon_proto_msgTypes[24] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2417,7 +2319,7 @@ func (x *SelectNetworksRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SelectNetworksRequest.ProtoReflect.Descriptor instead. func (*SelectNetworksRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{26} + return file_daemon_proto_rawDescGZIP(), []int{24} } func (x *SelectNetworksRequest) GetNetworkIDs() []string { @@ -2449,7 +2351,7 @@ type SelectNetworksResponse struct { func (x *SelectNetworksResponse) Reset() { *x = SelectNetworksResponse{} - mi := &file_daemon_proto_msgTypes[27] + mi := &file_daemon_proto_msgTypes[25] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2461,7 +2363,7 @@ func (x *SelectNetworksResponse) String() string { func (*SelectNetworksResponse) ProtoMessage() {} func (x *SelectNetworksResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[27] + mi := &file_daemon_proto_msgTypes[25] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2474,7 +2376,7 @@ func (x *SelectNetworksResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SelectNetworksResponse.ProtoReflect.Descriptor instead. func (*SelectNetworksResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{27} + return file_daemon_proto_rawDescGZIP(), []int{25} } type IPList struct { @@ -2486,7 +2388,7 @@ type IPList struct { func (x *IPList) Reset() { *x = IPList{} - mi := &file_daemon_proto_msgTypes[28] + mi := &file_daemon_proto_msgTypes[26] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2498,7 +2400,7 @@ func (x *IPList) String() string { func (*IPList) ProtoMessage() {} func (x *IPList) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[28] + mi := &file_daemon_proto_msgTypes[26] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2511,7 +2413,7 @@ func (x *IPList) ProtoReflect() protoreflect.Message { // Deprecated: Use IPList.ProtoReflect.Descriptor instead. func (*IPList) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{28} + return file_daemon_proto_rawDescGZIP(), []int{26} } func (x *IPList) GetIps() []string { @@ -2534,7 +2436,7 @@ type Network struct { func (x *Network) Reset() { *x = Network{} - mi := &file_daemon_proto_msgTypes[29] + mi := &file_daemon_proto_msgTypes[27] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2546,7 +2448,7 @@ func (x *Network) String() string { func (*Network) ProtoMessage() {} func (x *Network) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[29] + mi := &file_daemon_proto_msgTypes[27] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2559,7 +2461,7 @@ func (x *Network) ProtoReflect() protoreflect.Message { // Deprecated: Use Network.ProtoReflect.Descriptor instead. func (*Network) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{29} + return file_daemon_proto_rawDescGZIP(), []int{27} } func (x *Network) GetID() string { @@ -2611,7 +2513,7 @@ type PortInfo struct { func (x *PortInfo) Reset() { *x = PortInfo{} - mi := &file_daemon_proto_msgTypes[30] + mi := &file_daemon_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2623,7 +2525,7 @@ func (x *PortInfo) String() string { func (*PortInfo) ProtoMessage() {} func (x *PortInfo) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[30] + mi := &file_daemon_proto_msgTypes[28] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2636,7 +2538,7 @@ func (x *PortInfo) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo.ProtoReflect.Descriptor instead. func (*PortInfo) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{30} + return file_daemon_proto_rawDescGZIP(), []int{28} } func (x *PortInfo) GetPortSelection() isPortInfo_PortSelection { @@ -2693,7 +2595,7 @@ type ForwardingRule struct { func (x *ForwardingRule) Reset() { *x = ForwardingRule{} - mi := &file_daemon_proto_msgTypes[31] + mi := &file_daemon_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2705,7 +2607,7 @@ func (x *ForwardingRule) String() string { func (*ForwardingRule) ProtoMessage() {} func (x *ForwardingRule) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[31] + mi := &file_daemon_proto_msgTypes[29] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2718,7 +2620,7 @@ func (x *ForwardingRule) ProtoReflect() protoreflect.Message { // Deprecated: Use ForwardingRule.ProtoReflect.Descriptor instead. func (*ForwardingRule) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{31} + return file_daemon_proto_rawDescGZIP(), []int{29} } func (x *ForwardingRule) GetProtocol() string { @@ -2765,7 +2667,7 @@ type ForwardingRulesResponse struct { func (x *ForwardingRulesResponse) Reset() { *x = ForwardingRulesResponse{} - mi := &file_daemon_proto_msgTypes[32] + mi := &file_daemon_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2777,7 +2679,7 @@ func (x *ForwardingRulesResponse) String() string { func (*ForwardingRulesResponse) ProtoMessage() {} func (x *ForwardingRulesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[32] + mi := &file_daemon_proto_msgTypes[30] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2790,7 +2692,7 @@ func (x *ForwardingRulesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ForwardingRulesResponse.ProtoReflect.Descriptor instead. func (*ForwardingRulesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{32} + return file_daemon_proto_rawDescGZIP(), []int{30} } func (x *ForwardingRulesResponse) GetRules() []*ForwardingRule { @@ -2813,7 +2715,7 @@ type DebugBundleRequest struct { func (x *DebugBundleRequest) Reset() { *x = DebugBundleRequest{} - mi := &file_daemon_proto_msgTypes[33] + mi := &file_daemon_proto_msgTypes[31] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2825,7 +2727,7 @@ func (x *DebugBundleRequest) String() string { func (*DebugBundleRequest) ProtoMessage() {} func (x *DebugBundleRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[33] + mi := &file_daemon_proto_msgTypes[31] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2838,7 +2740,7 @@ func (x *DebugBundleRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DebugBundleRequest.ProtoReflect.Descriptor instead. func (*DebugBundleRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{33} + return file_daemon_proto_rawDescGZIP(), []int{31} } func (x *DebugBundleRequest) GetAnonymize() bool { @@ -2880,7 +2782,7 @@ type DebugBundleResponse struct { func (x *DebugBundleResponse) Reset() { *x = DebugBundleResponse{} - mi := &file_daemon_proto_msgTypes[34] + mi := &file_daemon_proto_msgTypes[32] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2892,7 +2794,7 @@ func (x *DebugBundleResponse) String() string { func (*DebugBundleResponse) ProtoMessage() {} func (x *DebugBundleResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[34] + mi := &file_daemon_proto_msgTypes[32] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2905,7 +2807,7 @@ func (x *DebugBundleResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DebugBundleResponse.ProtoReflect.Descriptor instead. func (*DebugBundleResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{34} + return file_daemon_proto_rawDescGZIP(), []int{32} } func (x *DebugBundleResponse) GetPath() string { @@ -2937,7 +2839,7 @@ type GetLogLevelRequest struct { func (x *GetLogLevelRequest) Reset() { *x = GetLogLevelRequest{} - mi := &file_daemon_proto_msgTypes[35] + mi := &file_daemon_proto_msgTypes[33] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2949,7 +2851,7 @@ func (x *GetLogLevelRequest) String() string { func (*GetLogLevelRequest) ProtoMessage() {} func (x *GetLogLevelRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[35] + mi := &file_daemon_proto_msgTypes[33] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2962,7 +2864,7 @@ func (x *GetLogLevelRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetLogLevelRequest.ProtoReflect.Descriptor instead. func (*GetLogLevelRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{35} + return file_daemon_proto_rawDescGZIP(), []int{33} } type GetLogLevelResponse struct { @@ -2974,7 +2876,7 @@ type GetLogLevelResponse struct { func (x *GetLogLevelResponse) Reset() { *x = GetLogLevelResponse{} - mi := &file_daemon_proto_msgTypes[36] + mi := &file_daemon_proto_msgTypes[34] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2986,7 +2888,7 @@ func (x *GetLogLevelResponse) String() string { func (*GetLogLevelResponse) ProtoMessage() {} func (x *GetLogLevelResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[36] + mi := &file_daemon_proto_msgTypes[34] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2999,7 +2901,7 @@ func (x *GetLogLevelResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetLogLevelResponse.ProtoReflect.Descriptor instead. func (*GetLogLevelResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{36} + return file_daemon_proto_rawDescGZIP(), []int{34} } func (x *GetLogLevelResponse) GetLevel() LogLevel { @@ -3018,7 +2920,7 @@ type SetLogLevelRequest struct { func (x *SetLogLevelRequest) Reset() { *x = SetLogLevelRequest{} - mi := &file_daemon_proto_msgTypes[37] + mi := &file_daemon_proto_msgTypes[35] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3030,7 +2932,7 @@ func (x *SetLogLevelRequest) String() string { func (*SetLogLevelRequest) ProtoMessage() {} func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[37] + mi := &file_daemon_proto_msgTypes[35] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3043,7 +2945,7 @@ func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SetLogLevelRequest.ProtoReflect.Descriptor instead. func (*SetLogLevelRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{37} + return file_daemon_proto_rawDescGZIP(), []int{35} } func (x *SetLogLevelRequest) GetLevel() LogLevel { @@ -3061,7 +2963,7 @@ type SetLogLevelResponse struct { func (x *SetLogLevelResponse) Reset() { *x = SetLogLevelResponse{} - mi := &file_daemon_proto_msgTypes[38] + mi := &file_daemon_proto_msgTypes[36] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3073,7 +2975,7 @@ func (x *SetLogLevelResponse) String() string { func (*SetLogLevelResponse) ProtoMessage() {} func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[38] + mi := &file_daemon_proto_msgTypes[36] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3086,7 +2988,7 @@ func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SetLogLevelResponse.ProtoReflect.Descriptor instead. func (*SetLogLevelResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{38} + return file_daemon_proto_rawDescGZIP(), []int{36} } // State represents a daemon state entry @@ -3099,7 +3001,7 @@ type State struct { func (x *State) Reset() { *x = State{} - mi := &file_daemon_proto_msgTypes[39] + mi := &file_daemon_proto_msgTypes[37] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3111,7 +3013,7 @@ func (x *State) String() string { func (*State) ProtoMessage() {} func (x *State) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[39] + mi := &file_daemon_proto_msgTypes[37] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3124,7 +3026,7 @@ func (x *State) ProtoReflect() protoreflect.Message { // Deprecated: Use State.ProtoReflect.Descriptor instead. func (*State) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{39} + return file_daemon_proto_rawDescGZIP(), []int{37} } func (x *State) GetName() string { @@ -3143,7 +3045,7 @@ type ListStatesRequest struct { func (x *ListStatesRequest) Reset() { *x = ListStatesRequest{} - mi := &file_daemon_proto_msgTypes[40] + mi := &file_daemon_proto_msgTypes[38] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3155,7 +3057,7 @@ func (x *ListStatesRequest) String() string { func (*ListStatesRequest) ProtoMessage() {} func (x *ListStatesRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[40] + mi := &file_daemon_proto_msgTypes[38] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3168,7 +3070,7 @@ func (x *ListStatesRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListStatesRequest.ProtoReflect.Descriptor instead. func (*ListStatesRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{40} + return file_daemon_proto_rawDescGZIP(), []int{38} } // ListStatesResponse contains a list of states @@ -3181,7 +3083,7 @@ type ListStatesResponse struct { func (x *ListStatesResponse) Reset() { *x = ListStatesResponse{} - mi := &file_daemon_proto_msgTypes[41] + mi := &file_daemon_proto_msgTypes[39] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3193,7 +3095,7 @@ func (x *ListStatesResponse) String() string { func (*ListStatesResponse) ProtoMessage() {} func (x *ListStatesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[41] + mi := &file_daemon_proto_msgTypes[39] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3206,7 +3108,7 @@ func (x *ListStatesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListStatesResponse.ProtoReflect.Descriptor instead. func (*ListStatesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{41} + return file_daemon_proto_rawDescGZIP(), []int{39} } func (x *ListStatesResponse) GetStates() []*State { @@ -3227,7 +3129,7 @@ type CleanStateRequest struct { func (x *CleanStateRequest) Reset() { *x = CleanStateRequest{} - mi := &file_daemon_proto_msgTypes[42] + mi := &file_daemon_proto_msgTypes[40] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3239,7 +3141,7 @@ func (x *CleanStateRequest) String() string { func (*CleanStateRequest) ProtoMessage() {} func (x *CleanStateRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[42] + mi := &file_daemon_proto_msgTypes[40] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3252,7 +3154,7 @@ func (x *CleanStateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use CleanStateRequest.ProtoReflect.Descriptor instead. func (*CleanStateRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{42} + return file_daemon_proto_rawDescGZIP(), []int{40} } func (x *CleanStateRequest) GetStateName() string { @@ -3279,7 +3181,7 @@ type CleanStateResponse struct { func (x *CleanStateResponse) Reset() { *x = CleanStateResponse{} - mi := &file_daemon_proto_msgTypes[43] + mi := &file_daemon_proto_msgTypes[41] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3291,7 +3193,7 @@ func (x *CleanStateResponse) String() string { func (*CleanStateResponse) ProtoMessage() {} func (x *CleanStateResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[43] + mi := &file_daemon_proto_msgTypes[41] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3304,7 +3206,7 @@ func (x *CleanStateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use CleanStateResponse.ProtoReflect.Descriptor instead. func (*CleanStateResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{43} + return file_daemon_proto_rawDescGZIP(), []int{41} } func (x *CleanStateResponse) GetCleanedStates() int32 { @@ -3325,7 +3227,7 @@ type DeleteStateRequest struct { func (x *DeleteStateRequest) Reset() { *x = DeleteStateRequest{} - mi := &file_daemon_proto_msgTypes[44] + mi := &file_daemon_proto_msgTypes[42] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3337,7 +3239,7 @@ func (x *DeleteStateRequest) String() string { func (*DeleteStateRequest) ProtoMessage() {} func (x *DeleteStateRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[44] + mi := &file_daemon_proto_msgTypes[42] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3350,7 +3252,7 @@ func (x *DeleteStateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DeleteStateRequest.ProtoReflect.Descriptor instead. func (*DeleteStateRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{44} + return file_daemon_proto_rawDescGZIP(), []int{42} } func (x *DeleteStateRequest) GetStateName() string { @@ -3377,7 +3279,7 @@ type DeleteStateResponse struct { func (x *DeleteStateResponse) Reset() { *x = DeleteStateResponse{} - mi := &file_daemon_proto_msgTypes[45] + mi := &file_daemon_proto_msgTypes[43] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3389,7 +3291,7 @@ func (x *DeleteStateResponse) String() string { func (*DeleteStateResponse) ProtoMessage() {} func (x *DeleteStateResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[45] + mi := &file_daemon_proto_msgTypes[43] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3402,7 +3304,7 @@ func (x *DeleteStateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DeleteStateResponse.ProtoReflect.Descriptor instead. func (*DeleteStateResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{45} + return file_daemon_proto_rawDescGZIP(), []int{43} } func (x *DeleteStateResponse) GetDeletedStates() int32 { @@ -3421,7 +3323,7 @@ type SetSyncResponsePersistenceRequest struct { func (x *SetSyncResponsePersistenceRequest) Reset() { *x = SetSyncResponsePersistenceRequest{} - mi := &file_daemon_proto_msgTypes[46] + mi := &file_daemon_proto_msgTypes[44] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3433,7 +3335,7 @@ func (x *SetSyncResponsePersistenceRequest) String() string { func (*SetSyncResponsePersistenceRequest) ProtoMessage() {} func (x *SetSyncResponsePersistenceRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[46] + mi := &file_daemon_proto_msgTypes[44] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3446,7 +3348,7 @@ func (x *SetSyncResponsePersistenceRequest) ProtoReflect() protoreflect.Message // Deprecated: Use SetSyncResponsePersistenceRequest.ProtoReflect.Descriptor instead. func (*SetSyncResponsePersistenceRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{46} + return file_daemon_proto_rawDescGZIP(), []int{44} } func (x *SetSyncResponsePersistenceRequest) GetEnabled() bool { @@ -3464,7 +3366,7 @@ type SetSyncResponsePersistenceResponse struct { func (x *SetSyncResponsePersistenceResponse) Reset() { *x = SetSyncResponsePersistenceResponse{} - mi := &file_daemon_proto_msgTypes[47] + mi := &file_daemon_proto_msgTypes[45] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3476,7 +3378,7 @@ func (x *SetSyncResponsePersistenceResponse) String() string { func (*SetSyncResponsePersistenceResponse) ProtoMessage() {} func (x *SetSyncResponsePersistenceResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[47] + mi := &file_daemon_proto_msgTypes[45] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3489,7 +3391,7 @@ func (x *SetSyncResponsePersistenceResponse) ProtoReflect() protoreflect.Message // Deprecated: Use SetSyncResponsePersistenceResponse.ProtoReflect.Descriptor instead. func (*SetSyncResponsePersistenceResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{47} + return file_daemon_proto_rawDescGZIP(), []int{45} } type TCPFlags struct { @@ -3506,7 +3408,7 @@ type TCPFlags struct { func (x *TCPFlags) Reset() { *x = TCPFlags{} - mi := &file_daemon_proto_msgTypes[48] + mi := &file_daemon_proto_msgTypes[46] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3518,7 +3420,7 @@ func (x *TCPFlags) String() string { func (*TCPFlags) ProtoMessage() {} func (x *TCPFlags) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[48] + mi := &file_daemon_proto_msgTypes[46] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3531,7 +3433,7 @@ func (x *TCPFlags) ProtoReflect() protoreflect.Message { // Deprecated: Use TCPFlags.ProtoReflect.Descriptor instead. func (*TCPFlags) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{48} + return file_daemon_proto_rawDescGZIP(), []int{46} } func (x *TCPFlags) GetSyn() bool { @@ -3593,7 +3495,7 @@ type TracePacketRequest struct { func (x *TracePacketRequest) Reset() { *x = TracePacketRequest{} - mi := &file_daemon_proto_msgTypes[49] + mi := &file_daemon_proto_msgTypes[47] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3605,7 +3507,7 @@ func (x *TracePacketRequest) String() string { func (*TracePacketRequest) ProtoMessage() {} func (x *TracePacketRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[49] + mi := &file_daemon_proto_msgTypes[47] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3618,7 +3520,7 @@ func (x *TracePacketRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use TracePacketRequest.ProtoReflect.Descriptor instead. func (*TracePacketRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{49} + return file_daemon_proto_rawDescGZIP(), []int{47} } func (x *TracePacketRequest) GetSourceIp() string { @@ -3696,7 +3598,7 @@ type TraceStage struct { func (x *TraceStage) Reset() { *x = TraceStage{} - mi := &file_daemon_proto_msgTypes[50] + mi := &file_daemon_proto_msgTypes[48] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3708,7 +3610,7 @@ func (x *TraceStage) String() string { func (*TraceStage) ProtoMessage() {} func (x *TraceStage) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[50] + mi := &file_daemon_proto_msgTypes[48] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3721,7 +3623,7 @@ func (x *TraceStage) ProtoReflect() protoreflect.Message { // Deprecated: Use TraceStage.ProtoReflect.Descriptor instead. func (*TraceStage) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{50} + return file_daemon_proto_rawDescGZIP(), []int{48} } func (x *TraceStage) GetName() string { @@ -3762,7 +3664,7 @@ type TracePacketResponse struct { func (x *TracePacketResponse) Reset() { *x = TracePacketResponse{} - mi := &file_daemon_proto_msgTypes[51] + mi := &file_daemon_proto_msgTypes[49] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3774,7 +3676,7 @@ func (x *TracePacketResponse) String() string { func (*TracePacketResponse) ProtoMessage() {} func (x *TracePacketResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[51] + mi := &file_daemon_proto_msgTypes[49] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3787,7 +3689,7 @@ func (x *TracePacketResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use TracePacketResponse.ProtoReflect.Descriptor instead. func (*TracePacketResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{51} + return file_daemon_proto_rawDescGZIP(), []int{49} } func (x *TracePacketResponse) GetStages() []*TraceStage { @@ -3812,7 +3714,7 @@ type SubscribeRequest struct { func (x *SubscribeRequest) Reset() { *x = SubscribeRequest{} - mi := &file_daemon_proto_msgTypes[52] + mi := &file_daemon_proto_msgTypes[50] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3824,7 +3726,7 @@ func (x *SubscribeRequest) String() string { func (*SubscribeRequest) ProtoMessage() {} func (x *SubscribeRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[52] + mi := &file_daemon_proto_msgTypes[50] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3837,7 +3739,7 @@ func (x *SubscribeRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SubscribeRequest.ProtoReflect.Descriptor instead. func (*SubscribeRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{52} + return file_daemon_proto_rawDescGZIP(), []int{50} } type SystemEvent struct { @@ -3855,7 +3757,7 @@ type SystemEvent struct { func (x *SystemEvent) Reset() { *x = SystemEvent{} - mi := &file_daemon_proto_msgTypes[53] + mi := &file_daemon_proto_msgTypes[51] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3867,7 +3769,7 @@ func (x *SystemEvent) String() string { func (*SystemEvent) ProtoMessage() {} func (x *SystemEvent) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[53] + mi := &file_daemon_proto_msgTypes[51] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3880,7 +3782,7 @@ func (x *SystemEvent) ProtoReflect() protoreflect.Message { // Deprecated: Use SystemEvent.ProtoReflect.Descriptor instead. func (*SystemEvent) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{53} + return file_daemon_proto_rawDescGZIP(), []int{51} } func (x *SystemEvent) GetId() string { @@ -3940,7 +3842,7 @@ type GetEventsRequest struct { func (x *GetEventsRequest) Reset() { *x = GetEventsRequest{} - mi := &file_daemon_proto_msgTypes[54] + mi := &file_daemon_proto_msgTypes[52] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3952,7 +3854,7 @@ func (x *GetEventsRequest) String() string { func (*GetEventsRequest) ProtoMessage() {} func (x *GetEventsRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[54] + mi := &file_daemon_proto_msgTypes[52] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3965,7 +3867,7 @@ func (x *GetEventsRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetEventsRequest.ProtoReflect.Descriptor instead. func (*GetEventsRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{54} + return file_daemon_proto_rawDescGZIP(), []int{52} } type GetEventsResponse struct { @@ -3977,7 +3879,7 @@ type GetEventsResponse struct { func (x *GetEventsResponse) Reset() { *x = GetEventsResponse{} - mi := &file_daemon_proto_msgTypes[55] + mi := &file_daemon_proto_msgTypes[53] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3989,7 +3891,7 @@ func (x *GetEventsResponse) String() string { func (*GetEventsResponse) ProtoMessage() {} func (x *GetEventsResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[55] + mi := &file_daemon_proto_msgTypes[53] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4002,7 +3904,7 @@ func (x *GetEventsResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetEventsResponse.ProtoReflect.Descriptor instead. func (*GetEventsResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{55} + return file_daemon_proto_rawDescGZIP(), []int{53} } func (x *GetEventsResponse) GetEvents() []*SystemEvent { @@ -4022,7 +3924,7 @@ type SwitchProfileRequest struct { func (x *SwitchProfileRequest) Reset() { *x = SwitchProfileRequest{} - mi := &file_daemon_proto_msgTypes[56] + mi := &file_daemon_proto_msgTypes[54] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4034,7 +3936,7 @@ func (x *SwitchProfileRequest) String() string { func (*SwitchProfileRequest) ProtoMessage() {} func (x *SwitchProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[56] + mi := &file_daemon_proto_msgTypes[54] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4047,7 +3949,7 @@ func (x *SwitchProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SwitchProfileRequest.ProtoReflect.Descriptor instead. func (*SwitchProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{56} + return file_daemon_proto_rawDescGZIP(), []int{54} } func (x *SwitchProfileRequest) GetProfileName() string { @@ -4072,7 +3974,7 @@ type SwitchProfileResponse struct { func (x *SwitchProfileResponse) Reset() { *x = SwitchProfileResponse{} - mi := &file_daemon_proto_msgTypes[57] + mi := &file_daemon_proto_msgTypes[55] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4084,7 +3986,7 @@ func (x *SwitchProfileResponse) String() string { func (*SwitchProfileResponse) ProtoMessage() {} func (x *SwitchProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[57] + mi := &file_daemon_proto_msgTypes[55] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4097,7 +3999,7 @@ func (x *SwitchProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SwitchProfileResponse.ProtoReflect.Descriptor instead. func (*SwitchProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{57} + return file_daemon_proto_rawDescGZIP(), []int{55} } type SetConfigRequest struct { @@ -4139,13 +4041,14 @@ type SetConfigRequest struct { EnableSSHRemotePortForwarding *bool `protobuf:"varint,32,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"` DisableSSHAuth *bool `protobuf:"varint,33,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"` SshJWTCacheTTL *int32 `protobuf:"varint,34,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"` + DisableIpv6 *bool `protobuf:"varint,35,opt,name=disable_ipv6,json=disableIpv6,proto3,oneof" json:"disable_ipv6,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *SetConfigRequest) Reset() { *x = SetConfigRequest{} - mi := &file_daemon_proto_msgTypes[58] + mi := &file_daemon_proto_msgTypes[56] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4157,7 +4060,7 @@ func (x *SetConfigRequest) String() string { func (*SetConfigRequest) ProtoMessage() {} func (x *SetConfigRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[58] + mi := &file_daemon_proto_msgTypes[56] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4170,7 +4073,7 @@ func (x *SetConfigRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SetConfigRequest.ProtoReflect.Descriptor instead. func (*SetConfigRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{58} + return file_daemon_proto_rawDescGZIP(), []int{56} } func (x *SetConfigRequest) GetUsername() string { @@ -4411,6 +4314,13 @@ func (x *SetConfigRequest) GetSshJWTCacheTTL() int32 { return 0 } +func (x *SetConfigRequest) GetDisableIpv6() bool { + if x != nil && x.DisableIpv6 != nil { + return *x.DisableIpv6 + } + return false +} + type SetConfigResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -4419,7 +4329,7 @@ type SetConfigResponse struct { func (x *SetConfigResponse) Reset() { *x = SetConfigResponse{} - mi := &file_daemon_proto_msgTypes[59] + mi := &file_daemon_proto_msgTypes[57] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4431,7 +4341,7 @@ func (x *SetConfigResponse) String() string { func (*SetConfigResponse) ProtoMessage() {} func (x *SetConfigResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[59] + mi := &file_daemon_proto_msgTypes[57] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4444,7 +4354,7 @@ func (x *SetConfigResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SetConfigResponse.ProtoReflect.Descriptor instead. func (*SetConfigResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{59} + return file_daemon_proto_rawDescGZIP(), []int{57} } type AddProfileRequest struct { @@ -4457,7 +4367,7 @@ type AddProfileRequest struct { func (x *AddProfileRequest) Reset() { *x = AddProfileRequest{} - mi := &file_daemon_proto_msgTypes[60] + mi := &file_daemon_proto_msgTypes[58] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4469,7 +4379,7 @@ func (x *AddProfileRequest) String() string { func (*AddProfileRequest) ProtoMessage() {} func (x *AddProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[60] + mi := &file_daemon_proto_msgTypes[58] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4482,7 +4392,7 @@ func (x *AddProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use AddProfileRequest.ProtoReflect.Descriptor instead. func (*AddProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{60} + return file_daemon_proto_rawDescGZIP(), []int{58} } func (x *AddProfileRequest) GetUsername() string { @@ -4507,7 +4417,7 @@ type AddProfileResponse struct { func (x *AddProfileResponse) Reset() { *x = AddProfileResponse{} - mi := &file_daemon_proto_msgTypes[61] + mi := &file_daemon_proto_msgTypes[59] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4519,7 +4429,7 @@ func (x *AddProfileResponse) String() string { func (*AddProfileResponse) ProtoMessage() {} func (x *AddProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[61] + mi := &file_daemon_proto_msgTypes[59] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4532,7 +4442,7 @@ func (x *AddProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use AddProfileResponse.ProtoReflect.Descriptor instead. func (*AddProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{61} + return file_daemon_proto_rawDescGZIP(), []int{59} } type RemoveProfileRequest struct { @@ -4545,7 +4455,7 @@ type RemoveProfileRequest struct { func (x *RemoveProfileRequest) Reset() { *x = RemoveProfileRequest{} - mi := &file_daemon_proto_msgTypes[62] + mi := &file_daemon_proto_msgTypes[60] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4557,7 +4467,7 @@ func (x *RemoveProfileRequest) String() string { func (*RemoveProfileRequest) ProtoMessage() {} func (x *RemoveProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[62] + mi := &file_daemon_proto_msgTypes[60] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4570,7 +4480,7 @@ func (x *RemoveProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use RemoveProfileRequest.ProtoReflect.Descriptor instead. func (*RemoveProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{62} + return file_daemon_proto_rawDescGZIP(), []int{60} } func (x *RemoveProfileRequest) GetUsername() string { @@ -4595,7 +4505,7 @@ type RemoveProfileResponse struct { func (x *RemoveProfileResponse) Reset() { *x = RemoveProfileResponse{} - mi := &file_daemon_proto_msgTypes[63] + mi := &file_daemon_proto_msgTypes[61] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4607,7 +4517,7 @@ func (x *RemoveProfileResponse) String() string { func (*RemoveProfileResponse) ProtoMessage() {} func (x *RemoveProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[63] + mi := &file_daemon_proto_msgTypes[61] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4620,7 +4530,7 @@ func (x *RemoveProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RemoveProfileResponse.ProtoReflect.Descriptor instead. func (*RemoveProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{63} + return file_daemon_proto_rawDescGZIP(), []int{61} } type ListProfilesRequest struct { @@ -4632,7 +4542,7 @@ type ListProfilesRequest struct { func (x *ListProfilesRequest) Reset() { *x = ListProfilesRequest{} - mi := &file_daemon_proto_msgTypes[64] + mi := &file_daemon_proto_msgTypes[62] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4644,7 +4554,7 @@ func (x *ListProfilesRequest) String() string { func (*ListProfilesRequest) ProtoMessage() {} func (x *ListProfilesRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[64] + mi := &file_daemon_proto_msgTypes[62] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4657,7 +4567,7 @@ func (x *ListProfilesRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListProfilesRequest.ProtoReflect.Descriptor instead. func (*ListProfilesRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{64} + return file_daemon_proto_rawDescGZIP(), []int{62} } func (x *ListProfilesRequest) GetUsername() string { @@ -4676,7 +4586,7 @@ type ListProfilesResponse struct { func (x *ListProfilesResponse) Reset() { *x = ListProfilesResponse{} - mi := &file_daemon_proto_msgTypes[65] + mi := &file_daemon_proto_msgTypes[63] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4688,7 +4598,7 @@ func (x *ListProfilesResponse) String() string { func (*ListProfilesResponse) ProtoMessage() {} func (x *ListProfilesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[65] + mi := &file_daemon_proto_msgTypes[63] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4701,7 +4611,7 @@ func (x *ListProfilesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListProfilesResponse.ProtoReflect.Descriptor instead. func (*ListProfilesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{65} + return file_daemon_proto_rawDescGZIP(), []int{63} } func (x *ListProfilesResponse) GetProfiles() []*Profile { @@ -4721,7 +4631,7 @@ type Profile struct { func (x *Profile) Reset() { *x = Profile{} - mi := &file_daemon_proto_msgTypes[66] + mi := &file_daemon_proto_msgTypes[64] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4733,7 +4643,7 @@ func (x *Profile) String() string { func (*Profile) ProtoMessage() {} func (x *Profile) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[66] + mi := &file_daemon_proto_msgTypes[64] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4746,7 +4656,7 @@ func (x *Profile) ProtoReflect() protoreflect.Message { // Deprecated: Use Profile.ProtoReflect.Descriptor instead. func (*Profile) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{66} + return file_daemon_proto_rawDescGZIP(), []int{64} } func (x *Profile) GetName() string { @@ -4771,7 +4681,7 @@ type GetActiveProfileRequest struct { func (x *GetActiveProfileRequest) Reset() { *x = GetActiveProfileRequest{} - mi := &file_daemon_proto_msgTypes[67] + mi := &file_daemon_proto_msgTypes[65] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4783,7 +4693,7 @@ func (x *GetActiveProfileRequest) String() string { func (*GetActiveProfileRequest) ProtoMessage() {} func (x *GetActiveProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[67] + mi := &file_daemon_proto_msgTypes[65] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4796,7 +4706,7 @@ func (x *GetActiveProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetActiveProfileRequest.ProtoReflect.Descriptor instead. func (*GetActiveProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{67} + return file_daemon_proto_rawDescGZIP(), []int{65} } type GetActiveProfileResponse struct { @@ -4809,7 +4719,7 @@ type GetActiveProfileResponse struct { func (x *GetActiveProfileResponse) Reset() { *x = GetActiveProfileResponse{} - mi := &file_daemon_proto_msgTypes[68] + mi := &file_daemon_proto_msgTypes[66] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4821,7 +4731,7 @@ func (x *GetActiveProfileResponse) String() string { func (*GetActiveProfileResponse) ProtoMessage() {} func (x *GetActiveProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[68] + mi := &file_daemon_proto_msgTypes[66] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4834,7 +4744,7 @@ func (x *GetActiveProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetActiveProfileResponse.ProtoReflect.Descriptor instead. func (*GetActiveProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{68} + return file_daemon_proto_rawDescGZIP(), []int{66} } func (x *GetActiveProfileResponse) GetProfileName() string { @@ -4861,7 +4771,7 @@ type LogoutRequest struct { func (x *LogoutRequest) Reset() { *x = LogoutRequest{} - mi := &file_daemon_proto_msgTypes[69] + mi := &file_daemon_proto_msgTypes[67] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4873,7 +4783,7 @@ func (x *LogoutRequest) String() string { func (*LogoutRequest) ProtoMessage() {} func (x *LogoutRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[69] + mi := &file_daemon_proto_msgTypes[67] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4886,7 +4796,7 @@ func (x *LogoutRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use LogoutRequest.ProtoReflect.Descriptor instead. func (*LogoutRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{69} + return file_daemon_proto_rawDescGZIP(), []int{67} } func (x *LogoutRequest) GetProfileName() string { @@ -4911,7 +4821,7 @@ type LogoutResponse struct { func (x *LogoutResponse) Reset() { *x = LogoutResponse{} - mi := &file_daemon_proto_msgTypes[70] + mi := &file_daemon_proto_msgTypes[68] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4923,7 +4833,7 @@ func (x *LogoutResponse) String() string { func (*LogoutResponse) ProtoMessage() {} func (x *LogoutResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[70] + mi := &file_daemon_proto_msgTypes[68] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4936,7 +4846,7 @@ func (x *LogoutResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use LogoutResponse.ProtoReflect.Descriptor instead. func (*LogoutResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{70} + return file_daemon_proto_rawDescGZIP(), []int{68} } type GetFeaturesRequest struct { @@ -4947,7 +4857,7 @@ type GetFeaturesRequest struct { func (x *GetFeaturesRequest) Reset() { *x = GetFeaturesRequest{} - mi := &file_daemon_proto_msgTypes[71] + mi := &file_daemon_proto_msgTypes[69] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4959,7 +4869,7 @@ func (x *GetFeaturesRequest) String() string { func (*GetFeaturesRequest) ProtoMessage() {} func (x *GetFeaturesRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[71] + mi := &file_daemon_proto_msgTypes[69] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4972,7 +4882,7 @@ func (x *GetFeaturesRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetFeaturesRequest.ProtoReflect.Descriptor instead. func (*GetFeaturesRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{71} + return file_daemon_proto_rawDescGZIP(), []int{69} } type GetFeaturesResponse struct { @@ -4986,7 +4896,7 @@ type GetFeaturesResponse struct { func (x *GetFeaturesResponse) Reset() { *x = GetFeaturesResponse{} - mi := &file_daemon_proto_msgTypes[72] + mi := &file_daemon_proto_msgTypes[70] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4998,7 +4908,7 @@ func (x *GetFeaturesResponse) String() string { func (*GetFeaturesResponse) ProtoMessage() {} func (x *GetFeaturesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[72] + mi := &file_daemon_proto_msgTypes[70] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5011,7 +4921,7 @@ func (x *GetFeaturesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetFeaturesResponse.ProtoReflect.Descriptor instead. func (*GetFeaturesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{72} + return file_daemon_proto_rawDescGZIP(), []int{70} } func (x *GetFeaturesResponse) GetDisableProfiles() bool { @@ -5043,7 +4953,7 @@ type TriggerUpdateRequest struct { func (x *TriggerUpdateRequest) Reset() { *x = TriggerUpdateRequest{} - mi := &file_daemon_proto_msgTypes[73] + mi := &file_daemon_proto_msgTypes[71] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5055,7 +4965,7 @@ func (x *TriggerUpdateRequest) String() string { func (*TriggerUpdateRequest) ProtoMessage() {} func (x *TriggerUpdateRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[73] + mi := &file_daemon_proto_msgTypes[71] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5068,7 +4978,7 @@ func (x *TriggerUpdateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use TriggerUpdateRequest.ProtoReflect.Descriptor instead. func (*TriggerUpdateRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{73} + return file_daemon_proto_rawDescGZIP(), []int{71} } type TriggerUpdateResponse struct { @@ -5081,7 +4991,7 @@ type TriggerUpdateResponse struct { func (x *TriggerUpdateResponse) Reset() { *x = TriggerUpdateResponse{} - mi := &file_daemon_proto_msgTypes[74] + mi := &file_daemon_proto_msgTypes[72] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5093,7 +5003,7 @@ func (x *TriggerUpdateResponse) String() string { func (*TriggerUpdateResponse) ProtoMessage() {} func (x *TriggerUpdateResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[74] + mi := &file_daemon_proto_msgTypes[72] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5106,7 +5016,7 @@ func (x *TriggerUpdateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use TriggerUpdateResponse.ProtoReflect.Descriptor instead. func (*TriggerUpdateResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{74} + return file_daemon_proto_rawDescGZIP(), []int{72} } func (x *TriggerUpdateResponse) GetSuccess() bool { @@ -5134,7 +5044,7 @@ type GetPeerSSHHostKeyRequest struct { func (x *GetPeerSSHHostKeyRequest) Reset() { *x = GetPeerSSHHostKeyRequest{} - mi := &file_daemon_proto_msgTypes[75] + mi := &file_daemon_proto_msgTypes[73] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5146,7 +5056,7 @@ func (x *GetPeerSSHHostKeyRequest) String() string { func (*GetPeerSSHHostKeyRequest) ProtoMessage() {} func (x *GetPeerSSHHostKeyRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[75] + mi := &file_daemon_proto_msgTypes[73] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5159,7 +5069,7 @@ func (x *GetPeerSSHHostKeyRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetPeerSSHHostKeyRequest.ProtoReflect.Descriptor instead. func (*GetPeerSSHHostKeyRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{75} + return file_daemon_proto_rawDescGZIP(), []int{73} } func (x *GetPeerSSHHostKeyRequest) GetPeerAddress() string { @@ -5186,7 +5096,7 @@ type GetPeerSSHHostKeyResponse struct { func (x *GetPeerSSHHostKeyResponse) Reset() { *x = GetPeerSSHHostKeyResponse{} - mi := &file_daemon_proto_msgTypes[76] + mi := &file_daemon_proto_msgTypes[74] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5198,7 +5108,7 @@ func (x *GetPeerSSHHostKeyResponse) String() string { func (*GetPeerSSHHostKeyResponse) ProtoMessage() {} func (x *GetPeerSSHHostKeyResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[76] + mi := &file_daemon_proto_msgTypes[74] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5211,7 +5121,7 @@ func (x *GetPeerSSHHostKeyResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetPeerSSHHostKeyResponse.ProtoReflect.Descriptor instead. func (*GetPeerSSHHostKeyResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{76} + return file_daemon_proto_rawDescGZIP(), []int{74} } func (x *GetPeerSSHHostKeyResponse) GetSshHostKey() []byte { @@ -5253,7 +5163,7 @@ type RequestJWTAuthRequest struct { func (x *RequestJWTAuthRequest) Reset() { *x = RequestJWTAuthRequest{} - mi := &file_daemon_proto_msgTypes[77] + mi := &file_daemon_proto_msgTypes[75] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5265,7 +5175,7 @@ func (x *RequestJWTAuthRequest) String() string { func (*RequestJWTAuthRequest) ProtoMessage() {} func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[77] + mi := &file_daemon_proto_msgTypes[75] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5278,7 +5188,7 @@ func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use RequestJWTAuthRequest.ProtoReflect.Descriptor instead. func (*RequestJWTAuthRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{77} + return file_daemon_proto_rawDescGZIP(), []int{75} } func (x *RequestJWTAuthRequest) GetHint() string { @@ -5311,7 +5221,7 @@ type RequestJWTAuthResponse struct { func (x *RequestJWTAuthResponse) Reset() { *x = RequestJWTAuthResponse{} - mi := &file_daemon_proto_msgTypes[78] + mi := &file_daemon_proto_msgTypes[76] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5323,7 +5233,7 @@ func (x *RequestJWTAuthResponse) String() string { func (*RequestJWTAuthResponse) ProtoMessage() {} func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[78] + mi := &file_daemon_proto_msgTypes[76] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5336,7 +5246,7 @@ func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RequestJWTAuthResponse.ProtoReflect.Descriptor instead. func (*RequestJWTAuthResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{78} + return file_daemon_proto_rawDescGZIP(), []int{76} } func (x *RequestJWTAuthResponse) GetVerificationURI() string { @@ -5401,7 +5311,7 @@ type WaitJWTTokenRequest struct { func (x *WaitJWTTokenRequest) Reset() { *x = WaitJWTTokenRequest{} - mi := &file_daemon_proto_msgTypes[79] + mi := &file_daemon_proto_msgTypes[77] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5413,7 +5323,7 @@ func (x *WaitJWTTokenRequest) String() string { func (*WaitJWTTokenRequest) ProtoMessage() {} func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[79] + mi := &file_daemon_proto_msgTypes[77] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5426,7 +5336,7 @@ func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitJWTTokenRequest.ProtoReflect.Descriptor instead. func (*WaitJWTTokenRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{79} + return file_daemon_proto_rawDescGZIP(), []int{77} } func (x *WaitJWTTokenRequest) GetDeviceCode() string { @@ -5458,7 +5368,7 @@ type WaitJWTTokenResponse struct { func (x *WaitJWTTokenResponse) Reset() { *x = WaitJWTTokenResponse{} - mi := &file_daemon_proto_msgTypes[80] + mi := &file_daemon_proto_msgTypes[78] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5470,7 +5380,7 @@ func (x *WaitJWTTokenResponse) String() string { func (*WaitJWTTokenResponse) ProtoMessage() {} func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[80] + mi := &file_daemon_proto_msgTypes[78] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5483,7 +5393,7 @@ func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitJWTTokenResponse.ProtoReflect.Descriptor instead. func (*WaitJWTTokenResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{80} + return file_daemon_proto_rawDescGZIP(), []int{78} } func (x *WaitJWTTokenResponse) GetToken() string { @@ -5516,7 +5426,7 @@ type StartCPUProfileRequest struct { func (x *StartCPUProfileRequest) Reset() { *x = StartCPUProfileRequest{} - mi := &file_daemon_proto_msgTypes[81] + mi := &file_daemon_proto_msgTypes[79] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5528,7 +5438,7 @@ func (x *StartCPUProfileRequest) String() string { func (*StartCPUProfileRequest) ProtoMessage() {} func (x *StartCPUProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[81] + mi := &file_daemon_proto_msgTypes[79] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5541,7 +5451,7 @@ func (x *StartCPUProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use StartCPUProfileRequest.ProtoReflect.Descriptor instead. func (*StartCPUProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{81} + return file_daemon_proto_rawDescGZIP(), []int{79} } // StartCPUProfileResponse confirms CPU profiling has started @@ -5553,7 +5463,7 @@ type StartCPUProfileResponse struct { func (x *StartCPUProfileResponse) Reset() { *x = StartCPUProfileResponse{} - mi := &file_daemon_proto_msgTypes[82] + mi := &file_daemon_proto_msgTypes[80] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5565,7 +5475,7 @@ func (x *StartCPUProfileResponse) String() string { func (*StartCPUProfileResponse) ProtoMessage() {} func (x *StartCPUProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[82] + mi := &file_daemon_proto_msgTypes[80] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5578,7 +5488,7 @@ func (x *StartCPUProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use StartCPUProfileResponse.ProtoReflect.Descriptor instead. func (*StartCPUProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{82} + return file_daemon_proto_rawDescGZIP(), []int{80} } // StopCPUProfileRequest for stopping CPU profiling @@ -5590,7 +5500,7 @@ type StopCPUProfileRequest struct { func (x *StopCPUProfileRequest) Reset() { *x = StopCPUProfileRequest{} - mi := &file_daemon_proto_msgTypes[83] + mi := &file_daemon_proto_msgTypes[81] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5602,7 +5512,7 @@ func (x *StopCPUProfileRequest) String() string { func (*StopCPUProfileRequest) ProtoMessage() {} func (x *StopCPUProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[83] + mi := &file_daemon_proto_msgTypes[81] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5615,7 +5525,7 @@ func (x *StopCPUProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use StopCPUProfileRequest.ProtoReflect.Descriptor instead. func (*StopCPUProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{83} + return file_daemon_proto_rawDescGZIP(), []int{81} } // StopCPUProfileResponse confirms CPU profiling has stopped @@ -5627,7 +5537,7 @@ type StopCPUProfileResponse struct { func (x *StopCPUProfileResponse) Reset() { *x = StopCPUProfileResponse{} - mi := &file_daemon_proto_msgTypes[84] + mi := &file_daemon_proto_msgTypes[82] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5639,7 +5549,7 @@ func (x *StopCPUProfileResponse) String() string { func (*StopCPUProfileResponse) ProtoMessage() {} func (x *StopCPUProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[84] + mi := &file_daemon_proto_msgTypes[82] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5652,7 +5562,7 @@ func (x *StopCPUProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use StopCPUProfileResponse.ProtoReflect.Descriptor instead. func (*StopCPUProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{84} + return file_daemon_proto_rawDescGZIP(), []int{82} } type InstallerResultRequest struct { @@ -5663,7 +5573,7 @@ type InstallerResultRequest struct { func (x *InstallerResultRequest) Reset() { *x = InstallerResultRequest{} - mi := &file_daemon_proto_msgTypes[85] + mi := &file_daemon_proto_msgTypes[83] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5675,7 +5585,7 @@ func (x *InstallerResultRequest) String() string { func (*InstallerResultRequest) ProtoMessage() {} func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[85] + mi := &file_daemon_proto_msgTypes[83] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5688,7 +5598,7 @@ func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead. func (*InstallerResultRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{85} + return file_daemon_proto_rawDescGZIP(), []int{83} } type InstallerResultResponse struct { @@ -5701,7 +5611,7 @@ type InstallerResultResponse struct { func (x *InstallerResultResponse) Reset() { *x = InstallerResultResponse{} - mi := &file_daemon_proto_msgTypes[86] + mi := &file_daemon_proto_msgTypes[84] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5713,7 +5623,7 @@ func (x *InstallerResultResponse) String() string { func (*InstallerResultResponse) ProtoMessage() {} func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[86] + mi := &file_daemon_proto_msgTypes[84] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5726,7 +5636,7 @@ func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead. func (*InstallerResultResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{86} + return file_daemon_proto_rawDescGZIP(), []int{84} } func (x *InstallerResultResponse) GetSuccess() bool { @@ -5759,7 +5669,7 @@ type ExposeServiceRequest struct { func (x *ExposeServiceRequest) Reset() { *x = ExposeServiceRequest{} - mi := &file_daemon_proto_msgTypes[87] + mi := &file_daemon_proto_msgTypes[85] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5771,7 +5681,7 @@ func (x *ExposeServiceRequest) String() string { func (*ExposeServiceRequest) ProtoMessage() {} func (x *ExposeServiceRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[87] + mi := &file_daemon_proto_msgTypes[85] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5784,7 +5694,7 @@ func (x *ExposeServiceRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ExposeServiceRequest.ProtoReflect.Descriptor instead. func (*ExposeServiceRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{87} + return file_daemon_proto_rawDescGZIP(), []int{85} } func (x *ExposeServiceRequest) GetPort() uint32 { @@ -5855,7 +5765,7 @@ type ExposeServiceEvent struct { func (x *ExposeServiceEvent) Reset() { *x = ExposeServiceEvent{} - mi := &file_daemon_proto_msgTypes[88] + mi := &file_daemon_proto_msgTypes[86] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5867,7 +5777,7 @@ func (x *ExposeServiceEvent) String() string { func (*ExposeServiceEvent) ProtoMessage() {} func (x *ExposeServiceEvent) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[88] + mi := &file_daemon_proto_msgTypes[86] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5880,7 +5790,7 @@ func (x *ExposeServiceEvent) ProtoReflect() protoreflect.Message { // Deprecated: Use ExposeServiceEvent.ProtoReflect.Descriptor instead. func (*ExposeServiceEvent) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{88} + return file_daemon_proto_rawDescGZIP(), []int{86} } func (x *ExposeServiceEvent) GetEvent() isExposeServiceEvent_Event { @@ -5921,7 +5831,7 @@ type ExposeServiceReady struct { func (x *ExposeServiceReady) Reset() { *x = ExposeServiceReady{} - mi := &file_daemon_proto_msgTypes[89] + mi := &file_daemon_proto_msgTypes[87] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5933,7 +5843,7 @@ func (x *ExposeServiceReady) String() string { func (*ExposeServiceReady) ProtoMessage() {} func (x *ExposeServiceReady) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[89] + mi := &file_daemon_proto_msgTypes[87] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5946,7 +5856,7 @@ func (x *ExposeServiceReady) ProtoReflect() protoreflect.Message { // Deprecated: Use ExposeServiceReady.ProtoReflect.Descriptor instead. func (*ExposeServiceReady) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{89} + return file_daemon_proto_rawDescGZIP(), []int{87} } func (x *ExposeServiceReady) GetServiceName() string { @@ -5977,6 +5887,288 @@ func (x *ExposeServiceReady) GetPortAutoAssigned() bool { return false } +type StartCaptureRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + TextOutput bool `protobuf:"varint,1,opt,name=text_output,json=textOutput,proto3" json:"text_output,omitempty"` + SnapLen uint32 `protobuf:"varint,2,opt,name=snap_len,json=snapLen,proto3" json:"snap_len,omitempty"` + Duration *durationpb.Duration `protobuf:"bytes,3,opt,name=duration,proto3" json:"duration,omitempty"` + FilterExpr string `protobuf:"bytes,4,opt,name=filter_expr,json=filterExpr,proto3" json:"filter_expr,omitempty"` + Verbose bool `protobuf:"varint,5,opt,name=verbose,proto3" json:"verbose,omitempty"` + Ascii bool `protobuf:"varint,6,opt,name=ascii,proto3" json:"ascii,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StartCaptureRequest) Reset() { + *x = StartCaptureRequest{} + mi := &file_daemon_proto_msgTypes[88] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StartCaptureRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StartCaptureRequest) ProtoMessage() {} + +func (x *StartCaptureRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[88] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StartCaptureRequest.ProtoReflect.Descriptor instead. +func (*StartCaptureRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{88} +} + +func (x *StartCaptureRequest) GetTextOutput() bool { + if x != nil { + return x.TextOutput + } + return false +} + +func (x *StartCaptureRequest) GetSnapLen() uint32 { + if x != nil { + return x.SnapLen + } + return 0 +} + +func (x *StartCaptureRequest) GetDuration() *durationpb.Duration { + if x != nil { + return x.Duration + } + return nil +} + +func (x *StartCaptureRequest) GetFilterExpr() string { + if x != nil { + return x.FilterExpr + } + return "" +} + +func (x *StartCaptureRequest) GetVerbose() bool { + if x != nil { + return x.Verbose + } + return false +} + +func (x *StartCaptureRequest) GetAscii() bool { + if x != nil { + return x.Ascii + } + return false +} + +type CapturePacket struct { + state protoimpl.MessageState `protogen:"open.v1"` + Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *CapturePacket) Reset() { + *x = CapturePacket{} + mi := &file_daemon_proto_msgTypes[89] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *CapturePacket) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CapturePacket) ProtoMessage() {} + +func (x *CapturePacket) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[89] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CapturePacket.ProtoReflect.Descriptor instead. +func (*CapturePacket) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{89} +} + +func (x *CapturePacket) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + +type StartBundleCaptureRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // timeout auto-stops the capture after this duration. + // Clamped to a server-side maximum (10 minutes). Zero or unset defaults to the maximum. + Timeout *durationpb.Duration `protobuf:"bytes,1,opt,name=timeout,proto3" json:"timeout,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StartBundleCaptureRequest) Reset() { + *x = StartBundleCaptureRequest{} + mi := &file_daemon_proto_msgTypes[90] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StartBundleCaptureRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StartBundleCaptureRequest) ProtoMessage() {} + +func (x *StartBundleCaptureRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[90] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StartBundleCaptureRequest.ProtoReflect.Descriptor instead. +func (*StartBundleCaptureRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{90} +} + +func (x *StartBundleCaptureRequest) GetTimeout() *durationpb.Duration { + if x != nil { + return x.Timeout + } + return nil +} + +type StartBundleCaptureResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StartBundleCaptureResponse) Reset() { + *x = StartBundleCaptureResponse{} + mi := &file_daemon_proto_msgTypes[91] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StartBundleCaptureResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StartBundleCaptureResponse) ProtoMessage() {} + +func (x *StartBundleCaptureResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[91] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StartBundleCaptureResponse.ProtoReflect.Descriptor instead. +func (*StartBundleCaptureResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{91} +} + +type StopBundleCaptureRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StopBundleCaptureRequest) Reset() { + *x = StopBundleCaptureRequest{} + mi := &file_daemon_proto_msgTypes[92] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StopBundleCaptureRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StopBundleCaptureRequest) ProtoMessage() {} + +func (x *StopBundleCaptureRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[92] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StopBundleCaptureRequest.ProtoReflect.Descriptor instead. +func (*StopBundleCaptureRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{92} +} + +type StopBundleCaptureResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StopBundleCaptureResponse) Reset() { + *x = StopBundleCaptureResponse{} + mi := &file_daemon_proto_msgTypes[93] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StopBundleCaptureResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StopBundleCaptureResponse) ProtoMessage() {} + +func (x *StopBundleCaptureResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[93] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StopBundleCaptureResponse.ProtoReflect.Descriptor instead. +func (*StopBundleCaptureResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{93} +} + type PortInfo_Range struct { state protoimpl.MessageState `protogen:"open.v1"` Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"` @@ -5987,7 +6179,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} - mi := &file_daemon_proto_msgTypes[91] + mi := &file_daemon_proto_msgTypes[95] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5999,7 +6191,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[91] + mi := &file_daemon_proto_msgTypes[95] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -6012,7 +6204,7 @@ func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo_Range.ProtoReflect.Descriptor instead. func (*PortInfo_Range) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{30, 0} + return file_daemon_proto_rawDescGZIP(), []int{28, 0} } func (x *PortInfo_Range) GetStart() uint32 { @@ -6034,15 +6226,7 @@ var File_daemon_proto protoreflect.FileDescriptor const file_daemon_proto_rawDesc = "" + "\n" + "\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" + - "\fEmptyRequest\"\x7f\n" + - "\x12OSLifecycleRequest\x128\n" + - "\x04type\x18\x01 \x01(\x0e2$.daemon.OSLifecycleRequest.CycleTypeR\x04type\"/\n" + - "\tCycleType\x12\v\n" + - "\aUNKNOWN\x10\x00\x12\t\n" + - "\x05SLEEP\x10\x01\x12\n" + - "\n" + - "\x06WAKEUP\x10\x02\"\x15\n" + - "\x13OSLifecycleResponse\"\xb6\x12\n" + + "\fEmptyRequest\"\xef\x12\n" + "\fLoginRequest\x12\x1a\n" + "\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" + "\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" + @@ -6086,7 +6270,8 @@ const file_daemon_proto_rawDesc = "" + "\x1cenableSSHLocalPortForwarding\x18$ \x01(\bH\x17R\x1cenableSSHLocalPortForwarding\x88\x01\x01\x12I\n" + "\x1denableSSHRemotePortForwarding\x18% \x01(\bH\x18R\x1denableSSHRemotePortForwarding\x88\x01\x01\x12+\n" + "\x0edisableSSHAuth\x18& \x01(\bH\x19R\x0edisableSSHAuth\x88\x01\x01\x12+\n" + - "\x0esshJWTCacheTTL\x18' \x01(\x05H\x1aR\x0esshJWTCacheTTL\x88\x01\x01B\x13\n" + + "\x0esshJWTCacheTTL\x18' \x01(\x05H\x1aR\x0esshJWTCacheTTL\x88\x01\x01\x12&\n" + + "\fdisable_ipv6\x18( \x01(\bH\x1bR\vdisableIpv6\x88\x01\x01B\x13\n" + "\x11_rosenpassEnabledB\x10\n" + "\x0e_interfaceNameB\x10\n" + "\x0e_wireguardPortB\x17\n" + @@ -6113,7 +6298,8 @@ const file_daemon_proto_rawDesc = "" + "\x1d_enableSSHLocalPortForwardingB \n" + "\x1e_enableSSHRemotePortForwardingB\x11\n" + "\x0f_disableSSHAuthB\x11\n" + - "\x0f_sshJWTCacheTTL\"\xb5\x01\n" + + "\x0f_sshJWTCacheTTLB\x0f\n" + + "\r_disable_ipv6\"\xb5\x01\n" + "\rLoginResponse\x12$\n" + "\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" + "\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" + @@ -6146,7 +6332,7 @@ const file_daemon_proto_rawDesc = "" + "\fDownResponse\"P\n" + "\x10GetConfigRequest\x12 \n" + "\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" + - "\busername\x18\x02 \x01(\tR\busername\"\xdb\b\n" + + "\busername\x18\x02 \x01(\tR\busername\"\xfe\b\n" + "\x11GetConfigResponse\x12$\n" + "\rmanagementUrl\x18\x01 \x01(\tR\rmanagementUrl\x12\x1e\n" + "\n" + @@ -6177,7 +6363,8 @@ const file_daemon_proto_rawDesc = "" + "\x1cenableSSHLocalPortForwarding\x18\x16 \x01(\bR\x1cenableSSHLocalPortForwarding\x12D\n" + "\x1denableSSHRemotePortForwarding\x18\x17 \x01(\bR\x1denableSSHRemotePortForwarding\x12&\n" + "\x0edisableSSHAuth\x18\x19 \x01(\bR\x0edisableSSHAuth\x12&\n" + - "\x0esshJWTCacheTTL\x18\x1a \x01(\x05R\x0esshJWTCacheTTL\"\xfe\x05\n" + + "\x0esshJWTCacheTTL\x18\x1a \x01(\x05R\x0esshJWTCacheTTL\x12!\n" + + "\fdisable_ipv6\x18\x1b \x01(\bR\vdisableIpv6\"\x92\x06\n" + "\tPeerState\x12\x0e\n" + "\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" + "\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12\x1e\n" + @@ -6201,7 +6388,8 @@ const file_daemon_proto_rawDesc = "" + "\frelayAddress\x18\x12 \x01(\tR\frelayAddress\x12\x1e\n" + "\n" + "sshHostKey\x18\x13 \x01(\fR\n" + - "sshHostKey\"\xf0\x01\n" + + "sshHostKey\x12\x12\n" + + "\x04ipv6\x18\x14 \x01(\tR\x04ipv6\"\x84\x02\n" + "\x0eLocalPeerState\x12\x0e\n" + "\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" + "\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12(\n" + @@ -6209,7 +6397,8 @@ const file_daemon_proto_rawDesc = "" + "\x04fqdn\x18\x04 \x01(\tR\x04fqdn\x12*\n" + "\x10rosenpassEnabled\x18\x05 \x01(\bR\x10rosenpassEnabled\x120\n" + "\x13rosenpassPermissive\x18\x06 \x01(\bR\x13rosenpassPermissive\x12\x1a\n" + - "\bnetworks\x18\a \x03(\tR\bnetworks\"S\n" + + "\bnetworks\x18\a \x03(\tR\bnetworks\x12\x12\n" + + "\x04ipv6\x18\b \x01(\tR\x04ipv6\"S\n" + "\vSignalState\x12\x10\n" + "\x03URL\x18\x01 \x01(\tR\x03URL\x12\x1c\n" + "\tconnected\x18\x02 \x01(\bR\tconnected\x12\x14\n" + @@ -6390,7 +6579,7 @@ const file_daemon_proto_rawDesc = "" + "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" + "\f_profileNameB\v\n" + "\t_username\"\x17\n" + - "\x15SwitchProfileResponse\"\xdf\x10\n" + + "\x15SwitchProfileResponse\"\x98\x11\n" + "\x10SetConfigRequest\x12\x1a\n" + "\busername\x18\x01 \x01(\tR\busername\x12 \n" + "\vprofileName\x18\x02 \x01(\tR\vprofileName\x12$\n" + @@ -6429,7 +6618,8 @@ const file_daemon_proto_rawDesc = "" + "\x1cenableSSHLocalPortForwarding\x18\x1f \x01(\bH\x14R\x1cenableSSHLocalPortForwarding\x88\x01\x01\x12I\n" + "\x1denableSSHRemotePortForwarding\x18 \x01(\bH\x15R\x1denableSSHRemotePortForwarding\x88\x01\x01\x12+\n" + "\x0edisableSSHAuth\x18! \x01(\bH\x16R\x0edisableSSHAuth\x88\x01\x01\x12+\n" + - "\x0esshJWTCacheTTL\x18\" \x01(\x05H\x17R\x0esshJWTCacheTTL\x88\x01\x01B\x13\n" + + "\x0esshJWTCacheTTL\x18\" \x01(\x05H\x17R\x0esshJWTCacheTTL\x88\x01\x01\x12&\n" + + "\fdisable_ipv6\x18# \x01(\bH\x18R\vdisableIpv6\x88\x01\x01B\x13\n" + "\x11_rosenpassEnabledB\x10\n" + "\x0e_interfaceNameB\x10\n" + "\x0e_wireguardPortB\x17\n" + @@ -6453,7 +6643,8 @@ const file_daemon_proto_rawDesc = "" + "\x1d_enableSSHLocalPortForwardingB \n" + "\x1e_enableSSHRemotePortForwardingB\x11\n" + "\x0f_disableSSHAuthB\x11\n" + - "\x0f_sshJWTCacheTTL\"\x13\n" + + "\x0f_sshJWTCacheTTLB\x0f\n" + + "\r_disable_ipv6\"\x13\n" + "\x11SetConfigResponse\"Q\n" + "\x11AddProfileRequest\x12\x1a\n" + "\busername\x18\x01 \x01(\tR\busername\x12 \n" + @@ -6548,7 +6739,23 @@ const file_daemon_proto_rawDesc = "" + "\vservice_url\x18\x02 \x01(\tR\n" + "serviceUrl\x12\x16\n" + "\x06domain\x18\x03 \x01(\tR\x06domain\x12,\n" + - "\x12port_auto_assigned\x18\x04 \x01(\bR\x10portAutoAssigned*b\n" + + "\x12port_auto_assigned\x18\x04 \x01(\bR\x10portAutoAssigned\"\xd9\x01\n" + + "\x13StartCaptureRequest\x12\x1f\n" + + "\vtext_output\x18\x01 \x01(\bR\n" + + "textOutput\x12\x19\n" + + "\bsnap_len\x18\x02 \x01(\rR\asnapLen\x125\n" + + "\bduration\x18\x03 \x01(\v2\x19.google.protobuf.DurationR\bduration\x12\x1f\n" + + "\vfilter_expr\x18\x04 \x01(\tR\n" + + "filterExpr\x12\x18\n" + + "\averbose\x18\x05 \x01(\bR\averbose\x12\x14\n" + + "\x05ascii\x18\x06 \x01(\bR\x05ascii\"#\n" + + "\rCapturePacket\x12\x12\n" + + "\x04data\x18\x01 \x01(\fR\x04data\"P\n" + + "\x19StartBundleCaptureRequest\x123\n" + + "\atimeout\x18\x01 \x01(\v2\x19.google.protobuf.DurationR\atimeout\"\x1c\n" + + "\x1aStartBundleCaptureResponse\"\x1a\n" + + "\x18StopBundleCaptureRequest\"\x1b\n" + + "\x19StopBundleCaptureResponse*b\n" + "\bLogLevel\x12\v\n" + "\aUNKNOWN\x10\x00\x12\t\n" + "\x05PANIC\x10\x01\x12\t\n" + @@ -6566,7 +6773,7 @@ const file_daemon_proto_rawDesc = "" + "\n" + "EXPOSE_UDP\x10\x03\x12\x0e\n" + "\n" + - "EXPOSE_TLS\x10\x042\xfc\x15\n" + + "EXPOSE_TLS\x10\x042\xaf\x17\n" + "\rDaemonService\x126\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + @@ -6587,7 +6794,10 @@ const file_daemon_proto_rawDesc = "" + "CleanState\x12\x19.daemon.CleanStateRequest\x1a\x1a.daemon.CleanStateResponse\"\x00\x12H\n" + "\vDeleteState\x12\x1a.daemon.DeleteStateRequest\x1a\x1b.daemon.DeleteStateResponse\"\x00\x12u\n" + "\x1aSetSyncResponsePersistence\x12).daemon.SetSyncResponsePersistenceRequest\x1a*.daemon.SetSyncResponsePersistenceResponse\"\x00\x12H\n" + - "\vTracePacket\x12\x1a.daemon.TracePacketRequest\x1a\x1b.daemon.TracePacketResponse\"\x00\x12D\n" + + "\vTracePacket\x12\x1a.daemon.TracePacketRequest\x1a\x1b.daemon.TracePacketResponse\"\x00\x12F\n" + + "\fStartCapture\x12\x1b.daemon.StartCaptureRequest\x1a\x15.daemon.CapturePacket\"\x000\x01\x12]\n" + + "\x12StartBundleCapture\x12!.daemon.StartBundleCaptureRequest\x1a\".daemon.StartBundleCaptureResponse\"\x00\x12Z\n" + + "\x11StopBundleCapture\x12 .daemon.StopBundleCaptureRequest\x1a!.daemon.StopBundleCaptureResponse\"\x00\x12D\n" + "\x0fSubscribeEvents\x12\x18.daemon.SubscribeRequest\x1a\x13.daemon.SystemEvent\"\x000\x01\x12B\n" + "\tGetEvents\x12\x18.daemon.GetEventsRequest\x1a\x19.daemon.GetEventsResponse\"\x00\x12N\n" + "\rSwitchProfile\x12\x1c.daemon.SwitchProfileRequest\x1a\x1d.daemon.SwitchProfileResponse\"\x00\x12B\n" + @@ -6604,8 +6814,7 @@ const file_daemon_proto_rawDesc = "" + "\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" + "\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12T\n" + "\x0fStartCPUProfile\x12\x1e.daemon.StartCPUProfileRequest\x1a\x1f.daemon.StartCPUProfileResponse\"\x00\x12Q\n" + - "\x0eStopCPUProfile\x12\x1d.daemon.StopCPUProfileRequest\x1a\x1e.daemon.StopCPUProfileResponse\"\x00\x12N\n" + - "\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" + + "\x0eStopCPUProfile\x12\x1d.daemon.StopCPUProfileRequest\x1a\x1e.daemon.StopCPUProfileResponse\"\x00\x12W\n" + "\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00\x12M\n" + "\rExposeService\x12\x1c.daemon.ExposeServiceRequest\x1a\x1a.daemon.ExposeServiceEvent\"\x000\x01B\bZ\x06/protob\x06proto3" @@ -6621,226 +6830,234 @@ func file_daemon_proto_rawDescGZIP() []byte { return file_daemon_proto_rawDescData } -var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 5) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 93) +var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 97) var file_daemon_proto_goTypes = []any{ (LogLevel)(0), // 0: daemon.LogLevel (ExposeProtocol)(0), // 1: daemon.ExposeProtocol - (OSLifecycleRequest_CycleType)(0), // 2: daemon.OSLifecycleRequest.CycleType - (SystemEvent_Severity)(0), // 3: daemon.SystemEvent.Severity - (SystemEvent_Category)(0), // 4: daemon.SystemEvent.Category - (*EmptyRequest)(nil), // 5: daemon.EmptyRequest - (*OSLifecycleRequest)(nil), // 6: daemon.OSLifecycleRequest - (*OSLifecycleResponse)(nil), // 7: daemon.OSLifecycleResponse - (*LoginRequest)(nil), // 8: daemon.LoginRequest - (*LoginResponse)(nil), // 9: daemon.LoginResponse - (*WaitSSOLoginRequest)(nil), // 10: daemon.WaitSSOLoginRequest - (*WaitSSOLoginResponse)(nil), // 11: daemon.WaitSSOLoginResponse - (*UpRequest)(nil), // 12: daemon.UpRequest - (*UpResponse)(nil), // 13: daemon.UpResponse - (*StatusRequest)(nil), // 14: daemon.StatusRequest - (*StatusResponse)(nil), // 15: daemon.StatusResponse - (*DownRequest)(nil), // 16: daemon.DownRequest - (*DownResponse)(nil), // 17: daemon.DownResponse - (*GetConfigRequest)(nil), // 18: daemon.GetConfigRequest - (*GetConfigResponse)(nil), // 19: daemon.GetConfigResponse - (*PeerState)(nil), // 20: daemon.PeerState - (*LocalPeerState)(nil), // 21: daemon.LocalPeerState - (*SignalState)(nil), // 22: daemon.SignalState - (*ManagementState)(nil), // 23: daemon.ManagementState - (*RelayState)(nil), // 24: daemon.RelayState - (*NSGroupState)(nil), // 25: daemon.NSGroupState - (*SSHSessionInfo)(nil), // 26: daemon.SSHSessionInfo - (*SSHServerState)(nil), // 27: daemon.SSHServerState - (*FullStatus)(nil), // 28: daemon.FullStatus - (*ListNetworksRequest)(nil), // 29: daemon.ListNetworksRequest - (*ListNetworksResponse)(nil), // 30: daemon.ListNetworksResponse - (*SelectNetworksRequest)(nil), // 31: daemon.SelectNetworksRequest - (*SelectNetworksResponse)(nil), // 32: daemon.SelectNetworksResponse - (*IPList)(nil), // 33: daemon.IPList - (*Network)(nil), // 34: daemon.Network - (*PortInfo)(nil), // 35: daemon.PortInfo - (*ForwardingRule)(nil), // 36: daemon.ForwardingRule - (*ForwardingRulesResponse)(nil), // 37: daemon.ForwardingRulesResponse - (*DebugBundleRequest)(nil), // 38: daemon.DebugBundleRequest - (*DebugBundleResponse)(nil), // 39: daemon.DebugBundleResponse - (*GetLogLevelRequest)(nil), // 40: daemon.GetLogLevelRequest - (*GetLogLevelResponse)(nil), // 41: daemon.GetLogLevelResponse - (*SetLogLevelRequest)(nil), // 42: daemon.SetLogLevelRequest - (*SetLogLevelResponse)(nil), // 43: daemon.SetLogLevelResponse - (*State)(nil), // 44: daemon.State - (*ListStatesRequest)(nil), // 45: daemon.ListStatesRequest - (*ListStatesResponse)(nil), // 46: daemon.ListStatesResponse - (*CleanStateRequest)(nil), // 47: daemon.CleanStateRequest - (*CleanStateResponse)(nil), // 48: daemon.CleanStateResponse - (*DeleteStateRequest)(nil), // 49: daemon.DeleteStateRequest - (*DeleteStateResponse)(nil), // 50: daemon.DeleteStateResponse - (*SetSyncResponsePersistenceRequest)(nil), // 51: daemon.SetSyncResponsePersistenceRequest - (*SetSyncResponsePersistenceResponse)(nil), // 52: daemon.SetSyncResponsePersistenceResponse - (*TCPFlags)(nil), // 53: daemon.TCPFlags - (*TracePacketRequest)(nil), // 54: daemon.TracePacketRequest - (*TraceStage)(nil), // 55: daemon.TraceStage - (*TracePacketResponse)(nil), // 56: daemon.TracePacketResponse - (*SubscribeRequest)(nil), // 57: daemon.SubscribeRequest - (*SystemEvent)(nil), // 58: daemon.SystemEvent - (*GetEventsRequest)(nil), // 59: daemon.GetEventsRequest - (*GetEventsResponse)(nil), // 60: daemon.GetEventsResponse - (*SwitchProfileRequest)(nil), // 61: daemon.SwitchProfileRequest - (*SwitchProfileResponse)(nil), // 62: daemon.SwitchProfileResponse - (*SetConfigRequest)(nil), // 63: daemon.SetConfigRequest - (*SetConfigResponse)(nil), // 64: daemon.SetConfigResponse - (*AddProfileRequest)(nil), // 65: daemon.AddProfileRequest - (*AddProfileResponse)(nil), // 66: daemon.AddProfileResponse - (*RemoveProfileRequest)(nil), // 67: daemon.RemoveProfileRequest - (*RemoveProfileResponse)(nil), // 68: daemon.RemoveProfileResponse - (*ListProfilesRequest)(nil), // 69: daemon.ListProfilesRequest - (*ListProfilesResponse)(nil), // 70: daemon.ListProfilesResponse - (*Profile)(nil), // 71: daemon.Profile - (*GetActiveProfileRequest)(nil), // 72: daemon.GetActiveProfileRequest - (*GetActiveProfileResponse)(nil), // 73: daemon.GetActiveProfileResponse - (*LogoutRequest)(nil), // 74: daemon.LogoutRequest - (*LogoutResponse)(nil), // 75: daemon.LogoutResponse - (*GetFeaturesRequest)(nil), // 76: daemon.GetFeaturesRequest - (*GetFeaturesResponse)(nil), // 77: daemon.GetFeaturesResponse - (*TriggerUpdateRequest)(nil), // 78: daemon.TriggerUpdateRequest - (*TriggerUpdateResponse)(nil), // 79: daemon.TriggerUpdateResponse - (*GetPeerSSHHostKeyRequest)(nil), // 80: daemon.GetPeerSSHHostKeyRequest - (*GetPeerSSHHostKeyResponse)(nil), // 81: daemon.GetPeerSSHHostKeyResponse - (*RequestJWTAuthRequest)(nil), // 82: daemon.RequestJWTAuthRequest - (*RequestJWTAuthResponse)(nil), // 83: daemon.RequestJWTAuthResponse - (*WaitJWTTokenRequest)(nil), // 84: daemon.WaitJWTTokenRequest - (*WaitJWTTokenResponse)(nil), // 85: daemon.WaitJWTTokenResponse - (*StartCPUProfileRequest)(nil), // 86: daemon.StartCPUProfileRequest - (*StartCPUProfileResponse)(nil), // 87: daemon.StartCPUProfileResponse - (*StopCPUProfileRequest)(nil), // 88: daemon.StopCPUProfileRequest - (*StopCPUProfileResponse)(nil), // 89: daemon.StopCPUProfileResponse - (*InstallerResultRequest)(nil), // 90: daemon.InstallerResultRequest - (*InstallerResultResponse)(nil), // 91: daemon.InstallerResultResponse - (*ExposeServiceRequest)(nil), // 92: daemon.ExposeServiceRequest - (*ExposeServiceEvent)(nil), // 93: daemon.ExposeServiceEvent - (*ExposeServiceReady)(nil), // 94: daemon.ExposeServiceReady - nil, // 95: daemon.Network.ResolvedIPsEntry - (*PortInfo_Range)(nil), // 96: daemon.PortInfo.Range - nil, // 97: daemon.SystemEvent.MetadataEntry - (*durationpb.Duration)(nil), // 98: google.protobuf.Duration - (*timestamppb.Timestamp)(nil), // 99: google.protobuf.Timestamp + (SystemEvent_Severity)(0), // 2: daemon.SystemEvent.Severity + (SystemEvent_Category)(0), // 3: daemon.SystemEvent.Category + (*EmptyRequest)(nil), // 4: daemon.EmptyRequest + (*LoginRequest)(nil), // 5: daemon.LoginRequest + (*LoginResponse)(nil), // 6: daemon.LoginResponse + (*WaitSSOLoginRequest)(nil), // 7: daemon.WaitSSOLoginRequest + (*WaitSSOLoginResponse)(nil), // 8: daemon.WaitSSOLoginResponse + (*UpRequest)(nil), // 9: daemon.UpRequest + (*UpResponse)(nil), // 10: daemon.UpResponse + (*StatusRequest)(nil), // 11: daemon.StatusRequest + (*StatusResponse)(nil), // 12: daemon.StatusResponse + (*DownRequest)(nil), // 13: daemon.DownRequest + (*DownResponse)(nil), // 14: daemon.DownResponse + (*GetConfigRequest)(nil), // 15: daemon.GetConfigRequest + (*GetConfigResponse)(nil), // 16: daemon.GetConfigResponse + (*PeerState)(nil), // 17: daemon.PeerState + (*LocalPeerState)(nil), // 18: daemon.LocalPeerState + (*SignalState)(nil), // 19: daemon.SignalState + (*ManagementState)(nil), // 20: daemon.ManagementState + (*RelayState)(nil), // 21: daemon.RelayState + (*NSGroupState)(nil), // 22: daemon.NSGroupState + (*SSHSessionInfo)(nil), // 23: daemon.SSHSessionInfo + (*SSHServerState)(nil), // 24: daemon.SSHServerState + (*FullStatus)(nil), // 25: daemon.FullStatus + (*ListNetworksRequest)(nil), // 26: daemon.ListNetworksRequest + (*ListNetworksResponse)(nil), // 27: daemon.ListNetworksResponse + (*SelectNetworksRequest)(nil), // 28: daemon.SelectNetworksRequest + (*SelectNetworksResponse)(nil), // 29: daemon.SelectNetworksResponse + (*IPList)(nil), // 30: daemon.IPList + (*Network)(nil), // 31: daemon.Network + (*PortInfo)(nil), // 32: daemon.PortInfo + (*ForwardingRule)(nil), // 33: daemon.ForwardingRule + (*ForwardingRulesResponse)(nil), // 34: daemon.ForwardingRulesResponse + (*DebugBundleRequest)(nil), // 35: daemon.DebugBundleRequest + (*DebugBundleResponse)(nil), // 36: daemon.DebugBundleResponse + (*GetLogLevelRequest)(nil), // 37: daemon.GetLogLevelRequest + (*GetLogLevelResponse)(nil), // 38: daemon.GetLogLevelResponse + (*SetLogLevelRequest)(nil), // 39: daemon.SetLogLevelRequest + (*SetLogLevelResponse)(nil), // 40: daemon.SetLogLevelResponse + (*State)(nil), // 41: daemon.State + (*ListStatesRequest)(nil), // 42: daemon.ListStatesRequest + (*ListStatesResponse)(nil), // 43: daemon.ListStatesResponse + (*CleanStateRequest)(nil), // 44: daemon.CleanStateRequest + (*CleanStateResponse)(nil), // 45: daemon.CleanStateResponse + (*DeleteStateRequest)(nil), // 46: daemon.DeleteStateRequest + (*DeleteStateResponse)(nil), // 47: daemon.DeleteStateResponse + (*SetSyncResponsePersistenceRequest)(nil), // 48: daemon.SetSyncResponsePersistenceRequest + (*SetSyncResponsePersistenceResponse)(nil), // 49: daemon.SetSyncResponsePersistenceResponse + (*TCPFlags)(nil), // 50: daemon.TCPFlags + (*TracePacketRequest)(nil), // 51: daemon.TracePacketRequest + (*TraceStage)(nil), // 52: daemon.TraceStage + (*TracePacketResponse)(nil), // 53: daemon.TracePacketResponse + (*SubscribeRequest)(nil), // 54: daemon.SubscribeRequest + (*SystemEvent)(nil), // 55: daemon.SystemEvent + (*GetEventsRequest)(nil), // 56: daemon.GetEventsRequest + (*GetEventsResponse)(nil), // 57: daemon.GetEventsResponse + (*SwitchProfileRequest)(nil), // 58: daemon.SwitchProfileRequest + (*SwitchProfileResponse)(nil), // 59: daemon.SwitchProfileResponse + (*SetConfigRequest)(nil), // 60: daemon.SetConfigRequest + (*SetConfigResponse)(nil), // 61: daemon.SetConfigResponse + (*AddProfileRequest)(nil), // 62: daemon.AddProfileRequest + (*AddProfileResponse)(nil), // 63: daemon.AddProfileResponse + (*RemoveProfileRequest)(nil), // 64: daemon.RemoveProfileRequest + (*RemoveProfileResponse)(nil), // 65: daemon.RemoveProfileResponse + (*ListProfilesRequest)(nil), // 66: daemon.ListProfilesRequest + (*ListProfilesResponse)(nil), // 67: daemon.ListProfilesResponse + (*Profile)(nil), // 68: daemon.Profile + (*GetActiveProfileRequest)(nil), // 69: daemon.GetActiveProfileRequest + (*GetActiveProfileResponse)(nil), // 70: daemon.GetActiveProfileResponse + (*LogoutRequest)(nil), // 71: daemon.LogoutRequest + (*LogoutResponse)(nil), // 72: daemon.LogoutResponse + (*GetFeaturesRequest)(nil), // 73: daemon.GetFeaturesRequest + (*GetFeaturesResponse)(nil), // 74: daemon.GetFeaturesResponse + (*TriggerUpdateRequest)(nil), // 75: daemon.TriggerUpdateRequest + (*TriggerUpdateResponse)(nil), // 76: daemon.TriggerUpdateResponse + (*GetPeerSSHHostKeyRequest)(nil), // 77: daemon.GetPeerSSHHostKeyRequest + (*GetPeerSSHHostKeyResponse)(nil), // 78: daemon.GetPeerSSHHostKeyResponse + (*RequestJWTAuthRequest)(nil), // 79: daemon.RequestJWTAuthRequest + (*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse + (*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest + (*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse + (*StartCPUProfileRequest)(nil), // 83: daemon.StartCPUProfileRequest + (*StartCPUProfileResponse)(nil), // 84: daemon.StartCPUProfileResponse + (*StopCPUProfileRequest)(nil), // 85: daemon.StopCPUProfileRequest + (*StopCPUProfileResponse)(nil), // 86: daemon.StopCPUProfileResponse + (*InstallerResultRequest)(nil), // 87: daemon.InstallerResultRequest + (*InstallerResultResponse)(nil), // 88: daemon.InstallerResultResponse + (*ExposeServiceRequest)(nil), // 89: daemon.ExposeServiceRequest + (*ExposeServiceEvent)(nil), // 90: daemon.ExposeServiceEvent + (*ExposeServiceReady)(nil), // 91: daemon.ExposeServiceReady + (*StartCaptureRequest)(nil), // 92: daemon.StartCaptureRequest + (*CapturePacket)(nil), // 93: daemon.CapturePacket + (*StartBundleCaptureRequest)(nil), // 94: daemon.StartBundleCaptureRequest + (*StartBundleCaptureResponse)(nil), // 95: daemon.StartBundleCaptureResponse + (*StopBundleCaptureRequest)(nil), // 96: daemon.StopBundleCaptureRequest + (*StopBundleCaptureResponse)(nil), // 97: daemon.StopBundleCaptureResponse + nil, // 98: daemon.Network.ResolvedIPsEntry + (*PortInfo_Range)(nil), // 99: daemon.PortInfo.Range + nil, // 100: daemon.SystemEvent.MetadataEntry + (*durationpb.Duration)(nil), // 101: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 102: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ - 2, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType - 98, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration - 28, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 99, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 99, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 98, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration - 26, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo - 23, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState - 22, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState - 21, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState - 20, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState - 24, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState - 25, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState - 58, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent - 27, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState - 34, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network - 95, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry - 96, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range - 35, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo - 35, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo - 36, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule - 0, // 21: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel - 0, // 22: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel - 44, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State - 53, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags - 55, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage - 3, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity - 4, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category - 99, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp - 97, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry - 58, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent - 98, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration - 71, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile - 1, // 33: daemon.ExposeServiceRequest.protocol:type_name -> daemon.ExposeProtocol - 94, // 34: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady - 33, // 35: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList - 8, // 36: daemon.DaemonService.Login:input_type -> daemon.LoginRequest - 10, // 37: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest - 12, // 38: daemon.DaemonService.Up:input_type -> daemon.UpRequest - 14, // 39: daemon.DaemonService.Status:input_type -> daemon.StatusRequest - 16, // 40: daemon.DaemonService.Down:input_type -> daemon.DownRequest - 18, // 41: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest - 29, // 42: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest - 31, // 43: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest - 31, // 44: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest - 5, // 45: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest - 38, // 46: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest - 40, // 47: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest - 42, // 48: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest - 45, // 49: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest - 47, // 50: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest - 49, // 51: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest - 51, // 52: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest - 54, // 53: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest - 57, // 54: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest - 59, // 55: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest - 61, // 56: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest - 63, // 57: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest - 65, // 58: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest - 67, // 59: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest - 69, // 60: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest - 72, // 61: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest - 74, // 62: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest - 76, // 63: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest - 78, // 64: daemon.DaemonService.TriggerUpdate:input_type -> daemon.TriggerUpdateRequest - 80, // 65: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest - 82, // 66: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest - 84, // 67: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest - 86, // 68: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest - 88, // 69: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest - 6, // 70: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest - 90, // 71: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest - 92, // 72: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest - 9, // 73: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 11, // 74: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 13, // 75: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 15, // 76: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 17, // 77: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 19, // 78: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 30, // 79: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse - 32, // 80: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse - 32, // 81: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse - 37, // 82: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse - 39, // 83: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 41, // 84: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 43, // 85: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 46, // 86: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse - 48, // 87: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse - 50, // 88: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse - 52, // 89: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse - 56, // 90: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse - 58, // 91: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent - 60, // 92: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse - 62, // 93: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse - 64, // 94: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse - 66, // 95: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse - 68, // 96: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse - 70, // 97: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse - 73, // 98: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse - 75, // 99: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse - 77, // 100: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse - 79, // 101: daemon.DaemonService.TriggerUpdate:output_type -> daemon.TriggerUpdateResponse - 81, // 102: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse - 83, // 103: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse - 85, // 104: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse - 87, // 105: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse - 89, // 106: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse - 7, // 107: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse - 91, // 108: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse - 93, // 109: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent - 73, // [73:110] is the sub-list for method output_type - 36, // [36:73] is the sub-list for method input_type - 36, // [36:36] is the sub-list for extension type_name - 36, // [36:36] is the sub-list for extension extendee - 0, // [0:36] is the sub-list for field type_name + 101, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 25, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus + 102, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 102, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 101, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 23, // 5: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo + 20, // 6: daemon.FullStatus.managementState:type_name -> daemon.ManagementState + 19, // 7: daemon.FullStatus.signalState:type_name -> daemon.SignalState + 18, // 8: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState + 17, // 9: daemon.FullStatus.peers:type_name -> daemon.PeerState + 21, // 10: daemon.FullStatus.relays:type_name -> daemon.RelayState + 22, // 11: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState + 55, // 12: daemon.FullStatus.events:type_name -> daemon.SystemEvent + 24, // 13: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState + 31, // 14: daemon.ListNetworksResponse.routes:type_name -> daemon.Network + 98, // 15: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry + 99, // 16: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range + 32, // 17: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo + 32, // 18: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo + 33, // 19: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule + 0, // 20: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel + 0, // 21: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel + 41, // 22: daemon.ListStatesResponse.states:type_name -> daemon.State + 50, // 23: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags + 52, // 24: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage + 2, // 25: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity + 3, // 26: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category + 102, // 27: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp + 100, // 28: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry + 55, // 29: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent + 101, // 30: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 68, // 31: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile + 1, // 32: daemon.ExposeServiceRequest.protocol:type_name -> daemon.ExposeProtocol + 91, // 33: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady + 101, // 34: daemon.StartCaptureRequest.duration:type_name -> google.protobuf.Duration + 101, // 35: daemon.StartBundleCaptureRequest.timeout:type_name -> google.protobuf.Duration + 30, // 36: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList + 5, // 37: daemon.DaemonService.Login:input_type -> daemon.LoginRequest + 7, // 38: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest + 9, // 39: daemon.DaemonService.Up:input_type -> daemon.UpRequest + 11, // 40: daemon.DaemonService.Status:input_type -> daemon.StatusRequest + 13, // 41: daemon.DaemonService.Down:input_type -> daemon.DownRequest + 15, // 42: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest + 26, // 43: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest + 28, // 44: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest + 28, // 45: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest + 4, // 46: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest + 35, // 47: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest + 37, // 48: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest + 39, // 49: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest + 42, // 50: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest + 44, // 51: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest + 46, // 52: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest + 48, // 53: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest + 51, // 54: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest + 92, // 55: daemon.DaemonService.StartCapture:input_type -> daemon.StartCaptureRequest + 94, // 56: daemon.DaemonService.StartBundleCapture:input_type -> daemon.StartBundleCaptureRequest + 96, // 57: daemon.DaemonService.StopBundleCapture:input_type -> daemon.StopBundleCaptureRequest + 54, // 58: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest + 56, // 59: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest + 58, // 60: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest + 60, // 61: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest + 62, // 62: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest + 64, // 63: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest + 66, // 64: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest + 69, // 65: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest + 71, // 66: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest + 73, // 67: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest + 75, // 68: daemon.DaemonService.TriggerUpdate:input_type -> daemon.TriggerUpdateRequest + 77, // 69: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest + 79, // 70: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest + 81, // 71: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest + 83, // 72: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest + 85, // 73: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest + 87, // 74: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest + 89, // 75: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest + 6, // 76: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 8, // 77: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 10, // 78: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 12, // 79: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 14, // 80: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 16, // 81: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 27, // 82: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 29, // 83: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 29, // 84: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse + 34, // 85: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse + 36, // 86: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 38, // 87: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 40, // 88: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 43, // 89: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse + 45, // 90: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse + 47, // 91: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse + 49, // 92: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse + 53, // 93: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse + 93, // 94: daemon.DaemonService.StartCapture:output_type -> daemon.CapturePacket + 95, // 95: daemon.DaemonService.StartBundleCapture:output_type -> daemon.StartBundleCaptureResponse + 97, // 96: daemon.DaemonService.StopBundleCapture:output_type -> daemon.StopBundleCaptureResponse + 55, // 97: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent + 57, // 98: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse + 59, // 99: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse + 61, // 100: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse + 63, // 101: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse + 65, // 102: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse + 67, // 103: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse + 70, // 104: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse + 72, // 105: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse + 74, // 106: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse + 76, // 107: daemon.DaemonService.TriggerUpdate:output_type -> daemon.TriggerUpdateResponse + 78, // 108: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse + 80, // 109: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse + 82, // 110: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse + 84, // 111: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse + 86, // 112: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse + 88, // 113: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse + 90, // 114: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent + 76, // [76:115] is the sub-list for method output_type + 37, // [37:76] is the sub-list for method input_type + 37, // [37:37] is the sub-list for extension type_name + 37, // [37:37] is the sub-list for extension extendee + 0, // [0:37] is the sub-list for field type_name } func init() { file_daemon_proto_init() } @@ -6848,20 +7065,20 @@ func file_daemon_proto_init() { if File_daemon_proto != nil { return } - file_daemon_proto_msgTypes[3].OneofWrappers = []any{} + file_daemon_proto_msgTypes[1].OneofWrappers = []any{} + file_daemon_proto_msgTypes[5].OneofWrappers = []any{} file_daemon_proto_msgTypes[7].OneofWrappers = []any{} - file_daemon_proto_msgTypes[9].OneofWrappers = []any{} - file_daemon_proto_msgTypes[30].OneofWrappers = []any{ + file_daemon_proto_msgTypes[28].OneofWrappers = []any{ (*PortInfo_Port)(nil), (*PortInfo_Range_)(nil), } - file_daemon_proto_msgTypes[49].OneofWrappers = []any{} - file_daemon_proto_msgTypes[50].OneofWrappers = []any{} + file_daemon_proto_msgTypes[47].OneofWrappers = []any{} + file_daemon_proto_msgTypes[48].OneofWrappers = []any{} + file_daemon_proto_msgTypes[54].OneofWrappers = []any{} file_daemon_proto_msgTypes[56].OneofWrappers = []any{} - file_daemon_proto_msgTypes[58].OneofWrappers = []any{} - file_daemon_proto_msgTypes[69].OneofWrappers = []any{} - file_daemon_proto_msgTypes[77].OneofWrappers = []any{} - file_daemon_proto_msgTypes[88].OneofWrappers = []any{ + file_daemon_proto_msgTypes[67].OneofWrappers = []any{} + file_daemon_proto_msgTypes[75].OneofWrappers = []any{} + file_daemon_proto_msgTypes[86].OneofWrappers = []any{ (*ExposeServiceEvent_Ready)(nil), } type x struct{} @@ -6869,8 +7086,8 @@ func file_daemon_proto_init() { File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)), - NumEnums: 5, - NumMessages: 93, + NumEnums: 4, + NumMessages: 97, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 19976660c..dedff43e2 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -64,6 +64,17 @@ service DaemonService { rpc TracePacket(TracePacketRequest) returns (TracePacketResponse) {} + // StartCapture begins streaming packet capture on the WireGuard interface. + // Requires --enable-capture set at service install/reconfigure time. + rpc StartCapture(StartCaptureRequest) returns (stream CapturePacket) {} + + // StartBundleCapture begins capturing packets to a server-side temp file + // for inclusion in the next debug bundle. Auto-stops after the given timeout. + rpc StartBundleCapture(StartBundleCaptureRequest) returns (StartBundleCaptureResponse) {} + + // StopBundleCapture stops the running bundle capture. Idempotent. + rpc StopBundleCapture(StopBundleCaptureRequest) returns (StopBundleCaptureResponse) {} + rpc SubscribeEvents(SubscribeRequest) returns (stream SystemEvent) {} rpc GetEvents(GetEventsRequest) returns (GetEventsResponse) {} @@ -104,8 +115,6 @@ service DaemonService { // StopCPUProfile stops CPU profiling in the daemon rpc StopCPUProfile(StopCPUProfileRequest) returns (StopCPUProfileResponse) {} - rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {} - rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {} // ExposeService exposes a local port via the NetBird reverse proxy @@ -114,20 +123,6 @@ service DaemonService { -message OSLifecycleRequest { - // avoid collision with loglevel enum - enum CycleType { - UNKNOWN = 0; - SLEEP = 1; - WAKEUP = 2; - } - - CycleType type = 1; -} - -message OSLifecycleResponse {} - - message LoginRequest { // setupKey netbird setup key. string setupKey = 1; @@ -209,6 +204,7 @@ message LoginRequest { optional bool enableSSHRemotePortForwarding = 37; optional bool disableSSHAuth = 38; optional int32 sshJWTCacheTTL = 39; + optional bool disable_ipv6 = 40; } message LoginResponse { @@ -316,6 +312,8 @@ message GetConfigResponse { bool disableSSHAuth = 25; int32 sshJWTCacheTTL = 26; + + bool disable_ipv6 = 27; } // PeerState contains the latest state of a peer @@ -338,6 +336,7 @@ message PeerState { google.protobuf.Duration latency = 17; string relayAddress = 18; bytes sshHostKey = 19; + string ipv6 = 20; } // LocalPeerState contains the latest state of the local peer @@ -349,6 +348,7 @@ message LocalPeerState { bool rosenpassEnabled = 5; bool rosenpassPermissive = 6; repeated string networks = 7; + string ipv6 = 8; } // SignalState contains the latest state of a signal connection @@ -677,6 +677,7 @@ message SetConfigRequest { optional bool enableSSHRemotePortForwarding = 32; optional bool disableSSHAuth = 33; optional int32 sshJWTCacheTTL = 34; + optional bool disable_ipv6 = 35; } message SetConfigResponse{} @@ -848,3 +849,26 @@ message ExposeServiceReady { string domain = 3; bool port_auto_assigned = 4; } + +message StartCaptureRequest { + bool text_output = 1; + uint32 snap_len = 2; + google.protobuf.Duration duration = 3; + string filter_expr = 4; + bool verbose = 5; + bool ascii = 6; +} + +message CapturePacket { + bytes data = 1; +} + +message StartBundleCaptureRequest { + // timeout auto-stops the capture after this duration. + // Clamped to a server-side maximum (10 minutes). Zero or unset defaults to the maximum. + google.protobuf.Duration timeout = 1; +} + +message StartBundleCaptureResponse {} +message StopBundleCaptureRequest {} +message StopBundleCaptureResponse {} diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index e5bd89597..66a8efcc3 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.6.1 +// - protoc v6.33.1 +// source: daemon.proto package proto @@ -11,8 +15,50 @@ import ( // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.32.0 or later. -const _ = grpc.SupportPackageIsVersion7 +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + DaemonService_Login_FullMethodName = "/daemon.DaemonService/Login" + DaemonService_WaitSSOLogin_FullMethodName = "/daemon.DaemonService/WaitSSOLogin" + DaemonService_Up_FullMethodName = "/daemon.DaemonService/Up" + DaemonService_Status_FullMethodName = "/daemon.DaemonService/Status" + DaemonService_Down_FullMethodName = "/daemon.DaemonService/Down" + DaemonService_GetConfig_FullMethodName = "/daemon.DaemonService/GetConfig" + DaemonService_ListNetworks_FullMethodName = "/daemon.DaemonService/ListNetworks" + DaemonService_SelectNetworks_FullMethodName = "/daemon.DaemonService/SelectNetworks" + DaemonService_DeselectNetworks_FullMethodName = "/daemon.DaemonService/DeselectNetworks" + DaemonService_ForwardingRules_FullMethodName = "/daemon.DaemonService/ForwardingRules" + DaemonService_DebugBundle_FullMethodName = "/daemon.DaemonService/DebugBundle" + DaemonService_GetLogLevel_FullMethodName = "/daemon.DaemonService/GetLogLevel" + DaemonService_SetLogLevel_FullMethodName = "/daemon.DaemonService/SetLogLevel" + DaemonService_ListStates_FullMethodName = "/daemon.DaemonService/ListStates" + DaemonService_CleanState_FullMethodName = "/daemon.DaemonService/CleanState" + DaemonService_DeleteState_FullMethodName = "/daemon.DaemonService/DeleteState" + DaemonService_SetSyncResponsePersistence_FullMethodName = "/daemon.DaemonService/SetSyncResponsePersistence" + DaemonService_TracePacket_FullMethodName = "/daemon.DaemonService/TracePacket" + DaemonService_StartCapture_FullMethodName = "/daemon.DaemonService/StartCapture" + DaemonService_StartBundleCapture_FullMethodName = "/daemon.DaemonService/StartBundleCapture" + DaemonService_StopBundleCapture_FullMethodName = "/daemon.DaemonService/StopBundleCapture" + DaemonService_SubscribeEvents_FullMethodName = "/daemon.DaemonService/SubscribeEvents" + DaemonService_GetEvents_FullMethodName = "/daemon.DaemonService/GetEvents" + DaemonService_SwitchProfile_FullMethodName = "/daemon.DaemonService/SwitchProfile" + DaemonService_SetConfig_FullMethodName = "/daemon.DaemonService/SetConfig" + DaemonService_AddProfile_FullMethodName = "/daemon.DaemonService/AddProfile" + DaemonService_RemoveProfile_FullMethodName = "/daemon.DaemonService/RemoveProfile" + DaemonService_ListProfiles_FullMethodName = "/daemon.DaemonService/ListProfiles" + DaemonService_GetActiveProfile_FullMethodName = "/daemon.DaemonService/GetActiveProfile" + DaemonService_Logout_FullMethodName = "/daemon.DaemonService/Logout" + DaemonService_GetFeatures_FullMethodName = "/daemon.DaemonService/GetFeatures" + DaemonService_TriggerUpdate_FullMethodName = "/daemon.DaemonService/TriggerUpdate" + DaemonService_GetPeerSSHHostKey_FullMethodName = "/daemon.DaemonService/GetPeerSSHHostKey" + DaemonService_RequestJWTAuth_FullMethodName = "/daemon.DaemonService/RequestJWTAuth" + DaemonService_WaitJWTToken_FullMethodName = "/daemon.DaemonService/WaitJWTToken" + DaemonService_StartCPUProfile_FullMethodName = "/daemon.DaemonService/StartCPUProfile" + DaemonService_StopCPUProfile_FullMethodName = "/daemon.DaemonService/StopCPUProfile" + DaemonService_GetInstallerResult_FullMethodName = "/daemon.DaemonService/GetInstallerResult" + DaemonService_ExposeService_FullMethodName = "/daemon.DaemonService/ExposeService" +) // DaemonServiceClient is the client API for DaemonService service. // @@ -53,7 +99,15 @@ type DaemonServiceClient interface { // SetSyncResponsePersistence enables or disables sync response persistence SetSyncResponsePersistence(ctx context.Context, in *SetSyncResponsePersistenceRequest, opts ...grpc.CallOption) (*SetSyncResponsePersistenceResponse, error) TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error) - SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error) + // StartCapture begins streaming packet capture on the WireGuard interface. + // Requires --enable-capture set at service install/reconfigure time. + StartCapture(ctx context.Context, in *StartCaptureRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[CapturePacket], error) + // StartBundleCapture begins capturing packets to a server-side temp file + // for inclusion in the next debug bundle. Auto-stops after the given timeout. + StartBundleCapture(ctx context.Context, in *StartBundleCaptureRequest, opts ...grpc.CallOption) (*StartBundleCaptureResponse, error) + // StopBundleCapture stops the running bundle capture. Idempotent. + StopBundleCapture(ctx context.Context, in *StopBundleCaptureRequest, opts ...grpc.CallOption) (*StopBundleCaptureResponse, error) + SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[SystemEvent], error) GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error) SwitchProfile(ctx context.Context, in *SwitchProfileRequest, opts ...grpc.CallOption) (*SwitchProfileResponse, error) SetConfig(ctx context.Context, in *SetConfigRequest, opts ...grpc.CallOption) (*SetConfigResponse, error) @@ -77,10 +131,9 @@ type DaemonServiceClient interface { StartCPUProfile(ctx context.Context, in *StartCPUProfileRequest, opts ...grpc.CallOption) (*StartCPUProfileResponse, error) // StopCPUProfile stops CPU profiling in the daemon StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error) - NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error) // ExposeService exposes a local port via the NetBird reverse proxy - ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (DaemonService_ExposeServiceClient, error) + ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ExposeServiceEvent], error) } type daemonServiceClient struct { @@ -92,8 +145,9 @@ func NewDaemonServiceClient(cc grpc.ClientConnInterface) DaemonServiceClient { } func (c *daemonServiceClient) Login(ctx context.Context, in *LoginRequest, opts ...grpc.CallOption) (*LoginResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(LoginResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/Login", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_Login_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -101,8 +155,9 @@ func (c *daemonServiceClient) Login(ctx context.Context, in *LoginRequest, opts } func (c *daemonServiceClient) WaitSSOLogin(ctx context.Context, in *WaitSSOLoginRequest, opts ...grpc.CallOption) (*WaitSSOLoginResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(WaitSSOLoginResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/WaitSSOLogin", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_WaitSSOLogin_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -110,8 +165,9 @@ func (c *daemonServiceClient) WaitSSOLogin(ctx context.Context, in *WaitSSOLogin } func (c *daemonServiceClient) Up(ctx context.Context, in *UpRequest, opts ...grpc.CallOption) (*UpResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(UpResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/Up", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_Up_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -119,8 +175,9 @@ func (c *daemonServiceClient) Up(ctx context.Context, in *UpRequest, opts ...grp } func (c *daemonServiceClient) Status(ctx context.Context, in *StatusRequest, opts ...grpc.CallOption) (*StatusResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(StatusResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/Status", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_Status_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -128,8 +185,9 @@ func (c *daemonServiceClient) Status(ctx context.Context, in *StatusRequest, opt } func (c *daemonServiceClient) Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(DownResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/Down", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_Down_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -137,8 +195,9 @@ func (c *daemonServiceClient) Down(ctx context.Context, in *DownRequest, opts .. } func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetConfigResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetConfig", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_GetConfig_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -146,8 +205,9 @@ func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigReques } func (c *daemonServiceClient) ListNetworks(ctx context.Context, in *ListNetworksRequest, opts ...grpc.CallOption) (*ListNetworksResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ListNetworksResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListNetworks", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_ListNetworks_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -155,8 +215,9 @@ func (c *daemonServiceClient) ListNetworks(ctx context.Context, in *ListNetworks } func (c *daemonServiceClient) SelectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(SelectNetworksResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/SelectNetworks", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_SelectNetworks_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -164,8 +225,9 @@ func (c *daemonServiceClient) SelectNetworks(ctx context.Context, in *SelectNetw } func (c *daemonServiceClient) DeselectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(SelectNetworksResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeselectNetworks", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_DeselectNetworks_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -173,8 +235,9 @@ func (c *daemonServiceClient) DeselectNetworks(ctx context.Context, in *SelectNe } func (c *daemonServiceClient) ForwardingRules(ctx context.Context, in *EmptyRequest, opts ...grpc.CallOption) (*ForwardingRulesResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ForwardingRulesResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/ForwardingRules", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_ForwardingRules_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -182,8 +245,9 @@ func (c *daemonServiceClient) ForwardingRules(ctx context.Context, in *EmptyRequ } func (c *daemonServiceClient) DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(DebugBundleResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/DebugBundle", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_DebugBundle_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -191,8 +255,9 @@ func (c *daemonServiceClient) DebugBundle(ctx context.Context, in *DebugBundleRe } func (c *daemonServiceClient) GetLogLevel(ctx context.Context, in *GetLogLevelRequest, opts ...grpc.CallOption) (*GetLogLevelResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetLogLevelResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetLogLevel", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_GetLogLevel_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -200,8 +265,9 @@ func (c *daemonServiceClient) GetLogLevel(ctx context.Context, in *GetLogLevelRe } func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(SetLogLevelResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetLogLevel", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_SetLogLevel_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -209,8 +275,9 @@ func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRe } func (c *daemonServiceClient) ListStates(ctx context.Context, in *ListStatesRequest, opts ...grpc.CallOption) (*ListStatesResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ListStatesResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListStates", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_ListStates_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -218,8 +285,9 @@ func (c *daemonServiceClient) ListStates(ctx context.Context, in *ListStatesRequ } func (c *daemonServiceClient) CleanState(ctx context.Context, in *CleanStateRequest, opts ...grpc.CallOption) (*CleanStateResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(CleanStateResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/CleanState", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_CleanState_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -227,8 +295,9 @@ func (c *daemonServiceClient) CleanState(ctx context.Context, in *CleanStateRequ } func (c *daemonServiceClient) DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(DeleteStateResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeleteState", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_DeleteState_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -236,8 +305,9 @@ func (c *daemonServiceClient) DeleteState(ctx context.Context, in *DeleteStateRe } func (c *daemonServiceClient) SetSyncResponsePersistence(ctx context.Context, in *SetSyncResponsePersistenceRequest, opts ...grpc.CallOption) (*SetSyncResponsePersistenceResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(SetSyncResponsePersistenceResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetSyncResponsePersistence", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_SetSyncResponsePersistence_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -245,20 +315,22 @@ func (c *daemonServiceClient) SetSyncResponsePersistence(ctx context.Context, in } func (c *daemonServiceClient) TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(TracePacketResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/TracePacket", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_TracePacket_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } -func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error) { - stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[0], "/daemon.DaemonService/SubscribeEvents", opts...) +func (c *daemonServiceClient) StartCapture(ctx context.Context, in *StartCaptureRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[CapturePacket], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[0], DaemonService_StartCapture_FullMethodName, cOpts...) if err != nil { return nil, err } - x := &daemonServiceSubscribeEventsClient{stream} + x := &grpc.GenericClientStream[StartCaptureRequest, CapturePacket]{ClientStream: stream} if err := x.ClientStream.SendMsg(in); err != nil { return nil, err } @@ -268,26 +340,52 @@ func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *Subscribe return x, nil } -type DaemonService_SubscribeEventsClient interface { - Recv() (*SystemEvent, error) - grpc.ClientStream -} +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type DaemonService_StartCaptureClient = grpc.ServerStreamingClient[CapturePacket] -type daemonServiceSubscribeEventsClient struct { - grpc.ClientStream -} - -func (x *daemonServiceSubscribeEventsClient) Recv() (*SystemEvent, error) { - m := new(SystemEvent) - if err := x.ClientStream.RecvMsg(m); err != nil { +func (c *daemonServiceClient) StartBundleCapture(ctx context.Context, in *StartBundleCaptureRequest, opts ...grpc.CallOption) (*StartBundleCaptureResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(StartBundleCaptureResponse) + err := c.cc.Invoke(ctx, DaemonService_StartBundleCapture_FullMethodName, in, out, cOpts...) + if err != nil { return nil, err } - return m, nil + return out, nil } +func (c *daemonServiceClient) StopBundleCapture(ctx context.Context, in *StopBundleCaptureRequest, opts ...grpc.CallOption) (*StopBundleCaptureResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(StopBundleCaptureResponse) + err := c.cc.Invoke(ctx, DaemonService_StopBundleCapture_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[SystemEvent], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[1], DaemonService_SubscribeEvents_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[SubscribeRequest, SystemEvent]{ClientStream: stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type DaemonService_SubscribeEventsClient = grpc.ServerStreamingClient[SystemEvent] + func (c *daemonServiceClient) GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetEventsResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetEvents", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_GetEvents_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -295,8 +393,9 @@ func (c *daemonServiceClient) GetEvents(ctx context.Context, in *GetEventsReques } func (c *daemonServiceClient) SwitchProfile(ctx context.Context, in *SwitchProfileRequest, opts ...grpc.CallOption) (*SwitchProfileResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(SwitchProfileResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/SwitchProfile", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_SwitchProfile_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -304,8 +403,9 @@ func (c *daemonServiceClient) SwitchProfile(ctx context.Context, in *SwitchProfi } func (c *daemonServiceClient) SetConfig(ctx context.Context, in *SetConfigRequest, opts ...grpc.CallOption) (*SetConfigResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(SetConfigResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetConfig", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_SetConfig_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -313,8 +413,9 @@ func (c *daemonServiceClient) SetConfig(ctx context.Context, in *SetConfigReques } func (c *daemonServiceClient) AddProfile(ctx context.Context, in *AddProfileRequest, opts ...grpc.CallOption) (*AddProfileResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(AddProfileResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/AddProfile", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_AddProfile_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -322,8 +423,9 @@ func (c *daemonServiceClient) AddProfile(ctx context.Context, in *AddProfileRequ } func (c *daemonServiceClient) RemoveProfile(ctx context.Context, in *RemoveProfileRequest, opts ...grpc.CallOption) (*RemoveProfileResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(RemoveProfileResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/RemoveProfile", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_RemoveProfile_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -331,8 +433,9 @@ func (c *daemonServiceClient) RemoveProfile(ctx context.Context, in *RemoveProfi } func (c *daemonServiceClient) ListProfiles(ctx context.Context, in *ListProfilesRequest, opts ...grpc.CallOption) (*ListProfilesResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ListProfilesResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListProfiles", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_ListProfiles_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -340,8 +443,9 @@ func (c *daemonServiceClient) ListProfiles(ctx context.Context, in *ListProfiles } func (c *daemonServiceClient) GetActiveProfile(ctx context.Context, in *GetActiveProfileRequest, opts ...grpc.CallOption) (*GetActiveProfileResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetActiveProfileResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetActiveProfile", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_GetActiveProfile_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -349,8 +453,9 @@ func (c *daemonServiceClient) GetActiveProfile(ctx context.Context, in *GetActiv } func (c *daemonServiceClient) Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(LogoutResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/Logout", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_Logout_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -358,8 +463,9 @@ func (c *daemonServiceClient) Logout(ctx context.Context, in *LogoutRequest, opt } func (c *daemonServiceClient) GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetFeaturesResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetFeatures", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_GetFeatures_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -367,8 +473,9 @@ func (c *daemonServiceClient) GetFeatures(ctx context.Context, in *GetFeaturesRe } func (c *daemonServiceClient) TriggerUpdate(ctx context.Context, in *TriggerUpdateRequest, opts ...grpc.CallOption) (*TriggerUpdateResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(TriggerUpdateResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/TriggerUpdate", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_TriggerUpdate_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -376,8 +483,9 @@ func (c *daemonServiceClient) TriggerUpdate(ctx context.Context, in *TriggerUpda } func (c *daemonServiceClient) GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetPeerSSHHostKeyResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetPeerSSHHostKey", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_GetPeerSSHHostKey_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -385,8 +493,9 @@ func (c *daemonServiceClient) GetPeerSSHHostKey(ctx context.Context, in *GetPeer } func (c *daemonServiceClient) RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(RequestJWTAuthResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/RequestJWTAuth", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_RequestJWTAuth_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -394,8 +503,9 @@ func (c *daemonServiceClient) RequestJWTAuth(ctx context.Context, in *RequestJWT } func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(WaitJWTTokenResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/WaitJWTToken", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_WaitJWTToken_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -403,8 +513,9 @@ func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTToken } func (c *daemonServiceClient) StartCPUProfile(ctx context.Context, in *StartCPUProfileRequest, opts ...grpc.CallOption) (*StartCPUProfileResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(StartCPUProfileResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/StartCPUProfile", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_StartCPUProfile_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -412,17 +523,9 @@ func (c *daemonServiceClient) StartCPUProfile(ctx context.Context, in *StartCPUP } func (c *daemonServiceClient) StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(StopCPUProfileResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/StopCPUProfile", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) { - out := new(OSLifecycleResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/NotifyOSLifecycle", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_StopCPUProfile_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -430,20 +533,22 @@ func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifec } func (c *daemonServiceClient) GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(InstallerResultResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetInstallerResult", in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_GetInstallerResult_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } -func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (DaemonService_ExposeServiceClient, error) { - stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[1], "/daemon.DaemonService/ExposeService", opts...) +func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ExposeServiceEvent], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[2], DaemonService_ExposeService_FullMethodName, cOpts...) if err != nil { return nil, err } - x := &daemonServiceExposeServiceClient{stream} + x := &grpc.GenericClientStream[ExposeServiceRequest, ExposeServiceEvent]{ClientStream: stream} if err := x.ClientStream.SendMsg(in); err != nil { return nil, err } @@ -453,26 +558,12 @@ func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServi return x, nil } -type DaemonService_ExposeServiceClient interface { - Recv() (*ExposeServiceEvent, error) - grpc.ClientStream -} - -type daemonServiceExposeServiceClient struct { - grpc.ClientStream -} - -func (x *daemonServiceExposeServiceClient) Recv() (*ExposeServiceEvent, error) { - m := new(ExposeServiceEvent) - if err := x.ClientStream.RecvMsg(m); err != nil { - return nil, err - } - return m, nil -} +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type DaemonService_ExposeServiceClient = grpc.ServerStreamingClient[ExposeServiceEvent] // DaemonServiceServer is the server API for DaemonService service. // All implementations must embed UnimplementedDaemonServiceServer -// for forward compatibility +// for forward compatibility. type DaemonServiceServer interface { // Login uses setup key to prepare configuration for the daemon. Login(context.Context, *LoginRequest) (*LoginResponse, error) @@ -509,7 +600,15 @@ type DaemonServiceServer interface { // SetSyncResponsePersistence enables or disables sync response persistence SetSyncResponsePersistence(context.Context, *SetSyncResponsePersistenceRequest) (*SetSyncResponsePersistenceResponse, error) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) - SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error + // StartCapture begins streaming packet capture on the WireGuard interface. + // Requires --enable-capture set at service install/reconfigure time. + StartCapture(*StartCaptureRequest, grpc.ServerStreamingServer[CapturePacket]) error + // StartBundleCapture begins capturing packets to a server-side temp file + // for inclusion in the next debug bundle. Auto-stops after the given timeout. + StartBundleCapture(context.Context, *StartBundleCaptureRequest) (*StartBundleCaptureResponse, error) + // StopBundleCapture stops the running bundle capture. Idempotent. + StopBundleCapture(context.Context, *StopBundleCaptureRequest) (*StopBundleCaptureResponse, error) + SubscribeEvents(*SubscribeRequest, grpc.ServerStreamingServer[SystemEvent]) error GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error) SwitchProfile(context.Context, *SwitchProfileRequest) (*SwitchProfileResponse, error) SetConfig(context.Context, *SetConfigRequest) (*SetConfigResponse, error) @@ -533,129 +632,138 @@ type DaemonServiceServer interface { StartCPUProfile(context.Context, *StartCPUProfileRequest) (*StartCPUProfileResponse, error) // StopCPUProfile stops CPU profiling in the daemon StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error) - NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) // ExposeService exposes a local port via the NetBird reverse proxy - ExposeService(*ExposeServiceRequest, DaemonService_ExposeServiceServer) error + ExposeService(*ExposeServiceRequest, grpc.ServerStreamingServer[ExposeServiceEvent]) error mustEmbedUnimplementedDaemonServiceServer() } -// UnimplementedDaemonServiceServer must be embedded to have forward compatible implementations. -type UnimplementedDaemonServiceServer struct { -} +// UnimplementedDaemonServiceServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedDaemonServiceServer struct{} func (UnimplementedDaemonServiceServer) Login(context.Context, *LoginRequest) (*LoginResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method Login not implemented") + return nil, status.Error(codes.Unimplemented, "method Login not implemented") } func (UnimplementedDaemonServiceServer) WaitSSOLogin(context.Context, *WaitSSOLoginRequest) (*WaitSSOLoginResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method WaitSSOLogin not implemented") + return nil, status.Error(codes.Unimplemented, "method WaitSSOLogin not implemented") } func (UnimplementedDaemonServiceServer) Up(context.Context, *UpRequest) (*UpResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method Up not implemented") + return nil, status.Error(codes.Unimplemented, "method Up not implemented") } func (UnimplementedDaemonServiceServer) Status(context.Context, *StatusRequest) (*StatusResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method Status not implemented") + return nil, status.Error(codes.Unimplemented, "method Status not implemented") } func (UnimplementedDaemonServiceServer) Down(context.Context, *DownRequest) (*DownResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method Down not implemented") + return nil, status.Error(codes.Unimplemented, "method Down not implemented") } func (UnimplementedDaemonServiceServer) GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetConfig not implemented") + return nil, status.Error(codes.Unimplemented, "method GetConfig not implemented") } func (UnimplementedDaemonServiceServer) ListNetworks(context.Context, *ListNetworksRequest) (*ListNetworksResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method ListNetworks not implemented") + return nil, status.Error(codes.Unimplemented, "method ListNetworks not implemented") } func (UnimplementedDaemonServiceServer) SelectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method SelectNetworks not implemented") + return nil, status.Error(codes.Unimplemented, "method SelectNetworks not implemented") } func (UnimplementedDaemonServiceServer) DeselectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method DeselectNetworks not implemented") + return nil, status.Error(codes.Unimplemented, "method DeselectNetworks not implemented") } func (UnimplementedDaemonServiceServer) ForwardingRules(context.Context, *EmptyRequest) (*ForwardingRulesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method ForwardingRules not implemented") + return nil, status.Error(codes.Unimplemented, "method ForwardingRules not implemented") } func (UnimplementedDaemonServiceServer) DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method DebugBundle not implemented") + return nil, status.Error(codes.Unimplemented, "method DebugBundle not implemented") } func (UnimplementedDaemonServiceServer) GetLogLevel(context.Context, *GetLogLevelRequest) (*GetLogLevelResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetLogLevel not implemented") + return nil, status.Error(codes.Unimplemented, "method GetLogLevel not implemented") } func (UnimplementedDaemonServiceServer) SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method SetLogLevel not implemented") + return nil, status.Error(codes.Unimplemented, "method SetLogLevel not implemented") } func (UnimplementedDaemonServiceServer) ListStates(context.Context, *ListStatesRequest) (*ListStatesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method ListStates not implemented") + return nil, status.Error(codes.Unimplemented, "method ListStates not implemented") } func (UnimplementedDaemonServiceServer) CleanState(context.Context, *CleanStateRequest) (*CleanStateResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method CleanState not implemented") + return nil, status.Error(codes.Unimplemented, "method CleanState not implemented") } func (UnimplementedDaemonServiceServer) DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method DeleteState not implemented") + return nil, status.Error(codes.Unimplemented, "method DeleteState not implemented") } func (UnimplementedDaemonServiceServer) SetSyncResponsePersistence(context.Context, *SetSyncResponsePersistenceRequest) (*SetSyncResponsePersistenceResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method SetSyncResponsePersistence not implemented") + return nil, status.Error(codes.Unimplemented, "method SetSyncResponsePersistence not implemented") } func (UnimplementedDaemonServiceServer) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method TracePacket not implemented") + return nil, status.Error(codes.Unimplemented, "method TracePacket not implemented") } -func (UnimplementedDaemonServiceServer) SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error { - return status.Errorf(codes.Unimplemented, "method SubscribeEvents not implemented") +func (UnimplementedDaemonServiceServer) StartCapture(*StartCaptureRequest, grpc.ServerStreamingServer[CapturePacket]) error { + return status.Error(codes.Unimplemented, "method StartCapture not implemented") +} +func (UnimplementedDaemonServiceServer) StartBundleCapture(context.Context, *StartBundleCaptureRequest) (*StartBundleCaptureResponse, error) { + return nil, status.Error(codes.Unimplemented, "method StartBundleCapture not implemented") +} +func (UnimplementedDaemonServiceServer) StopBundleCapture(context.Context, *StopBundleCaptureRequest) (*StopBundleCaptureResponse, error) { + return nil, status.Error(codes.Unimplemented, "method StopBundleCapture not implemented") +} +func (UnimplementedDaemonServiceServer) SubscribeEvents(*SubscribeRequest, grpc.ServerStreamingServer[SystemEvent]) error { + return status.Error(codes.Unimplemented, "method SubscribeEvents not implemented") } func (UnimplementedDaemonServiceServer) GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetEvents not implemented") + return nil, status.Error(codes.Unimplemented, "method GetEvents not implemented") } func (UnimplementedDaemonServiceServer) SwitchProfile(context.Context, *SwitchProfileRequest) (*SwitchProfileResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method SwitchProfile not implemented") + return nil, status.Error(codes.Unimplemented, "method SwitchProfile not implemented") } func (UnimplementedDaemonServiceServer) SetConfig(context.Context, *SetConfigRequest) (*SetConfigResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method SetConfig not implemented") + return nil, status.Error(codes.Unimplemented, "method SetConfig not implemented") } func (UnimplementedDaemonServiceServer) AddProfile(context.Context, *AddProfileRequest) (*AddProfileResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method AddProfile not implemented") + return nil, status.Error(codes.Unimplemented, "method AddProfile not implemented") } func (UnimplementedDaemonServiceServer) RemoveProfile(context.Context, *RemoveProfileRequest) (*RemoveProfileResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method RemoveProfile not implemented") + return nil, status.Error(codes.Unimplemented, "method RemoveProfile not implemented") } func (UnimplementedDaemonServiceServer) ListProfiles(context.Context, *ListProfilesRequest) (*ListProfilesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method ListProfiles not implemented") + return nil, status.Error(codes.Unimplemented, "method ListProfiles not implemented") } func (UnimplementedDaemonServiceServer) GetActiveProfile(context.Context, *GetActiveProfileRequest) (*GetActiveProfileResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetActiveProfile not implemented") + return nil, status.Error(codes.Unimplemented, "method GetActiveProfile not implemented") } func (UnimplementedDaemonServiceServer) Logout(context.Context, *LogoutRequest) (*LogoutResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method Logout not implemented") + return nil, status.Error(codes.Unimplemented, "method Logout not implemented") } func (UnimplementedDaemonServiceServer) GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetFeatures not implemented") + return nil, status.Error(codes.Unimplemented, "method GetFeatures not implemented") } func (UnimplementedDaemonServiceServer) TriggerUpdate(context.Context, *TriggerUpdateRequest) (*TriggerUpdateResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method TriggerUpdate not implemented") + return nil, status.Error(codes.Unimplemented, "method TriggerUpdate not implemented") } func (UnimplementedDaemonServiceServer) GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetPeerSSHHostKey not implemented") + return nil, status.Error(codes.Unimplemented, "method GetPeerSSHHostKey not implemented") } func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method RequestJWTAuth not implemented") + return nil, status.Error(codes.Unimplemented, "method RequestJWTAuth not implemented") } func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented") + return nil, status.Error(codes.Unimplemented, "method WaitJWTToken not implemented") } func (UnimplementedDaemonServiceServer) StartCPUProfile(context.Context, *StartCPUProfileRequest) (*StartCPUProfileResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method StartCPUProfile not implemented") + return nil, status.Error(codes.Unimplemented, "method StartCPUProfile not implemented") } func (UnimplementedDaemonServiceServer) StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method StopCPUProfile not implemented") -} -func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented") + return nil, status.Error(codes.Unimplemented, "method StopCPUProfile not implemented") } func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetInstallerResult not implemented") + return nil, status.Error(codes.Unimplemented, "method GetInstallerResult not implemented") } -func (UnimplementedDaemonServiceServer) ExposeService(*ExposeServiceRequest, DaemonService_ExposeServiceServer) error { - return status.Errorf(codes.Unimplemented, "method ExposeService not implemented") +func (UnimplementedDaemonServiceServer) ExposeService(*ExposeServiceRequest, grpc.ServerStreamingServer[ExposeServiceEvent]) error { + return status.Error(codes.Unimplemented, "method ExposeService not implemented") } func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} +func (UnimplementedDaemonServiceServer) testEmbeddedByValue() {} // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. // Use of this interface is not recommended, as added methods to DaemonServiceServer will @@ -665,6 +773,13 @@ type UnsafeDaemonServiceServer interface { } func RegisterDaemonServiceServer(s grpc.ServiceRegistrar, srv DaemonServiceServer) { + // If the following call panics, it indicates UnimplementedDaemonServiceServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } s.RegisterService(&DaemonService_ServiceDesc, srv) } @@ -678,7 +793,7 @@ func _DaemonService_Login_Handler(srv interface{}, ctx context.Context, dec func } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/Login", + FullMethod: DaemonService_Login_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).Login(ctx, req.(*LoginRequest)) @@ -696,7 +811,7 @@ func _DaemonService_WaitSSOLogin_Handler(srv interface{}, ctx context.Context, d } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/WaitSSOLogin", + FullMethod: DaemonService_WaitSSOLogin_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).WaitSSOLogin(ctx, req.(*WaitSSOLoginRequest)) @@ -714,7 +829,7 @@ func _DaemonService_Up_Handler(srv interface{}, ctx context.Context, dec func(in } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/Up", + FullMethod: DaemonService_Up_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).Up(ctx, req.(*UpRequest)) @@ -732,7 +847,7 @@ func _DaemonService_Status_Handler(srv interface{}, ctx context.Context, dec fun } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/Status", + FullMethod: DaemonService_Status_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).Status(ctx, req.(*StatusRequest)) @@ -750,7 +865,7 @@ func _DaemonService_Down_Handler(srv interface{}, ctx context.Context, dec func( } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/Down", + FullMethod: DaemonService_Down_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).Down(ctx, req.(*DownRequest)) @@ -768,7 +883,7 @@ func _DaemonService_GetConfig_Handler(srv interface{}, ctx context.Context, dec } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/GetConfig", + FullMethod: DaemonService_GetConfig_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).GetConfig(ctx, req.(*GetConfigRequest)) @@ -786,7 +901,7 @@ func _DaemonService_ListNetworks_Handler(srv interface{}, ctx context.Context, d } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/ListNetworks", + FullMethod: DaemonService_ListNetworks_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).ListNetworks(ctx, req.(*ListNetworksRequest)) @@ -804,7 +919,7 @@ func _DaemonService_SelectNetworks_Handler(srv interface{}, ctx context.Context, } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/SelectNetworks", + FullMethod: DaemonService_SelectNetworks_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).SelectNetworks(ctx, req.(*SelectNetworksRequest)) @@ -822,7 +937,7 @@ func _DaemonService_DeselectNetworks_Handler(srv interface{}, ctx context.Contex } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/DeselectNetworks", + FullMethod: DaemonService_DeselectNetworks_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).DeselectNetworks(ctx, req.(*SelectNetworksRequest)) @@ -840,7 +955,7 @@ func _DaemonService_ForwardingRules_Handler(srv interface{}, ctx context.Context } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/ForwardingRules", + FullMethod: DaemonService_ForwardingRules_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).ForwardingRules(ctx, req.(*EmptyRequest)) @@ -858,7 +973,7 @@ func _DaemonService_DebugBundle_Handler(srv interface{}, ctx context.Context, de } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/DebugBundle", + FullMethod: DaemonService_DebugBundle_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).DebugBundle(ctx, req.(*DebugBundleRequest)) @@ -876,7 +991,7 @@ func _DaemonService_GetLogLevel_Handler(srv interface{}, ctx context.Context, de } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/GetLogLevel", + FullMethod: DaemonService_GetLogLevel_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).GetLogLevel(ctx, req.(*GetLogLevelRequest)) @@ -894,7 +1009,7 @@ func _DaemonService_SetLogLevel_Handler(srv interface{}, ctx context.Context, de } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/SetLogLevel", + FullMethod: DaemonService_SetLogLevel_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).SetLogLevel(ctx, req.(*SetLogLevelRequest)) @@ -912,7 +1027,7 @@ func _DaemonService_ListStates_Handler(srv interface{}, ctx context.Context, dec } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/ListStates", + FullMethod: DaemonService_ListStates_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).ListStates(ctx, req.(*ListStatesRequest)) @@ -930,7 +1045,7 @@ func _DaemonService_CleanState_Handler(srv interface{}, ctx context.Context, dec } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/CleanState", + FullMethod: DaemonService_CleanState_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).CleanState(ctx, req.(*CleanStateRequest)) @@ -948,7 +1063,7 @@ func _DaemonService_DeleteState_Handler(srv interface{}, ctx context.Context, de } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/DeleteState", + FullMethod: DaemonService_DeleteState_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).DeleteState(ctx, req.(*DeleteStateRequest)) @@ -966,7 +1081,7 @@ func _DaemonService_SetSyncResponsePersistence_Handler(srv interface{}, ctx cont } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/SetSyncResponsePersistence", + FullMethod: DaemonService_SetSyncResponsePersistence_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).SetSyncResponsePersistence(ctx, req.(*SetSyncResponsePersistenceRequest)) @@ -984,7 +1099,7 @@ func _DaemonService_TracePacket_Handler(srv interface{}, ctx context.Context, de } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/TracePacket", + FullMethod: DaemonService_TracePacket_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).TracePacket(ctx, req.(*TracePacketRequest)) @@ -992,26 +1107,63 @@ func _DaemonService_TracePacket_Handler(srv interface{}, ctx context.Context, de return interceptor(ctx, in, info, handler) } +func _DaemonService_StartCapture_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(StartCaptureRequest) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(DaemonServiceServer).StartCapture(m, &grpc.GenericServerStream[StartCaptureRequest, CapturePacket]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type DaemonService_StartCaptureServer = grpc.ServerStreamingServer[CapturePacket] + +func _DaemonService_StartBundleCapture_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(StartBundleCaptureRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).StartBundleCapture(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: DaemonService_StartBundleCapture_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).StartBundleCapture(ctx, req.(*StartBundleCaptureRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_StopBundleCapture_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(StopBundleCaptureRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).StopBundleCapture(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: DaemonService_StopBundleCapture_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).StopBundleCapture(ctx, req.(*StopBundleCaptureRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _DaemonService_SubscribeEvents_Handler(srv interface{}, stream grpc.ServerStream) error { m := new(SubscribeRequest) if err := stream.RecvMsg(m); err != nil { return err } - return srv.(DaemonServiceServer).SubscribeEvents(m, &daemonServiceSubscribeEventsServer{stream}) + return srv.(DaemonServiceServer).SubscribeEvents(m, &grpc.GenericServerStream[SubscribeRequest, SystemEvent]{ServerStream: stream}) } -type DaemonService_SubscribeEventsServer interface { - Send(*SystemEvent) error - grpc.ServerStream -} - -type daemonServiceSubscribeEventsServer struct { - grpc.ServerStream -} - -func (x *daemonServiceSubscribeEventsServer) Send(m *SystemEvent) error { - return x.ServerStream.SendMsg(m) -} +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type DaemonService_SubscribeEventsServer = grpc.ServerStreamingServer[SystemEvent] func _DaemonService_GetEvents_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(GetEventsRequest) @@ -1023,7 +1175,7 @@ func _DaemonService_GetEvents_Handler(srv interface{}, ctx context.Context, dec } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/GetEvents", + FullMethod: DaemonService_GetEvents_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).GetEvents(ctx, req.(*GetEventsRequest)) @@ -1041,7 +1193,7 @@ func _DaemonService_SwitchProfile_Handler(srv interface{}, ctx context.Context, } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/SwitchProfile", + FullMethod: DaemonService_SwitchProfile_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).SwitchProfile(ctx, req.(*SwitchProfileRequest)) @@ -1059,7 +1211,7 @@ func _DaemonService_SetConfig_Handler(srv interface{}, ctx context.Context, dec } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/SetConfig", + FullMethod: DaemonService_SetConfig_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).SetConfig(ctx, req.(*SetConfigRequest)) @@ -1077,7 +1229,7 @@ func _DaemonService_AddProfile_Handler(srv interface{}, ctx context.Context, dec } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/AddProfile", + FullMethod: DaemonService_AddProfile_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).AddProfile(ctx, req.(*AddProfileRequest)) @@ -1095,7 +1247,7 @@ func _DaemonService_RemoveProfile_Handler(srv interface{}, ctx context.Context, } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/RemoveProfile", + FullMethod: DaemonService_RemoveProfile_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).RemoveProfile(ctx, req.(*RemoveProfileRequest)) @@ -1113,7 +1265,7 @@ func _DaemonService_ListProfiles_Handler(srv interface{}, ctx context.Context, d } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/ListProfiles", + FullMethod: DaemonService_ListProfiles_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).ListProfiles(ctx, req.(*ListProfilesRequest)) @@ -1131,7 +1283,7 @@ func _DaemonService_GetActiveProfile_Handler(srv interface{}, ctx context.Contex } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/GetActiveProfile", + FullMethod: DaemonService_GetActiveProfile_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).GetActiveProfile(ctx, req.(*GetActiveProfileRequest)) @@ -1149,7 +1301,7 @@ func _DaemonService_Logout_Handler(srv interface{}, ctx context.Context, dec fun } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/Logout", + FullMethod: DaemonService_Logout_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).Logout(ctx, req.(*LogoutRequest)) @@ -1167,7 +1319,7 @@ func _DaemonService_GetFeatures_Handler(srv interface{}, ctx context.Context, de } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/GetFeatures", + FullMethod: DaemonService_GetFeatures_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).GetFeatures(ctx, req.(*GetFeaturesRequest)) @@ -1185,7 +1337,7 @@ func _DaemonService_TriggerUpdate_Handler(srv interface{}, ctx context.Context, } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/TriggerUpdate", + FullMethod: DaemonService_TriggerUpdate_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).TriggerUpdate(ctx, req.(*TriggerUpdateRequest)) @@ -1203,7 +1355,7 @@ func _DaemonService_GetPeerSSHHostKey_Handler(srv interface{}, ctx context.Conte } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/GetPeerSSHHostKey", + FullMethod: DaemonService_GetPeerSSHHostKey_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).GetPeerSSHHostKey(ctx, req.(*GetPeerSSHHostKeyRequest)) @@ -1221,7 +1373,7 @@ func _DaemonService_RequestJWTAuth_Handler(srv interface{}, ctx context.Context, } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/RequestJWTAuth", + FullMethod: DaemonService_RequestJWTAuth_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).RequestJWTAuth(ctx, req.(*RequestJWTAuthRequest)) @@ -1239,7 +1391,7 @@ func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, d } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/WaitJWTToken", + FullMethod: DaemonService_WaitJWTToken_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).WaitJWTToken(ctx, req.(*WaitJWTTokenRequest)) @@ -1257,7 +1409,7 @@ func _DaemonService_StartCPUProfile_Handler(srv interface{}, ctx context.Context } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/StartCPUProfile", + FullMethod: DaemonService_StartCPUProfile_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).StartCPUProfile(ctx, req.(*StartCPUProfileRequest)) @@ -1275,7 +1427,7 @@ func _DaemonService_StopCPUProfile_Handler(srv interface{}, ctx context.Context, } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/StopCPUProfile", + FullMethod: DaemonService_StopCPUProfile_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).StopCPUProfile(ctx, req.(*StopCPUProfileRequest)) @@ -1283,24 +1435,6 @@ func _DaemonService_StopCPUProfile_Handler(srv interface{}, ctx context.Context, return interceptor(ctx, in, info, handler) } -func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(OSLifecycleRequest) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(DaemonServiceServer).NotifyOSLifecycle(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/daemon.DaemonService/NotifyOSLifecycle", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(DaemonServiceServer).NotifyOSLifecycle(ctx, req.(*OSLifecycleRequest)) - } - return interceptor(ctx, in, info, handler) -} - func _DaemonService_GetInstallerResult_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(InstallerResultRequest) if err := dec(in); err != nil { @@ -1311,7 +1445,7 @@ func _DaemonService_GetInstallerResult_Handler(srv interface{}, ctx context.Cont } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/GetInstallerResult", + FullMethod: DaemonService_GetInstallerResult_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).GetInstallerResult(ctx, req.(*InstallerResultRequest)) @@ -1324,21 +1458,11 @@ func _DaemonService_ExposeService_Handler(srv interface{}, stream grpc.ServerStr if err := stream.RecvMsg(m); err != nil { return err } - return srv.(DaemonServiceServer).ExposeService(m, &daemonServiceExposeServiceServer{stream}) + return srv.(DaemonServiceServer).ExposeService(m, &grpc.GenericServerStream[ExposeServiceRequest, ExposeServiceEvent]{ServerStream: stream}) } -type DaemonService_ExposeServiceServer interface { - Send(*ExposeServiceEvent) error - grpc.ServerStream -} - -type daemonServiceExposeServiceServer struct { - grpc.ServerStream -} - -func (x *daemonServiceExposeServiceServer) Send(m *ExposeServiceEvent) error { - return x.ServerStream.SendMsg(m) -} +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type DaemonService_ExposeServiceServer = grpc.ServerStreamingServer[ExposeServiceEvent] // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // It's only intended for direct use with grpc.RegisterService, @@ -1419,6 +1543,14 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "TracePacket", Handler: _DaemonService_TracePacket_Handler, }, + { + MethodName: "StartBundleCapture", + Handler: _DaemonService_StartBundleCapture_Handler, + }, + { + MethodName: "StopBundleCapture", + Handler: _DaemonService_StopBundleCapture_Handler, + }, { MethodName: "GetEvents", Handler: _DaemonService_GetEvents_Handler, @@ -1479,16 +1611,17 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "StopCPUProfile", Handler: _DaemonService_StopCPUProfile_Handler, }, - { - MethodName: "NotifyOSLifecycle", - Handler: _DaemonService_NotifyOSLifecycle_Handler, - }, { MethodName: "GetInstallerResult", Handler: _DaemonService_GetInstallerResult_Handler, }, }, Streams: []grpc.StreamDesc{ + { + StreamName: "StartCapture", + Handler: _DaemonService_StartCapture_Handler, + ServerStreams: true, + }, { StreamName: "SubscribeEvents", Handler: _DaemonService_SubscribeEvents_Handler, diff --git a/client/server/capture.go b/client/server/capture.go new file mode 100644 index 000000000..308c00338 --- /dev/null +++ b/client/server/capture.go @@ -0,0 +1,365 @@ +package server + +import ( + "context" + "io" + "os" + "sync" + "time" + + log "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/util/capture" +) + +const maxBundleCaptureDuration = 10 * time.Minute + +// bundleCapture holds the state of an in-progress capture destined for the +// debug bundle. The lifecycle is: +// +// StartBundleCapture → capture running, writing to temp file +// StopBundleCapture → capture stopped, temp file available +// DebugBundle → temp file included in zip, then cleaned up +type bundleCapture struct { + mu sync.Mutex + sess *capture.Session + file *os.File + engine *internal.Engine + cancel context.CancelFunc + stopped bool +} + +// stop halts the capture session and closes the pcap writer. Idempotent. +func (bc *bundleCapture) stop() { + bc.mu.Lock() + defer bc.mu.Unlock() + + if bc.stopped { + return + } + bc.stopped = true + + if bc.cancel != nil { + bc.cancel() + } + if bc.sess != nil { + bc.sess.Stop() + } +} + +// path returns the temp file path, or "" if no file exists. +func (bc *bundleCapture) path() string { + if bc.file == nil { + return "" + } + return bc.file.Name() +} + +// cleanup removes the temp file. +func (bc *bundleCapture) cleanup() { + if bc.file == nil { + return + } + name := bc.file.Name() + if err := bc.file.Close(); err != nil { + log.Debugf("close bundle capture file: %v", err) + } + if err := os.Remove(name); err != nil && !os.IsNotExist(err) { + log.Debugf("remove bundle capture file: %v", err) + } + bc.file = nil +} + +// StartCapture streams a pcap or text packet capture over gRPC. +// Gated by the --enable-capture service flag. +func (s *Server) StartCapture(req *proto.StartCaptureRequest, stream proto.DaemonService_StartCaptureServer) error { + if !s.captureEnabled { + return status.Error(codes.PermissionDenied, + "packet capture is disabled; reinstall or reconfigure the service with --enable-capture") + } + + if d := req.GetDuration(); d != nil && d.AsDuration() < 0 { + return status.Error(codes.InvalidArgument, "duration must not be negative") + } + + matcher, err := parseCaptureFilter(req) + if err != nil { + return status.Errorf(codes.InvalidArgument, "invalid filter: %v", err) + } + + pr, pw := io.Pipe() + + opts := capture.Options{ + Matcher: matcher, + SnapLen: req.GetSnapLen(), + Verbose: req.GetVerbose(), + ASCII: req.GetAscii(), + } + if req.GetTextOutput() { + opts.TextOutput = pw + } else { + opts.Output = pw + } + + sess, err := capture.NewSession(opts) + if err != nil { + pw.Close() + return status.Errorf(codes.Internal, "create capture session: %v", err) + } + + engine, err := s.claimCapture(sess) + if err != nil { + sess.Stop() + pw.Close() + return err + } + + if err := engine.SetCapture(sess); err != nil { + s.releaseCapture(sess) + sess.Stop() + pw.Close() + return status.Errorf(codes.Internal, "set capture: %v", err) + } + + // Send an empty initial message to signal that the capture was accepted. + // The client waits for this before printing the banner, so it must arrive + // before any packet data. + if err := stream.Send(&proto.CapturePacket{}); err != nil { + s.clearCaptureIfOwner(sess, engine) + sess.Stop() + pw.Close() + return status.Errorf(codes.Internal, "send initial message: %v", err) + } + + ctx := stream.Context() + if d := req.GetDuration(); d != nil { + if dur := d.AsDuration(); dur > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, dur) + defer cancel() + } + } + + go func() { + <-ctx.Done() + s.clearCaptureIfOwner(sess, engine) + sess.Stop() + pw.Close() + }() + defer pr.Close() + + log.Infof("packet capture started (text=%v, expr=%q)", req.GetTextOutput(), req.GetFilterExpr()) + defer func() { + stats := sess.Stats() + log.Infof("packet capture stopped: %d packets, %d bytes, %d dropped", + stats.Packets, stats.Bytes, stats.Dropped) + }() + + return streamToGRPC(pr, stream) +} + +func streamToGRPC(r io.Reader, stream proto.DaemonService_StartCaptureServer) error { + buf := make([]byte, 32*1024) + for { + n, readErr := r.Read(buf) + if n > 0 { + if err := stream.Send(&proto.CapturePacket{Data: buf[:n]}); err != nil { + log.Debugf("capture stream send: %v", err) + return nil //nolint:nilerr // client disconnected + } + } + if readErr != nil { + return nil //nolint:nilerr // pipe closed, capture stopped normally + } + } +} + +// StartBundleCapture begins capturing packets to a server-side temp file for +// inclusion in the next debug bundle. Not gated by --enable-capture since the +// output stays on the server (same trust level as CPU profiling). +// +// A timeout auto-stops the capture as a safety net if StopBundleCapture is +// never called (e.g. CLI crash). +func (s *Server) StartBundleCapture(_ context.Context, req *proto.StartBundleCaptureRequest) (*proto.StartBundleCaptureResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + s.stopBundleCaptureLocked() + s.cleanupBundleCapture() + + if s.activeCapture != nil { + return nil, status.Error(codes.FailedPrecondition, "another capture is already running") + } + + engine, err := s.getCaptureEngineLocked() + if err != nil { + // Not fatal: kernel mode or not connected. Log and return success + // so the debug bundle still generates without capture data. + log.Warnf("packet capture unavailable, skipping: %v", err) + return &proto.StartBundleCaptureResponse{}, nil + } + + timeout := req.GetTimeout().AsDuration() + if timeout <= 0 || timeout > maxBundleCaptureDuration { + timeout = maxBundleCaptureDuration + } + + f, err := os.CreateTemp("", "netbird.capture.*.pcap") + if err != nil { + return nil, status.Errorf(codes.Internal, "create temp file: %v", err) + } + + sess, err := capture.NewSession(capture.Options{Output: f}) + if err != nil { + f.Close() + os.Remove(f.Name()) + return nil, status.Errorf(codes.Internal, "create capture session: %v", err) + } + + if err := engine.SetCapture(sess); err != nil { + sess.Stop() + f.Close() + os.Remove(f.Name()) + log.Warnf("packet capture unavailable (no filtered device), skipping: %v", err) + return &proto.StartBundleCaptureResponse{}, nil + } + s.activeCapture = sess + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + bc := &bundleCapture{ + sess: sess, + file: f, + engine: engine, + cancel: cancel, + } + + s.bundleCapture = bc + + go func() { + <-ctx.Done() + s.mutex.Lock() + if s.bundleCapture == bc { + s.stopBundleCaptureLocked() + } else { + bc.stop() + } + s.mutex.Unlock() + log.Infof("bundle capture auto-stopped after timeout") + }() + log.Infof("bundle capture started (timeout=%s, file=%s)", timeout, f.Name()) + + return &proto.StartBundleCaptureResponse{}, nil +} + +// StopBundleCapture stops the running bundle capture. Idempotent. +func (s *Server) StopBundleCapture(_ context.Context, _ *proto.StopBundleCaptureRequest) (*proto.StopBundleCaptureResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + s.stopBundleCaptureLocked() + return &proto.StopBundleCaptureResponse{}, nil +} + +// stopBundleCaptureLocked stops the bundle capture if running. Must hold s.mutex. +func (s *Server) stopBundleCaptureLocked() { + if s.bundleCapture == nil { + return + } + bc := s.bundleCapture + if bc.engine != nil && s.activeCapture == bc.sess { + if err := bc.engine.SetCapture(nil); err != nil { + log.Debugf("clear bundle capture: %v", err) + } + s.activeCapture = nil + } + bc.stop() + + stats := bc.sess.Stats() + log.Infof("bundle capture stopped: %d packets, %d bytes, %d dropped", + stats.Packets, stats.Bytes, stats.Dropped) +} + +// bundleCapturePath returns the temp file path if a capture has been taken, +// stops any running capture, and returns "". Called from DebugBundle. +// Must hold s.mutex. +func (s *Server) bundleCapturePath() string { + if s.bundleCapture == nil { + return "" + } + + s.bundleCapture.stop() + return s.bundleCapture.path() +} + +// cleanupBundleCapture removes the temp file and clears state. Must hold s.mutex. +func (s *Server) cleanupBundleCapture() { + if s.bundleCapture == nil { + return + } + s.bundleCapture.cleanup() + s.bundleCapture = nil +} + +// claimCapture reserves the engine's capture slot for sess. Returns +// FailedPrecondition if another capture is already active. +func (s *Server) claimCapture(sess *capture.Session) (*internal.Engine, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.activeCapture != nil { + return nil, status.Error(codes.FailedPrecondition, "another capture is already running") + } + engine, err := s.getCaptureEngineLocked() + if err != nil { + return nil, err + } + s.activeCapture = sess + return engine, nil +} + +// releaseCapture clears the active-capture owner if it still matches sess. +func (s *Server) releaseCapture(sess *capture.Session) { + s.mutex.Lock() + defer s.mutex.Unlock() + if s.activeCapture == sess { + s.activeCapture = nil + } +} + +// clearCaptureIfOwner clears engine's capture slot only if sess still owns it. +func (s *Server) clearCaptureIfOwner(sess *capture.Session, engine *internal.Engine) { + s.mutex.Lock() + defer s.mutex.Unlock() + if s.activeCapture != sess { + return + } + if err := engine.SetCapture(nil); err != nil { + log.Debugf("clear capture: %v", err) + } + s.activeCapture = nil +} + +func (s *Server) getCaptureEngineLocked() (*internal.Engine, error) { + if s.connectClient == nil { + return nil, status.Error(codes.FailedPrecondition, "client not connected") + } + engine := s.connectClient.Engine() + if engine == nil { + return nil, status.Error(codes.FailedPrecondition, "engine not initialized") + } + return engine, nil +} + +// parseCaptureFilter returns a Matcher from the request. +// Returns nil (match all) when no filter expression is set. +func parseCaptureFilter(req *proto.StartCaptureRequest) (capture.Matcher, error) { + expr := req.GetFilterExpr() + if expr == "" { + return nil, nil //nolint:nilnil // nil Matcher means "match all" + } + return capture.ParseFilter(expr) +} diff --git a/client/server/debug.go b/client/server/debug.go index 81708e576..33247db5f 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -43,7 +43,9 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( }() } - // Prepare refresh callback for health probes + capturePath := s.bundleCapturePath() + defer s.cleanupBundleCapture() + var refreshStatus func() if s.connectClient != nil { engine := s.connectClient.Engine() @@ -62,6 +64,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( SyncResponse: syncResponse, LogPath: s.logFile, CPUProfile: cpuProfileData, + CapturePath: capturePath, RefreshStatus: refreshStatus, ClientMetrics: clientMetrics, }, diff --git a/client/server/network.go b/client/server/network.go index 76c5af40e..12cefbd9c 100644 --- a/client/server/network.go +++ b/client/server/network.go @@ -18,10 +18,11 @@ import ( ) type selectRoute struct { - NetID route.NetID - Network netip.Prefix - Domains domain.List - Selected bool + NetID route.NetID + Network netip.Prefix + Domains domain.List + Selected bool + extraNetworks []netip.Prefix } // ListNetworks returns a list of all available networks. @@ -50,18 +51,32 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro routesMap := routeMgr.GetClientRoutesWithNetID() routeSelector := routeMgr.GetRouteSelector() + v6ExitMerged := route.V6ExitMergeSet(routesMap) + var routes []*selectRoute for id, rt := range routesMap { if len(rt) == 0 { continue } - route := &selectRoute{ + // Skip v6 exit nodes that are merged into their v4 counterpart. + if _, ok := v6ExitMerged[id]; ok { + continue + } + + r := &selectRoute{ NetID: id, Network: rt[0].Network, Domains: rt[0].Domains, Selected: routeSelector.IsSelected(id), } - routes = append(routes, route) + + // Merge paired v6 exit node prefix into this entry. + v6ID := route.NetID(string(id) + route.V6ExitSuffix) + if _, ok := v6ExitMerged[v6ID]; ok && len(routesMap[v6ID]) > 0 { + r.extraNetworks = []netip.Prefix{routesMap[v6ID][0].Network} + } + + routes = append(routes, r) } sort.Slice(routes, func(i, j int) bool { @@ -82,9 +97,13 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro resolvedDomains := s.statusRecorder.GetResolvedDomainsStates() var pbRoutes []*proto.Network for _, route := range routes { + rangeStr := route.Network.String() + for _, extra := range route.extraNetworks { + rangeStr += ", " + extra.String() + } pbRoute := &proto.Network{ ID: string(route.NetID), - Range: route.Network.String(), + Range: rangeStr, Domains: route.Domains.ToSafeStringList(), ResolvedIPs: map[string]*proto.IPList{}, Selected: route.Selected, @@ -147,7 +166,9 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ routeSelector.SelectAllRoutes() } else { routes := toNetIDs(req.GetNetworkIDs()) - netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID()) + routesMap := routeManager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) + netIdRoutes := maps.Keys(routesMap) if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil { return nil, fmt.Errorf("select routes: %w", err) } @@ -197,7 +218,9 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe routeSelector.DeselectAllRoutes() } else { routes := toNetIDs(req.GetNetworkIDs()) - netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID()) + routesMap := routeManager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) + netIdRoutes := maps.Keys(routesMap) if err := routeSelector.DeselectRoutes(routes, netIdRoutes); err != nil { return nil, fmt.Errorf("deselect routes: %w", err) } diff --git a/client/server/server.go b/client/server/server.go index 70e4c342f..bc8de8f9f 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -33,6 +33,7 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/updater" "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/util/capture" "github.com/netbirdio/netbird/version" ) @@ -89,7 +90,11 @@ type Server struct { profileManager *profilemanager.ServiceManager profilesDisabled bool updateSettingsDisabled bool - networksDisabled bool + captureEnabled bool + bundleCapture *bundleCapture + // activeCapture is the session currently installed on the engine; guarded by s.mutex. + activeCapture *capture.Session + networksDisabled bool sleepHandler *sleephandler.SleepHandler @@ -106,7 +111,7 @@ type oauthAuthFlow struct { } // New server instance constructor. -func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool, networksDisabled bool) *Server { +func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool, captureEnabled bool, networksDisabled bool) *Server { s := &Server{ rootCtx: ctx, logFile: logFile, @@ -115,11 +120,13 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable profileManager: profilemanager.NewServiceManager(configFile), profilesDisabled: profilesDisabled, updateSettingsDisabled: updateSettingsDisabled, + captureEnabled: captureEnabled, networksDisabled: networksDisabled, jwtCache: newJWTCache(), } agent := &serverAgent{s} s.sleepHandler = sleephandler.New(agent) + s.startSleepDetector() return s } @@ -378,6 +385,7 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques config.DisableNotifications = msg.DisableNotifications config.LazyConnectionEnabled = msg.LazyConnectionEnabled config.BlockInbound = msg.BlockInbound + config.DisableIPv6 = msg.DisableIpv6 config.EnableSSHRoot = msg.EnableSSHRoot config.EnableSSHSFTP = msg.EnableSSHSFTP config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForwarding @@ -1476,6 +1484,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p disableDNS := cfg.DisableDNS disableClientRoutes := cfg.DisableClientRoutes disableServerRoutes := cfg.DisableServerRoutes + disableIPv6 := cfg.DisableIPv6 blockLANAccess := cfg.BlockLANAccess enableSSHRoot := false @@ -1526,6 +1535,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p DisableDns: disableDNS, DisableClientRoutes: disableClientRoutes, DisableServerRoutes: disableServerRoutes, + DisableIpv6: disableIPv6, BlockLanAccess: blockLANAccess, EnableSSHRoot: enableSSHRoot, EnableSSHSFTP: enableSSHSFTP, diff --git a/client/server/server_test.go b/client/server/server_test.go index 772997575..641cd85fe 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -104,7 +104,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "debug", "", false, false, false) + s := New(ctx, "debug", "", false, false, false, false) s.config = config @@ -165,7 +165,7 @@ func TestServer_Up(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "console", "", false, false, false) + s := New(ctx, "console", "", false, false, false, false) err = s.Start() require.NoError(t, err) @@ -235,7 +235,7 @@ func TestServer_SubcribeEvents(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "console", "", false, false, false) + s := New(ctx, "console", "", false, false, false, false) err = s.Start() require.NoError(t, err) @@ -335,7 +335,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve if err != nil { return nil, "", err } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil) if err != nil { return nil, "", err } diff --git a/client/server/setconfig_test.go b/client/server/setconfig_test.go index 7f6847c43..553d4ad71 100644 --- a/client/server/setconfig_test.go +++ b/client/server/setconfig_test.go @@ -53,7 +53,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) { require.NoError(t, err) ctx := context.Background() - s := New(ctx, "console", "", false, false, false) + s := New(ctx, "console", "", false, false, false, false) rosenpassEnabled := true rosenpassPermissive := true @@ -71,6 +71,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) { disableNotifications := true lazyConnectionEnabled := true blockInbound := true + disableIPv6 := true mtu := int64(1280) sshJWTCacheTTL := int32(300) @@ -95,6 +96,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) { DisableNotifications: &disableNotifications, LazyConnectionEnabled: &lazyConnectionEnabled, BlockInbound: &blockInbound, + DisableIpv6: &disableIPv6, NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"}, CleanNATExternalIPs: false, CustomDNSAddress: []byte("1.1.1.1:53"), @@ -140,6 +142,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) { require.Equal(t, disableNotifications, *cfg.DisableNotifications) require.Equal(t, lazyConnectionEnabled, cfg.LazyConnectionEnabled) require.Equal(t, blockInbound, cfg.BlockInbound) + require.Equal(t, disableIPv6, cfg.DisableIPv6) require.Equal(t, []string{"1.2.3.4", "5.6.7.8"}, cfg.NATExternalIPs) require.Equal(t, "1.1.1.1:53", cfg.CustomDNSAddress) // IFaceBlackList contains defaults + extras @@ -189,6 +192,7 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) { "DisableNotifications": true, "LazyConnectionEnabled": true, "BlockInbound": true, + "DisableIpv6": true, "NatExternalIPs": true, "CustomDNSAddress": true, "ExtraIFaceBlacklist": true, @@ -247,6 +251,7 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) { "disable-firewall": "DisableFirewall", "block-lan-access": "BlockLanAccess", "block-inbound": "BlockInbound", + "disable-ipv6": "DisableIpv6", "enable-lazy-connection": "LazyConnectionEnabled", "external-ip-map": "NatExternalIPs", "dns-resolver-address": "CustomDNSAddress", diff --git a/client/server/sleep.go b/client/server/sleep.go index 7a83c75a6..877ad9690 100644 --- a/client/server/sleep.go +++ b/client/server/sleep.go @@ -2,13 +2,18 @@ package server import ( "context" + "os" + "strconv" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/sleep" "github.com/netbirdio/netbird/client/proto" ) +const envDisableSleepDetector = "NB_DISABLE_SLEEP_DETECTOR" + // serverAgent adapts Server to the handler.Agent and handler.StatusChecker interfaces type serverAgent struct { s *Server @@ -28,19 +33,61 @@ func (a *serverAgent) Status() (internal.StatusType, error) { return internal.CtxGetState(a.s.rootCtx).Status() } -// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type. -func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) { - switch req.GetType() { - case proto.OSLifecycleRequest_WAKEUP: - if err := s.sleepHandler.HandleWakeUp(callerCtx); err != nil { - return &proto.OSLifecycleResponse{}, err - } - case proto.OSLifecycleRequest_SLEEP: - if err := s.sleepHandler.HandleSleep(callerCtx); err != nil { - return &proto.OSLifecycleResponse{}, err - } - default: - log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType()) +// startSleepDetector starts the OS sleep/wake detector and forwards events to +// the sleep handler. On platforms without a supported detector the attempt +// logs a warning and returns. Setting NB_DISABLE_SLEEP_DETECTOR=true skips +// registration entirely. +func (s *Server) startSleepDetector() { + if sleepDetectorDisabled() { + log.Info("sleep detection disabled via " + envDisableSleepDetector) + return } - return &proto.OSLifecycleResponse{}, nil + + svc, err := sleep.New() + if err != nil { + log.Warnf("failed to initialize sleep detection: %v", err) + return + } + + err = svc.Register(func(event sleep.EventType) { + switch event { + case sleep.EventTypeSleep: + log.Info("handling sleep event") + if err := s.sleepHandler.HandleSleep(s.rootCtx); err != nil { + log.Errorf("failed to handle sleep event: %v", err) + } + case sleep.EventTypeWakeUp: + log.Info("handling wakeup event") + if err := s.sleepHandler.HandleWakeUp(s.rootCtx); err != nil { + log.Errorf("failed to handle wakeup event: %v", err) + } + } + }) + if err != nil { + log.Errorf("failed to register sleep detector: %v", err) + return + } + + log.Info("sleep detection service initialized") + + go func() { + <-s.rootCtx.Done() + log.Info("stopping sleep event listener") + if err := svc.Deregister(); err != nil { + log.Errorf("failed to deregister sleep detector: %v", err) + } + }() +} + +func sleepDetectorDisabled() bool { + val := os.Getenv(envDisableSleepDetector) + if val == "" { + return false + } + disabled, err := strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s=%q: %v", envDisableSleepDetector, val, err) + return false + } + return disabled } diff --git a/client/server/trace.go b/client/server/trace.go index e4ac91487..7fea31c49 100644 --- a/client/server/trace.go +++ b/client/server/trace.go @@ -24,14 +24,9 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) ( return nil, err } - srcAddr, err := s.parseAddress(req.GetSourceIp(), engine) + srcAddr, dstAddr, err := s.resolveTraceAddresses(req.GetSourceIp(), req.GetDestinationIp(), engine) if err != nil { - return nil, fmt.Errorf("invalid source IP address: %w", err) - } - - dstAddr, err := s.parseAddress(req.GetDestinationIp(), engine) - if err != nil { - return nil, fmt.Errorf("invalid destination IP address: %w", err) + return nil, err } protocol, err := s.parseProtocol(req.GetProtocol()) @@ -89,16 +84,73 @@ func (s *Server) getPacketTracer() (packetTracer, *internal.Engine, error) { return tracer, engine, nil } -func (s *Server) parseAddress(addr string, engine *internal.Engine) (netip.Addr, error) { - if addr == "self" { - return engine.GetWgAddr(), nil +// resolveTraceAddresses parses src/dst, resolving "self" to the local overlay +// address matching the peer's address family. +func (s *Server) resolveTraceAddresses(src, dst string, engine *internal.Engine) (netip.Addr, netip.Addr, error) { + srcSelf := src == "self" + dstSelf := dst == "self" + + if srcSelf && dstSelf { + return netip.Addr{}, netip.Addr{}, fmt.Errorf("both source and destination cannot be 'self'") } + var srcAddr, dstAddr netip.Addr + var err error + + // Parse the non-self address first so we know the family for self resolution. + if !srcSelf { + if srcAddr, err = parseAddr(src); err != nil { + return netip.Addr{}, netip.Addr{}, fmt.Errorf("invalid source IP: %w", err) + } + } + if !dstSelf { + if dstAddr, err = parseAddr(dst); err != nil { + return netip.Addr{}, netip.Addr{}, fmt.Errorf("invalid destination IP: %w", err) + } + } + + // Determine the peer address to pick the right self address. + peer := srcAddr + if srcSelf { + peer = dstAddr + } + + if srcSelf { + if srcAddr, err = selfAddr(engine, peer); err != nil { + return netip.Addr{}, netip.Addr{}, err + } + } + if dstSelf { + if dstAddr, err = selfAddr(engine, peer); err != nil { + return netip.Addr{}, netip.Addr{}, err + } + } + + return srcAddr, dstAddr, nil +} + +func selfAddr(engine *internal.Engine, peer netip.Addr) (netip.Addr, error) { + var addr netip.Addr + if peer.Is6() { + addr = engine.GetWgV6Addr() + } else { + addr = engine.GetWgAddr() + } + if !addr.IsValid() { + family := "IPv4" + if peer.Is6() { + family = "IPv6" + } + return netip.Addr{}, fmt.Errorf("no local %s overlay address configured", family) + } + return addr, nil +} + +func parseAddr(addr string) (netip.Addr, error) { a, err := netip.ParseAddr(addr) if err != nil { return netip.Addr{}, err } - return a.Unmap(), nil } diff --git a/client/ssh/config/manager.go b/client/ssh/config/manager.go index 6e584b2c3..01822ead6 100644 --- a/client/ssh/config/manager.go +++ b/client/ssh/config/manager.go @@ -3,6 +3,7 @@ package config import ( "context" "fmt" + "net/netip" "os" "path/filepath" "runtime" @@ -91,7 +92,8 @@ type Manager struct { // PeerSSHInfo represents a peer's SSH configuration information type PeerSSHInfo struct { Hostname string - IP string + IP netip.Addr + IPv6 netip.Addr FQDN string } @@ -210,8 +212,11 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) { func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string { var hostPatterns []string - if peer.IP != "" { - hostPatterns = append(hostPatterns, peer.IP) + if peer.IP.IsValid() { + hostPatterns = append(hostPatterns, peer.IP.String()) + } + if peer.IPv6.IsValid() { + hostPatterns = append(hostPatterns, peer.IPv6.String()) } if peer.FQDN != "" { hostPatterns = append(hostPatterns, peer.FQDN) @@ -224,15 +229,20 @@ func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string { func (m *Manager) writeSSHConfig(sshConfig string) error { sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile) + sshConfigPathTmp := sshConfigPath + ".tmp" if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil { return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err) } - if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil { + if err := writeFileWithTimeout(sshConfigPathTmp, []byte(sshConfig), 0644); err != nil { return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err) } + if err := os.Rename(sshConfigPathTmp, sshConfigPath); err != nil { + return fmt.Errorf("rename ssh config %s -> %s: %w", sshConfigPathTmp, sshConfigPath, err) + } + log.Infof("Created NetBird SSH client config: %s", sshConfigPath) return nil } diff --git a/client/ssh/config/manager_test.go b/client/ssh/config/manager_test.go index e7380c7f2..8e6be40a3 100644 --- a/client/ssh/config/manager_test.go +++ b/client/ssh/config/manager_test.go @@ -2,6 +2,7 @@ package config import ( "fmt" + "net/netip" "os" "path/filepath" "runtime" @@ -28,12 +29,12 @@ func TestManager_SetupSSHClientConfig(t *testing.T) { peers := []PeerSSHInfo{ { Hostname: "peer1", - IP: "100.125.1.1", + IP: netip.MustParseAddr("100.125.1.1"), FQDN: "peer1.nb.internal", }, { Hostname: "peer2", - IP: "100.125.1.2", + IP: netip.MustParseAddr("100.125.1.2"), FQDN: "peer2.nb.internal", }, } @@ -101,7 +102,7 @@ func TestManager_PeerLimit(t *testing.T) { for i := 0; i < MaxPeersForSSHConfig+10; i++ { peers = append(peers, PeerSSHInfo{ Hostname: fmt.Sprintf("peer%d", i), - IP: fmt.Sprintf("100.125.1.%d", i%254+1), + IP: netip.MustParseAddr(fmt.Sprintf("100.125.1.%d", i%254+1)), FQDN: fmt.Sprintf("peer%d.nb.internal", i), }) } @@ -127,8 +128,8 @@ func TestManager_MatchHostFormat(t *testing.T) { } peers := []PeerSSHInfo{ - {Hostname: "peer1", IP: "100.125.1.1", FQDN: "peer1.nb.internal"}, - {Hostname: "peer2", IP: "100.125.1.2", FQDN: "peer2.nb.internal"}, + {Hostname: "peer1", IP: netip.MustParseAddr("100.125.1.1"), FQDN: "peer1.nb.internal"}, + {Hostname: "peer2", IP: netip.MustParseAddr("100.125.1.2"), FQDN: "peer2.nb.internal"}, } err = manager.SetupSSHClientConfig(peers) @@ -167,7 +168,7 @@ func TestManager_ForcedSSHConfig(t *testing.T) { for i := 0; i < MaxPeersForSSHConfig+10; i++ { peers = append(peers, PeerSSHInfo{ Hostname: fmt.Sprintf("peer%d", i), - IP: fmt.Sprintf("100.125.1.%d", i%254+1), + IP: netip.MustParseAddr(fmt.Sprintf("100.125.1.%d", i%254+1)), FQDN: fmt.Sprintf("peer%d.nb.internal", i), }) } diff --git a/client/ssh/proxy/proxy.go b/client/ssh/proxy/proxy.go index f6bc0d250..73b50122c 100644 --- a/client/ssh/proxy/proxy.go +++ b/client/ssh/proxy/proxy.go @@ -322,7 +322,7 @@ func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, ne return } - dest := fmt.Sprintf("%s:%d", payload.DestAddr, payload.DestPort) + dest := net.JoinHostPort(payload.DestAddr, strconv.Itoa(int(payload.DestPort))) log.Debugf("local port forwarding: %s", dest) backendClient, err := p.getOrCreateBackendClient(sshCtx, sshCtx.User()) diff --git a/client/ssh/server/port_forwarding.go b/client/ssh/server/port_forwarding.go index a819840a5..a47fdb48a 100644 --- a/client/ssh/server/port_forwarding.go +++ b/client/ssh/server/port_forwarding.go @@ -56,12 +56,12 @@ func (s *Server) configurePortForwarding(server *ssh.Server) { server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool { logger := s.getRequestLogger(ctx) if !allowLocal { - logger.Warnf("local port forwarding denied for %s:%d: disabled", dstHost, dstPort) + logger.Warnf("local port forwarding denied for %s: disabled", net.JoinHostPort(dstHost, strconv.Itoa(int(dstPort)))) return false } if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil { - logger.Warnf("local port forwarding denied for %s:%d: %v", dstHost, dstPort, err) + logger.Warnf("local port forwarding denied for %s: %v", net.JoinHostPort(dstHost, strconv.Itoa(int(dstPort))), err) return false } @@ -71,12 +71,12 @@ func (s *Server) configurePortForwarding(server *ssh.Server) { server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool { logger := s.getRequestLogger(ctx) if !allowRemote { - logger.Warnf("remote port forwarding denied for %s:%d: disabled", bindHost, bindPort) + logger.Warnf("remote port forwarding denied for %s: disabled", net.JoinHostPort(bindHost, strconv.Itoa(int(bindPort)))) return false } if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil { - logger.Warnf("remote port forwarding denied for %s:%d: %v", bindHost, bindPort, err) + logger.Warnf("remote port forwarding denied for %s: %v", net.JoinHostPort(bindHost, strconv.Itoa(int(bindPort))), err) return false } @@ -183,15 +183,16 @@ func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req * return false, nil } - key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port)) + hostPort := net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))) + key := forwardKey(hostPort) if s.removeRemoteForwardListener(key) { - forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, payload.Port) + forwardAddr := "-R " + hostPort s.removeConnectionPortForward(ctx.RemoteAddr(), forwardAddr) - logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port) + logger.Infof("remote port forwarding cancelled: %s", hostPort) return true, nil } - logger.Warnf("cancel-tcpip-forward failed: no listener found for %s:%d", payload.Host, payload.Port) + logger.Warnf("cancel-tcpip-forward failed: no listener found for %s", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port)))) return false, nil } @@ -201,7 +202,7 @@ func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, h defer func() { if err := ln.Close(); err != nil { - logger.Debugf("remote forward listener close error for %s:%d: %v", host, port, err) + logger.Debugf("remote forward listener close error for %s: %v", net.JoinHostPort(host, strconv.Itoa(int(port))), err) } }() @@ -230,7 +231,7 @@ func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, h } go s.handleRemoteForwardConnection(ctx, result.conn, host, port) case <-ctx.Done(): - logger.Debugf("remote forward listener shutting down for %s:%d", host, port) + logger.Debugf("remote forward listener shutting down for %s", net.JoinHostPort(host, strconv.Itoa(int(port)))) return } } @@ -311,17 +312,17 @@ func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host) } - key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port)) + key := forwardKey(net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port)))) s.storeRemoteForwardListener(key, ln) - forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, actualPort) + forwardAddr := "-R " + net.JoinHostPort(payload.Host, strconv.Itoa(int(actualPort))) s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr) go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort) response := make([]byte, 4) binary.BigEndian.PutUint32(response, actualPort) - logger.Infof("remote port forwarding established: %s:%d", payload.Host, actualPort) + logger.Infof("remote port forwarding established: %s", net.JoinHostPort(payload.Host, strconv.Itoa(int(actualPort)))) return true, response } @@ -351,7 +352,7 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr) if err != nil { - logger.Debugf("open forward channel for %s:%d: %v", host, port, err) + logger.Debugf("open forward channel for %s: %v", net.JoinHostPort(host, strconv.Itoa(int(port))), err) _ = conn.Close() return } diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go index 739eec513..6735e0f3b 100644 --- a/client/ssh/server/server.go +++ b/client/ssh/server/server.go @@ -143,10 +143,11 @@ type sessionState struct { } type Server struct { - sshServer *ssh.Server - listener net.Listener - mu sync.RWMutex - hostKeyPEM []byte + sshServer *ssh.Server + listener net.Listener + extraListeners []net.Listener + mu sync.RWMutex + hostKeyPEM []byte // sessions tracks active SSH sessions (shell, command, SFTP). // These are created when a client opens a session channel and requests shell/exec/subsystem. @@ -260,6 +261,35 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error { return nil } +// AddListener starts serving SSH on an additional address (e.g. IPv6). +// Must be called after Start. +func (s *Server) AddListener(ctx context.Context, addr netip.AddrPort) error { + s.mu.Lock() + srv := s.sshServer + if srv == nil { + s.mu.Unlock() + return errors.New("SSH server is not running") + } + + ln, addrDesc, err := s.createListener(ctx, addr) + if err != nil { + s.mu.Unlock() + return fmt.Errorf("create listener: %w", err) + } + + s.extraListeners = append(s.extraListeners, ln) + s.mu.Unlock() + + log.Infof("SSH server also listening on %s", addrDesc) + + go func() { + if err := srv.Serve(ln); err != nil && !errors.Is(err, ssh.ErrServerClosed) { + log.Errorf("SSH server error on %s: %v", addrDesc, err) + } + }() + return nil +} + func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) { if s.netstackNet != nil { ln, err := s.netstackNet.ListenTCPAddrPort(addr) @@ -297,6 +327,8 @@ func (s *Server) Stop() error { } s.sshServer = nil s.listener = nil + extraListeners := s.extraListeners + s.extraListeners = nil s.mu.Unlock() // Close outside the lock: session handlers need s.mu for unregisterSession. @@ -304,6 +336,12 @@ func (s *Server) Stop() error { log.Debugf("close SSH server: %v", err) } + for _, ln := range extraListeners { + if err := ln.Close(); err != nil { + log.Debugf("close extra SSH listener: %v", err) + } + } + s.mu.Lock() maps.Clear(s.sessions) maps.Clear(s.pendingAuthJWT) @@ -755,11 +793,10 @@ func (s *Server) findSessionKeyByContext(ctx ssh.Context) sessionKey { func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn { s.mu.RLock() - netbirdNetwork := s.wgAddress.Network - localIP := s.wgAddress.IP + wgAddr := s.wgAddress s.mu.RUnlock() - if !netbirdNetwork.IsValid() || !localIP.IsValid() { + if !wgAddr.Network.IsValid() || !wgAddr.IP.IsValid() { return conn } @@ -775,14 +812,17 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn { log.Warnf("SSH connection rejected: invalid remote IP %s", tcpAddr.IP) return nil } + remoteIP = remoteIP.Unmap() // Block connections from our own IP (prevent local apps from connecting to ourselves) - if remoteIP == localIP { + if remoteIP == wgAddr.IP || wgAddr.IPv6.IsValid() && remoteIP == wgAddr.IPv6 { log.Warnf("SSH connection rejected from own IP %s", remoteIP) return nil } - if !netbirdNetwork.Contains(remoteIP) { + inV4 := wgAddr.Network.Contains(remoteIP) + inV6 := wgAddr.IPv6Net.IsValid() && wgAddr.IPv6Net.Contains(remoteIP) + if !inV4 && !inV6 { log.Warnf("SSH connection rejected from non-NetBird IP %s", remoteIP) return nil } @@ -882,20 +922,21 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, s.mu.RUnlock() if !allowLocal { - logger.Warnf("local port forwarding denied for %s:%d: disabled", payload.Host, payload.Port) + logger.Warnf("local port forwarding denied for %s: disabled", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port)))) _ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled") return } if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil { - logger.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err) + logger.Warnf("local port forwarding denied for %s: %v", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))), err) _ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges") return } - forwardAddr := fmt.Sprintf("-L %s:%d", payload.Host, payload.Port) + hostPort := net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))) + forwardAddr := "-L " + hostPort s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr) - logger.Infof("local port forwarding: %s:%d", payload.Host, payload.Port) + logger.Infof("local port forwarding: %s", hostPort) s.relayDirectTCPIP(ctx, newChan, payload.Host, int(payload.Port), logger) } diff --git a/client/status/status.go b/client/status/status.go index 8c932bbab..11ed06c2d 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -60,6 +60,7 @@ type ConvertOptions struct { type PeerStateDetailOutput struct { FQDN string `json:"fqdn" yaml:"fqdn"` IP string `json:"netbirdIp" yaml:"netbirdIp"` + IPv6 string `json:"netbirdIpv6,omitempty" yaml:"netbirdIpv6,omitempty"` PubKey string `json:"publicKey" yaml:"publicKey"` Status string `json:"status" yaml:"status"` LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"` @@ -139,6 +140,7 @@ type OutputOverview struct { SignalState SignalStateOutput `json:"signal" yaml:"signal"` Relays RelayStateOutput `json:"relays" yaml:"relays"` IP string `json:"netbirdIp" yaml:"netbirdIp"` + IPv6 string `json:"netbirdIpv6,omitempty" yaml:"netbirdIpv6,omitempty"` PubKey string `json:"publicKey" yaml:"publicKey"` KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"` FQDN string `json:"fqdn" yaml:"fqdn"` @@ -182,6 +184,7 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO SignalState: signalOverview, Relays: relayOverview, IP: pbFullStatus.GetLocalPeerState().GetIP(), + IPv6: pbFullStatus.GetLocalPeerState().GetIpv6(), PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(), KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(), FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(), @@ -317,6 +320,7 @@ func mapPeers( timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local() peerState := PeerStateDetailOutput{ IP: pbPeerState.GetIP(), + IPv6: pbPeerState.GetIpv6(), PubKey: pbPeerState.GetPubKey(), Status: pbPeerState.GetConnStatus(), LastStatusUpdate: timeLocal, @@ -417,6 +421,11 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS interfaceIP = "N/A" } + ipv6Line := "" + if o.IPv6 != "" { + ipv6Line = fmt.Sprintf("NetBird IPv6: %s\n", o.IPv6) + } + var relaysString string if showRelays { for _, relay := range o.Relays.Details { @@ -549,6 +558,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS "Nameservers: %s\n"+ "FQDN: %s\n"+ "NetBird IP: %s\n"+ + "%s"+ "Interface type: %s\n"+ "Quantum resistance: %s\n"+ "Lazy connection: %s\n"+ @@ -566,6 +576,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS dnsServersString, domain.Domain(o.FQDN).SafeString(), interfaceIP, + ipv6Line, interfaceTypeString, rosenpassEnabledStatus, lazyConnectionEnabledStatus, @@ -616,6 +627,7 @@ func ToProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { } pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP + pbFullStatus.LocalPeerState.Ipv6 = fullStatus.LocalPeerState.IPv6 pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN @@ -628,6 +640,7 @@ func ToProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { for _, peerState := range fullStatus.Peers { pbPeerState := &proto.PeerState{ IP: peerState.IP, + Ipv6: peerState.IPv6, PubKey: peerState.PubKey, ConnStatus: peerState.ConnStatus.String(), ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate), @@ -733,9 +746,15 @@ func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bo networks = strings.Join(peerState.Networks, ", ") } + ipv6Line := "" + if peerState.IPv6 != "" { + ipv6Line = fmt.Sprintf(" NetBird IPv6: %s\n", peerState.IPv6) + } + peerString := fmt.Sprintf( "\n %s:\n"+ " NetBird IP: %s\n"+ + "%s"+ " Public key: %s\n"+ " Status: %s\n"+ " -- detail --\n"+ @@ -751,6 +770,7 @@ func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bo " Latency: %s\n", domain.Domain(peerState.FQDN).SafeString(), peerState.IP, + ipv6Line, peerState.PubKey, peerState.Status, peerState.ConnType, @@ -787,6 +807,9 @@ func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFi if len(ipsFilter) > 0 { _, ok := ipsFilter[peerState.IP] + if !ok { + _, ok = ipsFilter[peerState.Ipv6] + } if !ok { ipEval = true } @@ -905,6 +928,7 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *PeerStateDetailOutput) { peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port) } + peer.IPv6 = a.AnonymizeIPString(peer.IPv6) peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress) for i, route := range peer.Networks { @@ -929,6 +953,7 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) { overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error) overview.IP = a.AnonymizeIPString(overview.IP) + overview.IPv6 = a.AnonymizeIPString(overview.IPv6) for i, detail := range overview.Relays.Details { detail.URI = a.AnonymizeURI(detail.URI) detail.Error = a.AnonymizeString(detail.Error) diff --git a/client/status/status_test.go b/client/status/status_test.go index 7754eebae..0986bf0cd 100644 --- a/client/status/status_test.go +++ b/client/status/status_test.go @@ -32,6 +32,7 @@ var resp = &proto.StatusResponse{ Peers: []*proto.PeerState{ { IP: "192.168.178.101", + Ipv6: "fd00::1", PubKey: "Pubkey1", Fqdn: "peer-1.awesome-domain.com", ConnStatus: "Connected", @@ -90,6 +91,7 @@ var resp = &proto.StatusResponse{ }, LocalPeerState: &proto.LocalPeerState{ IP: "192.168.178.100/16", + Ipv6: "fd00::100", PubKey: "Some-Pub-Key", KernelInterface: true, Fqdn: "some-localhost.awesome-domain.com", @@ -130,6 +132,7 @@ var overview = OutputOverview{ Details: []PeerStateDetailOutput{ { IP: "192.168.178.101", + IPv6: "fd00::1", PubKey: "Pubkey1", FQDN: "peer-1.awesome-domain.com", Status: "Connected", @@ -204,6 +207,7 @@ var overview = OutputOverview{ }, }, IP: "192.168.178.100/16", + IPv6: "fd00::100", PubKey: "Some-Pub-Key", KernelInterface: true, FQDN: "some-localhost.awesome-domain.com", @@ -284,6 +288,7 @@ func TestParsingToJSON(t *testing.T) { { "fqdn": "peer-1.awesome-domain.com", "netbirdIp": "192.168.178.101", + "netbirdIpv6": "fd00::1", "publicKey": "Pubkey1", "status": "Connected", "lastStatusUpdate": "2001-01-01T01:01:01Z", @@ -361,6 +366,7 @@ func TestParsingToJSON(t *testing.T) { ] }, "netbirdIp": "192.168.178.100/16", + "netbirdIpv6": "fd00::100", "publicKey": "Some-Pub-Key", "usesKernelInterface": true, "fqdn": "some-localhost.awesome-domain.com", @@ -418,6 +424,7 @@ func TestParsingToYAML(t *testing.T) { details: - fqdn: peer-1.awesome-domain.com netbirdIp: 192.168.178.101 + netbirdIpv6: fd00::1 publicKey: Pubkey1 status: Connected lastStatusUpdate: 2001-01-01T01:01:01Z @@ -477,6 +484,7 @@ relays: available: false error: 'context: deadline exceeded' netbirdIp: 192.168.178.100/16 +netbirdIpv6: fd00::100 publicKey: Some-Pub-Key usesKernelInterface: true fqdn: some-localhost.awesome-domain.com @@ -523,6 +531,7 @@ func TestParsingToDetail(t *testing.T) { `Peers detail: peer-1.awesome-domain.com: NetBird IP: 192.168.178.101 + NetBird IPv6: fd00::1 Public key: Pubkey1 Status: Connected -- detail -- @@ -568,6 +577,7 @@ Nameservers: [1.1.1.1:53, 2.2.2.2:53] for [example.com, example.net] is Unavailable, reason: timeout FQDN: some-localhost.awesome-domain.com NetBird IP: 192.168.178.100/16 +NetBird IPv6: fd00::100 Interface type: Kernel Quantum resistance: false Lazy connection: false @@ -592,6 +602,7 @@ Relays: 1/2 Available Nameservers: 1/2 Available FQDN: some-localhost.awesome-domain.com NetBird IP: 192.168.178.100/16 +NetBird IPv6: fd00::100 Interface type: Kernel Quantum resistance: false Lazy connection: false diff --git a/client/system/info.go b/client/system/info.go index 175d1f07f..477d5162b 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -69,6 +69,7 @@ type Info struct { DisableFirewall bool BlockLANAccess bool BlockInbound bool + DisableIPv6 bool LazyConnectionEnabled bool @@ -83,7 +84,7 @@ func (i *Info) SetFlags( rosenpassEnabled, rosenpassPermissive bool, serverSSHAllowed *bool, disableClientRoutes, disableServerRoutes, - disableDNS, disableFirewall, blockLANAccess, blockInbound, lazyConnectionEnabled bool, + disableDNS, disableFirewall, blockLANAccess, blockInbound, disableIPv6, lazyConnectionEnabled bool, enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool, disableSSHAuth *bool, ) { @@ -99,6 +100,7 @@ func (i *Info) SetFlags( i.DisableFirewall = disableFirewall i.BlockLANAccess = blockLANAccess i.BlockInbound = blockInbound + i.DisableIPv6 = disableIPv6 i.LazyConnectionEnabled = lazyConnectionEnabled diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index c149b2152..c2129c7a2 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -38,10 +38,10 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/profilemanager" - "github.com/netbirdio/netbird/client/internal/sleep" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/ui/desktop" "github.com/netbirdio/netbird/client/ui/event" + "github.com/netbirdio/netbird/client/ui/notifier" "github.com/netbirdio/netbird/client/ui/process" "github.com/netbirdio/netbird/util" @@ -260,6 +260,7 @@ type serviceClient struct { // application with main windows. app fyne.App + notifier notifier.Notifier wSettings fyne.Window showAdvancedSettings bool sendNotification bool @@ -278,6 +279,7 @@ type serviceClient struct { sDisableDNS *widget.Check sDisableClientRoutes *widget.Check sDisableServerRoutes *widget.Check + sDisableIPv6 *widget.Check sBlockLANAccess *widget.Check sEnableSSHRoot *widget.Check sEnableSSHSFTP *widget.Check @@ -298,6 +300,7 @@ type serviceClient struct { disableDNS bool disableClientRoutes bool disableServerRoutes bool + disableIPv6 bool blockLANAccess bool enableSSHRoot bool enableSSHSFTP bool @@ -364,6 +367,7 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient { cancel: cancel, addr: args.addr, app: args.app, + notifier: notifier.New(args.app), logFile: args.logFile, sendNotification: false, @@ -466,6 +470,7 @@ func (s *serviceClient) showSettingsUI() { s.sDisableDNS = widget.NewCheck("Keeps system DNS settings unchanged", nil) s.sDisableClientRoutes = widget.NewCheck("This peer won't route traffic to other peers", nil) s.sDisableServerRoutes = widget.NewCheck("This peer won't act as router for others", nil) + s.sDisableIPv6 = widget.NewCheck("Disable IPv6 overlay addressing", nil) s.sBlockLANAccess = widget.NewCheck("Blocks local network access when used as exit node", nil) s.sEnableSSHRoot = widget.NewCheck("Enable SSH Root Login", nil) s.sEnableSSHSFTP = widget.NewCheck("Enable SSH SFTP", nil) @@ -583,6 +588,7 @@ func (s *serviceClient) hasSettingsChanged(iMngURL string, port, mtu int64) bool s.disableDNS != s.sDisableDNS.Checked || s.disableClientRoutes != s.sDisableClientRoutes.Checked || s.disableServerRoutes != s.sDisableServerRoutes.Checked || + s.disableIPv6 != s.sDisableIPv6.Checked || s.blockLANAccess != s.sBlockLANAccess.Checked || s.hasSSHChanges() } @@ -635,6 +641,7 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) ( req.DisableDns = &s.sDisableDNS.Checked req.DisableClientRoutes = &s.sDisableClientRoutes.Checked req.DisableServerRoutes = &s.sDisableServerRoutes.Checked + req.DisableIpv6 = &s.sDisableIPv6.Checked req.BlockLanAccess = &s.sBlockLANAccess.Checked req.EnableSSHRoot = &s.sEnableSSHRoot.Checked @@ -674,24 +681,23 @@ func (s *serviceClient) sendConfigUpdate(req *proto.SetConfigRequest) error { return fmt.Errorf("set config: %w", err) } - // Reconnect if connected to apply the new settings + // Reconnect if connected to apply the new settings. + // Use a background context so the reconnect outlives the settings window. go func() { - status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + status, err := conn.Status(ctx, &proto.StatusRequest{}) if err != nil { - log.Errorf("get service status: %v", err) + log.Errorf("failed to get service status: %v", err) return } if status.Status == string(internal.StatusConnected) { - // run down & up - _, err = conn.Down(s.ctx, &proto.DownRequest{}) - if err != nil { - log.Errorf("down service: %v", err) + if _, err = conn.Down(ctx, &proto.DownRequest{}); err != nil { + log.Errorf("failed to stop service: %v", err) } - - _, err = conn.Up(s.ctx, &proto.UpRequest{}) - if err != nil { - log.Errorf("up service: %v", err) - return + // TODO: wait for the service to be idle before calling Up, or use a fresh connection + if _, err = conn.Up(ctx, &proto.UpRequest{}); err != nil { + log.Errorf("failed to start service: %v", err) } } }() @@ -728,6 +734,7 @@ func (s *serviceClient) getNetworkForm() *widget.Form { {Text: "Disable DNS", Widget: s.sDisableDNS}, {Text: "Disable Client Routes", Widget: s.sDisableClientRoutes}, {Text: "Disable Server Routes", Widget: s.sDisableServerRoutes}, + {Text: "Disable IPv6", Widget: s.sDisableIPv6}, {Text: "Disable LAN Access", Widget: s.sBlockLANAccess}, }, } @@ -892,7 +899,7 @@ func (s *serviceClient) updateStatus() error { if err != nil { log.Errorf("get service status: %v", err) if s.connected { - s.app.SendNotification(fyne.NewNotification("Error", "Connection to service lost")) + s.notifier.Send("Error", "Connection to service lost") } s.setDisconnectedStatus() return err @@ -1109,7 +1116,7 @@ func (s *serviceClient) onTrayReady() { } }() - s.eventManager = event.NewManager(s.app, s.addr) + s.eventManager = event.NewManager(s.notifier, s.addr) s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked()) s.eventManager.AddHandler(func(event *proto.SystemEvent) { if event.Category == proto.SystemEvent_SYSTEM { @@ -1146,9 +1153,6 @@ func (s *serviceClient) onTrayReady() { go s.eventManager.Start(s.ctx) go s.eventHandler.listen(s.ctx) - - // Start sleep detection listener - go s.startSleepListener() } func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File { @@ -1209,62 +1213,6 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService return s.conn, nil } -// startSleepListener initializes the sleep detection service and listens for sleep events -func (s *serviceClient) startSleepListener() { - sleepService, err := sleep.New() - if err != nil { - log.Warnf("%v", err) - return - } - - if err := sleepService.Register(s.handleSleepEvents); err != nil { - log.Errorf("failed to start sleep detection: %v", err) - return - } - - log.Info("sleep detection service initialized") - - // Cleanup on context cancellation - go func() { - <-s.ctx.Done() - log.Info("stopping sleep event listener") - if err := sleepService.Deregister(); err != nil { - log.Errorf("failed to deregister sleep detection: %v", err) - } - }() -} - -// handleSleepEvents sends a sleep notification to the daemon via gRPC -func (s *serviceClient) handleSleepEvents(event sleep.EventType) { - conn, err := s.getSrvClient(0) - if err != nil { - log.Errorf("failed to get daemon client for sleep notification: %v", err) - return - } - - req := &proto.OSLifecycleRequest{} - - switch event { - case sleep.EventTypeWakeUp: - log.Infof("handle wakeup event: %v", event) - req.Type = proto.OSLifecycleRequest_WAKEUP - case sleep.EventTypeSleep: - log.Infof("handle sleep event: %v", event) - req.Type = proto.OSLifecycleRequest_SLEEP - default: - log.Infof("unknown event: %v", event) - return - } - - _, err = conn.NotifyOSLifecycle(s.ctx, req) - if err != nil { - log.Errorf("failed to notify daemon about os lifecycle notification: %v", err) - return - } - - log.Info("successfully notified daemon about os lifecycle") -} - // setSettingsEnabled enables or disables the settings menu based on the provided state func (s *serviceClient) setSettingsEnabled(enabled bool) { if s.mSettings != nil { @@ -1384,6 +1332,7 @@ func (s *serviceClient) getSrvConfig() { s.disableDNS = cfg.DisableDNS s.disableClientRoutes = cfg.DisableClientRoutes s.disableServerRoutes = cfg.DisableServerRoutes + s.disableIPv6 = cfg.DisableIPv6 s.blockLANAccess = cfg.BlockLANAccess if cfg.EnableSSHRoot != nil { @@ -1424,6 +1373,7 @@ func (s *serviceClient) getSrvConfig() { s.sDisableDNS.SetChecked(cfg.DisableDNS) s.sDisableClientRoutes.SetChecked(cfg.DisableClientRoutes) s.sDisableServerRoutes.SetChecked(cfg.DisableServerRoutes) + s.sDisableIPv6.SetChecked(cfg.DisableIPv6) s.sBlockLANAccess.SetChecked(cfg.BlockLANAccess) if cfg.EnableSSHRoot != nil { s.sEnableSSHRoot.SetChecked(*cfg.EnableSSHRoot) @@ -1511,6 +1461,7 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config { config.DisableDNS = cfg.DisableDns config.DisableClientRoutes = cfg.DisableClientRoutes config.DisableServerRoutes = cfg.DisableServerRoutes + config.DisableIPv6 = cfg.DisableIpv6 config.BlockLANAccess = cfg.BlockLanAccess config.EnableSSHRoot = &cfg.EnableSSHRoot @@ -1548,7 +1499,7 @@ func (s *serviceClient) onUpdateAvailable(newVersion string, enforced bool) { if enforced && s.lastNotifiedVersion != newVersion { s.lastNotifiedVersion = newVersion - s.app.SendNotification(fyne.NewNotification("Update available", "A new version "+newVersion+" is ready to install")) + s.notifier.Send("Update available", "A new version "+newVersion+" is ready to install") } } diff --git a/client/ui/debug.go b/client/ui/debug.go index 4ebe4d675..cf5ac1a75 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -16,6 +16,7 @@ import ( "fyne.io/fyne/v2/widget" log "github.com/sirupsen/logrus" "github.com/skratchdot/open-golang/open" + "google.golang.org/protobuf/types/known/durationpb" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/proto" @@ -38,6 +39,7 @@ type debugCollectionParams struct { upload bool uploadURL string enablePersistence bool + capture bool } // UI components for progress tracking @@ -51,25 +53,58 @@ type progressUI struct { func (s *serviceClient) showDebugUI() { w := s.app.NewWindow("NetBird Debug") w.SetOnClosed(s.cancel) - w.Resize(fyne.NewSize(600, 500)) w.SetFixedSize(true) anonymizeCheck := widget.NewCheck("Anonymize sensitive information (public IPs, domains, ...)", nil) systemInfoCheck := widget.NewCheck("Include system information (routes, interfaces, ...)", nil) systemInfoCheck.SetChecked(true) + captureCheck := widget.NewCheck("Include packet capture", nil) uploadCheck := widget.NewCheck("Upload bundle automatically after creation", nil) uploadCheck.SetChecked(true) - uploadURLLabel := widget.NewLabel("Debug upload URL:") + uploadURLContainer, uploadURL := s.buildUploadSection(uploadCheck) + + debugModeContainer, runForDurationCheck, durationInput, noteLabel := s.buildDurationSection() + + statusLabel := widget.NewLabel("") + statusLabel.Hide() + progressBar := widget.NewProgressBar() + progressBar.Hide() + createButton := widget.NewButton("Create Debug Bundle", nil) + + uiControls := []fyne.Disableable{ + anonymizeCheck, systemInfoCheck, captureCheck, + uploadCheck, uploadURL, runForDurationCheck, durationInput, createButton, + } + + createButton.OnTapped = s.getCreateHandler( + statusLabel, progressBar, uploadCheck, uploadURL, + anonymizeCheck, systemInfoCheck, captureCheck, + runForDurationCheck, durationInput, uiControls, w, + ) + + content := container.NewVBox( + widget.NewLabel("Create a debug bundle to help troubleshoot issues with NetBird"), + widget.NewLabel(""), + anonymizeCheck, systemInfoCheck, captureCheck, + uploadCheck, uploadURLContainer, + widget.NewLabel(""), + debugModeContainer, noteLabel, + widget.NewLabel(""), + statusLabel, progressBar, createButton, + ) + + w.SetContent(container.NewPadded(content)) + w.Show() +} + +func (s *serviceClient) buildUploadSection(uploadCheck *widget.Check) (*fyne.Container, *widget.Entry) { uploadURL := widget.NewEntry() uploadURL.SetText(uptypes.DefaultBundleURL) uploadURL.SetPlaceHolder("Enter upload URL") - uploadURLContainer := container.NewVBox( - uploadURLLabel, - uploadURL, - ) + uploadURLContainer := container.NewVBox(widget.NewLabel("Debug upload URL:"), uploadURL) uploadCheck.OnChanged = func(checked bool) { if checked { @@ -78,13 +113,14 @@ func (s *serviceClient) showDebugUI() { uploadURLContainer.Hide() } } + return uploadURLContainer, uploadURL +} - debugModeContainer := container.NewHBox() +func (s *serviceClient) buildDurationSection() (*fyne.Container, *widget.Check, *widget.Entry, *widget.Label) { runForDurationCheck := widget.NewCheck("Run with trace logs before creating bundle", nil) runForDurationCheck.SetChecked(true) forLabel := widget.NewLabel("for") - durationInput := widget.NewEntry() durationInput.SetText("1") minutesLabel := widget.NewLabel("minute") @@ -108,63 +144,8 @@ func (s *serviceClient) showDebugUI() { } } - debugModeContainer.Add(runForDurationCheck) - debugModeContainer.Add(forLabel) - debugModeContainer.Add(durationInput) - debugModeContainer.Add(minutesLabel) - - statusLabel := widget.NewLabel("") - statusLabel.Hide() - - progressBar := widget.NewProgressBar() - progressBar.Hide() - - createButton := widget.NewButton("Create Debug Bundle", nil) - - // UI controls that should be disabled during debug collection - uiControls := []fyne.Disableable{ - anonymizeCheck, - systemInfoCheck, - uploadCheck, - uploadURL, - runForDurationCheck, - durationInput, - createButton, - } - - createButton.OnTapped = s.getCreateHandler( - statusLabel, - progressBar, - uploadCheck, - uploadURL, - anonymizeCheck, - systemInfoCheck, - runForDurationCheck, - durationInput, - uiControls, - w, - ) - - content := container.NewVBox( - widget.NewLabel("Create a debug bundle to help troubleshoot issues with NetBird"), - widget.NewLabel(""), - anonymizeCheck, - systemInfoCheck, - uploadCheck, - uploadURLContainer, - widget.NewLabel(""), - debugModeContainer, - noteLabel, - widget.NewLabel(""), - statusLabel, - progressBar, - createButton, - ) - - paddedContent := container.NewPadded(content) - w.SetContent(paddedContent) - - w.Show() + modeContainer := container.NewHBox(runForDurationCheck, forLabel, durationInput, minutesLabel) + return modeContainer, runForDurationCheck, durationInput, noteLabel } func validateMinute(s string, minutesLabel *widget.Label) error { @@ -200,6 +181,7 @@ func (s *serviceClient) getCreateHandler( uploadURL *widget.Entry, anonymizeCheck *widget.Check, systemInfoCheck *widget.Check, + captureCheck *widget.Check, runForDurationCheck *widget.Check, duration *widget.Entry, uiControls []fyne.Disableable, @@ -222,6 +204,7 @@ func (s *serviceClient) getCreateHandler( params := &debugCollectionParams{ anonymize: anonymizeCheck.Checked, systemInfo: systemInfoCheck.Checked, + capture: captureCheck.Checked, upload: uploadCheck.Checked, uploadURL: url, enablePersistence: true, @@ -253,10 +236,7 @@ func (s *serviceClient) getCreateHandler( statusLabel.SetText("Creating debug bundle...") go s.handleDebugCreation( - anonymizeCheck.Checked, - systemInfoCheck.Checked, - uploadCheck.Checked, - url, + params, statusLabel, uiControls, w, @@ -371,7 +351,7 @@ func startProgressTracker(ctx context.Context, wg *sync.WaitGroup, duration time func (s *serviceClient) configureServiceForDebug( conn proto.DaemonServiceClient, state *debugInitialState, - enablePersistence bool, + params *debugCollectionParams, ) { if state.wasDown { if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil { @@ -397,7 +377,7 @@ func (s *serviceClient) configureServiceForDebug( time.Sleep(time.Second) } - if enablePersistence { + if params.enablePersistence { if _, err := conn.SetSyncResponsePersistence(s.ctx, &proto.SetSyncResponsePersistenceRequest{ Enabled: true, }); err != nil { @@ -417,6 +397,26 @@ func (s *serviceClient) configureServiceForDebug( if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil { log.Warnf("failed to start CPU profiling: %v", err) } + + s.startBundleCaptureIfEnabled(conn, params) +} + +func (s *serviceClient) startBundleCaptureIfEnabled(conn proto.DaemonServiceClient, params *debugCollectionParams) { + if !params.capture { + return + } + + const maxCapture = 10 * time.Minute + timeout := params.duration + 30*time.Second + if timeout > maxCapture { + timeout = maxCapture + log.Warnf("packet capture clamped to %s (server maximum)", maxCapture) + } + if _, err := conn.StartBundleCapture(s.ctx, &proto.StartBundleCaptureRequest{ + Timeout: durationpb.New(timeout), + }); err != nil { + log.Warnf("failed to start bundle capture: %v", err) + } } func (s *serviceClient) collectDebugData( @@ -430,7 +430,7 @@ func (s *serviceClient) collectDebugData( var wg sync.WaitGroup startProgressTracker(ctx, &wg, params.duration, progress) - s.configureServiceForDebug(conn, state, params.enablePersistence) + s.configureServiceForDebug(conn, state, params) wg.Wait() progress.progressBar.Hide() @@ -440,6 +440,14 @@ func (s *serviceClient) collectDebugData( log.Warnf("failed to stop CPU profiling: %v", err) } + if params.capture { + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := conn.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil { + log.Warnf("failed to stop bundle capture: %v", err) + } + } + return nil } @@ -520,18 +528,37 @@ func handleError(progress *progressUI, errMsg string) { } func (s *serviceClient) handleDebugCreation( - anonymize bool, - systemInfo bool, - upload bool, - uploadURL string, + params *debugCollectionParams, statusLabel *widget.Label, uiControls []fyne.Disableable, w fyne.Window, ) { - log.Infof("Creating debug bundle (Anonymized: %v, System Info: %v, Upload Attempt: %v)...", - anonymize, systemInfo, upload) + conn, err := s.getSrvClient(failFastTimeout) + if err != nil { + log.Errorf("Failed to get client for debug: %v", err) + statusLabel.SetText(fmt.Sprintf("Error: %v", err)) + enableUIControls(uiControls) + return + } - resp, err := s.createDebugBundle(anonymize, systemInfo, uploadURL) + if params.capture { + if _, err := conn.StartBundleCapture(s.ctx, &proto.StartBundleCaptureRequest{ + Timeout: durationpb.New(30 * time.Second), + }); err != nil { + log.Warnf("failed to start bundle capture: %v", err) + } else { + defer func() { + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := conn.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil { + log.Warnf("failed to stop bundle capture: %v", err) + } + }() + time.Sleep(2 * time.Second) + } + } + + resp, err := s.createDebugBundle(params.anonymize, params.systemInfo, params.uploadURL) if err != nil { log.Errorf("Failed to create debug bundle: %v", err) statusLabel.SetText(fmt.Sprintf("Error creating bundle: %v", err)) @@ -543,7 +570,7 @@ func (s *serviceClient) handleDebugCreation( uploadFailureReason := resp.GetUploadFailureReason() uploadedKey := resp.GetUploadedKey() - if upload { + if params.upload { if uploadFailureReason != "" { showUploadFailedDialog(w, localPath, uploadFailureReason) } else { diff --git a/client/ui/event/event.go b/client/ui/event/event.go index b8ed09a5c..3b43fdc7f 100644 --- a/client/ui/event/event.go +++ b/client/ui/event/event.go @@ -8,7 +8,6 @@ import ( "sync" "time" - "fyne.io/fyne/v2" "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "google.golang.org/grpc" @@ -18,11 +17,17 @@ import ( "github.com/netbirdio/netbird/client/ui/desktop" ) +// Notifier sends desktop notifications. Defined here so the event package +// does not depend on fyne or the platform-specific notifier implementation. +type Notifier interface { + Send(title, body string) +} + type Handler func(*proto.SystemEvent) type Manager struct { - app fyne.App - addr string + notifier Notifier + addr string mu sync.Mutex ctx context.Context @@ -31,10 +36,10 @@ type Manager struct { handlers []Handler } -func NewManager(app fyne.App, addr string) *Manager { +func NewManager(notifier Notifier, addr string) *Manager { return &Manager{ - app: app, - addr: addr, + notifier: notifier, + addr: addr, } } @@ -107,14 +112,14 @@ func (e *Manager) handleEvent(event *proto.SystemEvent) { handlers := slices.Clone(e.handlers) e.mu.Unlock() - if event.UserMessage != "" && (enabled || event.Severity == proto.SystemEvent_CRITICAL) { + if event.UserMessage != "" && (enabled || event.Severity == proto.SystemEvent_CRITICAL) && !isV6DefaultRoutePartner(event) { title := e.getEventTitle(event) body := event.UserMessage id := event.Metadata["id"] if id != "" { body += fmt.Sprintf(" ID: %s", id) } - e.app.SendNotification(fyne.NewNotification(title, body)) + e.notifier.Send(title, body) } for _, handler := range handlers { @@ -128,6 +133,14 @@ func (e *Manager) AddHandler(handler Handler) { e.handlers = append(e.handlers, handler) } +// isV6DefaultRoutePartner reports whether the event is the IPv6 half of a +// paired v4/v6 default-route event. Management always pairs ::/0 with 0.0.0.0/0 +// for exit nodes, so the v4 partner already drives the user-facing toast and +// the v6 one is suppressed to avoid a duplicate notification. +func isV6DefaultRoutePartner(event *proto.SystemEvent) bool { + return event.Category == proto.SystemEvent_NETWORK && event.Metadata["network"] == "::/0" +} + func (e *Manager) getEventTitle(event *proto.SystemEvent) string { var prefix string switch event.Severity { diff --git a/client/ui/event_handler.go b/client/ui/event_handler.go index 60a580dae..876fcef5f 100644 --- a/client/ui/event_handler.go +++ b/client/ui/event_handler.go @@ -9,7 +9,6 @@ import ( "os" "os/exec" - "fyne.io/fyne/v2" "fyne.io/systray" log "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" @@ -87,7 +86,7 @@ func (h *eventHandler) handleConnectClick() { if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) { log.Debugf("connect operation cancelled by user") } else { - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect")) + h.client.notifier.Send("Error", "Failed to connect") log.Errorf("connect failed: %v", err) } } @@ -112,7 +111,7 @@ func (h *eventHandler) handleDisconnectClick() { if err := h.client.menuDownClick(); err != nil { st, ok := status.FromError(err) if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) { - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to disconnect")) + h.client.notifier.Send("Error", "Failed to disconnect") log.Errorf("disconnect failed: %v", err) } else { log.Debugf("disconnect cancelled or already disconnecting") @@ -130,7 +129,7 @@ func (h *eventHandler) handleAllowSSHClick() { if err := h.updateConfigWithErr(); err != nil { h.toggleCheckbox(h.client.mAllowSSH) // revert checkbox state on error log.Errorf("failed to update config: %v", err) - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update SSH settings")) + h.client.notifier.Send("Error", "Failed to update SSH settings") } } @@ -140,7 +139,7 @@ func (h *eventHandler) handleAutoConnectClick() { if err := h.updateConfigWithErr(); err != nil { h.toggleCheckbox(h.client.mAutoConnect) // revert checkbox state on error log.Errorf("failed to update config: %v", err) - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update auto-connect settings")) + h.client.notifier.Send("Error", "Failed to update auto-connect settings") } } @@ -149,7 +148,7 @@ func (h *eventHandler) handleRosenpassClick() { if err := h.updateConfigWithErr(); err != nil { h.toggleCheckbox(h.client.mEnableRosenpass) // revert checkbox state on error log.Errorf("failed to update config: %v", err) - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update Rosenpass settings")) + h.client.notifier.Send("Error", "Failed to update Rosenpass settings") } } @@ -158,7 +157,7 @@ func (h *eventHandler) handleLazyConnectionClick() { if err := h.updateConfigWithErr(); err != nil { h.toggleCheckbox(h.client.mLazyConnEnabled) // revert checkbox state on error log.Errorf("failed to update config: %v", err) - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update lazy connection settings")) + h.client.notifier.Send("Error", "Failed to update lazy connection settings") } } @@ -167,7 +166,7 @@ func (h *eventHandler) handleBlockInboundClick() { if err := h.updateConfigWithErr(); err != nil { h.toggleCheckbox(h.client.mBlockInbound) // revert checkbox state on error log.Errorf("failed to update config: %v", err) - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update block inbound settings")) + h.client.notifier.Send("Error", "Failed to update block inbound settings") } } @@ -176,7 +175,7 @@ func (h *eventHandler) handleNotificationsClick() { if err := h.updateConfigWithErr(); err != nil { h.toggleCheckbox(h.client.mNotifications) // revert checkbox state on error log.Errorf("failed to update config: %v", err) - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update notifications settings")) + h.client.notifier.Send("Error", "Failed to update notifications settings") } else if h.client.eventManager != nil { h.client.eventManager.SetNotificationsEnabled(h.client.mNotifications.Checked()) } diff --git a/client/ui/network.go b/client/ui/network.go index 571e871bb..1619f78a2 100644 --- a/client/ui/network.go +++ b/client/ui/network.go @@ -192,10 +192,14 @@ func getOverlappingNetworks(routes []*proto.Network) []*proto.Network { return filteredRoutes } +func isDefaultRoute(routeRange string) bool { + return routeRange == "0.0.0.0/0" || routeRange == "::/0" +} + func getExitNodeNetworks(routes []*proto.Network) []*proto.Network { var filteredRoutes []*proto.Network for _, route := range routes { - if route.Range == "0.0.0.0/0" { + if isDefaultRoute(route.Range) { filteredRoutes = append(filteredRoutes, route) } } @@ -499,7 +503,7 @@ func (s *serviceClient) getExitNodes(conn proto.DaemonServiceClient) ([]*proto.N var exitNodes []*proto.Network for _, network := range resp.Routes { - if network.Range == "0.0.0.0/0" { + if isDefaultRoute(network.Range) { exitNodes = append(exitNodes, network) } } diff --git a/client/ui/notifier/notifier.go b/client/ui/notifier/notifier.go new file mode 100644 index 000000000..8d1cbe4c4 --- /dev/null +++ b/client/ui/notifier/notifier.go @@ -0,0 +1,27 @@ +// Package notifier sends desktop notifications. On Windows it uses the WinRT +// COM API directly via go-toast/v2 to avoid the PowerShell window flash that +// fyne's default implementation produces. On other platforms it delegates to +// fyne. +package notifier + +import "fyne.io/fyne/v2" + +// Notifier sends desktop notifications. +type Notifier interface { + Send(title, body string) +} + +// New returns a platform-specific Notifier. The fyne app is used as the +// fallback notifier on platforms where no native implementation is wired up, +// and on Windows when the COM path fails to initialize. +func New(app fyne.App) Notifier { + return newNotifier(app) +} + +type fyneNotifier struct { + app fyne.App +} + +func (f *fyneNotifier) Send(title, body string) { + f.app.SendNotification(fyne.NewNotification(title, body)) +} diff --git a/client/ui/notifier/notifier_other.go b/client/ui/notifier/notifier_other.go new file mode 100644 index 000000000..686d2885f --- /dev/null +++ b/client/ui/notifier/notifier_other.go @@ -0,0 +1,9 @@ +//go:build !windows + +package notifier + +import "fyne.io/fyne/v2" + +func newNotifier(app fyne.App) Notifier { + return &fyneNotifier{app: app} +} diff --git a/client/ui/notifier/notifier_windows.go b/client/ui/notifier/notifier_windows.go new file mode 100644 index 000000000..c7afb43ae --- /dev/null +++ b/client/ui/notifier/notifier_windows.go @@ -0,0 +1,88 @@ +package notifier + +import ( + "os" + "path/filepath" + "sync" + + "fyne.io/fyne/v2" + toast "git.sr.ht/~jackmordaunt/go-toast/v2" + "git.sr.ht/~jackmordaunt/go-toast/v2/wintoast" + log "github.com/sirupsen/logrus" +) + +const ( + // appID is the AppUserModelID shown in the Windows Action Center. It + // must match the System.AppUserModel.ID property set on the Start Menu + // shortcut by the MSI (see client/netbird.wxs); otherwise Windows + // groups toasts under a separate, unbranded entry. + appID = "NetBird" + + // appGUID identifies the COM activation callback class. Generated once + // for NetBird; do not change without coordinating an installer bump, + // since old registry entries pointing at the previous GUID would orphan. + appGUID = "{0E1B4DE7-E148-432B-9814-544F941826EC}" +) + +type comNotifier struct { + fallback *fyneNotifier + ready bool + iconPath string +} + +var ( + initOnce sync.Once + initErr error +) + +func newNotifier(app fyne.App) Notifier { + n := &comNotifier{ + fallback: &fyneNotifier{app: app}, + iconPath: resolveIcon(), + } + initOnce.Do(func() { + initErr = wintoast.SetAppData(wintoast.AppData{ + AppID: appID, + GUID: appGUID, + IconPath: n.iconPath, + }) + }) + if initErr != nil { + log.Warnf("toast: register app data failed, falling back to fyne notifications: %v", initErr) + return n.fallback + } + n.ready = true + return n +} + +func (n *comNotifier) Send(title, body string) { + if !n.ready { + n.fallback.Send(title, body) + return + } + notification := toast.Notification{ + AppID: appID, + Title: title, + Body: body, + Icon: n.iconPath, + } + if err := notification.Push(); err != nil { + log.Warnf("toast: push failed, using fyne fallback: %v", err) + n.fallback.Send(title, body) + } +} + +// resolveIcon returns an absolute path to the toast icon, or an empty string +// when no icon can be located. Windows requires a PNG/JPG for the +// AppUserModelId IconUri registry value; .ico is silently ignored. +func resolveIcon() string { + exe, err := os.Executable() + if err != nil { + return "" + } + candidate := filepath.Join(filepath.Dir(exe), "netbird.png") + if _, err := os.Stat(candidate); err == nil { + return candidate + } + return "" +} diff --git a/client/ui/profile.go b/client/ui/profile.go index 74189c9a0..7ee89e631 100644 --- a/client/ui/profile.go +++ b/client/ui/profile.go @@ -548,7 +548,7 @@ func (p *profileMenu) refresh() { if err != nil { log.Errorf("failed to switch profile: %v", err) // show notification dialog - p.app.SendNotification(fyne.NewNotification("Error", "Failed to switch profile")) + p.serviceClient.notifier.Send("Error", "Failed to switch profile") return } @@ -628,9 +628,9 @@ func (p *profileMenu) refresh() { } if err := p.eventHandler.logout(p.ctx); err != nil { log.Errorf("logout failed: %v", err) - p.app.SendNotification(fyne.NewNotification("Error", "Failed to deregister")) + p.serviceClient.notifier.Send("Error", "Failed to deregister") } else { - p.app.SendNotification(fyne.NewNotification("Success", "Deregistered successfully")) + p.serviceClient.notifier.Send("Success", "Deregistered successfully") } } } diff --git a/client/wasm/cmd/main.go b/client/wasm/cmd/main.go index d8e50ab6d..066fe043b 100644 --- a/client/wasm/cmd/main.go +++ b/client/wasm/cmd/main.go @@ -5,6 +5,9 @@ package main import ( "context" "fmt" + "net" + "strconv" + "sync" "syscall/js" "time" @@ -14,6 +17,7 @@ import ( netbird "github.com/netbirdio/netbird/client/embed" sshdetection "github.com/netbirdio/netbird/client/ssh/detection" nbstatus "github.com/netbirdio/netbird/client/status" + wasmcapture "github.com/netbirdio/netbird/client/wasm/internal/capture" "github.com/netbirdio/netbird/client/wasm/internal/http" "github.com/netbirdio/netbird/client/wasm/internal/rdp" "github.com/netbirdio/netbird/client/wasm/internal/ssh" @@ -81,6 +85,10 @@ func parseClientOptions(jsOptions js.Value) (netbird.Options, error) { options.DeviceName = deviceName.String() } + if disableIPv6 := jsOptions.Get("disableIPv6"); !disableIPv6.IsNull() && !disableIPv6.IsUndefined() { + options.DisableIPv6 = disableIPv6.Bool() + } + return options, nil } @@ -161,39 +169,58 @@ func createSSHMethod(client *netbird.Client) js.Func { }) } - var jwtToken string - if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() { - jwtToken = args[3].String() - } + jwtToken, ipVersion := parseSSHOptions(args) return createPromise(func(resolve, reject js.Value) { - sshClient := ssh.NewClient(client) - - if err := sshClient.Connect(host, port, username, jwtToken); err != nil { + jsInterface, err := connectSSH(client, host, port, username, jwtToken, ipVersion) + if err != nil { reject.Invoke(err.Error()) return } - - if err := sshClient.StartSession(80, 24); err != nil { - if closeErr := sshClient.Close(); closeErr != nil { - log.Errorf("Error closing SSH client: %v", closeErr) - } - reject.Invoke(err.Error()) - return - } - - jsInterface := ssh.CreateJSInterface(sshClient) resolve.Invoke(jsInterface) }) }) } -func performPing(client *netbird.Client, hostname string) { +func parseSSHOptions(args []js.Value) (jwtToken string, ipVersion int) { + if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() { + jwtToken = args[3].String() + } + if len(args) > 4 { + ipVersion = jsIPVersion(args[4]) + } + return +} + +func connectSSH(client *netbird.Client, host string, port int, username, jwtToken string, ipVersion int) (js.Value, error) { + sshClient := ssh.NewClient(client) + + if err := sshClient.Connect(host, port, username, jwtToken, ipVersion); err != nil { + return js.Undefined(), err + } + + if err := sshClient.StartSession(80, 24); err != nil { + if closeErr := sshClient.Close(); closeErr != nil { + log.Errorf("Error closing SSH client: %v", closeErr) + } + return js.Undefined(), err + } + + return ssh.CreateJSInterface(sshClient), nil +} + +func performPing(client *netbird.Client, hostname string, ipVersion int) { ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) defer cancel() + // Default to ping4 to avoid dual-stack ICMP endpoint issues in wireguard-go netstack. + network := "ping4" + if ipVersion == 6 { + network = "ping6" + } + start := time.Now() - conn, err := client.Dial(ctx, "ping", hostname) + conn, err := client.Dial(ctx, network, hostname) if err != nil { js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s failed: %v", hostname, err)) return @@ -220,27 +247,39 @@ func performPing(client *netbird.Client, hostname string) { } latency := time.Since(start) - js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds())) + remote := conn.RemoteAddr().String() + msg := fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds()) + if remote != hostname { + msg += fmt.Sprintf(" (via %s)", remote) + } + js.Global().Get("console").Call("log", msg) } -func performPingTCP(client *netbird.Client, hostname string, port int) { +func performPingTCP(client *netbird.Client, hostname string, port, ipVersion int) { ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) defer cancel() - address := fmt.Sprintf("%s:%d", hostname, port) + network := ipVersionNetwork("tcp", ipVersion) + + address := net.JoinHostPort(hostname, fmt.Sprintf("%d", port)) start := time.Now() - conn, err := client.Dial(ctx, "tcp", address) + conn, err := client.Dial(ctx, network, address) if err != nil { js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s failed: %v", address, err)) return } latency := time.Since(start) + remote := conn.RemoteAddr().String() if err := conn.Close(); err != nil { log.Debugf("failed to close TCP connection: %v", err) } - js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds())) + msg := fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds()) + if remote != address { + msg += fmt.Sprintf(" (via %s)", remote) + } + js.Global().Get("console").Call("log", msg) } // createPingMethod creates the ping method @@ -257,8 +296,12 @@ func createPingMethod(client *netbird.Client) js.Func { } hostname := args[0].String() + var ipVersion int + if len(args) > 1 { + ipVersion = jsIPVersion(args[1]) + } return createPromise(func(resolve, reject js.Value) { - performPing(client, hostname) + performPing(client, hostname, ipVersion) resolve.Invoke(js.Undefined()) }) }) @@ -285,8 +328,12 @@ func createPingTCPMethod(client *netbird.Client) js.Func { hostname := args[0].String() port := args[1].Int() + var ipVersion int + if len(args) > 2 { + ipVersion = jsIPVersion(args[2]) + } return createPromise(func(resolve, reject js.Value) { - performPingTCP(client, hostname, port) + performPingTCP(client, hostname, port, ipVersion) resolve.Invoke(js.Undefined()) }) }) @@ -459,6 +506,120 @@ func createSetLogLevelMethod(client *netbird.Client) js.Func { }) } +// ipVersionNetwork appends "4" or "6" to a base network string (e.g. "tcp" -> "tcp4"). +func ipVersionNetwork(base string, ipVersion int) string { + switch ipVersion { + case 4: + return base + "4" + case 6: + return base + "6" + default: + return base + } +} + +// jsIPVersion extracts an IP version (4 or 6) from a JS string or number. +func jsIPVersion(v js.Value) int { + switch v.Type() { + case js.TypeNumber: + return v.Int() + case js.TypeString: + n, _ := strconv.Atoi(v.String()) + return n + default: + return 0 + } +} + +// createStartCaptureMethod creates the programmable packet capture method. +// Returns a JS interface with onpacket callback and stop() method. +// +// Usage from JavaScript: +// +// const cap = await client.startCapture({ filter: "tcp port 443", verbose: true }) +// cap.onpacket = (line) => console.log(line) +// const stats = cap.stop() +func createStartCaptureMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(_ js.Value, args []js.Value) any { + var opts js.Value + if len(args) > 0 { + opts = args[0] + } + + return createPromise(func(resolve, reject js.Value) { + iface, err := wasmcapture.Start(client, opts) + if err != nil { + reject.Invoke(js.ValueOf(fmt.Sprintf("start capture: %v", err))) + return + } + resolve.Invoke(iface) + }) + }) +} + +// captureMethods returns capture() and stopCapture() that share state for +// the console-log shortcut. capture() logs packets to the browser console +// and stopCapture() ends it, like Ctrl+C on the CLI. +// +// Usage from browser devtools console: +// +// await client.capture() // capture all packets +// await client.capture("tcp") // capture with filter +// await client.capture({filter: "host 10.0.0.1", verbose: true}) +// client.stopCapture() // stop and print stats +func captureMethods(client *netbird.Client) (startFn, stopFn js.Func) { + var mu sync.Mutex + var active *wasmcapture.Handle + + startFn = js.FuncOf(func(_ js.Value, args []js.Value) any { + var opts js.Value + if len(args) > 0 { + opts = args[0] + } + + return createPromise(func(resolve, reject js.Value) { + mu.Lock() + defer mu.Unlock() + + if active != nil { + active.Stop() + active = nil + } + + h, err := wasmcapture.StartConsole(client, opts) + if err != nil { + reject.Invoke(js.ValueOf(fmt.Sprintf("start capture: %v", err))) + return + } + active = h + + console := js.Global().Get("console") + console.Call("log", "[capture] started, call client.stopCapture() to stop") + resolve.Invoke(js.Undefined()) + }) + }) + + stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any { + mu.Lock() + defer mu.Unlock() + + if active == nil { + js.Global().Get("console").Call("log", "[capture] no active capture") + return js.Undefined() + } + + stats := active.Stop() + active = nil + + console := js.Global().Get("console") + console.Call("log", fmt.Sprintf("[capture] stopped: %d packets, %d bytes, %d dropped", + stats.Packets, stats.Bytes, stats.Dropped)) + return js.Undefined() + }) + + return startFn, stopFn +} + // createPromise is a helper to create JavaScript promises func createPromise(handler func(resolve, reject js.Value)) js.Value { return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any { @@ -521,6 +682,11 @@ func createClientObject(client *netbird.Client) js.Value { obj["statusDetail"] = createStatusDetailMethod(client) obj["getSyncResponse"] = createGetSyncResponseMethod(client) obj["setLogLevel"] = createSetLogLevelMethod(client) + obj["startCapture"] = createStartCaptureMethod(client) + + capStart, capStop := captureMethods(client) + obj["capture"] = capStart + obj["stopCapture"] = capStop return js.ValueOf(obj) } diff --git a/client/wasm/internal/capture/capture.go b/client/wasm/internal/capture/capture.go new file mode 100644 index 000000000..53e43c45e --- /dev/null +++ b/client/wasm/internal/capture/capture.go @@ -0,0 +1,176 @@ +//go:build js + +// Package capture bridges the util/capture package to JavaScript via syscall/js. +package capture + +import ( + "strings" + "sync" + "syscall/js" + + netbird "github.com/netbirdio/netbird/client/embed" +) + +// Handle holds a running capture session so it can be stopped later. +type Handle struct { + cs *netbird.CaptureSession + stopFn js.Func + stopped bool +} + +// Stop ends the capture and returns stats. +func (h *Handle) Stop() netbird.CaptureStats { + if h.stopped { + return h.cs.Stats() + } + h.stopped = true + h.stopFn.Release() + + h.cs.Stop() + return h.cs.Stats() +} + +func statsToJS(s netbird.CaptureStats) js.Value { + obj := js.Global().Get("Object").Call("create", js.Null()) + obj.Set("packets", js.ValueOf(s.Packets)) + obj.Set("bytes", js.ValueOf(s.Bytes)) + obj.Set("dropped", js.ValueOf(s.Dropped)) + return obj +} + +// parseOpts extracts filter/verbose/ascii from a JS options value. +func parseOpts(jsOpts js.Value) (filter string, verbose, ascii bool) { + if jsOpts.IsNull() || jsOpts.IsUndefined() { + return + } + if jsOpts.Type() == js.TypeString { + filter = jsOpts.String() + return + } + if jsOpts.Type() != js.TypeObject { + return + } + if f := jsOpts.Get("filter"); !f.IsUndefined() && !f.IsNull() { + filter = f.String() + } + if v := jsOpts.Get("verbose"); !v.IsUndefined() { + verbose = v.Truthy() + } + if a := jsOpts.Get("ascii"); !a.IsUndefined() { + ascii = a.Truthy() + } + return +} + +// Start creates a capture session and returns a JS interface for streaming text +// output. The returned object exposes: +// +// onpacket(callback) - set callback(string) for each text line +// stop() - stop capture and return stats { packets, bytes, dropped } +// +// Options: { filter: string, verbose: bool, ascii: bool } or just a filter string. +func Start(client *netbird.Client, jsOpts js.Value) (js.Value, error) { + filter, verbose, ascii := parseOpts(jsOpts) + + cb := &jsCallbackWriter{} + + cs, err := client.StartCapture(netbird.CaptureOptions{ + TextOutput: cb, + Filter: filter, + Verbose: verbose, + ASCII: ascii, + }) + if err != nil { + return js.Undefined(), err + } + + handle := &Handle{cs: cs} + + iface := js.Global().Get("Object").Call("create", js.Null()) + handle.stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any { + return statsToJS(handle.Stop()) + }) + iface.Set("stop", handle.stopFn) + iface.Set("onpacket", js.Undefined()) + cb.setInterface(iface) + + return iface, nil +} + +// StartConsole starts a capture that logs every packet line to console.log. +// Returns a Handle so the caller can stop it later. +func StartConsole(client *netbird.Client, jsOpts js.Value) (*Handle, error) { + filter, verbose, ascii := parseOpts(jsOpts) + + cb := &jsCallbackWriter{} + + cs, err := client.StartCapture(netbird.CaptureOptions{ + TextOutput: cb, + Filter: filter, + Verbose: verbose, + ASCII: ascii, + }) + if err != nil { + return nil, err + } + + handle := &Handle{cs: cs} + handle.stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any { + return statsToJS(handle.Stop()) + }) + + iface := js.Global().Get("Object").Call("create", js.Null()) + console := js.Global().Get("console") + iface.Set("onpacket", console.Get("log").Call("bind", console, js.ValueOf("[capture]"))) + cb.setInterface(iface) + + return handle, nil +} + +// jsCallbackWriter is an io.Writer that buffers text until a newline, then +// invokes the JS onpacket callback with each complete line. +type jsCallbackWriter struct { + mu sync.Mutex + iface js.Value + buf strings.Builder +} + +func (w *jsCallbackWriter) setInterface(iface js.Value) { + w.mu.Lock() + defer w.mu.Unlock() + w.iface = iface +} + +func (w *jsCallbackWriter) Write(p []byte) (int, error) { + w.mu.Lock() + w.buf.Write(p) + + var lines []string + for { + str := w.buf.String() + idx := strings.IndexByte(str, '\n') + if idx < 0 { + break + } + lines = append(lines, str[:idx]) + w.buf.Reset() + if idx+1 < len(str) { + w.buf.WriteString(str[idx+1:]) + } + } + + iface := w.iface + w.mu.Unlock() + + if iface.IsUndefined() { + return len(p), nil + } + cb := iface.Get("onpacket") + if cb.IsUndefined() || cb.IsNull() { + return len(p), nil + } + for _, line := range lines { + cb.Invoke(js.ValueOf(line)) + } + return len(p), nil +} diff --git a/client/wasm/internal/rdp/rdcleanpath.go b/client/wasm/internal/rdp/rdcleanpath.go index 16bf63bb9..6c36fdec6 100644 --- a/client/wasm/internal/rdp/rdcleanpath.go +++ b/client/wasm/internal/rdp/rdcleanpath.go @@ -82,7 +82,7 @@ func NewRDCleanPathProxy(client interface { // CreateProxy creates a new proxy endpoint for the given destination func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value { - destination := fmt.Sprintf("%s:%s", hostname, port) + destination := net.JoinHostPort(hostname, port) return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, args []js.Value) any { resolve := args[0] diff --git a/client/wasm/internal/ssh/client.go b/client/wasm/internal/ssh/client.go index 568437e56..9cfe65266 100644 --- a/client/wasm/internal/ssh/client.go +++ b/client/wasm/internal/ssh/client.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "io" + "net" "sync" "time" @@ -45,9 +46,10 @@ func NewClient(nbClient *netbird.Client) *Client { } } -// Connect establishes an SSH connection through NetBird network -func (c *Client) Connect(host string, port int, username, jwtToken string) error { - addr := fmt.Sprintf("%s:%d", host, port) +// Connect establishes an SSH connection through NetBird network. +// ipVersion may be 4, 6, or 0 for automatic selection. +func (c *Client) Connect(host string, port int, username, jwtToken string, ipVersion int) error { + addr := net.JoinHostPort(host, fmt.Sprintf("%d", port)) logrus.Infof("SSH: Connecting to %s as %s", addr, username) authMethods, err := c.getAuthMethods(jwtToken) @@ -62,10 +64,18 @@ func (c *Client) Connect(host string, port int, username, jwtToken string) error Timeout: sshDialTimeout, } + network := "tcp" + switch ipVersion { + case 4: + network = "tcp4" + case 6: + network = "tcp6" + } + ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout) defer cancel() - conn, err := c.nbClient.Dial(ctx, "tcp", addr) + conn, err := c.nbClient.Dial(ctx, network, addr) if err != nil { return fmt.Errorf("dial %s: %w", addr, err) } diff --git a/combined/cmd/config.go b/combined/cmd/config.go index ce4df8394..9959f7a56 100644 --- a/combined/cmd/config.go +++ b/combined/cmd/config.go @@ -380,7 +380,7 @@ func (c *CombinedConfig) autoConfigureClientSettings(exposedProto, exposedHost, // Auto-configure local STUN servers for all ports for _, port := range c.Server.StunPorts { c.Management.Stuns = append(c.Management.Stuns, HostConfig{ - URI: fmt.Sprintf("stun:%s:%d", exposedHost, port), + URI: "stun:" + net.JoinHostPort(strings.Trim(exposedHost, "[]"), fmt.Sprintf("%d", port)), }) } } diff --git a/combined/config.yaml.example b/combined/config.yaml.example index dce658d89..af85b0477 100644 --- a/combined/config.yaml.example +++ b/combined/config.yaml.example @@ -119,6 +119,8 @@ server: # Reverse proxy settings (optional) # reverseProxy: - # trustedHTTPProxies: [] - # trustedHTTPProxiesCount: 0 - # trustedPeers: [] + # trustedHTTPProxies: [] # CIDRs of trusted reverse proxies (e.g. ["10.0.0.0/8"]) + # trustedHTTPProxiesCount: 0 # Number of trusted proxies in front of the server (alternative to trustedHTTPProxies) + # trustedPeers: [] # CIDRs of trusted peer networks (e.g. ["100.64.0.0/10"]) + # accessLogRetentionDays: 7 # Days to retain HTTP access logs. 0 (or unset) defaults to 7. Negative values disable cleanup (logs kept indefinitely). + # accessLogCleanupIntervalHours: 24 # How often (in hours) to run the access-log cleanup job. 0 (or unset) is treated as "not set" and defaults to 24 hours; cleanup remains enabled. To disable cleanup, set accessLogRetentionDays to a negative value. diff --git a/flow/client/client.go b/flow/client/client.go index 8ad637974..180a4b441 100644 --- a/flow/client/client.go +++ b/flow/client/client.go @@ -13,11 +13,9 @@ import ( "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "google.golang.org/grpc" - "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" - "google.golang.org/grpc/status" nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/flow/proto" @@ -301,12 +299,11 @@ func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff }, ctx) } +// isContextDone reports whether the local context has been canceled or has +// exceeded its deadline. It deliberately does not inspect gRPC status codes: +// a server- or proxy-sent codes.Canceled / codes.DeadlineExceeded must not +// short-circuit our retry loop, since retrying is the correct response when +// the local context is still alive. func isContextDone(err error) bool { - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return true - } - if s, ok := status.FromError(err); ok { - return s.Code() == codes.Canceled || s.Code() == codes.DeadlineExceeded - } - return false + return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) } diff --git a/flow/client/client_test.go b/flow/client/client_test.go index 55157acbc..c8f5f4af4 100644 --- a/flow/client/client_test.go +++ b/flow/client/client_test.go @@ -457,6 +457,18 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) { client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second) require.NoError(t, err) + + // Cleanups run LIFO: the goroutine-drain registered here runs after Close below, + // which is when Receive has actually returned. Without this, the Receive goroutine + // can outlive the test and call t.Logf after teardown, panicking. + receiveDone := make(chan struct{}) + t.Cleanup(func() { + select { + case <-receiveDone: + case <-time.After(2 * time.Second): + t.Error("Receive goroutine did not exit after Close") + } + }) t.Cleanup(func() { err := client.Close() assert.NoError(t, err, "failed to close flow") @@ -468,6 +480,7 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) { receivedAfterReconnect := make(chan struct{}) go func() { + defer close(receiveDone) err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error { if msg.IsInitiator || len(msg.EventId) == 0 { return nil diff --git a/go.mod b/go.mod index 5172b1a78..bc4e8af15 100644 --- a/go.mod +++ b/go.mod @@ -17,8 +17,8 @@ require ( github.com/spf13/cobra v1.10.1 github.com/spf13/pflag v1.0.9 github.com/vishvananda/netlink v1.3.1 - golang.org/x/crypto v0.49.0 - golang.org/x/sys v0.42.0 + golang.org/x/crypto v0.50.0 + golang.org/x/sys v0.43.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 @@ -30,6 +30,7 @@ require ( require ( fyne.io/fyne/v2 v2.7.0 fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9 + git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3 github.com/awnumar/memguard v0.23.0 github.com/aws/aws-sdk-go-v2 v1.38.3 github.com/aws/aws-sdk-go-v2/config v1.31.6 @@ -46,6 +47,7 @@ require ( github.com/crowdsecurity/go-cs-bouncer v0.0.21 github.com/dexidp/dex v0.0.0-00010101000000-000000000000 github.com/dexidp/dex/api/v2 v2.4.0 + github.com/ebitengine/purego v0.8.4 github.com/eko/gocache/lib/v4 v4.2.0 github.com/eko/gocache/store/go_cache/v4 v4.2.2 github.com/eko/gocache/store/redis/v4 v4.2.2 @@ -66,9 +68,10 @@ require ( github.com/jackc/pgx/v5 v5.5.5 github.com/libdns/route53 v1.5.0 github.com/libp2p/go-nat v0.2.0 - github.com/libp2p/go-netroute v0.2.1 + github.com/libp2p/go-netroute v0.4.0 github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 github.com/mdlayher/socket v0.5.1 + github.com/mdp/qrterminal/v3 v3.2.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 @@ -115,11 +118,11 @@ require ( goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b golang.org/x/mobile v0.0.0-20251113184115-a159579294ab - golang.org/x/mod v0.33.0 - golang.org/x/net v0.52.0 + golang.org/x/mod v0.34.0 + golang.org/x/net v0.53.0 golang.org/x/oauth2 v0.36.0 golang.org/x/sync v0.20.0 - golang.org/x/term v0.41.0 + golang.org/x/term v0.42.0 golang.org/x/time v0.15.0 google.golang.org/api v0.276.0 gopkg.in/yaml.v3 v3.0.1 @@ -178,7 +181,6 @@ require ( github.com/docker/docker v28.0.1+incompatible // indirect github.com/docker/go-connections v0.6.0 // indirect github.com/docker/go-units v0.5.0 // indirect - github.com/ebitengine/purego v0.8.4 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fredbi/uri v1.1.1 // indirect github.com/fyne-io/gl-js v0.2.0 // indirect @@ -301,13 +303,14 @@ require ( go.uber.org/multierr v1.11.0 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect golang.org/x/image v0.33.0 // indirect - golang.org/x/text v0.35.0 // indirect - golang.org/x/tools v0.42.0 // indirect + golang.org/x/text v0.36.0 // indirect + golang.org/x/tools v0.43.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect + rsc.io/qr v0.2.0 // indirect ) replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 @@ -320,6 +323,6 @@ replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-2023080111 replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 -replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 - replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0 + +replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0 diff --git a/go.sum b/go.sum index 9293ce73b..d54dc01e6 100644 --- a/go.sum +++ b/go.sum @@ -15,6 +15,8 @@ fyne.io/fyne/v2 v2.7.0 h1:GvZSpE3X0liU/fqstInVvRsaboIVpIWQ4/sfjDGIGGQ= fyne.io/fyne/v2 v2.7.0/go.mod h1:xClVlrhxl7D+LT+BWYmcrW4Nf+dJTvkhnPgji7spAwE= fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9 h1:829+77I4TaMrcg9B3wf+gHhdSgoCVEgH2czlPXPbfj4= fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs= +git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3 h1:N3IGoHHp9pb6mj1cbXbuaSXV/UMKwmbKLf53nQmtqMA= +git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3/go.mod h1:QtOLZGz8olr4qH2vWK0QH0w0O4T9fEIjMuWpKUsH7nc= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8= github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl0yUMSk= @@ -393,6 +395,8 @@ github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA= github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q= github.com/libp2p/go-nat v0.2.0 h1:Tyz+bUFAYqGyJ/ppPPymMGbIgNRH+WqC5QrT5fKrrGk= github.com/libp2p/go-nat v0.2.0/go.mod h1:3MJr+GRpRkyT65EpVPBstXLvOlAPzUVlG6Pwg9ohLJk= +github.com/libp2p/go-netroute v0.4.0 h1:sZZx9hyANYUx9PZyqcgE/E1GUG3iEtTZHUEvdtXT7/Q= +github.com/libp2p/go-netroute v0.4.0/go.mod h1:Nkd5ShYgSMS5MUKy/MU2T57xFoOKvvLR92Lic48LEyA= github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 h1:J56rFEfUTFT9j9CiRXhi1r8lUJ4W5idG3CiaBZGojNU= github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81/go.mod h1:RD8ML/YdXctQ7qbcizZkw5mZ6l8Ogrl1dodBzVJduwI= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= @@ -400,8 +404,6 @@ github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tA github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k= github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= -github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= -github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU= github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= @@ -415,6 +417,8 @@ github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0 github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o= github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= +github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFet4= +github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU= github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k= github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U= github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs= @@ -449,8 +453,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U= github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU= -github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk= -github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= +github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus= +github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 h1:F3zS5fT9xzD1OFLfcdAE+3FfyiwjGukF1hvj0jErgs8= @@ -707,8 +711,8 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= -golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ= @@ -725,8 +729,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= -golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= +golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= +golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= @@ -745,8 +749,8 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= -golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= +golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= @@ -797,8 +801,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= -golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -811,8 +815,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= -golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= -golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= +golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= +golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -824,8 +828,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= -golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= -golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -839,8 +843,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= -golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= -golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= +golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s= +golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -913,3 +917,5 @@ gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 h1:mGJaeA61P8dEHTqdvAgc70ZIV3QoUoJcXCRyyjO26OA= gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q= +rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY= +rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs= diff --git a/idp/dex/config.go b/idp/dex/config.go index 7f5300f14..e686233ad 100644 --- a/idp/dex/config.go +++ b/idp/dex/config.go @@ -193,7 +193,7 @@ func (c *Connector) ToStorageConnector() (storage.Connector, error) { // are stored with types that Dex can open. func mapConnectorToDex(connType string, config map[string]interface{}) (string, map[string]interface{}) { switch connType { - case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak": + case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs": return "oidc", applyOIDCDefaults(connType, config) default: return connType, config @@ -218,6 +218,8 @@ func applyOIDCDefaults(connType string, config map[string]interface{}) map[strin setDefault(augmented, "claimMapping", map[string]string{"email": "preferred_username"}) case "okta", "pocketid": augmented["scopes"] = []string{"openid", "profile", "email", "groups"} + case "adfs": + augmented["scopes"] = []string{"openid", "profile", "email", "allatclaims"} } return augmented diff --git a/idp/dex/connector.go b/idp/dex/connector.go index ba2bb1f00..fb20fdcc3 100644 --- a/idp/dex/connector.go +++ b/idp/dex/connector.go @@ -89,21 +89,33 @@ func (p *Provider) ListConnectors(ctx context.Context) ([]*ConnectorConfig, erro } // UpdateConnector updates an existing connector in Dex storage. -// It merges incoming updates with existing values to prevent data loss on partial updates. +// It overlays user-mutable config fields (issuer, clientID, clientSecret, +// redirectURI) onto the stored connector config, and updates the connector name +// when cfg.Name is set. Empty fields on cfg leave stored values unchanged, so +// partial updates preserve create-time defaults such as scopes, claimMapping, +// and userIDKey. func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error { if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) { - oldCfg, err := p.parseStorageConnector(old) - if err != nil { - return storage.Connector{}, fmt.Errorf("failed to parse existing connector: %w", err) + if cfg.Type != "" && cfg.Type != inferIdentityProviderType(old.Type, cfg.ID, nil) { + return storage.Connector{}, errors.New("connector type change not allowed") } - mergeConnectorConfig(cfg, oldCfg) - - storageConn, err := p.buildStorageConnector(cfg) + configData, err := overlayConnectorConfig(old.Config, cfg) if err != nil { - return storage.Connector{}, fmt.Errorf("failed to build connector: %w", err) + return storage.Connector{}, fmt.Errorf("failed to overlay connector config: %w", err) } - return storageConn, nil + + name := cfg.Name + if name == "" { + name = old.Name + } + + return storage.Connector{ + ID: cfg.ID, + Type: old.Type, + Name: name, + Config: configData, + }, nil }); err != nil { return fmt.Errorf("failed to update connector: %w", err) } @@ -112,23 +124,27 @@ func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) er return nil } -// mergeConnectorConfig preserves existing values for empty fields in the update. -func mergeConnectorConfig(cfg, oldCfg *ConnectorConfig) { - if cfg.ClientSecret == "" { - cfg.ClientSecret = oldCfg.ClientSecret +// overlayConnectorConfig writes only the user-mutable fields onto the existing +// stored config, preserving every other field (scopes, claimMapping, userIDKey, +// insecure flags, etc.). Empty fields on cfg leave the existing value alone. +func overlayConnectorConfig(oldConfig []byte, cfg *ConnectorConfig) ([]byte, error) { + var m map[string]any + if err := decodeConnectorConfig(oldConfig, &m); err != nil { + return nil, err } - if cfg.RedirectURI == "" { - cfg.RedirectURI = oldCfg.RedirectURI + if cfg.Issuer != "" { + m["issuer"] = cfg.Issuer } - if cfg.Issuer == "" && cfg.Type == oldCfg.Type { - cfg.Issuer = oldCfg.Issuer + if cfg.ClientID != "" { + m["clientID"] = cfg.ClientID } - if cfg.ClientID == "" { - cfg.ClientID = oldCfg.ClientID + if cfg.ClientSecret != "" { + m["clientSecret"] = cfg.ClientSecret } - if cfg.Name == "" { - cfg.Name = oldCfg.Name + if cfg.RedirectURI != "" { + m["redirectURI"] = cfg.RedirectURI } + return encodeConnectorConfig(m) } // DeleteConnector removes a connector from Dex storage. @@ -168,7 +184,7 @@ func (p *Provider) buildStorageConnector(cfg *ConnectorConfig) (storage.Connecto var err error switch cfg.Type { - case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak": + case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs": dexType = "oidc" configData, err = buildOIDCConnectorConfig(cfg, redirectURI) case "google": @@ -216,10 +232,16 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, oidcConfig["getUserInfo"] = true case "entra": oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"} + // Use the Entra Object ID (oid) instead of the default OIDC sub claim. + // Entra issues sub as a per-app pairwise identifier that does not match + // the stable Object ID. + oidcConfig["userIDKey"] = "oid" case "okta": oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"} case "pocketid": oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"} + case "adfs": + oidcConfig["scopes"] = []string{"openid", "profile", "email", "allatclaims"} } return encodeConnectorConfig(oidcConfig) } @@ -283,7 +305,7 @@ func inferIdentityProviderType(dexType, connectorID string, _ map[string]interfa // inferOIDCProviderType infers the specific OIDC provider from connector ID func inferOIDCProviderType(connectorID string) string { connectorIDLower := strings.ToLower(connectorID) - for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak"} { + for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak", "adfs"} { if strings.Contains(connectorIDLower, provider) { return provider } diff --git a/idp/dex/connector_test.go b/idp/dex/connector_test.go new file mode 100644 index 000000000..4253e02b7 --- /dev/null +++ b/idp/dex/connector_test.go @@ -0,0 +1,205 @@ +package dex + +import ( + "context" + "encoding/json" + "log/slog" + "os" + "path/filepath" + "testing" + + "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/sql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestProvider(t *testing.T) (*Provider, func()) { + t.Helper() + tmpDir, err := os.MkdirTemp("", "dex-connector-test-*") + require.NoError(t, err) + + logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) + s, err := (&sql.SQLite3{File: filepath.Join(tmpDir, "dex.db")}).Open(logger) + require.NoError(t, err) + + return &Provider{storage: s, logger: logger}, func() { + _ = s.Close() + _ = os.RemoveAll(tmpDir) + } +} + +func TestBuildOIDCConnectorConfig_EntraSetsUserIDKey(t *testing.T) { + cfg := &ConnectorConfig{ + ID: "entra-test", + Name: "Entra", + Type: "entra", + Issuer: "https://login.microsoftonline.com/tid/v2.0", + ClientID: "client-id", + ClientSecret: "client-secret", + } + data, err := buildOIDCConnectorConfig(cfg, "https://example.com/oauth2/callback") + require.NoError(t, err) + + var m map[string]any + require.NoError(t, json.Unmarshal(data, &m)) + + assert.Equal(t, "oid", m["userIDKey"], "entra connectors must default userIDKey to oid") + assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"]) +} + +func TestBuildOIDCConnectorConfig_NonEntraDoesNotSetUserIDKey(t *testing.T) { + // ensures the Entra userIDKey override does not leak into other OIDC providers, + // which already use a stable sub claim. + for _, typ := range []string{"oidc", "zitadel", "okta", "pocketid", "authentik", "keycloak", "adfs"} { + t.Run(typ, func(t *testing.T) { + data, err := buildOIDCConnectorConfig(&ConnectorConfig{Type: typ}, "https://example.com/oauth2/callback") + require.NoError(t, err) + var m map[string]any + require.NoError(t, json.Unmarshal(data, &m)) + _, ok := m["userIDKey"] + assert.False(t, ok, "%s connectors must not have userIDKey set", typ) + }) + } +} + +func TestUpdateConnector_PreservesCreateTimeDefaults(t *testing.T) { + ctx := context.Background() + p, cleanup := newTestProvider(t) + defer cleanup() + + created, err := p.CreateConnector(ctx, &ConnectorConfig{ + ID: "entra-test", + Name: "Entra", + Type: "entra", + Issuer: "https://login.microsoftonline.com/tid/v2.0", + ClientID: "client-id", + ClientSecret: "old-secret", + RedirectURI: "https://example.com/oauth2/callback", + }) + require.NoError(t, err) + require.Equal(t, "entra-test", created.ID) + + // Rotate only the client secret. + err = p.UpdateConnector(ctx, &ConnectorConfig{ + ID: "entra-test", + Type: "entra", + ClientSecret: "new-secret", + }) + require.NoError(t, err) + + conn, err := p.storage.GetConnector(ctx, "entra-test") + require.NoError(t, err) + var m map[string]any + require.NoError(t, json.Unmarshal(conn.Config, &m)) + + assert.Equal(t, "new-secret", m["clientSecret"], "clientSecret should be rotated") + assert.Equal(t, "client-id", m["clientID"], "clientID must survive (overlay should leave it alone)") + assert.Equal(t, "https://login.microsoftonline.com/tid/v2.0", m["issuer"]) + assert.Equal(t, "oid", m["userIDKey"], "userIDKey must survive update") + assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"], "claimMapping must survive update") +} + +func TestUpdateConnector_DoesNotAddUserIDKeyToExistingConnector(t *testing.T) { + ctx := context.Background() + p, cleanup := newTestProvider(t) + defer cleanup() + + // Seed a connector directly into storage without userIDKey + preFixConfig, err := json.Marshal(map[string]any{ + "issuer": "https://login.microsoftonline.com/tid/v2.0", + "clientID": "client-id", + "clientSecret": "old-secret", + "redirectURI": "https://example.com/oauth2/callback", + "scopes": []string{"openid", "profile", "email"}, + "claimMapping": map[string]string{"email": "preferred_username"}, + }) + require.NoError(t, err) + + require.NoError(t, p.storage.CreateConnector(ctx, storage.Connector{ + ID: "entra-prefix", + Type: "oidc", + Name: "Entra", + Config: preFixConfig, + })) + + // Rotate client secret via UpdateConnector. + err = p.UpdateConnector(ctx, &ConnectorConfig{ + ID: "entra-prefix", + Type: "entra", + ClientSecret: "new-secret", + }) + require.NoError(t, err) + + conn, err := p.storage.GetConnector(ctx, "entra-prefix") + require.NoError(t, err) + var m map[string]any + require.NoError(t, json.Unmarshal(conn.Config, &m)) + + assert.Equal(t, "new-secret", m["clientSecret"]) + _, has := m["userIDKey"] + assert.False(t, has, "userIDKey must not be auto-added to a connector that did not have it before") +} + +func TestUpdateConnector_RejectsTypeChange(t *testing.T) { + ctx := context.Background() + p, cleanup := newTestProvider(t) + defer cleanup() + + _, err := p.CreateConnector(ctx, &ConnectorConfig{ + ID: "entra-test", + Name: "Entra", + Type: "entra", + Issuer: "https://login.microsoftonline.com/tid/v2.0", + ClientID: "client-id", + ClientSecret: "secret", + RedirectURI: "https://example.com/oauth2/callback", + }) + require.NoError(t, err) + + // Attempt to switch the connector to okta. + err = p.UpdateConnector(ctx, &ConnectorConfig{ + ID: "entra-test", + Type: "okta", + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "connector type change not allowed") + + // stored connector type/config unchanged after the rejected update. + conn, err := p.storage.GetConnector(ctx, "entra-test") + require.NoError(t, err) + assert.Equal(t, "oidc", conn.Type) + var m map[string]any + require.NoError(t, json.Unmarshal(conn.Config, &m)) + assert.Equal(t, "oid", m["userIDKey"]) +} + +func TestUpdateConnector_AllowsSameTypeUpdate(t *testing.T) { + ctx := context.Background() + p, cleanup := newTestProvider(t) + defer cleanup() + + _, err := p.CreateConnector(ctx, &ConnectorConfig{ + ID: "entra-test", + Name: "Entra", + Type: "entra", + Issuer: "https://login.microsoftonline.com/old/v2.0", + ClientID: "client-id", + ClientSecret: "secret", + RedirectURI: "https://example.com/oauth2/callback", + }) + require.NoError(t, err) + + err = p.UpdateConnector(ctx, &ConnectorConfig{ + ID: "entra-test", + Type: "entra", + Issuer: "https://login.microsoftonline.com/new/v2.0", + }) + require.NoError(t, err) + + conn, err := p.storage.GetConnector(ctx, "entra-test") + require.NoError(t, err) + var m map[string]any + require.NoError(t, json.Unmarshal(conn.Config, &m)) + assert.Equal(t, "https://login.microsoftonline.com/new/v2.0", m["issuer"]) +} diff --git a/infrastructure_files/getting-started.sh b/infrastructure_files/getting-started.sh index 08da48264..9d1b57258 100755 --- a/infrastructure_files/getting-started.sh +++ b/infrastructure_files/getting-started.sh @@ -231,7 +231,20 @@ get_upstream_host() { wait_management_proxy() { local proxy_container="${1:-traefik}" + local use_docker_logs=false set +e + + if [[ "$proxy_container" == "detect-traefik" ]]; then + proxy_container=$(docker ps --format "{{.ID}}\t{{.Image}}\t{{.Ports}}" \ + | awk -F'\t' '$2 ~ /traefik/ && $3 ~ /:(80|443)->/ {print $1; exit}') + + if [[ -z "$proxy_container" ]]; then + echo "Warning: could not auto-detect Traefik container, log output will be skipped on timeout." > /dev/stderr + else + use_docker_logs=true + fi + fi + echo -n "Waiting for NetBird server to become ready" counter=1 while true; do @@ -242,7 +255,13 @@ wait_management_proxy() { if [[ $counter -eq 60 ]]; then echo "" echo "Taking too long. Checking logs..." - $DOCKER_COMPOSE_COMMAND logs --tail=20 "$proxy_container" + if [[ -n "$proxy_container" ]]; then + if [[ "$use_docker_logs" == "true" ]]; then + docker logs --tail=20 "$proxy_container" + else + $DOCKER_COMPOSE_COMMAND logs --tail=20 "$proxy_container" + fi + fi $DOCKER_COMPOSE_COMMAND logs --tail=20 netbird-server fi echo -n " ." @@ -472,7 +491,7 @@ start_services_and_show_instructions() { if [[ "$ENABLE_CROWDSEC" == "true" ]]; then echo "Registering CrowdSec bouncer..." local cs_retries=0 - while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli capi status >/dev/null 2>&1; do + while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli lapi status >/dev/null 2>&1; do cs_retries=$((cs_retries + 1)) if [[ $cs_retries -ge 30 ]]; then echo "WARNING: CrowdSec did not become ready. Skipping CrowdSec setup." > /dev/stderr @@ -518,7 +537,7 @@ start_services_and_show_instructions() { $DOCKER_COMPOSE_COMMAND up -d sleep 3 - wait_management_direct + wait_management_proxy detect-traefik echo -e "$MSG_DONE" print_post_setup_instructions diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index 4b414df6f..36de950e9 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -7,7 +7,6 @@ import ( "os" "slices" "strconv" - "strings" "sync" "sync/atomic" "time" @@ -16,11 +15,9 @@ import ( "golang.org/x/exp/maps" "golang.org/x/mod/semver" - nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" - "github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/account" @@ -58,13 +55,6 @@ type Controller struct { proxyController port_forwarding.Controller integratedPeerValidator integrated_validator.IntegratedValidator - - holder *types.Holder - - expNewNetworkMap bool - expNewNetworkMapAIDs map[string]struct{} - - compactedNetworkMap bool } type bufferUpdate struct { @@ -81,29 +71,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App log.Fatal(fmt.Errorf("error creating metrics: %w", err)) } - newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder)) - if err != nil { - log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err) - newNetworkMapBuilder = false - } - - compactedNetworkMap := true - compactedEnv := os.Getenv(types.EnvNewNetworkMapCompacted) - parsedCompactedNmap, err := strconv.ParseBool(compactedEnv) - if err != nil && len(compactedEnv) > 0 { - log.WithContext(ctx).Warnf("failed to parse %s, using default value true: %v", types.EnvNewNetworkMapCompacted, err) - } - if err == nil && !parsedCompactedNmap { - log.WithContext(ctx).Info("disabling compacted mode") - compactedNetworkMap = false - } - - ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",") - expIDs := make(map[string]struct{}, len(ids)) - for _, id := range ids { - expIDs[id] = struct{}{} - } - return &Controller{ repo: newRepository(store), metrics: nMetrics, @@ -117,12 +84,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App proxyController: proxyController, EphemeralPeersManager: ephemeralPeersManager, - - holder: types.NewHolder(), - expNewNetworkMap: newNetworkMapBuilder, - expNewNetworkMapAIDs: expIDs, - - compactedNetworkMap: compactedNetworkMap, } } @@ -153,17 +114,9 @@ func (c *Controller) CountStreams() int { func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error { log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) - var ( - account *types.Account - err error - ) - if c.experimentalNetworkMap(accountID) { - account = c.getAccountFromHolderOrInit(ctx, accountID) - } else { - account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf("failed to get account: %v", err) - } + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to get account: %v", err) } globalStart := time.Now() @@ -197,10 +150,6 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin routers := account.GetResourceRoutersMap() groupIDToUserIDs := account.GetActiveGroupUsers() - if c.experimentalNetworkMap(accountID) { - c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap) - } - proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers) if err != nil { log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) @@ -243,16 +192,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin c.metrics.CountCalcPostureChecksDuration(time.Since(start)) start = time.Now() - var remotePeerNetworkMap *types.NetworkMap - - switch { - case c.experimentalNetworkMap(accountID): - remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics) - case c.compactedNetworkMap: - remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) - default: - remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) - } + remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start)) @@ -317,11 +257,10 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID // UpdatePeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. -func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error { - if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return fmt.Errorf("recalculate network map cache: %v", err) +func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error { + if c.accountManagerMetrics != nil { + c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation)) } - return c.sendUpdateAccountPeers(ctx, accountID) } @@ -371,16 +310,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe return err } - var remotePeerNetworkMap *types.NetworkMap - - switch { - case c.experimentalNetworkMap(accountId): - remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics) - case c.compactedNetworkMap: - remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) - default: - remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) - } + remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { @@ -404,9 +334,13 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe return nil } -func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string) error { +func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error { log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName()) + if c.accountManagerMetrics != nil { + c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation)) + } + bufUpd, _ := c.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{}) b := bufUpd.(*bufferUpdate) @@ -421,14 +355,14 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str go func() { defer b.mu.Unlock() - _ = c.UpdateAccountPeers(ctx, accountID) + _ = c.sendUpdateAccountPeers(ctx, accountID) if !b.update.Load() { return } b.update.Store(false) if b.next == nil { b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() { - _ = c.UpdateAccountPeers(ctx, accountID) + _ = c.sendUpdateAccountPeers(ctx, accountID) }) return } @@ -451,17 +385,9 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr return peer, emptyMap, nil, 0, nil } - var ( - account *types.Account - err error - ) - if c.experimentalNetworkMap(accountID) { - account = c.getAccountFromHolderOrInit(ctx, accountID) - } else { - account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, nil, nil, 0, err - } + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, nil, nil, 0, err } account.InjectProxyPolicies(ctx) @@ -493,20 +419,10 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr return nil, nil, nil, 0, err } - var networkMap *types.NetworkMap - - if c.experimentalNetworkMap(accountID) { - networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics) - } else { - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - if c.compactedNetworkMap { - networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) - } else { - networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) - } - } + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { @@ -518,108 +434,6 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr return peer, networkMap, postureChecks, dnsFwdPort, nil } -func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) { - c.enrichAccountFromHolder(account) - account.InitNetworkMapBuilderIfNeeded(validatedPeers) -} - -func (c *Controller) getPeerNetworkMapExp( - ctx context.Context, - accountId string, - peerId string, - validatedPeers map[string]struct{}, - peersCustomZone nbdns.CustomZone, - accountZones []*zones.Zone, - metrics *telemetry.AccountManagerMetrics, -) *types.NetworkMap { - account := c.getAccountFromHolderOrInit(ctx, accountId) - if account == nil { - log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId) - return &types.NetworkMap{ - Network: &types.Network{}, - } - } - - return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics) -} - -func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) { - c.enrichAccountFromHolder(account) - account.OnPeersAddedUpdNetworkMapCache(peerIds...) -} - -func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error { - c.enrichAccountFromHolder(account) - return account.OnPeerDeletedUpdNetworkMapCache(peerId) -} - -func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) { - account := c.getAccountFromHolder(accountId) - if account == nil { - return - } - account.UpdatePeerInNetworkMapCache(peer) -} - -func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) { - account.RecalculateNetworkMapCache(validatedPeers) - c.updateAccountInHolder(account) -} - -func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error { - if c.experimentalNetworkMap(accountId) { - account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId) - if err != nil { - return err - } - validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - log.WithContext(ctx).Errorf("failed to get validate peers: %v", err) - return err - } - c.recalculateNetworkMapCache(account, validatedPeers) - } - return nil -} - -func (c *Controller) experimentalNetworkMap(accountId string) bool { - _, ok := c.expNewNetworkMapAIDs[accountId] - return c.expNewNetworkMap || ok -} - -func (c *Controller) enrichAccountFromHolder(account *types.Account) { - a := c.holder.GetAccount(account.Id) - if a == nil { - c.holder.AddAccount(account) - return - } - account.NetworkMapCache = a.NetworkMapCache - if account.NetworkMapCache == nil { - return - } - c.holder.AddAccount(account) -} - -func (c *Controller) getAccountFromHolder(accountID string) *types.Account { - return c.holder.GetAccount(accountID) -} - -func (c *Controller) getAccountFromHolderOrInit(ctx context.Context, accountID string) *types.Account { - a := c.holder.GetAccount(accountID) - if a != nil { - return a - } - account, err := c.holder.LoadOrStoreFunc(ctx, accountID, c.requestBuffer.GetAccountWithBackpressure) - if err != nil { - return nil - } - return account -} - -func (c *Controller) updateAccountInHolder(account *types.Account) { - c.holder.AddAccount(account) -} - // GetDNSDomain returns the configured dnsDomain func (c *Controller) GetDNSDomain(settings *types.Settings) string { if settings == nil { @@ -756,16 +570,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t } func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error { - peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs) - if err != nil { - return fmt.Errorf("failed to get peers by ids: %w", err) - } - - for _, peer := range peers { - c.UpdatePeerInNetworkMapCache(accountID, peer) - } - - err = c.bufferSendUpdateAccountPeers(ctx, accountID) + err := c.bufferSendUpdateAccountPeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err) } @@ -775,14 +580,6 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error { log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs) - if c.experimentalNetworkMap(accountID) { - account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return err - } - log.WithContext(ctx).Debugf("peers are ready to be added to networkmap cache: %v", peerIDs) - c.onPeersAddedUpdNetworkMapCache(account, peerIDs...) - } return c.bufferSendUpdateAccountPeers(ctx, accountID) } @@ -817,19 +614,6 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI MessageType: network_map.MessageTypeNetworkMap, }) c.peersUpdateManager.CloseChannel(ctx, peerID) - - if c.experimentalNetworkMap(accountID) { - account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err) - continue - } - err = c.onPeerDeletedUpdNetworkMapCache(account, peerID) - if err != nil { - log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err) - continue - } - } } return c.bufferSendUpdateAccountPeers(ctx, accountID) @@ -872,21 +656,11 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N return nil, err } - var networkMap *types.NetworkMap - - if c.experimentalNetworkMap(peer.AccountID) { - networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil) - } else { - account.InjectProxyPolicies(ctx) - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - if c.compactedNetworkMap { - networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) - } else { - networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) - } - } + account.InjectProxyPolicies(ctx) + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { diff --git a/management/internals/controllers/network_map/interface.go b/management/internals/controllers/network_map/interface.go index 64caac861..44d8f7d72 100644 --- a/management/internals/controllers/network_map/interface.go +++ b/management/internals/controllers/network_map/interface.go @@ -12,18 +12,15 @@ import ( ) const ( - EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP" - EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS" - DnsForwarderPort = nbdns.ForwarderServerPort OldForwarderPort = nbdns.ForwarderClientPort DnsForwarderPortMinVersion = "v0.59.0" ) type Controller interface { - UpdateAccountPeers(ctx context.Context, accountID string) error + UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error - BufferUpdateAccountPeers(ctx context.Context, accountID string) error + BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) GetDNSDomain(settings *types.Settings) string StartWarmup(context.Context) diff --git a/management/internals/controllers/network_map/interface_mock.go b/management/internals/controllers/network_map/interface_mock.go index 4e86d2973..073a75d3b 100644 --- a/management/internals/controllers/network_map/interface_mock.go +++ b/management/internals/controllers/network_map/interface_mock.go @@ -44,17 +44,17 @@ func (m *MockController) EXPECT() *MockControllerMockRecorder { } // BufferUpdateAccountPeers mocks base method. -func (m *MockController) BufferUpdateAccountPeers(ctx context.Context, accountID string) error { +func (m *MockController) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID) + ret := m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID, reason) ret0, _ := ret[0].(error) return ret0 } // BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers. -func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID any) *gomock.Call { +func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, reason any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID, reason) } // CountStreams mocks base method. @@ -238,15 +238,15 @@ func (mr *MockControllerMockRecorder) UpdateAccountPeer(ctx, accountId, peerId a } // UpdateAccountPeers mocks base method. -func (m *MockController) UpdateAccountPeers(ctx context.Context, accountID string) error { +func (m *MockController) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID) + ret := m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID, reason) ret0, _ := ret[0].(error) return ret0 } // UpdateAccountPeers indicates an expected call of UpdateAccountPeers. -func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID any) *gomock.Call { +func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID, reason any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID, reason) } diff --git a/management/internals/modules/peers/ephemeral/manager/ephemeral_test.go b/management/internals/modules/peers/ephemeral/manager/ephemeral_test.go index fc3010dd1..314e84501 100644 --- a/management/internals/modules/peers/ephemeral/manager/ephemeral_test.go +++ b/management/internals/modules/peers/ephemeral/manager/ephemeral_test.go @@ -62,7 +62,7 @@ func (a *MockAccountManager) GetDeletePeerCalls() int { return a.deletePeerCalls } -func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { +func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) { a.mu.Lock() defer a.mu.Unlock() if a.bufferUpdateCalls == nil { @@ -248,7 +248,7 @@ func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) { return err } } - mockAM.BufferUpdateAccountPeers(ctx, accountID) + mockAM.BufferUpdateAccountPeers(ctx, accountID, types.UpdateReason{}) return nil }). Times(1) diff --git a/management/internals/modules/peers/manager.go b/management/internals/modules/peers/manager.go index d3f8f44ff..c913efb92 100644 --- a/management/internals/modules/peers/manager.go +++ b/management/internals/modules/peers/manager.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -178,7 +179,7 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs } } - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete}) return nil } diff --git a/management/internals/modules/reverseproxy/proxy/manager.go b/management/internals/modules/reverseproxy/proxy/manager.go index aa7cd8630..53c52b3aa 100644 --- a/management/internals/modules/reverseproxy/proxy/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager.go @@ -11,9 +11,9 @@ import ( // Manager defines the interface for proxy operations type Manager interface { - Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error - Disconnect(ctx context.Context, proxyID string) error - Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error + Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) + Disconnect(ctx context.Context, proxyID, sessionID string) error + Heartbeat(ctx context.Context, p *Proxy) error GetActiveClusterAddresses(ctx context.Context) ([]string, error) GetActiveClusters(ctx context.Context) ([]Cluster, error) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool diff --git a/management/internals/modules/reverseproxy/proxy/manager/manager.go b/management/internals/modules/reverseproxy/proxy/manager/manager.go index d13334e83..341e8c943 100644 --- a/management/internals/modules/reverseproxy/proxy/manager/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager/manager.go @@ -13,7 +13,8 @@ import ( // store defines the interface for proxy persistence operations type store interface { SaveProxy(ctx context.Context, p *proxy.Proxy) error - UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error + DisconnectProxy(ctx context.Context, proxyID, sessionID string) error + UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool @@ -43,7 +44,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) { // Connect registers a new proxy connection in the database. // capabilities may be nil for old proxies that do not report them. -func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error { +func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) { now := time.Now() var caps proxy.Capabilities if capabilities != nil { @@ -51,6 +52,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress } p := &proxy.Proxy{ ID: proxyID, + SessionID: sessionID, ClusterAddress: clusterAddress, IPAddress: ipAddress, LastSeen: now, @@ -61,48 +63,42 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress if err := m.store.SaveProxy(ctx, p); err != nil { log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err) - return err + return nil, err } log.WithContext(ctx).WithFields(log.Fields{ "proxyID": proxyID, + "sessionID": sessionID, "clusterAddress": clusterAddress, "ipAddress": ipAddress, }).Info("proxy connected") - return nil + return p, nil } -// Disconnect marks a proxy as disconnected in the database -func (m Manager) Disconnect(ctx context.Context, proxyID string) error { - now := time.Now() - p := &proxy.Proxy{ - ID: proxyID, - Status: "disconnected", - DisconnectedAt: &now, - LastSeen: now, - } - - if err := m.store.SaveProxy(ctx, p); err != nil { - log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err) +// Disconnect marks a proxy as disconnected in the database. +func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error { + if err := m.store.DisconnectProxy(ctx, proxyID, sessionID); err != nil { + log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, err) return err } log.WithContext(ctx).WithFields(log.Fields{ - "proxyID": proxyID, + "proxyID": proxyID, + "sessionID": sessionID, }).Info("proxy disconnected") return nil } -// Heartbeat updates the proxy's last seen timestamp -func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { - if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil { - log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err) +// Heartbeat updates the proxy's last seen timestamp. +func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error { + if err := m.store.UpdateProxyHeartbeat(ctx, p); err != nil { + log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", p.ID, err) return err } - log.WithContext(ctx).Tracef("updated heartbeat for proxy %s", proxyID) + log.WithContext(ctx).Tracef("updated heartbeat for proxy %s session %s", p.ID, p.SessionID) m.metrics.IncrementProxyHeartbeatCount() return nil } diff --git a/management/internals/modules/reverseproxy/proxy/manager_mock.go b/management/internals/modules/reverseproxy/proxy/manager_mock.go index 282ca0ba5..98d97b3c6 100644 --- a/management/internals/modules/reverseproxy/proxy/manager_mock.go +++ b/management/internals/modules/reverseproxy/proxy/manager_mock.go @@ -93,31 +93,32 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte } // Connect mocks base method. -func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error { +func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities) + ret0, _ := ret[0].(*Proxy) + ret1, _ := ret[1].(error) + return ret0, ret1 } // Connect indicates an expected call of Connect. -func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities) } // Disconnect mocks base method. -func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error { +func (m *MockManager) Disconnect(ctx context.Context, proxyID, sessionID string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID) + ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID, sessionID) ret0, _ := ret[0].(error) return ret0 } // Disconnect indicates an expected call of Disconnect. -func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID, sessionID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID, sessionID) } // GetActiveClusterAddresses mocks base method. @@ -151,17 +152,17 @@ func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Ca } // Heartbeat mocks base method. -func (m *MockManager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { +func (m *MockManager) Heartbeat(ctx context.Context, p *Proxy) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID, clusterAddress, ipAddress) + ret := m.ctrl.Call(m, "Heartbeat", ctx, p) ret0, _ := ret[0].(error) return ret0 } // Heartbeat indicates an expected call of Heartbeat. -func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) Heartbeat(ctx, p interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, p) } // MockController is a mock of Controller interface. diff --git a/management/internals/modules/reverseproxy/proxy/proxy.go b/management/internals/modules/reverseproxy/proxy/proxy.go index 339c82446..dcedb8811 100644 --- a/management/internals/modules/reverseproxy/proxy/proxy.go +++ b/management/internals/modules/reverseproxy/proxy/proxy.go @@ -18,12 +18,13 @@ type Capabilities struct { // Proxy represents a reverse proxy instance type Proxy struct { ID string `gorm:"primaryKey;type:varchar(255)"` + SessionID string `gorm:"type:varchar(36)"` ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"` IPAddress string `gorm:"type:varchar(45)"` LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"` ConnectedAt *time.Time DisconnectedAt *time.Time - Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"` + Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"` Capabilities Capabilities `gorm:"embedded"` CreatedAt time.Time UpdatedAt time.Time diff --git a/management/internals/modules/reverseproxy/service/manager/l4_port_test.go b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go index 28461641d..3485d51fe 100644 --- a/management/internals/modules/reverseproxy/service/manager/l4_port_test.go +++ b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go @@ -2,7 +2,7 @@ package manager import ( "context" - "net" + "net/netip" "testing" "time" @@ -56,7 +56,8 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor Key: "test-key", DNSLabel: "test-peer", Name: "test-peer", - IP: net.ParseIP("100.64.0.1"), + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, }, @@ -85,7 +86,7 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor accountMgr := &mock_server.MockAccountManager{ StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, - UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, + UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {}, GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) }, diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index ed9d4201b..d03a8dc82 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "math/rand/v2" + "net" "net/http" "os" "slices" @@ -25,6 +26,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -231,7 +233,7 @@ func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationCreate}) return s, nil } @@ -515,7 +517,7 @@ func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, s } m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo) - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationUpdate}) return service, nil } @@ -819,7 +821,7 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete}) return nil } @@ -860,7 +862,7 @@ func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID strin m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster) } - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete}) return nil } @@ -916,7 +918,7 @@ func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationUpdate}) return nil } @@ -1098,11 +1100,11 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s } m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster) - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationCreate}) serviceURL := "https://" + svc.Domain if service.IsL4Protocol(svc.Mode) { - serviceURL = fmt.Sprintf("%s://%s:%d", svc.Mode, svc.Domain, svc.ListenPort) + serviceURL = fmt.Sprintf("%s://%s", svc.Mode, net.JoinHostPort(svc.Domain, strconv.Itoa(int(svc.ListenPort)))) } return &service.ExposeServiceResponse{ @@ -1210,7 +1212,7 @@ func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serv m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster) - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete}) return nil } @@ -1261,7 +1263,7 @@ func (m *Manager) deleteExpiredPeerService(ctx context.Context, accountID, peerI meta := addPeerInfoToEventMeta(svc.EventMeta(), peer) m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activity.PeerServiceExposeExpired, meta) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster) - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete}) return nil } @@ -1271,7 +1273,7 @@ func addPeerInfoToEventMeta(meta map[string]any, peer *nbpeer.Peer) map[string]a return meta } meta["peer_name"] = peer.Name - if peer.IP != nil { + if peer.IP.IsValid() { meta["peer_ip"] = peer.IP.String() } return meta diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index 54ac8ab18..46e79f1e5 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -3,7 +3,7 @@ package manager import ( "context" "errors" - "net" + "net/netip" "testing" "time" @@ -405,7 +405,8 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { testPeer := &nbpeer.Peer{ ID: ownerPeerID, Name: "test-peer", - IP: net.ParseIP("100.64.0.1"), + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), } newEphemeralService := func() *rpservice.Service { @@ -447,7 +448,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) { storedActivity = activityID.(activity.Activity) }, - UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, + UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {}, } mockStore.EXPECT(). @@ -549,7 +550,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) { storedActivity = activityID.(activity.Activity) }, - UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, + UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {}, } mockStore.EXPECT(). @@ -593,7 +594,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, meta map[string]any) { storedMeta = meta }, - UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, + UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {}, } mockStore.EXPECT(). @@ -682,7 +683,8 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) { Key: "test-key", DNSLabel: "test-peer", Name: "test-peer", - IP: net.ParseIP("100.64.0.1"), + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, }, @@ -704,7 +706,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) { accountMgr := &mock_server.MockAccountManager{ StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, - UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, + UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {}, GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) }, @@ -751,7 +753,8 @@ func Test_validateExposePermission(t *testing.T) { Key: "other-key", DNSLabel: "other-peer", Name: "other-peer", - IP: net.ParseIP("100.64.0.2"), + IP: netip.MustParseAddr("100.64.0.2"), + IPv6: netip.MustParseAddr("fd00::2"), Status: &nbpeer.PeerStatus{LastSeen: time.Now()}, Meta: nbpeer.PeerSystemMeta{Hostname: "other-peer"}, }) @@ -1173,7 +1176,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) { mockAcct.EXPECT(). StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any()) mockAcct.EXPECT(). - UpdateAccountPeers(ctx, accountID) + UpdateAccountPeers(ctx, accountID, gomock.Any()) err = mgr.DeleteService(ctx, accountID, userID, service.ID) require.NoError(t, err) diff --git a/management/internals/modules/zones/manager/manager.go b/management/internals/modules/zones/manager/manager.go index 8548dd48c..439671e65 100644 --- a/management/internals/modules/zones/manager/manager.go +++ b/management/internals/modules/zones/manager/manager.go @@ -11,6 +11,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -144,7 +145,7 @@ func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneUpdated, zone.EventMeta()) - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZone, Operation: types.UpdateOperationUpdate}) return zone, nil } @@ -206,7 +207,7 @@ func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID event() } - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZone, Operation: types.UpdateOperationDelete}) return nil } diff --git a/management/internals/modules/zones/records/manager/manager.go b/management/internals/modules/zones/records/manager/manager.go index 5374a2ef2..7458b41db 100644 --- a/management/internals/modules/zones/records/manager/manager.go +++ b/management/internals/modules/zones/records/manager/manager.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -95,7 +96,7 @@ func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneI meta := record.EventMeta(zone.ID, zone.Name) m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordCreated, meta) - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZoneRecord, Operation: types.UpdateOperationCreate}) return record, nil } @@ -154,7 +155,7 @@ func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneI meta := record.EventMeta(zone.ID, zone.Name) m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordUpdated, meta) - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZoneRecord, Operation: types.UpdateOperationUpdate}) return record, nil } @@ -201,7 +202,7 @@ func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneI meta := record.EventMeta(zone.ID, zone.Name) m.accountManager.StoreEvent(ctx, userID, recordID, accountID, activity.DNSRecordDeleted, meta) - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZoneRecord, Operation: types.UpdateOperationDelete}) return nil } diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 24dfb641b..f2ab0a2c4 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -30,6 +30,7 @@ import ( nbcache "github.com/netbirdio/netbird/management/server/cache" nbContext "github.com/netbirdio/netbird/management/server/context" nbhttp "github.com/netbirdio/netbird/management/server/http" + "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" mgmtProto "github.com/netbirdio/netbird/shared/management/proto" @@ -109,7 +110,7 @@ func (s *BaseServer) EventStore() activity.Store { func (s *BaseServer) APIHandler() http.Handler { return Create(s, func() http.Handler { - httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies) + httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter()) if err != nil { log.Fatalf("failed to create API handler: %v", err) } @@ -117,6 +118,15 @@ func (s *BaseServer) APIHandler() http.Handler { }) } +func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter { + return Create(s, func() *middleware.APIRateLimiter { + cfg, enabled := middleware.RateLimiterConfigFromEnv() + limiter := middleware.NewAPIRateLimiter(cfg) + limiter.SetEnabled(enabled) + return limiter + }) +} + func (s *BaseServer) GRPCServer() *grpc.Server { return Create(s, func() *grpc.Server { trustedPeers := s.Config.ReverseProxy.TrustedPeers @@ -163,7 +173,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server { } gRPCAPIHandler := grpc.NewServer(gRPCOpts...) - srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider()) + srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider(), s.SessionStore()) if err != nil { log.Fatalf("failed to create management server: %v", err) } diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index 9a8e45d33..89bdf0abe 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -6,6 +6,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" @@ -66,6 +67,12 @@ func (s *BaseServer) SecretsManager() grpc.SecretsManager { }) } +func (s *BaseServer) SessionStore() *auth.SessionStore { + return Create(s, func() *auth.SessionStore { + return auth.NewSessionStore(s.CacheStore()) + }) +} + func (s *BaseServer) AuthManager() auth.Manager { audiences := s.Config.GetAuthAudiences() audience := s.Config.HttpConfig.AuthAudience diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go index ef417d3cf..12402b420 100644 --- a/management/internals/shared/grpc/conversion.go +++ b/management/internals/shared/grpc/conversion.go @@ -3,12 +3,15 @@ package grpc import ( "context" "fmt" + "net/netip" "net/url" "strings" log "github.com/sirupsen/logrus" + goproto "google.golang.org/protobuf/proto" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" + "github.com/netbirdio/netbird/client/ssh/auth" nbdns "github.com/netbirdio/netbird/dns" @@ -17,8 +20,9 @@ import ( nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/route" + nbroute "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/netiputil" "github.com/netbirdio/netbird/shared/sshauth" ) @@ -100,7 +104,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set sshConfig.JwtConfig = buildJWTConfig(httpConfig, deviceFlowConfig) } - return &proto.PeerConfig{ + peerConfig := &proto.PeerConfig{ Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), SshConfig: sshConfig, Fqdn: fqdn, @@ -111,9 +115,25 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set AlwaysUpdate: settings.AutoUpdateAlways, }, } + + if peer.SupportsIPv6() && peer.IPv6.IsValid() && network.NetV6.IP != nil { + ones, _ := network.NetV6.Mask.Size() + v6Prefix := netip.PrefixFrom(peer.IPv6.Unmap(), ones) + if b, err := netiputil.EncodePrefix(v6Prefix); err == nil { + peerConfig.AddressV6 = b + } + } + + return peerConfig } func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { + // IPv6 data in AllowedIPs and SourcePrefixes wildcard expansion depends on + // whether the target peer supports IPv6. Routes and firewall rules are already + // filtered at the source (network map builder). + includeIPv6 := peer.SupportsIPv6() && peer.IPv6.IsValid() + useSourcePrefixes := peer.SupportsSourcePrefixes() + response := &proto.SyncResponse{ PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH), NetworkMap: &proto.NetworkMap{ @@ -132,15 +152,15 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb response.NetworkMap.PeerConfig = response.PeerConfig remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) - remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName) + remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6) response.RemotePeers = remotePeers response.NetworkMap.RemotePeers = remotePeers response.RemotePeersIsEmpty = len(remotePeers) == 0 response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty - response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName) + response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6) - firewallRules := toProtocolFirewallRules(networkMap.FirewallRules) + firewallRules := toProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes) response.NetworkMap.FirewallRules = firewallRules response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0 @@ -195,11 +215,15 @@ func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]m return hashedUsers, machineUsers } -func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig { +func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig { for _, rPeer := range peers { + allowedIPs := []string{rPeer.IP.String() + "/32"} + if includeIPv6 && rPeer.IPv6.IsValid() { + allowedIPs = append(allowedIPs, rPeer.IPv6.String()+"/128") + } dst = append(dst, &proto.RemotePeerConfig{ WgPubKey: rPeer.Key, - AllowedIps: []string{rPeer.IP.String() + "/32"}, + AllowedIps: allowedIPs, SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)}, Fqdn: rPeer.FQDN(dnsName), AgentVersion: rPeer.Meta.WtVersion, @@ -253,7 +277,7 @@ func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol { } } -func toProtocolRoutes(routes []*route.Route) []*proto.Route { +func toProtocolRoutes(routes []*nbroute.Route) []*proto.Route { protoRoutes := make([]*proto.Route, 0, len(routes)) for _, r := range routes { protoRoutes = append(protoRoutes, toProtocolRoute(r)) @@ -261,7 +285,7 @@ func toProtocolRoutes(routes []*route.Route) []*proto.Route { return protoRoutes } -func toProtocolRoute(route *route.Route) *proto.Route { +func toProtocolRoute(route *nbroute.Route) *proto.Route { return &proto.Route{ ID: string(route.ID), NetID: string(route.NetID), @@ -277,29 +301,70 @@ func toProtocolRoute(route *route.Route) *proto.Route { } // toProtocolFirewallRules converts the firewall rules to the protocol firewall rules. -func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule { - result := make([]*proto.FirewallRule, len(rules)) +// When useSourcePrefixes is true, the compact SourcePrefixes field is populated +// alongside the deprecated PeerIP for forward compatibility. +// Wildcard rules ("0.0.0.0") are expanded into separate v4 and v6 SourcePrefixes +// when includeIPv6 is true. +func toProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSourcePrefixes bool) []*proto.FirewallRule { + result := make([]*proto.FirewallRule, 0, len(rules)) for i := range rules { rule := rules[i] fwRule := &proto.FirewallRule{ PolicyID: []byte(rule.PolicyID), - PeerIP: rule.PeerIP, + PeerIP: rule.PeerIP, //nolint:staticcheck // populated for backward compatibility Direction: getProtoDirection(rule.Direction), Action: getProtoAction(rule.Action), Protocol: getProtoProtocol(rule.Protocol), Port: rule.Port, } + if useSourcePrefixes && rule.PeerIP != "" { + result = append(result, populateSourcePrefixes(fwRule, rule, includeIPv6)...) + } + if shouldUsePortRange(fwRule) { fwRule.PortInfo = rule.PortRange.ToProto() } - result[i] = fwRule + result = append(result, fwRule) } return result } + +// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any +// additional rules needed (e.g. a v6 wildcard clone when the peer IP is unspecified). +func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule { + addr, err := netip.ParseAddr(rule.PeerIP) + if err != nil { + return nil + } + + if !addr.IsUnspecified() { + fwRule.SourcePrefixes = [][]byte{netiputil.EncodeAddr(addr.Unmap())} + return nil + } + + // IPv4Unspecified/0 is always valid, error is impossible. + v4Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv4Unspecified(), 0)) + fwRule.SourcePrefixes = [][]byte{v4Wildcard} + + if !includeIPv6 { + return nil + } + + v6Rule := goproto.Clone(fwRule).(*proto.FirewallRule) + v6Rule.PeerIP = "::" //nolint:staticcheck // populated for backward compatibility + // IPv6Unspecified/0 is always valid, error is impossible. + v6Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv6Unspecified(), 0)) + v6Rule.SourcePrefixes = [][]byte{v6Wildcard} + if shouldUsePortRange(v6Rule) { + v6Rule.PortInfo = rule.PortRange.ToProto() + } + return []*proto.FirewallRule{v6Rule} +} + // getProtoDirection converts the direction to proto.RuleDirection. func getProtoDirection(direction int) proto.RuleDirection { if direction == types.FirewallRuleDirectionOUT { diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index a5e352e75..6763a3ba3 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -11,11 +11,14 @@ import ( "fmt" "net/http" "net/url" + "os" + "strconv" "strings" "sync" "time" "github.com/coreos/go-oidc/v3/oidc" + "github.com/google/uuid" log "github.com/sirupsen/logrus" "golang.org/x/oauth2" "google.golang.org/grpc/codes" @@ -81,14 +84,44 @@ type ProxyServiceServer struct { // Store for PKCE verifiers pkceVerifierStore *PKCEVerifierStore + // tokenTTL is the lifetime of one-time tokens generated for proxy + // authentication. Defaults to defaultProxyTokenTTL when zero. + tokenTTL time.Duration + + // snapshotBatchSize is the number of mappings per gRPC message during + // initial snapshot delivery. Configurable via NB_PROXY_SNAPSHOT_BATCH_SIZE. + snapshotBatchSize int + cancel context.CancelFunc } const pkceVerifierTTL = 10 * time.Minute +const defaultProxyTokenTTL = 5 * time.Minute + +const defaultSnapshotBatchSize = 500 + +func snapshotBatchSizeFromEnv() int { + if v := os.Getenv("NB_PROXY_SNAPSHOT_BATCH_SIZE"); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + return n + } + } + return defaultSnapshotBatchSize +} + +// proxyTokenTTL returns the configured token TTL or the default when unset. +func (s *ProxyServiceServer) proxyTokenTTL() time.Duration { + if s.tokenTTL > 0 { + return s.tokenTTL + } + return defaultProxyTokenTTL +} + // proxyConnection represents a connected proxy type proxyConnection struct { proxyID string + sessionID string address string capabilities *proto.ProxyCapabilities stream proto.ProxyService_GetMappingUpdateServer @@ -108,6 +141,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT peersManager: peersManager, usersManager: usersManager, proxyManager: proxyMgr, + snapshotBatchSize: snapshotBatchSizeFromEnv(), cancel: cancel, } go s.cleanupStaleProxies(ctx) @@ -166,9 +200,22 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest return status.Errorf(codes.InvalidArgument, "proxy address is invalid") } + sessionID := uuid.NewString() + + if old, loaded := s.connectedProxies.Load(proxyID); loaded { + oldConn := old.(*proxyConnection) + log.WithFields(log.Fields{ + "proxy_id": proxyID, + "old_session_id": oldConn.sessionID, + "new_session_id": sessionID, + }).Info("Superseding existing proxy connection") + oldConn.cancel() + } + connCtx, cancel := context.WithCancel(ctx) conn := &proxyConnection{ proxyID: proxyID, + sessionID: sessionID, address: proxyAddress, capabilities: req.GetCapabilities(), stream: stream, @@ -177,79 +224,93 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest cancel: cancel, } - s.connectedProxies.Store(proxyID, conn) - if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil { - log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err) - } - // Register proxy in database with capabilities var caps *proxy.Capabilities if c := req.GetCapabilities(); c != nil { caps = &proxy.Capabilities{ SupportsCustomPorts: c.SupportsCustomPorts, RequireSubdomain: c.RequireSubdomain, - SupportsCrowdsec: c.SupportsCrowdsec, + SupportsCrowdsec: c.SupportsCrowdsec, } } - if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil { + proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps) + if err != nil { log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err) - s.connectedProxies.Delete(proxyID) - if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil { - log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr) - } + cancel() return status.Errorf(codes.Internal, "register proxy in database: %v", err) } - log.WithFields(log.Fields{ - "proxy_id": proxyID, - "address": proxyAddress, - "cluster_addr": proxyAddress, - "total_proxies": len(s.GetConnectedProxies()), - }).Info("Proxy registered in cluster") - defer func() { - if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil { - log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err) - } - - s.connectedProxies.Delete(proxyID) - if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil { - log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err) - } - - cancel() - log.Infof("Proxy %s disconnected", proxyID) - }() + s.connectedProxies.Store(proxyID, conn) + if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil { + log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err) + } if err := s.sendSnapshot(ctx, conn); err != nil { + if s.connectedProxies.CompareAndDelete(proxyID, conn) { + if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil { + log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr) + } + } + cancel() + if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil { + log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr) + } return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err) } errChan := make(chan error, 2) go s.sender(conn, errChan) - // Start heartbeat goroutine - go s.heartbeat(connCtx, proxyID, proxyAddress, peerInfo) + log.WithFields(log.Fields{ + "proxy_id": proxyID, + "session_id": sessionID, + "address": proxyAddress, + "cluster_addr": proxyAddress, + "total_proxies": len(s.GetConnectedProxies()), + }).Info("Proxy registered in cluster") + defer func() { + if !s.connectedProxies.CompareAndDelete(proxyID, conn) { + log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", proxyID, sessionID) + cancel() + return + } + + if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil { + log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err) + } + if err := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); err != nil { + log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err) + } + + cancel() + log.Infof("Proxy %s session %s disconnected", proxyID, sessionID) + }() + + go s.heartbeat(connCtx, proxyRecord) select { case err := <-errChan: + log.WithContext(ctx).Warnf("Failed to send update: %v", err) return fmt.Errorf("send update to proxy %s: %w", proxyID, err) case <-connCtx.Done(): + log.WithContext(ctx).Infof("Proxy %s context canceled", proxyID) return connCtx.Err() } } // heartbeat updates the proxy's last_seen timestamp every minute -func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) { +func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() for { select { case <-ticker.C: - if err := s.proxyManager.Heartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil { - log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err) + if err := s.proxyManager.Heartbeat(ctx, p); err != nil { + log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", p.ID, err) } case <-ctx.Done(): + log.WithContext(ctx).Infof("proxy %s heartbeat stopped: context canceled", p.ID) return } } @@ -267,22 +328,27 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec return err } + // Send mappings in batches to reduce per-message gRPC overhead while + // staying well within the default 4 MB message size limit. + for i := 0; i < len(mappings); i += s.snapshotBatchSize { + end := i + s.snapshotBatchSize + if end > len(mappings) { + end = len(mappings) + } + if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ + Mapping: mappings[i:end], + InitialSyncComplete: end == len(mappings), + }); err != nil { + return fmt.Errorf("send snapshot batch: %w", err) + } + } + if len(mappings) == 0 { if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ InitialSyncComplete: true, }); err != nil { return fmt.Errorf("send snapshot completion: %w", err) } - return nil - } - - for i, m := range mappings { - if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ - Mapping: []*proto.ProxyMapping{m}, - InitialSyncComplete: i == len(mappings)-1, - }); err != nil { - return fmt.Errorf("send proxy mapping: %w", err) - } } return nil @@ -300,13 +366,9 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn * continue } - token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute) + token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL()) if err != nil { - log.WithFields(log.Fields{ - "service": service.Name, - "account": service.AccountID, - }).WithError(err).Error("failed to generate auth token for snapshot") - continue + return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err) } m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig()) @@ -386,13 +448,16 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes conn := value.(*proxyConnection) resp := s.perProxyMessage(update, conn.proxyID) if resp == nil { + log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID) + conn.cancel() return true } select { case conn.sendChan <- resp: log.Debugf("Sent service update to proxy server %s", conn.proxyID) default: - log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID) + log.Warnf("Send channel full for proxy %s, disconnecting to force resync", conn.proxyID) + conn.cancel() } return true }) @@ -472,13 +537,16 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd } msg := s.perProxyMessage(updateResponse, proxyID) if msg == nil { + log.WithContext(ctx).Warnf("Token generation failed for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr) + conn.cancel() continue } select { case conn.sendChan <- msg: log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr) default: - log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) + log.WithContext(ctx).Warnf("Send channel full for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr) + conn.cancel() } } } @@ -504,7 +572,8 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo // perProxyMessage returns a copy of update with a fresh one-time token for // create/update operations. For delete operations the original mapping is // used unchanged because proxies do not need to authenticate for removal. -// Returns nil if token generation fails (the proxy should be skipped). +// Returns nil if token generation fails; the caller must disconnect the +// proxy so it can resync via a fresh snapshot on reconnect. func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse { resp := make([]*proto.ProxyMapping, 0, len(update.Mapping)) for _, mapping := range update.Mapping { @@ -513,7 +582,7 @@ func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateRespo continue } - token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute) + token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, s.proxyTokenTTL()) if err != nil { log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err) return nil diff --git a/management/internals/shared/grpc/proxy_snapshot_test.go b/management/internals/shared/grpc/proxy_snapshot_test.go new file mode 100644 index 000000000..e0c7425c5 --- /dev/null +++ b/management/internals/shared/grpc/proxy_snapshot_test.go @@ -0,0 +1,174 @@ +package grpc + +import ( + "context" + "fmt" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// recordingStream captures all messages sent via Send so tests can inspect +// batching behaviour without a real gRPC transport. +type recordingStream struct { + grpc.ServerStream + messages []*proto.GetMappingUpdateResponse +} + +func (s *recordingStream) Send(m *proto.GetMappingUpdateResponse) error { + s.messages = append(s.messages, m) + return nil +} + +func (s *recordingStream) Context() context.Context { return context.Background() } +func (s *recordingStream) SetHeader(metadata.MD) error { return nil } +func (s *recordingStream) SendHeader(metadata.MD) error { return nil } +func (s *recordingStream) SetTrailer(metadata.MD) {} +func (s *recordingStream) SendMsg(any) error { return nil } +func (s *recordingStream) RecvMsg(any) error { return nil } + +// makeServices creates n enabled services assigned to the given cluster. +func makeServices(n int, cluster string) []*rpservice.Service { + services := make([]*rpservice.Service, n) + for i := range n { + services[i] = &rpservice.Service{ + ID: fmt.Sprintf("svc-%d", i), + AccountID: "acct-1", + Name: fmt.Sprintf("svc-%d", i), + Domain: fmt.Sprintf("svc-%d.example.com", i), + ProxyCluster: cluster, + Enabled: true, + Targets: []*rpservice.Target{ + {TargetType: rpservice.TargetTypeHost, TargetId: "host-1"}, + }, + } + } + return services +} + +func newSnapshotTestServer(t *testing.T, batchSize int) *ProxyServiceServer { + t.Helper() + s := &ProxyServiceServer{ + tokenStore: NewOneTimeTokenStore(context.Background(), testCacheStore(t)), + snapshotBatchSize: batchSize, + } + s.SetProxyController(newTestProxyController()) + return s +} + +func TestSendSnapshot_BatchesMappings(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 3 + const totalServices = 7 // 3 + 3 + 1 + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + + stream := &recordingStream{} + conn := &proxyConnection{ + proxyID: "proxy-a", + address: cluster, + stream: stream, + } + + err := s.sendSnapshot(context.Background(), conn) + require.NoError(t, err) + + // Expect ceil(7/3) = 3 messages + require.Len(t, stream.messages, 3, "should send ceil(totalServices/batchSize) messages") + + assert.Len(t, stream.messages[0].Mapping, 3) + assert.False(t, stream.messages[0].InitialSyncComplete, "first batch should not be sync-complete") + + assert.Len(t, stream.messages[1].Mapping, 3) + assert.False(t, stream.messages[1].InitialSyncComplete, "middle batch should not be sync-complete") + + assert.Len(t, stream.messages[2].Mapping, 1) + assert.True(t, stream.messages[2].InitialSyncComplete, "last batch must be sync-complete") + + // Verify all service IDs are present exactly once + seen := make(map[string]bool) + for _, msg := range stream.messages { + for _, m := range msg.Mapping { + assert.False(t, seen[m.Id], "duplicate service ID %s", m.Id) + seen[m.Id] = true + } + } + assert.Len(t, seen, totalServices) +} + +func TestSendSnapshot_ExactBatchMultiple(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 3 + const totalServices = 6 // exactly 2 batches + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + + stream := &recordingStream{} + conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream} + + require.NoError(t, s.sendSnapshot(context.Background(), conn)) + require.Len(t, stream.messages, 2) + + assert.Len(t, stream.messages[0].Mapping, 3) + assert.False(t, stream.messages[0].InitialSyncComplete) + + assert.Len(t, stream.messages[1].Mapping, 3) + assert.True(t, stream.messages[1].InitialSyncComplete) +} + +func TestSendSnapshot_SingleBatch(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 100 + const totalServices = 5 + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + + stream := &recordingStream{} + conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream} + + require.NoError(t, s.sendSnapshot(context.Background(), conn)) + require.Len(t, stream.messages, 1, "all mappings should fit in one batch") + assert.Len(t, stream.messages[0].Mapping, totalServices) + assert.True(t, stream.messages[0].InitialSyncComplete) +} + +func TestSendSnapshot_EmptySnapshot(t *testing.T) { + const cluster = "cluster.example.com" + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil) + + s := newSnapshotTestServer(t, 500) + s.serviceManager = mgr + + stream := &recordingStream{} + conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream} + + require.NoError(t, s.sendSnapshot(context.Background(), conn)) + require.Len(t, stream.messages, 1, "empty snapshot must still send sync-complete") + assert.Empty(t, stream.messages[0].Mapping) + assert.True(t, stream.messages[0].InitialSyncComplete) +} diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index de4e96d93..5a7a457df 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -85,11 +85,14 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan // registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities. func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse { ch := make(chan *proto.GetMappingUpdateResponse, 10) + ctx, cancel := context.WithCancel(context.Background()) conn := &proxyConnection{ proxyID: proxyID, address: clusterAddr, capabilities: caps, sendChan: ch, + ctx: ctx, + cancel: cancel, } s.connectedProxies.Store(proxyID, conn) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 6e8358f02..70024bac6 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -14,6 +14,7 @@ import ( "sync/atomic" "time" + jwtv5 "github.com/golang-jwt/jwt/v5" pb "github.com/golang/protobuf/proto" // nolint "github.com/golang/protobuf/ptypes/timestamp" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" @@ -67,6 +68,7 @@ type Server struct { appMetrics telemetry.AppMetrics peerLocks sync.Map authManager auth.Manager + sessionStore *auth.SessionStore logBlockedPeers bool blockPeersWithSameConfig bool @@ -98,6 +100,7 @@ func NewServer( integratedPeerValidator integrated_validator.IntegratedValidator, networkMapController network_map.Controller, oAuthConfigProvider idp.OAuthConfigProvider, + sessionStore *auth.SessionStore, ) (*Server, error) { if appMetrics != nil { // update gauge based on number of connected peers which is equal to open gRPC streams @@ -140,6 +143,7 @@ func NewServer( integratedPeerValidator: integratedPeerValidator, networkMapController: networkMapController, oAuthConfigProvider: oAuthConfigProvider, + sessionStore: sessionStore, loginFilter: newLoginFilter(), @@ -535,7 +539,7 @@ func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID st log.WithContext(ctx).Debugf("peer %s has been disconnected", peer.Key) } -func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) { +func (s *Server) validateToken(ctx context.Context, peerKey, jwtToken string) (string, error) { if s.authManager == nil { return "", status.Errorf(codes.Internal, "missing auth manager") } @@ -545,6 +549,10 @@ func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, er return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err) } + if err := s.claimLoginToken(ctx, peerKey, jwtToken, token); err != nil { + return "", err + } + // we need to call this method because if user is new, we will automatically add it to existing or create a new account accountId, _, err := s.accountManager.GetAccountIDFromUserAuth(ctx, userAuth) if err != nil { @@ -672,11 +680,21 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee BlockLANAccess: meta.GetFlags().GetBlockLANAccess(), BlockInbound: meta.GetFlags().GetBlockInbound(), LazyConnectionEnabled: meta.GetFlags().GetLazyConnectionEnabled(), + DisableIPv6: meta.GetFlags().GetDisableIPv6(), }, - Files: files, + Files: files, + Capabilities: capabilitiesToInt32(meta.GetCapabilities()), } } +func capabilitiesToInt32(caps []proto.PeerCapability) []int32 { + result := make([]int32, len(caps)) + for i, c := range caps { + result[i] = int32(c) + } + return result +} + func (s *Server) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) { peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) if err != nil { @@ -828,6 +846,31 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne return loginResp, nil } +func (s *Server) claimLoginToken(ctx context.Context, peerKey, jwtToken string, token *jwtv5.Token) error { + if s.sessionStore == nil || token == nil { + return nil + } + + exp, err := token.Claims.GetExpirationTime() + if err != nil || exp == nil { + log.WithContext(ctx).Warnf("JWT has no usable exp claim for peer %s", peerKey) + return status.Error(codes.Unauthenticated, "jwt token has no expiration") + } + + err = s.sessionStore.RegisterToken(ctx, jwtToken, exp.Time) + if err == nil { + return nil + } + + if errors.Is(err, auth.ErrTokenAlreadyUsed) || errors.Is(err, auth.ErrTokenExpired) { + log.WithContext(ctx).Warnf("%v for peer %s", err, peerKey) + return status.Error(codes.Unauthenticated, err.Error()) + } + + log.WithContext(ctx).Warnf("failed to claim JWT for peer %s: %v", peerKey, err) + return status.Error(codes.Unavailable, "failed to claim jwt token") +} + // processJwtToken validates the existence of a JWT token in the login request, and returns the corresponding user ID if // the token is valid. // @@ -838,7 +881,7 @@ func (s *Server) processJwtToken(ctx context.Context, loginReq *proto.LoginReque if loginReq.GetJwtToken() != "" { var err error for i := 0; i < 3; i++ { - userID, err = s.validateToken(ctx, loginReq.GetJwtToken()) + userID, err = s.validateToken(ctx, peerKey.String(), loginReq.GetJwtToken()) if err == nil { break } diff --git a/management/server/account.go b/management/server/account.go index 7d53cef03..45b99839f 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -329,6 +329,13 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco updateAccountPeers = true } + if ipv6SettingsChanged(oldSettings, newSettings) { + if err = am.updatePeerIPv6Addresses(ctx, transaction, accountID, newSettings); err != nil { + return err + } + updateAccountPeers = true + } + if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled || oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled || oldSettings.DNSDomain != newSettings.DNSDomain || @@ -338,7 +345,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled && newSettings.GroupsPropagationEnabled { - groupsUpdated, groupChangesAffectPeers, err = propagateUserGroupMemberships(ctx, transaction, accountID) + groupsUpdated, groupChangesAffectPeers, err = am.propagateUserGroupMemberships(ctx, transaction, accountID) if err != nil { return err } @@ -393,6 +400,22 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta) } + oldIPv6On := len(oldSettings.IPv6EnabledGroups) > 0 + newIPv6On := len(newSettings.IPv6EnabledGroups) > 0 + if oldIPv6On != newIPv6On { + if newIPv6On { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountIPv6Enabled, nil) + } else { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountIPv6Disabled, nil) + } + } + if oldSettings.NetworkRangeV6 != newSettings.NetworkRangeV6 { + eventMeta := map[string]any{ + "old_network_range_v6": oldSettings.NetworkRangeV6.String(), + "new_network_range_v6": newSettings.NetworkRangeV6.String(), + } + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta) + } if reloadReverseProxy { if err = am.serviceManager.ReloadAllServicesForAccount(ctx, accountID); err != nil { log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err) @@ -400,12 +423,23 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers { - go am.UpdateAccountPeers(ctx, accountID) + go am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceAccountSettings, Operation: types.UpdateOperationUpdate}) } return newSettings, nil } +func ipv6SettingsChanged(old, updated *types.Settings) bool { + if old.NetworkRangeV6 != updated.NetworkRangeV6 { + return true + } + oldGroups := slices.Clone(old.IPv6EnabledGroups) + newGroups := slices.Clone(updated.IPv6EnabledGroups) + slices.Sort(oldGroups) + slices.Sort(newGroups) + return !slices.Equal(oldGroups, newGroups) +} + func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { @@ -432,9 +466,38 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra } } + if err := validateIPv6EnabledGroups(ctx, transaction, accountID, newSettings.IPv6EnabledGroups); err != nil { + return err + } + return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, userID, accountID) } +// validateIPv6EnabledGroups checks that all referenced IPv6-enabled group IDs exist in the account. +func validateIPv6EnabledGroups(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) error { + if len(groupIDs) == 0 { + return nil + } + + groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return fmt.Errorf("get groups for IPv6 validation: %w", err) + } + + existing := make(map[string]struct{}, len(groups)) + for _, g := range groups { + existing[g.ID] = struct{}{} + } + + for _, gid := range groupIDs { + if _, ok := existing[gid]; !ok { + return status.Errorf(status.InvalidArgument, "IPv6 enabled group %s does not exist", gid) + } + } + + return nil +} + func (am *DefaultAccountManager) handleRoutingPeerDNSResolutionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) { if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled { if newSettings.RoutingPeerDNSResolutionEnabled { @@ -739,37 +802,8 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err) } - for _, otherUser := range account.Users { - if otherUser.Id == userID { - continue - } - - if otherUser.IsServiceUser { - err = am.deleteServiceUser(ctx, accountID, userID, otherUser) - if err != nil { - return err - } - continue - } - - userInfo, ok := userInfosMap[otherUser.Id] - if !ok { - return status.Errorf(status.NotFound, "user info not found for user %s", otherUser.Id) - } - - _, deleteUserErr := am.deleteRegularUser(ctx, accountID, userID, userInfo) - if deleteUserErr != nil { - return deleteUserErr - } - } - - userInfo, ok := userInfosMap[userID] - if ok { - _, err = am.deleteRegularUser(ctx, accountID, userID, userInfo) - if err != nil { - log.WithContext(ctx).Errorf("failed deleting user %s. error: %s", userID, err) - return err - } + if err = am.deleteAccountUsers(ctx, accountID, userID, account.Users, userInfosMap); err != nil { + return err } err = am.Store.DeleteAccount(ctx, account) @@ -787,6 +821,40 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return nil } +func (am *DefaultAccountManager) deleteAccountUsers(ctx context.Context, accountID, initiatorUserID string, users map[string]*types.User, userInfosMap map[string]*types.UserInfo) error { + for _, otherUser := range users { + if otherUser.Id == initiatorUserID { + continue + } + + if otherUser.IsServiceUser { + if err := am.deleteServiceUser(ctx, accountID, initiatorUserID, otherUser); err != nil { + return err + } + continue + } + + userInfo, ok := userInfosMap[otherUser.Id] + if !ok { + return status.Errorf(status.NotFound, "user info not found for user %s", otherUser.Id) + } + + if _, err := am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo); err != nil { + return err + } + } + + userInfo, ok := userInfosMap[initiatorUserID] + if ok { + if _, err := am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo); err != nil { + log.WithContext(ctx).Errorf("failed deleting user %s. error: %s", initiatorUserID, err) + return err + } + } + + return nil +} + // AccountExists checks if an account exists. func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { return am.Store.AccountExists(ctx, store.LockingStrengthNone, accountID) @@ -1528,6 +1596,11 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth } } + allGroupChanges := slices.Concat(addNewGroups, removeOldGroups) + if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, userAuth.AccountId, allGroupChanges); err != nil { + return fmt.Errorf("reconcile IPv6 for group changes: %w", err) + } + if err = transaction.IncrementNetworkSerial(ctx, userAuth.AccountId); err != nil { return fmt.Errorf("error incrementing network serial: %w", err) } @@ -1581,7 +1654,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth if removedGroupAffectsPeers || newGroupsAffectsPeers { log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId) - am.BufferUpdateAccountPeers(ctx, userAuth.AccountId) + am.BufferUpdateAccountPeers(ctx, userAuth.AccountId, types.UpdateReason{Resource: types.UpdateResourceUser, Operation: types.UpdateOperationUpdate}) } return nil @@ -1913,6 +1986,11 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain, email, nam if err := acc.AddAllGroup(disableDefaultPolicy); err != nil { log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err) } + + if allGroup, err := acc.GetGroupAll(); err == nil { + acc.Settings.IPv6EnabledGroups = []string{allGroup.ID} + } + return acc } @@ -2019,6 +2097,10 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C return nil, false, status.Errorf(status.Internal, "failed to add all group to new account by private domain") } + if allGroup, err := newAccount.GetGroupAll(); err == nil { + newAccount.Settings.IPv6EnabledGroups = []string{allGroup.ID} + } + if err := am.Store.SaveAccount(ctx, newAccount); err != nil { log.WithContext(ctx).WithFields(log.Fields{ "accountId": newAccount.Id, @@ -2080,7 +2162,7 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc // propagateUserGroupMemberships propagates all account users' group memberships to their peers. // Returns true if any groups were modified, true if those updates affect peers and an error. -func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) { +func (am *DefaultAccountManager) propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) { users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID) if err != nil { return false, false, err @@ -2102,29 +2184,13 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, } } - updatedGroups := []string{} - for _, user := range users { - userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, user.Id) - if err != nil { - return false, false, err - } + updatedGroups, err := propagateAutoGroupsForUsers(ctx, transaction, accountID, users, accountGroupPeers) + if err != nil { + return false, false, err + } - for _, peer := range userPeers { - for _, groupID := range user.AutoGroups { - if _, exists := accountGroupPeers[groupID]; !exists { - // we do not wanna create the groups here - log.WithContext(ctx).Warnf("group %s does not exist for user group propagation", groupID) - continue - } - if _, exists := accountGroupPeers[groupID][peer.ID]; exists { - continue - } - if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil { - return false, false, fmt.Errorf("error adding peer %s to group %s: %w", peer.ID, groupID, err) - } - updatedGroups = append(updatedGroups, groupID) - } - } + if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, updatedGroups); err != nil { + return false, false, fmt.Errorf("reconcile IPv6 for group changes: %w", err) } peersAffected, err = areGroupChangesAffectPeers(ctx, transaction, accountID, updatedGroups) @@ -2135,6 +2201,35 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, return len(updatedGroups) > 0, peersAffected, nil } +// propagateAutoGroupsForUsers adds each user's peers to their AutoGroups where not already present. +// Returns the list of group IDs that were modified. +func propagateAutoGroupsForUsers(ctx context.Context, transaction store.Store, accountID string, users []*types.User, accountGroupPeers map[string]map[string]struct{}) ([]string, error) { + var updatedGroups []string + for _, user := range users { + userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, user.Id) + if err != nil { + return nil, err + } + + for _, peer := range userPeers { + for _, groupID := range user.AutoGroups { + if _, exists := accountGroupPeers[groupID]; !exists { + log.WithContext(ctx).Warnf("group %s does not exist for user group propagation", groupID) + continue + } + if _, exists := accountGroupPeers[groupID][peer.ID]; exists { + continue + } + if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil { + return nil, fmt.Errorf("error adding peer %s to group %s: %w", peer.ID, groupID, err) + } + updatedGroups = append(updatedGroups, groupID) + } + } + } + return updatedGroups, nil +} + // reallocateAccountPeerIPs re-allocates all peer IPs when the network range changes func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, transaction store.Store, accountID string, newNetworkRange netip.Prefix) error { if !newNetworkRange.IsValid() { @@ -2156,10 +2251,10 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t return err } - var takenIPs []net.IP + var takenIPs []netip.Addr for _, peer := range peers { - newIP, err := types.AllocatePeerIP(newIPNet, takenIPs) + newIP, err := types.AllocatePeerIP(newNetworkRange, takenIPs) if err != nil { return status.Errorf(status.Internal, "allocate IP for peer %s: %v", peer.ID, err) } @@ -2183,13 +2278,199 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t return nil } +// updatePeerIPv6Addresses assigns or removes IPv6 addresses for all peers +// based on the current IPv6 settings. When IPv6 is enabled, peers without a +// v6 address get one allocated. When disabled, all v6 addresses are cleared. +// When the v6 range changes, all v6 addresses are reallocated. +func (am *DefaultAccountManager) checkIPv6Collision(ctx context.Context, transaction store.Store, accountID, peerID string, newIPv6 netip.Addr) error { + peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "") + if err != nil { + return fmt.Errorf("get peers: %w", err) + } + for _, p := range peers { + if p.ID != peerID && p.IPv6.IsValid() && p.IPv6 == newIPv6 { + return status.Errorf(status.InvalidArgument, "IPv6 %s is already assigned to peer %s", newIPv6, p.Name) + } + } + return nil +} + +func (am *DefaultAccountManager) updatePeerIPv6Addresses(ctx context.Context, transaction store.Store, accountID string, settings *types.Settings) error { + peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthUpdate, accountID, "", "") + if err != nil { + return fmt.Errorf("get peers: %w", err) + } + + network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthUpdate, accountID) + if err != nil { + return fmt.Errorf("get network: %w", err) + } + + if err := am.ensureIPv6Subnet(ctx, transaction, accountID, settings, network); err != nil { + return err + } + + allowedPeers, err := am.buildIPv6AllowedPeers(ctx, transaction, accountID, settings) + if err != nil { + return err + } + + v6Prefix, err := netip.ParsePrefix(network.NetV6.String()) + if err != nil { + return fmt.Errorf("parse IPv6 prefix: %w", err) + } + + if err := am.assignPeerIPv6Addresses(ctx, transaction, accountID, peers, network, allowedPeers, v6Prefix); err != nil { + return err + } + + log.WithContext(ctx).Infof("updated IPv6 addresses for %d peers in account %s (groups=%d)", + len(peers), accountID, len(settings.IPv6EnabledGroups)) + + return nil +} + +// reconcileIPv6ForGroupChanges checks whether the given group IDs overlap with +// the account's IPv6EnabledGroups. If they do, it runs a full IPv6 address +// reconciliation so that peers gaining or losing membership in an IPv6-enabled +// group get their addresses assigned or removed. +func (am *DefaultAccountManager) reconcileIPv6ForGroupChanges(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) error { + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return fmt.Errorf("get account settings: %w", err) + } + + if len(settings.IPv6EnabledGroups) == 0 { + return nil + } + + enabledSet := make(map[string]struct{}, len(settings.IPv6EnabledGroups)) + for _, gid := range settings.IPv6EnabledGroups { + enabledSet[gid] = struct{}{} + } + + affected := false + for _, gid := range groupIDs { + if _, ok := enabledSet[gid]; ok { + affected = true + break + } + } + + if !affected { + return nil + } + + return am.updatePeerIPv6Addresses(ctx, transaction, accountID, settings) +} + +func (am *DefaultAccountManager) ensureIPv6Subnet(ctx context.Context, transaction store.Store, accountID string, settings *types.Settings, network *types.Network) error { + if settings.NetworkRangeV6.IsValid() { + network.NetV6 = net.IPNet{ + IP: settings.NetworkRangeV6.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(settings.NetworkRangeV6.Bits(), 128), + } + return transaction.UpdateAccountNetworkV6(ctx, accountID, network.NetV6) + } + if network.NetV6.IP == nil { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + network.NetV6 = types.AllocateIPv6Subnet(r) + + // Sync settings to match the allocated subnet so SaveAccountSettings persists it. + ones, _ := network.NetV6.Mask.Size() + addr, _ := netip.AddrFromSlice(network.NetV6.IP) + settings.NetworkRangeV6 = netip.PrefixFrom(addr.Unmap(), ones) + + return transaction.UpdateAccountNetworkV6(ctx, accountID, network.NetV6) + } + return nil +} + +func (am *DefaultAccountManager) assignPeerIPv6Addresses( + ctx context.Context, transaction store.Store, accountID string, + peers []*nbpeer.Peer, network *types.Network, + allowedPeers map[string]struct{}, v6Prefix netip.Prefix, +) error { + takenV6 := make(map[netip.Addr]struct{}) + for _, peer := range peers { + if _, ok := allowedPeers[peer.ID]; ok && peer.IPv6.IsValid() && network.NetV6.Contains(peer.IPv6.AsSlice()) { + takenV6[peer.IPv6] = struct{}{} + } + } + + for _, peer := range peers { + _, allowed := allowedPeers[peer.ID] + oldIPv6 := peer.IPv6 + + if !allowed { + peer.IPv6 = netip.Addr{} + } else if !peer.IPv6.IsValid() || !network.NetV6.Contains(peer.IPv6.AsSlice()) { + newIP, err := allocateIPv6WithRetry(v6Prefix, takenV6, peer.ID) + if err != nil { + return err + } + peer.IPv6 = newIP + } + + if peer.IPv6 == oldIPv6 { + continue + } + + if err := transaction.SavePeer(ctx, accountID, peer); err != nil { + return fmt.Errorf("save peer %s: %w", peer.ID, err) + } + } + return nil +} + +func allocateIPv6WithRetry(prefix netip.Prefix, taken map[netip.Addr]struct{}, peerID string) (netip.Addr, error) { + for attempts := 0; attempts < 10; attempts++ { + newIP, err := types.AllocateRandomPeerIPv6(prefix) + if err != nil { + return netip.Addr{}, fmt.Errorf("allocate v6 for peer %s: %w", peerID, err) + } + if _, ok := taken[newIP]; !ok { + taken[newIP] = struct{}{} + return newIP, nil + } + } + return netip.Addr{}, fmt.Errorf("allocate v6 for peer %s: exhausted 10 attempts", peerID) +} + +func (am *DefaultAccountManager) buildIPv6AllowedPeers(ctx context.Context, transaction store.Store, accountID string, settings *types.Settings) (map[string]struct{}, error) { + if len(settings.IPv6EnabledGroups) == 0 { + return make(map[string]struct{}), nil + } + + groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return nil, fmt.Errorf("get groups: %w", err) + } + + enabledSet := make(map[string]struct{}, len(settings.IPv6EnabledGroups)) + for _, gid := range settings.IPv6EnabledGroups { + enabledSet[gid] = struct{}{} + } + + allowedPeers := make(map[string]struct{}) + for _, group := range groups { + if _, ok := enabledSet[group.ID]; !ok { + continue + } + for _, peerID := range group.Peers { + allowedPeers[peerID] = struct{}{} + } + } + return allowedPeers, nil +} + func (am *DefaultAccountManager) validateIPForUpdate(account *types.Account, peers []*nbpeer.Peer, peerID string, newIP netip.Addr) error { if !account.Network.Net.Contains(newIP.AsSlice()) { return status.Errorf(status.InvalidArgument, "IP %s is not within the account network range %s", newIP.String(), account.Network.Net.String()) } for _, peer := range peers { - if peer.ID != peerID && peer.IP.Equal(newIP.AsSlice()) { + if peer.ID != peerID && peer.IP == newIP { return status.Errorf(status.InvalidArgument, "IP %s is already assigned to peer %s", newIP.String(), peer.ID) } } @@ -2236,7 +2517,7 @@ func (am *DefaultAccountManager) updatePeerIPInTransaction(ctx context.Context, return fmt.Errorf("get peer: %w", err) } - if existingPeer.IP.Equal(newIP.AsSlice()) { + if existingPeer.IP == newIP { return nil } @@ -2271,7 +2552,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti eventMeta := peer.EventMeta(dnsDomain) oldIP := peer.IP.String() - peer.IP = newIP.AsSlice() + peer.IP = newIP err = transaction.SavePeer(ctx, accountID, peer) if err != nil { return fmt.Errorf("save peer: %w", err) @@ -2284,6 +2565,84 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti return nil } +// UpdatePeerIPv6 updates the IPv6 overlay address of a peer, validating it's +// within the account's v6 network range and not already taken. +func (am *DefaultAccountManager) UpdatePeerIPv6(ctx context.Context, accountID, userID, peerID string, newIPv6 netip.Addr) error { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update) + if err != nil { + return fmt.Errorf("validate user permissions: %w", err) + } + if !allowed { + return status.NewPermissionDeniedError() + } + + var updateNetworkMap bool + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var txErr error + updateNetworkMap, txErr = am.updatePeerIPv6InTransaction(ctx, transaction, accountID, peerID, newIPv6) + return txErr + }) + if err != nil { + return err + } + + if updateNetworkMap { + if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peerID}); err != nil { + return fmt.Errorf("notify network map controller: %w", err) + } + } + return nil +} + +// updatePeerIPv6InTransaction validates and applies an IPv6 address change within a store transaction. +func (am *DefaultAccountManager) updatePeerIPv6InTransaction(ctx context.Context, transaction store.Store, accountID, peerID string, newIPv6 netip.Addr) (bool, error) { + network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return false, fmt.Errorf("get network: %w", err) + } + + if network.NetV6.IP == nil { + return false, status.Errorf(status.PreconditionFailed, "IPv6 is not configured for this account") + } + + if !network.NetV6.Contains(newIPv6.AsSlice()) { + return false, status.Errorf(status.InvalidArgument, "IP %s is not within the account IPv6 range %s", newIPv6, network.NetV6.String()) + } + + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return false, fmt.Errorf("get settings: %w", err) + } + + allowedPeers, err := am.buildIPv6AllowedPeers(ctx, transaction, accountID, settings) + if err != nil { + return false, err + } + if _, ok := allowedPeers[peerID]; !ok { + return false, status.Errorf(status.PreconditionFailed, "peer is not in any IPv6-enabled group") + } + + peer, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID) + if err != nil { + return false, fmt.Errorf("get peer: %w", err) + } + + if peer.IPv6.IsValid() && peer.IPv6 == newIPv6 { + return false, nil + } + + if err := am.checkIPv6Collision(ctx, transaction, accountID, peerID, newIPv6); err != nil { + return false, err + } + + peer.IPv6 = newIPv6 + if err := transaction.SavePeer(ctx, accountID, peer); err != nil { + return false, fmt.Errorf("save peer: %w", err) + } + + return true, nil +} + func (am *DefaultAccountManager) GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error) { return am.Store.GetUserIDByPeerKey(ctx, store.LockingStrengthNone, peerKey) } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index b4516d512..71af0645c 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -65,6 +65,7 @@ type Manager interface { DeletePeer(ctx context.Context, accountID, peerID, userID string) error UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error + UpdatePeerIPv6(ctx context.Context, accountID, userID, peerID string, newIPv6 netip.Addr) error GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error) AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) @@ -124,8 +125,8 @@ type Manager interface { GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error - UpdateAccountPeers(ctx context.Context, accountID string) - BufferUpdateAccountPeers(ctx context.Context, accountID string) + UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) + BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error GetStore() store.Store diff --git a/management/server/account/manager_mock.go b/management/server/account/manager_mock.go index 36e5fe39f..7ffc41d73 100644 --- a/management/server/account/manager_mock.go +++ b/management/server/account/manager_mock.go @@ -111,15 +111,15 @@ func (mr *MockManagerMockRecorder) ApproveUser(ctx, accountID, initiatorUserID, } // BufferUpdateAccountPeers mocks base method. -func (m *MockManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { +func (m *MockManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) { m.ctrl.T.Helper() - m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID) + m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID, reason) } // BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers. -func (mr *MockManagerMockRecorder) BufferUpdateAccountPeers(ctx, accountID interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, reason interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).BufferUpdateAccountPeers), ctx, accountID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).BufferUpdateAccountPeers), ctx, accountID, reason) } // BuildUserInfosForAccount mocks base method. @@ -1597,15 +1597,15 @@ func (mr *MockManagerMockRecorder) UpdateAccountOnboarding(ctx, accountID, userI } // UpdateAccountPeers mocks base method. -func (m *MockManager) UpdateAccountPeers(ctx context.Context, accountID string) { +func (m *MockManager) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) { m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID) + m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID, reason) } // UpdateAccountPeers indicates an expected call of UpdateAccountPeers. -func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID, reason interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID, reason) } // UpdateAccountSettings mocks base method. @@ -1709,6 +1709,18 @@ func (mr *MockManagerMockRecorder) UpdatePeerIP(ctx, accountID, userID, peerID, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePeerIP", reflect.TypeOf((*MockManager)(nil).UpdatePeerIP), ctx, accountID, userID, peerID, newIP) } +func (m *MockManager) UpdatePeerIPv6(ctx context.Context, accountID, userID, peerID string, newIPv6 netip.Addr) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdatePeerIPv6", ctx, accountID, userID, peerID, newIPv6) + ret0, _ := ret[0].(error) + return ret0 +} + +func (mr *MockManagerMockRecorder) UpdatePeerIPv6(ctx, accountID, userID, peerID, newIPv6 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePeerIPv6", reflect.TypeOf((*MockManager)(nil).UpdatePeerIPv6), ctx, accountID, userID, peerID, newIPv6) +} + // UpdateToPrimaryAccount mocks base method. func (m *MockManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) error { m.ctrl.T.Helper() diff --git a/management/server/account/pat.go b/management/server/account/pat.go new file mode 100644 index 000000000..8e5e3e3f9 --- /dev/null +++ b/management/server/account/pat.go @@ -0,0 +1,8 @@ +package account + +const ( + // PATMinExpireDays is the minimum allowed Personal Access Token expiration in days. + PATMinExpireDays = 1 + // PATMaxExpireDays is the maximum allowed Personal Access Token expiration in days. + PATMaxExpireDays = 365 +) diff --git a/management/server/account_test.go b/management/server/account_test.go index 4453d064e..6bb875f99 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -160,7 +160,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { "peer-1": { ID: peerID1, Key: "peer-1-key", - IP: net.IP{100, 64, 0, 1}, + IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), + IPv6: netip.MustParseAddr("fd00::6440:1"), Name: peerID1, DNSLabel: peerID1, Status: &nbpeer.PeerStatus{ @@ -174,7 +175,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { "peer-2": { ID: peerID2, Key: "peer-2-key", - IP: net.IP{100, 64, 0, 1}, + IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), + IPv6: netip.MustParseAddr("fd00::6440:1"), Name: peerID2, DNSLabel: peerID2, Status: &nbpeer.PeerStatus{ @@ -198,7 +200,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { "peer-1": { ID: peerID1, Key: "peer-1-key", - IP: net.IP{100, 64, 0, 1}, + IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), + IPv6: netip.MustParseAddr("fd00::6440:1"), Name: peerID1, DNSLabel: peerID1, Status: &nbpeer.PeerStatus{ @@ -213,7 +216,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { "peer-2": { ID: peerID2, Key: "peer-2-key", - IP: net.IP{100, 64, 0, 1}, + IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), + IPv6: netip.MustParseAddr("fd00::6440:1"), Name: peerID2, DNSLabel: peerID2, Status: &nbpeer.PeerStatus{ @@ -237,7 +241,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { // "peer-1": { // ID: peerID1, // Key: "peer-1-key", - // IP: net.IP{100, 64, 0, 1}, + // IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), // Name: peerID1, // DNSLabel: peerID1, // Status: &PeerStatus{ @@ -251,7 +255,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { // "peer-2": { // ID: peerID2, // Key: "peer-2-key", - // IP: net.IP{100, 64, 0, 1}, + // IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), // Name: peerID2, // DNSLabel: peerID2, // Status: &PeerStatus{ @@ -265,7 +269,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { // "peer-3": { // ID: peerID3, // Key: "peer-3-key", - // IP: net.IP{100, 64, 0, 1}, + // IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), // Name: peerID3, // DNSLabel: peerID3, // Status: &PeerStatus{ @@ -288,7 +292,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { // "peer-1": { // ID: peerID1, // Key: "peer-1-key", - // IP: net.IP{100, 64, 0, 1}, + // IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), // Name: peerID1, // DNSLabel: peerID1, // Status: &PeerStatus{ @@ -302,7 +306,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { // "peer-2": { // ID: peerID2, // Key: "peer-2-key", - // IP: net.IP{100, 64, 0, 1}, + // IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), // Name: peerID2, // DNSLabel: peerID2, // Status: &PeerStatus{ @@ -316,7 +320,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { // "peer-3": { // ID: peerID3, // Key: "peer-3-key", - // IP: net.IP{100, 64, 0, 1}, + // IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), // Name: peerID3, // DNSLabel: peerID3, // Status: &PeerStatus{ @@ -339,7 +343,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { // "peer-1": { // ID: peerID1, // Key: "peer-1-key", - // IP: net.IP{100, 64, 0, 1}, + // IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), // Name: peerID1, // DNSLabel: peerID1, // Status: &PeerStatus{ @@ -353,7 +357,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { // "peer-2": { // ID: peerID2, // Key: "peer-2-key", - // IP: net.IP{100, 64, 0, 1}, + // IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), // Name: peerID2, // DNSLabel: peerID2, // Status: &PeerStatus{ @@ -367,7 +371,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { // "peer-3": { // ID: peerID3, // Key: "peer-3-key", - // IP: net.IP{100, 64, 0, 1}, + // IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), // Name: peerID3, // DNSLabel: peerID3, // Status: &PeerStatus{ @@ -408,7 +412,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { } customZone := account.GetPeersCustomZone(context.Background(), "netbird.io") - networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) + networkMap := account.GetPeerNetworkMapFromComponents(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) } @@ -1084,7 +1088,7 @@ func TestAccountManager_AddPeer(t *testing.T) { t.Errorf("expecting just added peer to have key = %s, got %s", expectedPeerKey, peer.Key) } - if !account.Network.Net.Contains(peer.IP) { + if !account.Network.Net.Contains(peer.IP.AsSlice()) { t.Errorf("expecting just added peer's IP %s to be in a network range %s", peer.IP.String(), account.Network.Net.String()) } @@ -1148,7 +1152,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { t.Errorf("expecting just added peer to have key = %s, got %s", expectedPeerKey, peer.Key) } - if !account.Network.Net.Contains(peer.IP) { + if !account.Network.Net.Contains(peer.IP.AsSlice()) { t.Errorf("expecting just added peer's IP %s to be in a network range %s", peer.IP.String(), account.Network.Net.String()) } @@ -1171,11 +1175,6 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"])) } -func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") - testAccountManager_NetworkUpdates_SaveGroup(t) -} - func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { testAccountManager_NetworkUpdates_SaveGroup(t) } @@ -1231,11 +1230,6 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { wg.Wait() } -func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") - testAccountManager_NetworkUpdates_DeletePolicy(t) -} - func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { testAccountManager_NetworkUpdates_DeletePolicy(t) } @@ -1274,11 +1268,6 @@ func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { wg.Wait() } -func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") - testAccountManager_NetworkUpdates_SavePolicy(t) -} - func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { testAccountManager_NetworkUpdates_SavePolicy(t) } @@ -1332,11 +1321,6 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { wg.Wait() } -func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") - testAccountManager_NetworkUpdates_DeletePeer(t) -} - func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { testAccountManager_NetworkUpdates_DeletePeer(t) } @@ -1397,11 +1381,6 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { wg.Wait() } -func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") - testAccountManager_NetworkUpdates_DeleteGroup(t) -} - func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { testAccountManager_NetworkUpdates_DeleteGroup(t) } @@ -1633,75 +1612,6 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) { assert.Contains(t, routeIDs, route.ID("route-2")) } -func TestAccount_GetRoutesToSync(t *testing.T) { - _, prefix, err := route.ParseNetwork("192.168.64.0/24") - if err != nil { - t.Fatal(err) - } - _, prefix2, err := route.ParseNetwork("192.168.0.0/24") - if err != nil { - t.Fatal(err) - } - account := &types.Account{ - Peers: map[string]*nbpeer.Peer{ - "peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, - }, - Groups: map[string]*types.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, - Routes: map[route.ID]*route.Route{ - "route-1": { - ID: "route-1", - Network: prefix, - NetID: "network-1", - Description: "network-1", - Peer: "peer-1", - NetworkType: 0, - Masquerade: false, - Metric: 999, - Enabled: true, - Groups: []string{"group1"}, - }, - "route-2": { - ID: "route-2", - Network: prefix2, - NetID: "network-2", - Description: "network-2", - Peer: "peer-2", - NetworkType: 0, - Masquerade: false, - Metric: 999, - Enabled: true, - Groups: []string{"group1"}, - }, - "route-3": { - ID: "route-3", - Network: prefix, - NetID: "network-1", - Description: "network-1", - Peer: "peer-2", - NetworkType: 0, - Masquerade: false, - Metric: 999, - Enabled: true, - Groups: []string{"group1"}, - }, - }, - } - - routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}, account.GetPeerGroups("peer-2")) - - assert.Len(t, routes, 2) - routeIDs := make(map[route.ID]struct{}, 2) - for _, r := range routes { - routeIDs[r.ID] = struct{}{} - } - assert.Contains(t, routeIDs, route.ID("route-2")) - assert.Contains(t, routeIDs, route.ID("route-3")) - - emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}, account.GetPeerGroups("peer-3")) - - assert.Len(t, emptyRoutes, 0) -} - func TestAccount_Copy(t *testing.T) { account := &types.Account{ Id: "account1", @@ -1824,9 +1734,7 @@ func TestAccount_Copy(t *testing.T) { AccountID: "account1", }, }, - NetworkMapCache: &types.NetworkMapBuilder{}, } - account.InitOnce() err := hasNilField(account) if err != nil { t.Fatal(err) @@ -1857,7 +1765,7 @@ func hasNilField(x interface{}) error { if f := rv.Field(i); f.IsValid() { k := f.Kind() switch k { - case reflect.Ptr: + case reflect.Pointer: if f.IsNil() { return fmt.Errorf("field %s is nil", f.String()) } @@ -2311,6 +2219,29 @@ func TestAccount_GetExpiredPeers(t *testing.T) { } } +func TestGetExpiredPeers_SkipsAlreadyExpired(t *testing.T) { + ctx := context.Background() + + testStore, cleanUp, err := store.NewTestStoreFromSQL(ctx, "testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanUp) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + // Verify the already-expired peer is excluded at the store level + peers, err := testStore.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID) + require.NoError(t, err) + + for _, peer := range peers { + assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should be excluded by the store query") + assert.False(t, peer.Status.LoginExpired, "returned peers should not already be marked as login expired") + } + + // Only the non-expired peer with expiration enabled should be returned + require.Len(t, peers, 1) + assert.Equal(t, "notexpired01", peers[0].ID) +} + func TestAccount_GetInactivePeers(t *testing.T) { type test struct { name string @@ -2861,11 +2792,46 @@ func TestAccount_SetJWTGroups(t *testing.T) { account := &types.Account{ Id: "accountID", Peers: map[string]*nbpeer.Peer{ - "peer1": {ID: "peer1", Key: "key1", UserID: "user1", IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}, - "peer2": {ID: "peer2", Key: "key2", UserID: "user1", IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}, - "peer3": {ID: "peer3", Key: "key3", UserID: "user1", IP: net.IP{3, 3, 3, 3}, DNSLabel: "peer3.domain.test"}, - "peer4": {ID: "peer4", Key: "key4", UserID: "user2", IP: net.IP{4, 4, 4, 4}, DNSLabel: "peer4.domain.test"}, - "peer5": {ID: "peer5", Key: "key5", UserID: "user2", IP: net.IP{5, 5, 5, 5}, DNSLabel: "peer5.domain.test"}, + "peer1": { + ID: "peer1", + Key: "key1", + UserID: "user1", + IP: netip.AddrFrom4([4]byte{1, 1, 1, 1}), + IPv6: netip.MustParseAddr("fd00::1"), + DNSLabel: "peer1.domain.test", + }, + "peer2": { + ID: "peer2", + Key: "key2", + UserID: "user1", + IP: netip.AddrFrom4([4]byte{2, 2, 2, 2}), + IPv6: netip.MustParseAddr("fd00::2"), + DNSLabel: "peer2.domain.test", + }, + "peer3": { + ID: "peer3", + Key: "key3", + UserID: "user1", + IP: netip.AddrFrom4([4]byte{3, 3, 3, 3}), + IPv6: netip.MustParseAddr("fd00::3"), + DNSLabel: "peer3.domain.test", + }, + "peer4": { + ID: "peer4", + Key: "key4", + UserID: "user2", + IP: netip.AddrFrom4([4]byte{4, 4, 4, 4}), + IPv6: netip.MustParseAddr("fd00::4"), + DNSLabel: "peer4.domain.test", + }, + "peer5": { + ID: "peer5", + Key: "key5", + UserID: "user2", + IP: netip.AddrFrom4([4]byte{5, 5, 5, 5}), + IPv6: netip.MustParseAddr("fd00::5"), + DNSLabel: "peer5.domain.test", + }, }, Groups: map[string]*types.Group{ "group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}}, @@ -3230,6 +3196,13 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel. return manager, updateManager, account, peer1, peer2, peer3 } +// peerUpdateTimeout bounds how long peerShouldReceiveUpdate and its outer +// wrappers wait for an expected update message. Sized for slow CI runners +// (MySQL, FreeBSD, loaded sqlite) where the channel publish can take +// seconds. Only runs down on failure; passing tests return immediately +// when the channel delivers. +const peerUpdateTimeout = 5 * time.Second + func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) { t.Helper() select { @@ -3248,7 +3221,7 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.Upd if msg == nil { t.Errorf("Received nil update message, expected valid message") } - case <-time.After(500 * time.Millisecond): + case <-time.After(peerUpdateTimeout): t.Error("Timed out waiting for update message") } } @@ -3615,16 +3588,32 @@ func TestPropagateUserGroupMemberships(t *testing.T) { account, err := manager.GetOrCreateAccountByUser(ctx, auth.UserAuth{UserId: initiatorId, Domain: domain}) require.NoError(t, err) - peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, Key: "key1", UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"} + peer1 := &nbpeer.Peer{ + ID: "peer1", + AccountID: account.Id, + Key: "key1", + UserID: initiatorId, + IP: netip.AddrFrom4([4]byte{1, 1, 1, 1}), + IPv6: netip.MustParseAddr("fd00::1"), + DNSLabel: "peer1.domain.test", + } err = manager.Store.AddPeerToAccount(ctx, peer1) require.NoError(t, err) - peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, Key: "key2", UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"} + peer2 := &nbpeer.Peer{ + ID: "peer2", + AccountID: account.Id, + Key: "key2", + UserID: initiatorId, + IP: netip.AddrFrom4([4]byte{2, 2, 2, 2}), + IPv6: netip.MustParseAddr("fd00::2"), + DNSLabel: "peer2.domain.test", + } err = manager.Store.AddPeerToAccount(ctx, peer2) require.NoError(t, err) t.Run("should skip propagation when the user has no groups", func(t *testing.T) { - groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) + groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id) require.NoError(t, err) assert.False(t, groupsUpdated) assert.False(t, groupChangesAffectPeers) @@ -3640,7 +3629,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) { user.AutoGroups = append(user.AutoGroups, group1.ID) require.NoError(t, manager.Store.SaveUser(ctx, user)) - groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) + groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id) require.NoError(t, err) assert.True(t, groupsUpdated) assert.False(t, groupChangesAffectPeers) @@ -3678,7 +3667,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) { }, true) require.NoError(t, err) - groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) + groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id) require.NoError(t, err) assert.True(t, groupsUpdated) assert.True(t, groupChangesAffectPeers) @@ -3693,7 +3682,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) { }) t.Run("should not update membership or account peers when no changes", func(t *testing.T) { - groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) + groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id) require.NoError(t, err) assert.False(t, groupsUpdated) assert.False(t, groupChangesAffectPeers) @@ -3706,7 +3695,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) { user.AutoGroups = []string{"group1"} require.NoError(t, manager.Store.SaveUser(ctx, user)) - groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) + groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id) require.NoError(t, err) assert.False(t, groupsUpdated) assert.False(t, groupChangesAffectPeers) @@ -3820,11 +3809,10 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) { account, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "unable to get account") - newIP, err := types.AllocatePeerIP(account.Network.Net, []net.IP{peer1.IP, peer2.IP}) + newIP, err := types.AllocatePeerIP(netip.MustParsePrefix(account.Network.Net.String()), []netip.Addr{peer1.IP, peer2.IP}) require.NoError(t, err, "unable to allocate new IP") - newAddr := netip.MustParseAddr(newIP.String()) - err = manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, newAddr) + err = manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, newIP) require.NoError(t, err, "unable to update peer IP") updatedPeer, err := manager.GetPeer(context.Background(), accountID, peer1.ID, userID) @@ -3982,6 +3970,109 @@ func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange(t *testi } } +func TestDefaultAccountManager_UpdateAccountSettings_IPv6EnabledGroups(t *testing.T) { + manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + ctx := context.Background() + accountID := account.Id + + // New accounts default to All group in IPv6EnabledGroups, so all 3 peers should have IPv6. + settings, err := manager.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) + require.NoError(t, err) + require.NotEmpty(t, settings.IPv6EnabledGroups, "new account should have IPv6 enabled for All group") + + peers, err := manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") + require.NoError(t, err) + for _, p := range peers { + assert.True(t, p.IPv6.IsValid(), "peer %s should have IPv6 with All group enabled", p.ID) + } + + // Create a group with only peer1 and peer2. + partialGroup := &types.Group{ + ID: "ipv6-partial-group", + AccountID: accountID, + Name: "IPv6Partial", + } + err = manager.Store.CreateGroup(ctx, partialGroup) + require.NoError(t, err) + require.NoError(t, manager.Store.AddPeerToGroup(ctx, accountID, peer1.ID, partialGroup.ID)) + require.NoError(t, manager.Store.AddPeerToGroup(ctx, accountID, peer2.ID, partialGroup.ID)) + + // Switch IPv6EnabledGroups to only the partial group. + updatedSettings, err := manager.UpdateAccountSettings(ctx, accountID, userID, &types.Settings{ + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + PeerLoginExpirationEnabled: true, + IPv6EnabledGroups: []string{partialGroup.ID}, + Extra: &types.ExtraSettings{}, + }) + require.NoError(t, err) + assert.Equal(t, []string{partialGroup.ID}, updatedSettings.IPv6EnabledGroups) + + // peer1 and peer2 should have IPv6; peer3 should not. + peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") + require.NoError(t, err) + peerMap := make(map[string]*nbpeer.Peer, len(peers)) + for _, p := range peers { + peerMap[p.ID] = p + } + assert.True(t, peerMap[peer1.ID].IPv6.IsValid(), "peer1 in partial group should keep IPv6") + assert.True(t, peerMap[peer2.ID].IPv6.IsValid(), "peer2 in partial group should keep IPv6") + assert.False(t, peerMap[peer3.ID].IPv6.IsValid(), "peer3 not in partial group should lose IPv6") + + // Clearing all groups disables IPv6 for everyone. + updatedSettings, err = manager.UpdateAccountSettings(ctx, accountID, userID, &types.Settings{ + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + PeerLoginExpirationEnabled: true, + IPv6EnabledGroups: []string{}, + Extra: &types.ExtraSettings{}, + }) + require.NoError(t, err) + assert.Empty(t, updatedSettings.IPv6EnabledGroups) + + peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") + require.NoError(t, err) + for _, p := range peers { + assert.False(t, p.IPv6.IsValid(), "peer %s should have no IPv6 when groups cleared", p.ID) + } + + // Re-enabling with the partial group should allocate IPv6 only for peer1 and peer2. + _, err = manager.UpdateAccountSettings(ctx, accountID, userID, &types.Settings{ + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + PeerLoginExpirationEnabled: true, + IPv6EnabledGroups: []string{partialGroup.ID}, + Extra: &types.ExtraSettings{}, + }) + require.NoError(t, err) + + peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") + require.NoError(t, err) + peerMap = make(map[string]*nbpeer.Peer, len(peers)) + for _, p := range peers { + peerMap[p.ID] = p + } + assert.True(t, peerMap[peer1.ID].IPv6.IsValid(), "peer1 should get IPv6 back") + assert.True(t, peerMap[peer2.ID].IPv6.IsValid(), "peer2 should get IPv6 back") + assert.False(t, peerMap[peer3.ID].IPv6.IsValid(), "peer3 still excluded") + + // No-op update with the same groups should not cause errors. + _, err = manager.UpdateAccountSettings(ctx, accountID, userID, &types.Settings{ + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + PeerLoginExpirationEnabled: true, + IPv6EnabledGroups: []string{partialGroup.ID}, + Extra: &types.ExtraSettings{}, + }) + require.NoError(t, err) + + // Setting a nonexistent group ID should fail. + _, err = manager.UpdateAccountSettings(ctx, accountID, userID, &types.Settings{ + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + PeerLoginExpirationEnabled: true, + IPv6EnabledGroups: []string{"nonexistent-group-id"}, + Extra: &types.ExtraSettings{}, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "does not exist") +} + func TestUpdateUserAuthWithSingleMode(t *testing.T) { t.Run("sets defaults and overrides domain from store", func(t *testing.T) { ctrl := gomock.NewController(t) diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index ddc3e00c3..2388115ff 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -231,6 +231,10 @@ const ( DomainDeleted Activity = 119 // DomainValidated indicates that a custom domain was validated DomainValidated Activity = 120 + // AccountIPv6Enabled indicates that a user enabled IPv6 overlay for the account + AccountIPv6Enabled Activity = 121 + // AccountIPv6Disabled indicates that a user disabled IPv6 overlay for the account + AccountIPv6Disabled Activity = 122 AccountDeleted Activity = 99999 ) @@ -347,6 +351,9 @@ var activityMap = map[Activity]Code{ AccountAutoUpdateAlwaysEnabled: {"Account auto-update always enabled", "account.setting.auto.update.always.enable"}, AccountAutoUpdateAlwaysDisabled: {"Account auto-update always disabled", "account.setting.auto.update.always.disable"}, + AccountIPv6Enabled: {"Account IPv6 overlay enabled", "account.setting.ipv6.enable"}, + AccountIPv6Disabled: {"Account IPv6 overlay disabled", "account.setting.ipv6.disable"}, + IdentityProviderCreated: {"Identity provider created", "identityprovider.create"}, IdentityProviderUpdated: {"Identity provider updated", "identityprovider.update"}, IdentityProviderDeleted: {"Identity provider deleted", "identityprovider.delete"}, diff --git a/management/server/auth/session.go b/management/server/auth/session.go new file mode 100644 index 000000000..7621a1c10 --- /dev/null +++ b/management/server/auth/session.go @@ -0,0 +1,61 @@ +package auth + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "time" + + "github.com/eko/gocache/lib/v4/cache" + "github.com/eko/gocache/lib/v4/store" +) + +const ( + usedTokenKeyPrefix = "jwt-used:" + usedTokenMarker = "1" +) + +var ( + ErrTokenAlreadyUsed = errors.New("JWT already used") + ErrTokenExpired = errors.New("JWT expired") +) + +type SessionStore struct { + cache *cache.Cache[string] +} + +func NewSessionStore(cacheStore store.StoreInterface) *SessionStore { + return &SessionStore{cache: cache.New[string](cacheStore)} +} + +// RegisterToken records a JWT until its exp time and rejects reuse. +func (s *SessionStore) RegisterToken(ctx context.Context, token string, expiresAt time.Time) error { + ttl := time.Until(expiresAt) + if ttl <= 0 { + return ErrTokenExpired + } + + key := usedTokenKeyPrefix + hashToken(token) + _, err := s.cache.Get(ctx, key) + if err == nil { + return ErrTokenAlreadyUsed + } + + var notFound *store.NotFound + if !errors.As(err, ¬Found) { + return fmt.Errorf("failed to lookup used token entry: %w", err) + } + + if err := s.cache.Set(ctx, key, usedTokenMarker, store.WithExpiration(ttl)); err != nil { + return fmt.Errorf("failed to store used token entry: %w", err) + } + + return nil +} + +func hashToken(token string) string { + sum := sha256.Sum256([]byte(token)) + return hex.EncodeToString(sum[:]) +} diff --git a/management/server/auth/session_test.go b/management/server/auth/session_test.go new file mode 100644 index 000000000..3a7d85f4c --- /dev/null +++ b/management/server/auth/session_test.go @@ -0,0 +1,82 @@ +package auth + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbcache "github.com/netbirdio/netbird/management/server/cache" +) + +func newTestSessionStore(t *testing.T) *SessionStore { + t.Helper() + cacheStore, err := nbcache.NewStore(context.Background(), time.Hour, time.Hour, 100) + require.NoError(t, err) + return NewSessionStore(cacheStore) +} + +func TestSessionStore_FirstRegisterSucceeds(t *testing.T) { + s := newTestSessionStore(t) + ctx := context.Background() + + require.NoError(t, s.RegisterToken(ctx, "token", time.Now().Add(time.Hour))) +} + +func TestSessionStore_RegisterSameTokenTwiceIsRejected(t *testing.T) { + s := newTestSessionStore(t) + ctx := context.Background() + token := "token" + exp := time.Now().Add(time.Hour) + + require.NoError(t, s.RegisterToken(ctx, token, exp)) + + err := s.RegisterToken(ctx, token, exp) + require.Error(t, err) + assert.ErrorIs(t, err, ErrTokenAlreadyUsed) +} + +func TestSessionStore_RegisterDifferentTokensAreIndependent(t *testing.T) { + s := newTestSessionStore(t) + ctx := context.Background() + exp := time.Now().Add(time.Hour) + + require.NoError(t, s.RegisterToken(ctx, "tokenA", exp)) + require.NoError(t, s.RegisterToken(ctx, "tokenB", exp)) +} + +func TestSessionStore_RegisterWithPastExpiryIsRejected(t *testing.T) { + s := newTestSessionStore(t) + ctx := context.Background() + token := "token" + + err := s.RegisterToken(ctx, token, time.Now().Add(-time.Second)) + require.Error(t, err) + assert.ErrorIs(t, err, ErrTokenExpired) +} + +func TestSessionStore_EntryEvictsAtTTLAndAllowsReRegistration(t *testing.T) { + s := newTestSessionStore(t) + ctx := context.Background() + token := "token" + + require.NoError(t, s.RegisterToken(ctx, token, time.Now().Add(50*time.Millisecond))) + + err := s.RegisterToken(ctx, token, time.Now().Add(50*time.Millisecond)) + assert.ErrorIs(t, err, ErrTokenAlreadyUsed) + + time.Sleep(120 * time.Millisecond) + + require.NoError(t, s.RegisterToken(ctx, token, time.Now().Add(time.Hour))) +} + +func TestHashToken_StableAndDoesNotLeak(t *testing.T) { + a := hashToken("tokenA") + b := hashToken("tokenB") + assert.Equal(t, a, hashToken("tokenA"), "hash must be deterministic") + assert.NotEqual(t, a, b, "different tokens must hash differently") + assert.Len(t, a, 64, "sha256 hex must be 64 chars") + assert.NotContains(t, a, "tokenA", "raw token must not appear in hash") +} diff --git a/management/server/dns.go b/management/server/dns.go index baf6debc3..c62fa5185 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -86,7 +86,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceDNSSettings, Operation: types.UpdateOperationUpdate}) } return nil diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 0e37a3b22..c443223c6 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -458,7 +458,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -478,7 +478,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -518,7 +518,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) diff --git a/management/server/group.go b/management/server/group.go index 7b5b9b86c..870a441ac 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -117,7 +117,7 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate}) } return nil @@ -174,6 +174,10 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use return err } + if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{newGroup.ID}); err != nil { + return err + } + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { @@ -185,7 +189,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate}) } return nil @@ -253,7 +257,7 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate}) } return globalErr @@ -278,37 +282,17 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us var globalErr error groupIDs := make([]string, 0, len(groups)) for _, newGroup := range groups { - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { - return err - } - - newGroup.AccountID = accountID - - if err = transaction.UpdateGroup(ctx, newGroup); err != nil { - return err - } - - err = transaction.IncrementNetworkSerial(ctx, accountID) - if err != nil { - return err - } - - events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) - eventsToStore = append(eventsToStore, events...) - - groupIDs = append(groupIDs, newGroup.ID) - - return nil - }) + events, err := am.updateSingleGroup(ctx, accountID, userID, newGroup) if err != nil { log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err) if len(groups) == 1 { return err } globalErr = errors.Join(globalErr, err) - // continue updating other groups + continue } + eventsToStore = append(eventsToStore, events...) + groupIDs = append(groupIDs, newGroup.ID) } updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs) @@ -321,12 +305,39 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate}) } return globalErr } +func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) ([]func(), error) { + var events []func() + err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + if err := validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { + return err + } + + newGroup.AccountID = accountID + + if err := transaction.UpdateGroup(ctx, newGroup); err != nil { + return err + } + + if err := am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{newGroup.ID}); err != nil { + return err + } + + if err := transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + return err + } + + events = am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) + return nil + }) + return events, err +} + // prepareGroupEvents prepares a list of event functions to be stored. func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction store.Store, accountID, userID string, newGroup *types.Group) []func() { var eventsToStore []func() @@ -458,6 +469,10 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us return err } + if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, groupIDsToDelete); err != nil { + return err + } + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { @@ -486,6 +501,10 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr return err } + if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil { + return err + } + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { @@ -493,7 +512,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate}) } return nil @@ -531,7 +550,7 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate}) } return nil @@ -552,6 +571,10 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, return err } + if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil { + return err + } + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { @@ -559,7 +582,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate}) } return nil @@ -597,7 +620,7 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate}) } return nil diff --git a/management/server/group_ipv6_test.go b/management/server/group_ipv6_test.go new file mode 100644 index 000000000..e4603c879 --- /dev/null +++ b/management/server/group_ipv6_test.go @@ -0,0 +1,125 @@ +package server + +import ( + "context" + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +// TestGroupIPv6Assignment verifies that peers gain or lose IPv6 addresses +// when they are added to or removed from an IPv6-enabled group. +func TestGroupIPv6Assignment(t *testing.T) { + am, _, err := createManager(t) + require.NoError(t, err) + + ctx := context.Background() + userID := groupAdminUserID + + account, err := createAccount(am, "ipv6-grp-test", userID, "ipv6test.example.com") + require.NoError(t, err) + + // Allocate IPv6 subnet for the account + account.Network.NetV6 = types.AllocateIPv6Subnet(rand.New(rand.NewSource(time.Now().UnixNano()))) + require.NoError(t, am.Store.SaveAccount(ctx, account)) + + // Create setup key + setupKey, err := am.CreateSetupKey(ctx, account.Id, "ipv6-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false) + require.NoError(t, err) + + // Create an IPv6-enabled group + ipv6GroupID := "ipv6-enabled-grp" + err = am.CreateGroup(ctx, account.Id, userID, &types.Group{ + ID: ipv6GroupID, + Name: "IPv6 Enabled", + Issued: types.GroupIssuedAPI, + Peers: []string{}, + }) + require.NoError(t, err) + + // Enable IPv6 on that group + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, account.Id) + require.NoError(t, err) + settings.IPv6EnabledGroups = []string{ipv6GroupID} + require.NoError(t, am.Store.SaveAccountSettings(ctx, account.Id, settings)) + + // Register a peer (will be in "All" group, not the IPv6 group) + key, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + peer, _, _, err := am.AddPeer(ctx, "", setupKey.Key, "", &nbpeer.Peer{ + Key: key.PublicKey().String(), + Meta: nbpeer.PeerSystemMeta{Hostname: "ipv6-test-host"}, + }, false) + require.NoError(t, err) + assert.False(t, peer.IPv6.IsValid(), "peer should not have IPv6 before joining an IPv6-enabled group") + + t.Run("GroupAddPeer assigns IPv6", func(t *testing.T) { + err := am.GroupAddPeer(ctx, account.Id, ipv6GroupID, peer.ID) + require.NoError(t, err) + + p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID) + require.NoError(t, err) + assert.True(t, p.IPv6.IsValid(), "peer should have an IPv6 address after joining the group") + }) + + t.Run("GroupDeletePeer clears IPv6", func(t *testing.T) { + err := am.GroupDeletePeer(ctx, account.Id, ipv6GroupID, peer.ID) + require.NoError(t, err) + + p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID) + require.NoError(t, err) + assert.False(t, p.IPv6.IsValid(), "peer should not have IPv6 after removal from the group") + }) + + t.Run("UpdateGroup with peer addition assigns IPv6", func(t *testing.T) { + grp, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, account.Id, ipv6GroupID) + require.NoError(t, err) + + grp.Peers = append(grp.Peers, peer.ID) + err = am.UpdateGroup(ctx, account.Id, userID, grp) + require.NoError(t, err) + + p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID) + require.NoError(t, err) + assert.True(t, p.IPv6.IsValid(), "peer should have IPv6 after UpdateGroup adds it") + }) + + t.Run("UpdateGroup with peer removal clears IPv6", func(t *testing.T) { + grp, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, account.Id, ipv6GroupID) + require.NoError(t, err) + + grp.Peers = []string{} + err = am.UpdateGroup(ctx, account.Id, userID, grp) + require.NoError(t, err) + + p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID) + require.NoError(t, err) + assert.False(t, p.IPv6.IsValid(), "peer should lose IPv6 after UpdateGroup removes it") + }) + + t.Run("non-IPv6 group changes do not affect IPv6", func(t *testing.T) { + err := am.CreateGroup(ctx, account.Id, userID, &types.Group{ + ID: "regular-grp", + Name: "Regular Group", + Issued: types.GroupIssuedAPI, + Peers: []string{}, + }) + require.NoError(t, err) + + err = am.GroupAddPeer(ctx, account.Id, "regular-grp", peer.ID) + require.NoError(t, err) + + p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID) + require.NoError(t, err) + assert.False(t, p.IPv6.IsValid(), "peer should not get IPv6 from a non-IPv6 group") + }) +} diff --git a/management/server/group_test.go b/management/server/group_test.go index fa818e532..22fda2671 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "errors" "fmt" - "net" "net/netip" "strconv" "sync" @@ -620,7 +619,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -638,7 +637,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -656,7 +655,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -689,7 +688,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -730,7 +729,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -757,7 +756,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -804,7 +803,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -999,10 +998,10 @@ func Test_AddPeerAndAddToAll(t *testing.T) { assert.Equal(t, totalPeers, len(account.Peers), "Expected %d peers in account %s, got %d", totalPeers, accountID, len(account.Peers)) } -func uint32ToIP(n uint32) net.IP { - ip := make(net.IP, 4) - binary.BigEndian.PutUint32(ip, n) - return ip +func uint32ToIP(n uint32) netip.Addr { + var b [4]byte + binary.BigEndian.PutUint32(b[:], n) + return netip.AddrFrom4(b) } func Test_IncrementNetworkSerial(t *testing.T) { diff --git a/management/server/http/handler.go b/management/server/http/handler.go index ad36b9d46..b9ea605d3 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -5,9 +5,6 @@ import ( "fmt" "net/http" "net/netip" - "os" - "strconv" - "time" "github.com/gorilla/mux" "github.com/rs/cors" @@ -65,15 +62,10 @@ import ( "github.com/netbirdio/netbird/management/server/telemetry" ) -const ( - apiPrefix = "/api" - rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED" - rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST" - rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM" -) +const apiPrefix = "/api" // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. -func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) { +func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) { // Register bypass paths for unauthenticated endpoints if err := bypass.AddBypassPath("/api/instance"); err != nil { @@ -94,34 +86,10 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks return nil, fmt.Errorf("failed to add bypass path: %w", err) } - var rateLimitingConfig *middleware.RateLimiterConfig - if os.Getenv(rateLimitingEnabledKey) == "true" { - rpm := 6 - if v := os.Getenv(rateLimitingRPMKey); v != "" { - value, err := strconv.Atoi(v) - if err != nil { - log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm) - } else { - rpm = value - } - } - - burst := 500 - if v := os.Getenv(rateLimitingBurstKey); v != "" { - value, err := strconv.Atoi(v) - if err != nil { - log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst) - } else { - burst = value - } - } - - rateLimitingConfig = &middleware.RateLimiterConfig{ - RequestsPerMinute: float64(rpm), - Burst: burst, - CleanupInterval: 6 * time.Hour, - LimiterTTL: 24 * time.Hour, - } + if rateLimiter == nil { + log.Warn("NewAPIHandler: nil rate limiter, rate limiting disabled") + rateLimiter = middleware.NewAPIRateLimiter(nil) + rateLimiter.SetEnabled(false) } authMiddleware := middleware.NewAuthMiddleware( @@ -129,7 +97,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks accountManager.GetAccountIDFromUserAuth, accountManager.SyncUserJWTGroups, accountManager.GetUserFromUserAuth, - rateLimitingConfig, + rateLimiter, appMetrics.GetMeter(), ) @@ -171,7 +139,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks zonesManager.RegisterEndpoints(router, zManager) recordsManager.RegisterEndpoints(router, rManager) idp.AddEndpoints(accountManager, router) - instance.AddEndpoints(instanceManager, router) + instance.AddEndpoints(instanceManager, accountManager, router) instance.AddVersionEndpoint(instanceManager, router) if serviceManager != nil && reverseProxyDomainManager != nil { reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router) diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index cc5567e3d..31820b9fb 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -4,10 +4,13 @@ import ( "context" "encoding/json" "fmt" + "math" "net/http" "net/netip" "time" + log "github.com/sirupsen/logrus" + "github.com/gorilla/mux" goversion "github.com/hashicorp/go-version" @@ -29,7 +32,9 @@ const ( // MinNetworkBits is the minimum prefix length for IPv4 network ranges (e.g., /29 gives 8 addresses, /28 gives 16) MinNetworkBitsIPv4 = 28 // MinNetworkBitsIPv6 is the minimum prefix length for IPv6 network ranges - MinNetworkBitsIPv6 = 120 + MinNetworkBitsIPv6 = 120 + // MaxNetworkSizeIPv6 is the largest allowed IPv6 prefix (smallest number) + MaxNetworkSizeIPv6 = 48 disableAutoUpdate = "disabled" autoUpdateLatestVersion = "latest" ) @@ -76,12 +81,35 @@ func validateMinimumSize(prefix netip.Prefix) error { if addr.Is4() && prefix.Bits() > MinNetworkBitsIPv4 { return status.Errorf(status.InvalidArgument, "network range too small: minimum size is /%d for IPv4", MinNetworkBitsIPv4) } - if addr.Is6() && prefix.Bits() > MinNetworkBitsIPv6 { - return status.Errorf(status.InvalidArgument, "network range too small: minimum size is /%d for IPv6", MinNetworkBitsIPv6) + if addr.Is6() { + if prefix.Bits() > MinNetworkBitsIPv6 { + return status.Errorf(status.InvalidArgument, "network range too small: minimum size is /%d for IPv6", MinNetworkBitsIPv6) + } + if prefix.Bits() < MaxNetworkSizeIPv6 { + return status.Errorf(status.InvalidArgument, "network range too large: maximum size is /%d for IPv6", MaxNetworkSizeIPv6) + } } return nil } +func (h *handler) parseAndValidateNetworkRange(ctx context.Context, accountID, userID, rangeStr string, requireV6 bool) (netip.Prefix, error) { + prefix, err := netip.ParsePrefix(rangeStr) + if err != nil { + return netip.Prefix{}, status.Errorf(status.InvalidArgument, "invalid CIDR format: %v", err) + } + prefix = prefix.Masked() + if requireV6 && !prefix.Addr().Is6() { + return netip.Prefix{}, status.Errorf(status.InvalidArgument, "network range must be an IPv6 address") + } + if !requireV6 && prefix.Addr().Is6() { + return netip.Prefix{}, status.Errorf(status.InvalidArgument, "network range must be an IPv4 address") + } + if err := h.validateNetworkRange(ctx, accountID, userID, prefix); err != nil { + return netip.Prefix{}, err + } + return prefix, nil +} + func (h *handler) validateNetworkRange(ctx context.Context, accountID, userID string, networkRange netip.Prefix) error { if !networkRange.IsValid() { return nil @@ -117,9 +145,12 @@ func (h *handler) validateCapacity(ctx context.Context, accountID, userID string } func calculateMaxHosts(prefix netip.Prefix) int64 { - availableAddresses := prefix.Addr().BitLen() - prefix.Bits() - maxHosts := int64(1) << availableAddresses + hostBits := prefix.Addr().BitLen() - prefix.Bits() + if hostBits >= 63 { + return math.MaxInt64 + } + maxHosts := int64(1) << hostBits if prefix.Addr().Is4() { maxHosts -= 2 // network and broadcast addresses } @@ -164,6 +195,24 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { } resp := toAccountResponse(accountID, settings, meta, onboarding) + + // Populate effective network ranges when settings don't have explicit overrides. + if resp.Settings.NetworkRange == nil || resp.Settings.NetworkRangeV6 == nil { + v4, v6, err := h.settingsManager.GetEffectiveNetworkRanges(r.Context(), accountID) + if err != nil { + log.WithContext(r.Context()).Warnf("get effective network ranges: %v", err) + } else { + if resp.Settings.NetworkRange == nil && v4.IsValid() { + s := v4.String() + resp.Settings.NetworkRange = &s + } + if resp.Settings.NetworkRangeV6 == nil && v6.IsValid() { + s := v6.String() + resp.Settings.NetworkRangeV6 = &s + } + } + } + util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } @@ -228,6 +277,9 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS if req.Settings.AutoUpdateAlways != nil { returnSettings.AutoUpdateAlways = *req.Settings.AutoUpdateAlways } + if req.Settings.Ipv6EnabledGroups != nil { + returnSettings.IPv6EnabledGroups = *req.Settings.Ipv6EnabledGroups + } return returnSettings, nil } @@ -262,18 +314,23 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { return } if req.Settings.NetworkRange != nil && *req.Settings.NetworkRange != "" { - prefix, err := netip.ParsePrefix(*req.Settings.NetworkRange) + prefix, err := h.parseAndValidateNetworkRange(r.Context(), accountID, userID, *req.Settings.NetworkRange, false) if err != nil { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid CIDR format: %v", err), w) - return - } - if err := h.validateNetworkRange(r.Context(), accountID, userID, prefix); err != nil { util.WriteError(r.Context(), err, w) return } settings.NetworkRange = prefix } + if req.Settings.NetworkRangeV6 != nil && *req.Settings.NetworkRangeV6 != "" { + prefix, err := h.parseAndValidateNetworkRange(r.Context(), accountID, userID, *req.Settings.NetworkRangeV6, true) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + settings.NetworkRangeV6 = prefix + } + var onboarding *types.AccountOnboarding if req.Onboarding != nil { onboarding = &types.AccountOnboarding{ @@ -352,6 +409,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A DnsDomain: &settings.DNSDomain, AutoUpdateVersion: &settings.AutoUpdateVersion, AutoUpdateAlways: &settings.AutoUpdateAlways, + Ipv6EnabledGroups: &settings.IPv6EnabledGroups, EmbeddedIdpEnabled: &settings.EmbeddedIdpEnabled, LocalAuthDisabled: &settings.LocalAuthDisabled, } @@ -360,6 +418,10 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A networkRangeStr := settings.NetworkRange.String() apiSettings.NetworkRange = &networkRangeStr } + if settings.NetworkRangeV6.IsValid() { + networkRangeV6Str := settings.NetworkRangeV6.String() + apiSettings.NetworkRangeV6 = &networkRangeV6Str + } apiOnboarding := api.AccountOnboarding{ OnboardingFlowPending: onboarding.OnboardingFlowPending, diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index 739dfe2f6..fc1517a30 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -5,8 +5,10 @@ import ( "context" "encoding/json" "io" + "math" "net/http" "net/http/httptest" + "net/netip" "testing" "time" @@ -31,6 +33,10 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler { GetSettings(gomock.Any(), account.Id, "test_user"). Return(account.Settings, nil). AnyTimes() + settingsMockManager.EXPECT(). + GetEffectiveNetworkRanges(gomock.Any(), account.Id). + Return(netip.Prefix{}, netip.Prefix{}, nil). + AnyTimes() return &handler{ accountManager: &mock_server.MockAccountManager{ @@ -336,3 +342,27 @@ func TestAccounts_AccountsHandler(t *testing.T) { }) } } + +func TestCalculateMaxHosts(t *testing.T) { + tests := []struct { + name string + prefix string + min int64 + }{ + {"v4 /24", "100.64.0.0/24", 254}, + {"v4 /16", "100.64.0.0/16", 65534}, + {"v4 /28", "100.64.0.0/28", 14}, + {"v6 /64", "fd00::/64", math.MaxInt64}, + {"v6 /120", "fd00::/120", 256}, + {"v6 /112", "fd00::/112", 65536}, + {"v6 /48", "fd00::/48", math.MaxInt64}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prefix := netip.MustParsePrefix(tt.prefix) + got := calculateMaxHosts(prefix) + assert.Equal(t, tt.min, got) + }) + } +} diff --git a/management/server/http/handlers/dns/nameservers_handler.go b/management/server/http/handlers/dns/nameservers_handler.go index bce1c4b78..dbbdf3ed9 100644 --- a/management/server/http/handlers/dns/nameservers_handler.go +++ b/management/server/http/handlers/dns/nameservers_handler.go @@ -3,7 +3,10 @@ package dns import ( "encoding/json" "fmt" + "net" "net/http" + "strconv" + "strings" "github.com/gorilla/mux" log "github.com/sirupsen/logrus" @@ -201,7 +204,11 @@ func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.R func toServerNSList(apiNSList []api.Nameserver) ([]nbdns.NameServer, error) { var nsList []nbdns.NameServer for _, apiNS := range apiNSList { - parsed, err := nbdns.ParseNameServerURL(fmt.Sprintf("%s://%s:%d", apiNS.NsType, apiNS.Ip, apiNS.Port)) + host, err := unwrapBracketedHost(apiNS.Ip) + if err != nil { + return nil, err + } + parsed, err := nbdns.ParseNameServerURL(fmt.Sprintf("%s://%s", apiNS.NsType, net.JoinHostPort(host, strconv.Itoa(apiNS.Port)))) if err != nil { return nil, err } @@ -211,6 +218,18 @@ func toServerNSList(apiNSList []api.Nameserver) ([]nbdns.NameServer, error) { return nsList, nil } +// unwrapBracketedHost returns ip with surrounding brackets stripped, rejecting +// inputs with mismatched brackets. +func unwrapBracketedHost(ip string) (string, error) { + if !strings.ContainsAny(ip, "[]") { + return ip, nil + } + if !strings.HasPrefix(ip, "[") || !strings.HasSuffix(ip, "]") { + return "", fmt.Errorf("malformed bracketed address: %s", ip) + } + return ip[1 : len(ip)-1], nil +} + func toNameserverGroupResponse(serverNSGroup *nbdns.NameServerGroup) *api.NameserverGroup { var nsList []api.Nameserver for _, ns := range serverNSGroup.NameServers { diff --git a/management/server/http/handlers/dns/nameservers_handler_test.go b/management/server/http/handlers/dns/nameservers_handler_test.go index 4716782f3..a165f009b 100644 --- a/management/server/http/handlers/dns/nameservers_handler_test.go +++ b/management/server/http/handlers/dns/nameservers_handler_test.go @@ -233,3 +233,37 @@ func TestNameserversHandlers(t *testing.T) { }) } } + +func TestToServerNSList_IPv6(t *testing.T) { + tests := []struct { + name string + input []api.Nameserver + expectIP netip.Addr + }{ + { + name: "IPv4", + input: []api.Nameserver{ + {Ip: "1.1.1.1", NsType: "udp", Port: 53}, + }, + expectIP: netip.MustParseAddr("1.1.1.1"), + }, + { + name: "IPv6", + input: []api.Nameserver{ + {Ip: "2001:4860:4860::8888", NsType: "udp", Port: 53}, + }, + expectIP: netip.MustParseAddr("2001:4860:4860::8888"), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := toServerNSList(tc.input) + assert.NoError(t, err) + if assert.Len(t, result, 1) { + assert.Equal(t, tc.expectIP, result[0].IP) + assert.Equal(t, 53, result[0].Port) + } + }) + } +} diff --git a/management/server/http/handlers/groups/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go index c7b4cbcdd..57e238630 100644 --- a/management/server/http/handlers/groups/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -7,8 +7,8 @@ import ( "errors" "fmt" "io" - "net" "net/http" + "net/netip" "net/http/httptest" "strings" "testing" @@ -29,8 +29,8 @@ import ( ) var TestPeers = map[string]*nbpeer.Peer{ - "A": {Key: "A", ID: "peer-A-ID", IP: net.ParseIP("100.100.100.100")}, - "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, + "A": {Key: "A", ID: "peer-A-ID", IP: netip.MustParseAddr("100.100.100.100")}, + "B": {Key: "B", ID: "peer-B-ID", IP: netip.MustParseAddr("200.200.200.200")}, } func initGroupTestData(initGroups ...*types.Group) *handler { diff --git a/management/server/http/handlers/instance/instance_handler.go b/management/server/http/handlers/instance/instance_handler.go index cd9fae6b8..e98ce4d7c 100644 --- a/management/server/http/handlers/instance/instance_handler.go +++ b/management/server/http/handlers/instance/instance_handler.go @@ -7,6 +7,7 @@ import ( "github.com/gorilla/mux" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/account" nbinstance "github.com/netbirdio/netbird/management/server/instance" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" @@ -15,13 +16,15 @@ import ( // handler handles the instance setup HTTP endpoints type handler struct { instanceManager nbinstance.Manager + setupManager *nbinstance.SetupService } // AddEndpoints registers the instance setup endpoints. // These endpoints bypass authentication for initial setup. -func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) { +func AddEndpoints(instanceManager nbinstance.Manager, accountManager account.Manager, router *mux.Router) { h := &handler{ instanceManager: instanceManager, + setupManager: nbinstance.NewSetupService(instanceManager, accountManager), } router.HandleFunc("/instance", h.getInstanceStatus).Methods("GET", "OPTIONS") @@ -55,24 +58,36 @@ func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) { // setup creates the initial admin user for the instance. // This endpoint is unauthenticated but only works when setup is required. func (h *handler) setup(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + var req api.SetupRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("invalid request body", http.StatusBadRequest, w) return } - userData, err := h.instanceManager.CreateOwnerUser(r.Context(), req.Email, req.Password, req.Name) + result, err := h.setupManager.SetupOwner(ctx, req.Email, req.Password, req.Name, nbinstance.SetupOptions{ + CreatePAT: req.CreatePat != nil && *req.CreatePat, + PATExpireInDays: req.PatExpireIn, + }) if err != nil { - util.WriteError(r.Context(), err, w) + util.WriteError(ctx, err, w) return } - log.WithContext(r.Context()).Infof("instance setup completed: created user %s", req.Email) + log.WithContext(ctx).Infof("instance setup completed: created user %s", req.Email) - util.WriteJSONObject(r.Context(), w, api.SetupResponse{ - UserId: userData.ID, - Email: userData.Email, - }) + resp := api.SetupResponse{ + UserId: result.User.ID, + Email: result.User.Email, + } + + if result.PATPlainToken != "" { + resp.PersonalAccessToken = &result.PATPlainToken + } + + w.Header().Set("Cache-Control", "no-store") + util.WriteJSONObject(ctx, w, resp) } // getVersionInfo returns version information for NetBird components. diff --git a/management/server/http/handlers/instance/instance_handler_test.go b/management/server/http/handlers/instance/instance_handler_test.go index 470079c85..711e01964 100644 --- a/management/server/http/handlers/instance/instance_handler_test.go +++ b/management/server/http/handlers/instance/instance_handler_test.go @@ -10,12 +10,18 @@ import ( "net/mail" "testing" + "github.com/golang/mock/gomock" "github.com/gorilla/mux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/idp" nbinstance "github.com/netbirdio/netbird/management/server/instance" + "github.com/netbirdio/netbird/management/server/mock_server" + nbstore "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" ) @@ -25,6 +31,7 @@ type mockInstanceManager struct { isSetupRequired bool isSetupRequiredFn func(ctx context.Context) (bool, error) createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error) + rollbackSetupFn func(ctx context.Context, userID string) error getVersionInfoFn func(ctx context.Context) (*nbinstance.VersionInfo, error) } @@ -67,6 +74,13 @@ func (m *mockInstanceManager) CreateOwnerUser(ctx context.Context, email, passwo }, nil } +func (m *mockInstanceManager) RollbackSetup(ctx context.Context, userID string) error { + if m.rollbackSetupFn != nil { + return m.rollbackSetupFn(ctx, userID) + } + return nil +} + func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.VersionInfo, error) { if m.getVersionInfoFn != nil { return m.getVersionInfoFn(ctx) @@ -82,8 +96,12 @@ func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.V var _ nbinstance.Manager = (*mockInstanceManager)(nil) func setupTestRouter(manager nbinstance.Manager) *mux.Router { + return setupTestRouterWithPAT(manager, nil) +} + +func setupTestRouterWithPAT(manager nbinstance.Manager, accountManager account.Manager) *mux.Router { router := mux.NewRouter() - AddEndpoints(manager, router) + AddEndpoints(manager, accountManager, router) return router } @@ -161,6 +179,7 @@ func TestSetup_Success(t *testing.T) { router.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "no-store", rec.Header().Get("Cache-Control")) var response api.SetupResponse err := json.NewDecoder(rec.Body).Decode(&response) @@ -293,6 +312,239 @@ func TestSetup_ManagerError(t *testing.T) { assert.Equal(t, http.StatusInternalServerError, rec.Code) } +func TestSetup_PAT_FeatureDisabled_IgnoresCreatePAT(t *testing.T) { + t.Setenv(nbinstance.SetupPATEnabledEnvKey, "false") + + manager := &mockInstanceManager{isSetupRequired: true} + // NB_SETUP_PAT_ENABLED=false: request fields must be silently ignored + router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{}) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var response api.SetupResponse + require.NoError(t, json.NewDecoder(rec.Body).Decode(&response)) + assert.Nil(t, response.PersonalAccessToken) +} + +func TestSetup_PAT_FlagOmitted_NoPAT(t *testing.T) { + t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true") + + manager := &mockInstanceManager{isSetupRequired: true} + router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{}) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin"}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var response api.SetupResponse + require.NoError(t, json.NewDecoder(rec.Body).Decode(&response)) + assert.Nil(t, response.PersonalAccessToken) +} + +func TestSetup_PAT_MissingExpireIn_DefaultsToOneDay(t *testing.T) { + t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true") + + createCalled := false + manager := &mockInstanceManager{ + isSetupRequired: true, + createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) { + createCalled = true + return &idp.UserData{ID: "u1", Email: email, Name: name}, nil + }, + } + accountMgr := &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) { + assert.Equal(t, "u1", userAuth.UserId) + return "acc-1", nil + }, + CreatePATFunc: func(_ context.Context, accountID, initiator, target, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { + assert.Equal(t, "acc-1", accountID) + assert.Equal(t, "u1", initiator) + assert.Equal(t, "u1", target) + assert.Equal(t, "setup-token", name) + assert.Equal(t, 1, expiresIn) + return &types.PersonalAccessTokenGenerated{PlainToken: "nbp_plain"}, nil + }, + } + router := setupTestRouterWithPAT(manager, accountMgr) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "no-store", rec.Header().Get("Cache-Control")) + assert.True(t, createCalled) + var response api.SetupResponse + require.NoError(t, json.NewDecoder(rec.Body).Decode(&response)) + require.NotNil(t, response.PersonalAccessToken) + assert.Equal(t, "nbp_plain", *response.PersonalAccessToken) +} + +func TestSetup_PAT_ExpireOutOfRange(t *testing.T) { + t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true") + + manager := &mockInstanceManager{isSetupRequired: true} + router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{}) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 0}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusUnprocessableEntity, rec.Code) +} + +func TestSetup_PAT_Success(t *testing.T) { + t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true") + + manager := &mockInstanceManager{ + isSetupRequired: true, + createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) { + return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil + }, + } + + gotAccountArgs := struct { + userID string + email string + }{} + accountMgr := &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) { + gotAccountArgs.userID = userAuth.UserId + gotAccountArgs.email = userAuth.Email + return "acc-1", nil + }, + CreatePATFunc: func(_ context.Context, accountID, initiator, target, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { + assert.Equal(t, "acc-1", accountID) + assert.Equal(t, "owner-id", initiator) + assert.Equal(t, "owner-id", target) + assert.Equal(t, "setup-token", name) + assert.Equal(t, 30, expiresIn) + return &types.PersonalAccessTokenGenerated{PlainToken: "nbp_plain"}, nil + }, + } + + router := setupTestRouterWithPAT(manager, accountMgr) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "no-store", rec.Header().Get("Cache-Control")) + var response api.SetupResponse + require.NoError(t, json.NewDecoder(rec.Body).Decode(&response)) + assert.Equal(t, "owner-id", response.UserId) + require.NotNil(t, response.PersonalAccessToken) + assert.Equal(t, "nbp_plain", *response.PersonalAccessToken) + assert.Equal(t, "owner-id", gotAccountArgs.userID) + assert.Equal(t, "admin@example.com", gotAccountArgs.email) +} + +func TestSetup_PAT_AccountCreationFails_Rollback(t *testing.T) { + t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true") + + ctrl := gomock.NewController(t) + accountStore := nbstore.NewMockStore(ctrl) + accountStore.EXPECT().GetAccountIDByUserID(gomock.Any(), nbstore.LockingStrengthNone, "owner-id").Return("", status.NewAccountNotFoundError("owner-id")) + + rolledBackFor := "" + manager := &mockInstanceManager{ + isSetupRequired: true, + createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) { + return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil + }, + rollbackSetupFn: func(_ context.Context, userID string) error { + rolledBackFor = userID + return nil + }, + } + accountMgr := &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) { + return "", errors.New("db down") + }, + GetStoreFunc: func() nbstore.Store { + return accountStore + }, + } + + router := setupTestRouterWithPAT(manager, accountMgr) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Equal(t, "owner-id", rolledBackFor, "RollbackSetup must be called with the created user id") +} + +func TestSetup_PAT_CreatePATFails_Rollback(t *testing.T) { + t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true") + + ctrl := gomock.NewController(t) + accountStore := nbstore.NewMockStore(ctrl) + account := &types.Account{Id: "acc-1"} + accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil) + accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(nil) + + rolledBackFor := "" + manager := &mockInstanceManager{ + isSetupRequired: true, + createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) { + return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil + }, + rollbackSetupFn: func(_ context.Context, userID string) error { + rolledBackFor = userID + return nil + }, + } + accountMgr := &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) { + return "acc-1", nil + }, + CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) { + return nil, status.Errorf(status.Internal, "token store unavailable") + }, + GetStoreFunc: func() nbstore.Store { + return accountStore + }, + } + + router := setupTestRouterWithPAT(manager, accountMgr) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Equal(t, "owner-id", rolledBackFor, "RollbackSetup must be called when CreatePAT fails") +} + func TestGetVersionInfo_Success(t *testing.T) { manager := &mockInstanceManager{} router := mux.NewRouter() diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 6b9a69f04..91026a374 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -220,6 +220,18 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri } } + if req.Ipv6 != nil { + v6Addr, err := parseIPv6(req.Ipv6) + if err != nil { + util.WriteError(ctx, status.Errorf(status.InvalidArgument, "%v", err), w) + return + } + if err = h.accountManager.UpdatePeerIPv6(ctx, accountID, userID, peerID, v6Addr); err != nil { + util.WriteError(ctx, err, w) + return + } + } + peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update) if err != nil { util.WriteError(ctx, err, w) @@ -355,6 +367,21 @@ func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersM } } +func parseIPv6(s *string) (netip.Addr, error) { + if s == nil { + return netip.Addr{}, fmt.Errorf("IPv6 address is nil") + } + addr, err := netip.ParseAddr(*s) + if err != nil { + return netip.Addr{}, fmt.Errorf("invalid IPv6 address %s: %w", *s, err) + } + addr = addr.Unmap() + if !addr.Is6() { + return netip.Addr{}, fmt.Errorf("address %s is not IPv6", *s) + } + return addr, nil +} + // GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network. func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) @@ -417,7 +444,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.networkMapController.GetDNSDomain(account.Settings) - netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) + netMap := account.GetPeerNetworkMapFromComponents(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) } @@ -529,6 +556,7 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee GeonameId: int(peer.Location.GeoNameID), Id: peer.ID, Ip: peer.IP.String(), + Ipv6: peerIPv6String(peer), LastSeen: peer.Status.LastSeen, Name: peer.Name, Os: peer.Meta.OS, @@ -547,6 +575,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD Id: peer.ID, Name: peer.Name, Ip: peer.IP.String(), + Ipv6: peerIPv6String(peer), ConnectionIp: peer.Location.ConnectionIP.String(), Connected: peer.Status.Connected, LastSeen: peer.Status.LastSeen, @@ -601,6 +630,7 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn Id: peer.ID, Name: peer.Name, Ip: peer.IP.String(), + Ipv6: peerIPv6String(peer), ConnectionIp: peer.Location.ConnectionIP.String(), Connected: peer.Status.Connected, LastSeen: peer.Status.LastSeen, @@ -677,3 +707,11 @@ func fqdnList(extraLabels []string, dnsDomain string) []string { } return fqdnList } + +func peerIPv6String(peer *nbpeer.Peer) *string { + if !peer.IPv6.IsValid() { + return nil + } + s := peer.IPv6.String() + return &s +} diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index 6b3616597..9db095c8d 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -146,7 +146,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler { UpdatePeerIPFunc: func(_ context.Context, accountID, userID, peerID string, newIP netip.Addr) error { for _, peer := range peers { if peer.ID == peerID { - peer.IP = net.IP(newIP.AsSlice()) + peer.IP = newIP return nil } } @@ -228,7 +228,8 @@ func TestGetPeers(t *testing.T) { peer := &nbpeer.Peer{ ID: testPeerID, Key: "key", - IP: net.ParseIP("100.64.0.1"), + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), Status: &nbpeer.PeerStatus{Connected: true}, Name: "PeerName", LoginExpirationEnabled: false, @@ -368,7 +369,8 @@ func TestGetAccessiblePeers(t *testing.T) { peer1 := &nbpeer.Peer{ ID: "peer1", Key: "key1", - IP: net.ParseIP("100.64.0.1"), + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00:1234::1"), Status: &nbpeer.PeerStatus{Connected: true}, Name: "peer1", LoginExpirationEnabled: false, @@ -378,7 +380,8 @@ func TestGetAccessiblePeers(t *testing.T) { peer2 := &nbpeer.Peer{ ID: "peer2", Key: "key2", - IP: net.ParseIP("100.64.0.2"), + IP: netip.MustParseAddr("100.64.0.2"), + IPv6: netip.MustParseAddr("fd00:1234::2"), Status: &nbpeer.PeerStatus{Connected: true}, Name: "peer2", LoginExpirationEnabled: false, @@ -388,7 +391,8 @@ func TestGetAccessiblePeers(t *testing.T) { peer3 := &nbpeer.Peer{ ID: "peer3", Key: "key3", - IP: net.ParseIP("100.64.0.3"), + IP: netip.MustParseAddr("100.64.0.3"), + IPv6: netip.MustParseAddr("fd00:1234::3"), Status: &nbpeer.PeerStatus{Connected: true}, Name: "peer3", LoginExpirationEnabled: false, @@ -532,7 +536,8 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) { testPeer := &nbpeer.Peer{ ID: testPeerID, Key: "key", - IP: net.ParseIP("100.64.0.1"), + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, Name: "test-host@netbird.io", LoginExpirationEnabled: false, diff --git a/management/server/http/handlers/policies/geolocations_handler.go b/management/server/http/handlers/policies/geolocations_handler.go index a2d656a47..eea31ebc6 100644 --- a/management/server/http/handlers/policies/geolocations_handler.go +++ b/management/server/http/handlers/policies/geolocations_handler.go @@ -7,11 +7,8 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -45,11 +42,6 @@ func newGeolocationsHandlerHandler(accountManager account.Manager, geolocationMa // getAllCountries retrieves a list of all countries func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) { - if err := l.authenticateUser(r); err != nil { - util.WriteError(r.Context(), err, w) - return - } - if l.geolocationManager == nil { // TODO: update error message to include geo db self hosted doc link when ready util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w) @@ -71,11 +63,6 @@ func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Req // getCitiesByCountry retrieves a list of cities based on the given country code func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) { - if err := l.authenticateUser(r); err != nil { - util.WriteError(r.Context(), err, w) - return - } - vars := mux.Vars(r) countryCode := vars["country"] if !countryCodeRegex.MatchString(countryCode) { @@ -102,27 +89,6 @@ func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http. util.WriteJSONObject(r.Context(), w, cities) } -func (l *geolocationsHandler) authenticateUser(r *http.Request) error { - ctx := r.Context() - - userAuth, err := nbcontext.GetUserAuthFromContext(ctx) - if err != nil { - return err - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - - allowed, err := l.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read) - if err != nil { - return status.NewPermissionValidationError(err) - } - - if !allowed { - return status.NewPermissionDeniedError() - } - return nil -} - func toCountryResponse(country geolocation.Country) api.Country { return api.Country{ CountryName: country.CountryName, diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 63be672e6..6d075d9c2 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -12,6 +12,7 @@ import ( "go.opentelemetry.io/otel/metric" "github.com/netbirdio/management-integrations/integrations" + serverauth "github.com/netbirdio/netbird/management/server/auth" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" @@ -42,14 +43,9 @@ func NewAuthMiddleware( ensureAccount EnsureAccountFunc, syncUserJWTGroups SyncUserJWTGroupsFunc, getUserFromUserAuth GetUserFromUserAuthFunc, - rateLimiterConfig *RateLimiterConfig, + rateLimiter *APIRateLimiter, meter metric.Meter, ) *AuthMiddleware { - var rateLimiter *APIRateLimiter - if rateLimiterConfig != nil { - rateLimiter = NewAPIRateLimiter(rateLimiterConfig) - } - var patUsageTracker *PATUsageTracker if meter != nil { var err error @@ -87,17 +83,14 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { switch authType { case "bearer": - request, err := m.checkJWTFromRequest(r, authHeader) - if err != nil { + if err := m.checkJWTFromRequest(r, authHeader); err != nil { log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error()) util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) return } - - h.ServeHTTP(w, request) + h.ServeHTTP(w, r) case "token": - request, err := m.checkPATFromRequest(r, authHeader) - if err != nil { + if err := m.checkPATFromRequest(r, authHeader); err != nil { log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error()) // Check if it's a status error, otherwise default to Unauthorized if _, ok := status.FromError(err); !ok { @@ -106,7 +99,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { util.WriteError(r.Context(), err, w) return } - h.ServeHTTP(w, request) + h.ServeHTTP(w, r) default: util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w) return @@ -115,19 +108,19 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { } // CheckJWTFromRequest checks if the JWT is valid -func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) { +func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) error { token, err := getTokenFromJWTRequest(authHeaderParts) // If an error occurs, call the error handler and return an error if err != nil { - return r, fmt.Errorf("error extracting token: %w", err) + return fmt.Errorf("error extracting token: %w", err) } ctx := r.Context() userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token) if err != nil { - return r, err + return err } if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 { @@ -143,7 +136,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts [] // we need to call this method because if user is new, we will automatically add it to existing or create a new account accountId, _, err := m.ensureAccount(ctx, userAuth) if err != nil { - return r, err + return err } if userAuth.AccountId != accountId { @@ -153,7 +146,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts [] userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken) if err != nil { - return r, err + return err } err = m.syncUserJWTGroups(ctx, userAuth) @@ -164,41 +157,41 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts [] _, err = m.getUserFromUserAuth(ctx, userAuth) if err != nil { log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err) - return r, err + return err } - return nbcontext.SetUserAuthInRequest(r, userAuth), nil + // propagates ctx change to upstream middleware + *r = *nbcontext.SetUserAuthInRequest(r, userAuth) + return nil } // CheckPATFromRequest checks if the PAT is valid -func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) { +func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) error { token, err := getTokenFromPATRequest(authHeaderParts) if err != nil { - return r, fmt.Errorf("error extracting token: %w", err) + return fmt.Errorf("error extracting token: %w", err) } if m.patUsageTracker != nil { m.patUsageTracker.IncrementUsage(token) } - if m.rateLimiter != nil && !isTerraformRequest(r) { - if !m.rateLimiter.Allow(token) { - return r, status.Errorf(status.TooManyRequests, "too many requests") - } + if !isTerraformRequest(r) && !m.rateLimiter.Allow(token) { + return status.Errorf(status.TooManyRequests, "too many requests") } ctx := r.Context() user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token) if err != nil { - return r, fmt.Errorf("invalid Token: %w", err) + return fmt.Errorf("invalid Token: %w", err) } if time.Now().After(pat.GetExpirationDate()) { - return r, fmt.Errorf("token expired") + return fmt.Errorf("token expired") } err = m.authManager.MarkPATUsed(ctx, pat.ID) if err != nil { - return r, err + return err } userAuth := auth.UserAuth{ @@ -216,7 +209,9 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts [] } } - return nbcontext.SetUserAuthInRequest(r, userAuth), nil + // propagates ctx change to upstream middleware + *r = *nbcontext.SetUserAuthInRequest(r, userAuth) + return nil } func isTerraformRequest(r *http.Request) bool { diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index f397c63a4..8f736fbfd 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -196,6 +196,8 @@ func TestAuthMiddleware_Handler(t *testing.T) { GetPATInfoFunc: mockGetAccountInfoFromPAT, } + disabledLimiter := NewAPIRateLimiter(nil) + disabledLimiter.SetEnabled(false) authMiddleware := NewAuthMiddleware( mockAuth, func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { @@ -207,7 +209,7 @@ func TestAuthMiddleware_Handler(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - nil, + disabledLimiter, nil, ) @@ -266,7 +268,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -318,7 +320,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -361,7 +363,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -405,7 +407,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -469,7 +471,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -528,7 +530,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -583,7 +585,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -670,6 +672,8 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { GetPATInfoFunc: mockGetAccountInfoFromPAT, } + disabledLimiter := NewAPIRateLimiter(nil) + disabledLimiter.SetEnabled(false) authMiddleware := NewAuthMiddleware( mockAuth, func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { @@ -681,7 +685,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - nil, + disabledLimiter, nil, ) diff --git a/management/server/http/middleware/rate_limiter.go b/management/server/http/middleware/rate_limiter.go index 936b34319..bfd44afee 100644 --- a/management/server/http/middleware/rate_limiter.go +++ b/management/server/http/middleware/rate_limiter.go @@ -4,14 +4,27 @@ import ( "context" "net" "net/http" + "os" + "strconv" "sync" + "sync/atomic" "time" + log "github.com/sirupsen/logrus" "golang.org/x/time/rate" "github.com/netbirdio/netbird/shared/management/http/util" ) +const ( + RateLimitingEnabledEnv = "NB_API_RATE_LIMITING_ENABLED" + RateLimitingBurstEnv = "NB_API_RATE_LIMITING_BURST" + RateLimitingRPMEnv = "NB_API_RATE_LIMITING_RPM" + + defaultAPIRPM = 6 + defaultAPIBurst = 500 +) + // RateLimiterConfig holds configuration for the API rate limiter type RateLimiterConfig struct { // RequestsPerMinute defines the rate at which tokens are replenished @@ -34,6 +47,43 @@ func DefaultRateLimiterConfig() *RateLimiterConfig { } } +func RateLimiterConfigFromEnv() (cfg *RateLimiterConfig, enabled bool) { + rpm := defaultAPIRPM + if v := os.Getenv(RateLimitingRPMEnv); v != "" { + value, err := strconv.Atoi(v) + if err != nil { + log.Warnf("parsing %s env var: %v, using default %d", RateLimitingRPMEnv, err, rpm) + } else { + rpm = value + } + } + if rpm <= 0 { + log.Warnf("%s=%d is non-positive, using default %d", RateLimitingRPMEnv, rpm, defaultAPIRPM) + rpm = defaultAPIRPM + } + + burst := defaultAPIBurst + if v := os.Getenv(RateLimitingBurstEnv); v != "" { + value, err := strconv.Atoi(v) + if err != nil { + log.Warnf("parsing %s env var: %v, using default %d", RateLimitingBurstEnv, err, burst) + } else { + burst = value + } + } + if burst <= 0 { + log.Warnf("%s=%d is non-positive, using default %d", RateLimitingBurstEnv, burst, defaultAPIBurst) + burst = defaultAPIBurst + } + + return &RateLimiterConfig{ + RequestsPerMinute: float64(rpm), + Burst: burst, + CleanupInterval: 6 * time.Hour, + LimiterTTL: 24 * time.Hour, + }, os.Getenv(RateLimitingEnabledEnv) == "true" +} + // limiterEntry holds a rate limiter and its last access time type limiterEntry struct { limiter *rate.Limiter @@ -46,6 +96,7 @@ type APIRateLimiter struct { limiters map[string]*limiterEntry mu sync.RWMutex stopChan chan struct{} + enabled atomic.Bool } // NewAPIRateLimiter creates a new API rate limiter with the given configuration @@ -59,14 +110,53 @@ func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter { limiters: make(map[string]*limiterEntry), stopChan: make(chan struct{}), } + rl.enabled.Store(true) go rl.cleanupLoop() return rl } +func (rl *APIRateLimiter) SetEnabled(enabled bool) { + rl.enabled.Store(enabled) +} + +func (rl *APIRateLimiter) Enabled() bool { + return rl.enabled.Load() +} + +func (rl *APIRateLimiter) UpdateConfig(config *RateLimiterConfig) { + if config == nil { + return + } + if config.RequestsPerMinute <= 0 || config.Burst <= 0 { + log.Warnf("UpdateConfig: ignoring invalid rpm=%v burst=%d", config.RequestsPerMinute, config.Burst) + return + } + + newRPS := rate.Limit(config.RequestsPerMinute / 60.0) + newBurst := config.Burst + + rl.mu.Lock() + rl.config.RequestsPerMinute = config.RequestsPerMinute + rl.config.Burst = newBurst + snapshot := make([]*rate.Limiter, 0, len(rl.limiters)) + for _, entry := range rl.limiters { + snapshot = append(snapshot, entry.limiter) + } + rl.mu.Unlock() + + for _, l := range snapshot { + l.SetLimit(newRPS) + l.SetBurst(newBurst) + } +} + // Allow checks if a request for the given key (token) is allowed func (rl *APIRateLimiter) Allow(key string) bool { + if !rl.enabled.Load() { + return true + } limiter := rl.getLimiter(key) return limiter.Allow() } @@ -74,6 +164,9 @@ func (rl *APIRateLimiter) Allow(key string) bool { // Wait blocks until the rate limiter allows another request for the given key // Returns an error if the context is canceled func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error { + if !rl.enabled.Load() { + return nil + } limiter := rl.getLimiter(key) return limiter.Wait(ctx) } @@ -153,6 +246,10 @@ func (rl *APIRateLimiter) Reset(key string) { // Returns 429 Too Many Requests if the rate limit is exceeded. func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !rl.enabled.Load() { + next.ServeHTTP(w, r) + return + } clientIP := getClientIP(r) if !rl.Allow(clientIP) { util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w) diff --git a/management/server/http/middleware/rate_limiter_test.go b/management/server/http/middleware/rate_limiter_test.go index 68f804e57..4b97d1874 100644 --- a/management/server/http/middleware/rate_limiter_test.go +++ b/management/server/http/middleware/rate_limiter_test.go @@ -1,8 +1,10 @@ package middleware import ( + "fmt" "net/http" "net/http/httptest" + "sync" "testing" "time" @@ -156,3 +158,172 @@ func TestAPIRateLimiter_Reset(t *testing.T) { // Should be allowed again assert.True(t, rl.Allow("test-key")) } + +func TestAPIRateLimiter_SetEnabled(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 1, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + assert.True(t, rl.Allow("key")) + assert.False(t, rl.Allow("key"), "burst exhausted while enabled") + + rl.SetEnabled(false) + assert.False(t, rl.Enabled()) + for i := 0; i < 5; i++ { + assert.True(t, rl.Allow("key"), "disabled limiter must always allow") + } + + rl.SetEnabled(true) + assert.True(t, rl.Enabled()) + assert.False(t, rl.Allow("key"), "re-enabled limiter retains prior bucket state") +} + +func TestAPIRateLimiter_UpdateConfig(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 2, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + assert.True(t, rl.Allow("k1")) + assert.True(t, rl.Allow("k1")) + assert.False(t, rl.Allow("k1"), "burst=2 exhausted") + + rl.UpdateConfig(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 10, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + + // New burst applies to existing keys in place; bucket refills up to new burst over time, + // but importantly newly-added keys use the updated config immediately. + assert.True(t, rl.Allow("k2")) + for i := 0; i < 9; i++ { + assert.True(t, rl.Allow("k2")) + } + assert.False(t, rl.Allow("k2"), "new burst=10 exhausted") +} + +func TestAPIRateLimiter_UpdateConfig_NilIgnored(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 1, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + rl.UpdateConfig(nil) // must not panic or zero the config + + assert.True(t, rl.Allow("k")) + assert.False(t, rl.Allow("k")) +} + +func TestAPIRateLimiter_UpdateConfig_NonPositiveIgnored(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 1, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + assert.True(t, rl.Allow("k")) + assert.False(t, rl.Allow("k")) + + rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 0, Burst: 0, CleanupInterval: time.Minute, LimiterTTL: time.Minute}) + rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: -1, Burst: 5, CleanupInterval: time.Minute, LimiterTTL: time.Minute}) + rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 60, Burst: -1, CleanupInterval: time.Minute, LimiterTTL: time.Minute}) + + rl.Reset("k") + assert.True(t, rl.Allow("k")) + assert.False(t, rl.Allow("k"), "burst should still be 1 — invalid UpdateConfig calls were ignored") +} + +func TestAPIRateLimiter_ConcurrentAllowAndUpdate(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 600, + Burst: 10, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + var wg sync.WaitGroup + stop := make(chan struct{}) + + for i := 0; i < 8; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + key := fmt.Sprintf("k%d", id) + for { + select { + case <-stop: + return + default: + rl.Allow(key) + } + } + }(i) + } + + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 200; i++ { + select { + case <-stop: + return + default: + rl.UpdateConfig(&RateLimiterConfig{ + RequestsPerMinute: float64(30 + (i % 90)), + Burst: 1 + (i % 20), + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + rl.SetEnabled(i%2 == 0) + } + } + }() + + time.Sleep(100 * time.Millisecond) + close(stop) + wg.Wait() +} + +func TestRateLimiterConfigFromEnv(t *testing.T) { + t.Setenv(RateLimitingEnabledEnv, "true") + t.Setenv(RateLimitingRPMEnv, "42") + t.Setenv(RateLimitingBurstEnv, "7") + + cfg, enabled := RateLimiterConfigFromEnv() + assert.True(t, enabled) + assert.Equal(t, float64(42), cfg.RequestsPerMinute) + assert.Equal(t, 7, cfg.Burst) + + t.Setenv(RateLimitingEnabledEnv, "false") + _, enabled = RateLimiterConfigFromEnv() + assert.False(t, enabled) + + t.Setenv(RateLimitingEnabledEnv, "") + t.Setenv(RateLimitingRPMEnv, "") + t.Setenv(RateLimitingBurstEnv, "") + cfg, enabled = RateLimiterConfigFromEnv() + assert.False(t, enabled) + assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute) + assert.Equal(t, defaultAPIBurst, cfg.Burst) + + t.Setenv(RateLimitingRPMEnv, "0") + t.Setenv(RateLimitingBurstEnv, "-5") + cfg, _ = RateLimiterConfigFromEnv() + assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute, "non-positive rpm must fall back to default") + assert.Equal(t, defaultAPIBurst, cfg.Burst, "non-positive burst must fall back to default") +} diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 0203d6177..1a8b83c7e 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -135,7 +135,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil) + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } @@ -264,7 +264,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil) + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } diff --git a/management/server/http/testing/testing_tools/tools.go b/management/server/http/testing/testing_tools/tools.go index b7a63b104..9a78620c9 100644 --- a/management/server/http/testing/testing_tools/tools.go +++ b/management/server/http/testing/testing_tools/tools.go @@ -5,9 +5,9 @@ import ( "context" "fmt" "io" - "net" "net/http" "net/http/httptest" + "net/netip" "os" "strconv" "testing" @@ -133,7 +133,7 @@ func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, se ID: fmt.Sprintf("oldpeer-%d", i), DNSLabel: fmt.Sprintf("oldpeer-%d", i), Key: peerKey.PublicKey().String(), - IP: net.ParseIP(fmt.Sprintf("100.64.%d.%d", i/256, i%256)), + IP: netip.MustParseAddr(fmt.Sprintf("100.64.%d.%d", i/256, i%256)), Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, UserID: TestUserId, } diff --git a/management/server/identity_provider.go b/management/server/identity_provider.go index 8fd96c238..f965f36b8 100644 --- a/management/server/identity_provider.go +++ b/management/server/identity_provider.go @@ -274,7 +274,7 @@ func identityProviderToConnectorConfig(idpConfig *types.IdentityProvider) *dex.C } // generateIdentityProviderID generates a unique ID for an identity provider. -// For specific provider types (okta, zitadel, entra, google, pocketid, microsoft), +// For specific provider types (okta, zitadel, entra, google, pocketid, microsoft, adfs), // the ID is prefixed with the type name. Generic OIDC providers get no prefix. func generateIdentityProviderID(idpType types.IdentityProviderType) string { id := xid.New().String() @@ -296,6 +296,8 @@ func generateIdentityProviderID(idpType types.IdentityProviderType) string { return "authentik-" + id case types.IdentityProviderTypeKeycloak: return "keycloak-" + id + case types.IdentityProviderTypeADFS: + return "adfs-" + id default: // Generic OIDC - no prefix return id diff --git a/management/server/instance/manager.go b/management/server/instance/manager.go index 9579d7a35..2c355bb3b 100644 --- a/management/server/instance/manager.go +++ b/management/server/instance/manager.go @@ -12,6 +12,7 @@ import ( "sync" "time" + "github.com/dexidp/dex/storage" goversion "github.com/hashicorp/go-version" log "github.com/sirupsen/logrus" @@ -60,6 +61,13 @@ type Manager interface { // This should only be called when IsSetupRequired returns true. CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) + // RollbackSetup reverses a successful CreateOwnerUser by deleting the user + // from the embedded IDP and reloading setupRequired from persistent state, so + // /api/setup can be retried only when no accounts or local users remain. Used + // when post-user steps (account or PAT creation) fail and the caller wants a + // clean slate. + RollbackSetup(ctx context.Context, userID string) error + // GetVersionInfo returns version information for NetBird components. GetVersionInfo(ctx context.Context) (*VersionInfo, error) } @@ -70,6 +78,7 @@ type instanceStore interface { type embeddedIdP interface { CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error) + DeleteUser(ctx context.Context, userID string) error GetAllAccounts(ctx context.Context) (map[string][]*idp.UserData, error) } @@ -187,6 +196,51 @@ func (m *DefaultManager) CreateOwnerUser(ctx context.Context, email, password, n return userData, nil } +// RollbackSetup undoes a successful CreateOwnerUser: deletes the user from the +// embedded IDP and reloads setupRequired from persistent state. +func (m *DefaultManager) RollbackSetup(ctx context.Context, userID string) error { + if m.embeddedIdpManager == nil { + return errors.New("embedded IDP is not enabled") + } + + var deleteErr error + if err := m.embeddedIdpManager.DeleteUser(ctx, userID); err != nil { + if isNotFoundError(err) { + log.WithContext(ctx).Debugf("setup rollback user %s already deleted", userID) + } else { + deleteErr = fmt.Errorf("failed to delete user from embedded IdP: %w", err) + } + } + + if err := m.loadSetupRequired(ctx); err != nil { + reloadErr := fmt.Errorf("failed to reload setup state after rollback: %w", err) + if deleteErr != nil { + return errors.Join(deleteErr, reloadErr) + } + return reloadErr + } + + if deleteErr != nil { + return deleteErr + } + + log.WithContext(ctx).Infof("rolled back setup for user %s", userID) + return nil +} + +func isNotFoundError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, storage.ErrNotFound) { + return true + } + if s, ok := status.FromError(err); ok { + return s.Type() == status.NotFound + } + return false +} + func (m *DefaultManager) checkSetupRequiredFromDB(ctx context.Context) error { numAccounts, err := m.store.GetAccountsCounter(ctx) if err != nil { diff --git a/management/server/instance/manager_test.go b/management/server/instance/manager_test.go index e3be9cfea..5ffb016de 100644 --- a/management/server/instance/manager_test.go +++ b/management/server/instance/manager_test.go @@ -10,16 +10,19 @@ import ( "testing" "time" + "github.com/dexidp/dex/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/shared/management/status" ) type mockIdP struct { - mu sync.Mutex - createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error) - users map[string][]*idp.UserData + mu sync.Mutex + createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error) + deleteUserFunc func(ctx context.Context, userID string) error + users map[string][]*idp.UserData getAllAccountsErr error } @@ -30,6 +33,13 @@ func (m *mockIdP) CreateUserWithPassword(ctx context.Context, email, password, n return &idp.UserData{ID: "test-user-id", Email: email, Name: name}, nil } +func (m *mockIdP) DeleteUser(ctx context.Context, userID string) error { + if m.deleteUserFunc != nil { + return m.deleteUserFunc(ctx, userID) + } + return nil +} + func (m *mockIdP) GetAllAccounts(_ context.Context) (map[string][]*idp.UserData, error) { m.mu.Lock() defer m.mu.Unlock() @@ -223,6 +233,77 @@ func TestIsSetupRequired_ReturnsFlag(t *testing.T) { assert.False(t, required) } +func TestRollbackSetup_UserAlreadyDeletedIsSuccess(t *testing.T) { + tests := []struct { + name string + err error + }{ + { + name: "management status not found", + err: status.NewUserNotFoundError("owner-id"), + }, + { + name: "dex storage not found", + err: fmt.Errorf("failed to get user for deletion: %w", storage.ErrNotFound), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + idpMock := &mockIdP{ + deleteUserFunc: func(_ context.Context, userID string) error { + assert.Equal(t, "owner-id", userID) + return tt.err + }, + } + mgr := newTestManager(idpMock, &mockStore{}) + mgr.setupRequired = false + + err := mgr.RollbackSetup(context.Background(), "owner-id") + require.NoError(t, err) + + required, err := mgr.IsSetupRequired(context.Background()) + require.NoError(t, err) + assert.True(t, required, "setup should be required when no accounts or local users remain") + }) + } +} + +func TestRollbackSetup_RecomputesSetupStateWhenAccountStillExists(t *testing.T) { + idpMock := &mockIdP{ + deleteUserFunc: func(_ context.Context, _ string) error { + return status.NewUserNotFoundError("owner-id") + }, + } + mgr := newTestManager(idpMock, &mockStore{accountsCount: 1}) + mgr.setupRequired = true + + err := mgr.RollbackSetup(context.Background(), "owner-id") + require.NoError(t, err) + + required, err := mgr.IsSetupRequired(context.Background()) + require.NoError(t, err) + assert.False(t, required, "setup should not be required while an account still exists") +} + +func TestRollbackSetup_ReturnsDeleteErrorButReloadsSetupState(t *testing.T) { + idpMock := &mockIdP{ + deleteUserFunc: func(_ context.Context, _ string) error { + return errors.New("idp unavailable") + }, + } + mgr := newTestManager(idpMock, &mockStore{}) + mgr.setupRequired = false + + err := mgr.RollbackSetup(context.Background(), "owner-id") + require.Error(t, err) + assert.Contains(t, err.Error(), "idp unavailable") + + required, err := mgr.IsSetupRequired(context.Background()) + require.NoError(t, err) + assert.True(t, required, "setup state should be reloaded even when user deletion fails") +} + func TestDefaultManager_ValidateSetupRequest(t *testing.T) { manager := &DefaultManager{setupRequired: true} diff --git a/management/server/instance/setup_service.go b/management/server/instance/setup_service.go new file mode 100644 index 000000000..92a4923be --- /dev/null +++ b/management/server/instance/setup_service.go @@ -0,0 +1,216 @@ +package instance + +import ( + "context" + "fmt" + "os" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/status" +) + +const ( + setupPATTokenName = "setup-token" + + // SetupPATEnabledEnvKey enables setup-time Personal Access Token creation. + SetupPATEnabledEnvKey = "NB_SETUP_PAT_ENABLED" + + setupPATDefaultExpireDays = 1 +) + +// SetupOptions controls optional work performed during initial instance setup. +type SetupOptions struct { + // CreatePAT requests creation of a setup Personal Access Token. It is honored + // only when SetupPATEnabledEnvKey is set to "true". + CreatePAT bool + // PATExpireInDays defaults to 1 day when CreatePAT is requested and setup PAT + // creation is enabled. + PATExpireInDays *int +} + +// SetupResult contains resources created during initial instance setup. +type SetupResult struct { + User *idp.UserData + PATPlainToken string +} + +// SetupService orchestrates the initial setup use case across the instance and +// account bounded contexts and owns the compensation logic when a later step +// fails. +type SetupService struct { + instanceManager Manager + accountManager account.Manager + setupPATEnabled bool +} + +// NewSetupService creates a setup use-case service. +func NewSetupService(instanceManager Manager, accountManager account.Manager) *SetupService { + return &SetupService{ + instanceManager: instanceManager, + accountManager: accountManager, + setupPATEnabled: os.Getenv(SetupPATEnabledEnvKey) == "true", + } +} + +func normalizeSetupOptions(opts SetupOptions, setupPATEnabled bool) (SetupOptions, error) { + if !opts.CreatePAT { + return opts, nil + } + + if !setupPATEnabled { + opts.CreatePAT = false + opts.PATExpireInDays = nil + return opts, nil + } + + if opts.PATExpireInDays == nil { + defaultExpireInDays := setupPATDefaultExpireDays + opts.PATExpireInDays = &defaultExpireInDays + } + + if *opts.PATExpireInDays < account.PATMinExpireDays || *opts.PATExpireInDays > account.PATMaxExpireDays { + return opts, status.Errorf(status.InvalidArgument, "pat_expire_in must be between %d and %d", account.PATMinExpireDays, account.PATMaxExpireDays) + } + + return opts, nil +} + +// SetupOwner creates the initial owner user and, when requested and enabled by +// SetupPATEnabledEnvKey, provisions the account and a setup Personal Access +// Token. If account or PAT provisioning fails, created resources are rolled +// back so setup can be retried. If account rollback fails, user rollback is +// skipped to avoid leaving an account without its owner user. +func (m *SetupService) SetupOwner(ctx context.Context, email, password, name string, opts SetupOptions) (*SetupResult, error) { + opts, err := normalizeSetupOptions(opts, m.setupPATEnabled) + if err != nil { + return nil, err + } + + if opts.CreatePAT && m.accountManager == nil { + return nil, fmt.Errorf("account manager is required to create setup PAT") + } + + userData, err := m.instanceManager.CreateOwnerUser(ctx, email, password, name) + if err != nil { + return nil, err + } + + result := &SetupResult{User: userData} + if !opts.CreatePAT { + return result, nil + } + + userAuth := auth.UserAuth{ + UserId: userData.ID, + Email: userData.Email, + Name: userData.Name, + } + + accountID, err := m.accountManager.GetAccountIDByUserID(ctx, userAuth) + if err != nil { + err = fmt.Errorf("create account for setup user: %w", err) + if rollbackErr := m.rollbackSetup(ctx, userData.ID, "account provisioning failed", err, ""); rollbackErr != nil { + return nil, fmt.Errorf("%w; failed to roll back setup resources: %v", err, rollbackErr) + } + return nil, err + } + + pat, err := m.accountManager.CreatePAT(ctx, accountID, userData.ID, userData.ID, setupPATTokenName, *opts.PATExpireInDays) + if err != nil { + err = fmt.Errorf("create setup PAT: %w", err) + if rollbackErr := m.rollbackSetup(ctx, userData.ID, "setup PAT provisioning failed", err, accountID); rollbackErr != nil { + return nil, fmt.Errorf("%w; failed to roll back setup resources: %v", err, rollbackErr) + } + return nil, err + } + + result.PATPlainToken = pat.PlainToken + return result, nil +} + +func (m *SetupService) rollbackSetup(ctx context.Context, userID, reason string, origErr error, accountID string) error { + if accountID == "" { + resolvedAccountID, err := m.lookupSetupAccountIDForRollback(ctx, userID) + if err != nil { + rollbackErr := fmt.Errorf("resolve setup account for rollback: %w", err) + log.WithContext(ctx).Errorf("failed to resolve setup account for user %s after %s: original error: %v, rollback error: %v", userID, reason, origErr, rollbackErr) + return rollbackErr + } + accountID = resolvedAccountID + } + + if accountID != "" { + if err := m.rollbackSetupAccount(ctx, accountID); err != nil { + rollbackErr := fmt.Errorf("roll back setup account %s: %w", accountID, err) + log.WithContext(ctx).Errorf("failed to roll back setup account %s for user %s after %s: original error: %v, rollback error: %v", accountID, userID, reason, origErr, rollbackErr) + return rollbackErr + } + log.WithContext(ctx).Warnf("rolled back setup account %s for user %s after %s: %v", accountID, userID, reason, origErr) + } + + if err := m.instanceManager.RollbackSetup(ctx, userID); err != nil { + rollbackErr := fmt.Errorf("roll back setup user %s: %w", userID, err) + log.WithContext(ctx).Errorf("failed to roll back setup user %s after %s: original error: %v, rollback error: %v", userID, reason, origErr, rollbackErr) + return rollbackErr + } + log.WithContext(ctx).Warnf("rolled back setup user %s after %s: %v", userID, reason, origErr) + return nil +} + +func (m *SetupService) lookupSetupAccountIDForRollback(ctx context.Context, userID string) (string, error) { + if m.accountManager == nil { + return "", fmt.Errorf("account manager is required to resolve setup account") + } + + accountStore := m.accountManager.GetStore() + if accountStore == nil { + return "", fmt.Errorf("account store is unavailable") + } + + accountID, err := accountStore.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID) + if err != nil { + if isNotFoundError(err) { + return "", nil + } + return "", fmt.Errorf("get setup account ID for rollback: %w", err) + } + + return accountID, nil +} + +// rollbackSetupAccount removes only the setup-created account data from the +// store. It intentionally avoids accountManager.DeleteAccount because the normal +// account deletion path also deletes users from the IdP; embedded IdP cleanup is +// owned by instanceManager.RollbackSetup. +func (m *SetupService) rollbackSetupAccount(ctx context.Context, accountID string) error { + if m.accountManager == nil { + return fmt.Errorf("account manager is required to roll back setup account") + } + + accountStore := m.accountManager.GetStore() + if accountStore == nil { + return fmt.Errorf("account store is unavailable") + } + + account, err := accountStore.GetAccount(ctx, accountID) + if err != nil { + if isNotFoundError(err) { + return nil + } + return fmt.Errorf("get setup account for rollback: %w", err) + } + + if err := accountStore.DeleteAccount(ctx, account); err != nil { + if isNotFoundError(err) { + return nil + } + return fmt.Errorf("delete setup account for rollback: %w", err) + } + + return nil +} diff --git a/management/server/instance/setup_service_test.go b/management/server/instance/setup_service_test.go new file mode 100644 index 000000000..12ec7d0fa --- /dev/null +++ b/management/server/instance/setup_service_test.go @@ -0,0 +1,318 @@ +package instance + +import ( + "context" + "errors" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/mock_server" + nbstore "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/status" +) + +type setupInstanceManagerMock struct { + createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error) + rollbackSetupFn func(ctx context.Context, userID string) error +} + +func (m *setupInstanceManagerMock) IsSetupRequired(context.Context) (bool, error) { + return true, nil +} + +func (m *setupInstanceManagerMock) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) { + if m.createOwnerUserFn != nil { + return m.createOwnerUserFn(ctx, email, password, name) + } + return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil +} + +func (m *setupInstanceManagerMock) RollbackSetup(ctx context.Context, userID string) error { + if m.rollbackSetupFn != nil { + return m.rollbackSetupFn(ctx, userID) + } + return nil +} + +func (m *setupInstanceManagerMock) GetVersionInfo(context.Context) (*VersionInfo, error) { + return &VersionInfo{}, nil +} + +var _ Manager = (*setupInstanceManagerMock)(nil) + +func intPtr(v int) *int { + return &v +} + +func TestSetupOwner_PATFeatureDisabled_IgnoresCreatePAT(t *testing.T) { + t.Setenv(SetupPATEnabledEnvKey, "false") + + createCalls := 0 + setupManager := NewSetupService( + &setupInstanceManagerMock{ + createOwnerUserFn: func(_ context.Context, email, _, name string) (*idp.UserData, error) { + createCalls++ + return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil + }, + }, + &mock_server.MockAccountManager{}, + ) + + result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{ + CreatePAT: true, + }) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, "owner-id", result.User.ID) + assert.Empty(t, result.PATPlainToken) + assert.Equal(t, 1, createCalls) +} + +func TestSetupOwner_PATFeatureEnabled_MissingExpireDefaultsToOneDay(t *testing.T) { + t.Setenv(SetupPATEnabledEnvKey, "true") + + createCalled := false + setupManager := NewSetupService( + &setupInstanceManagerMock{ + createOwnerUserFn: func(_ context.Context, email, _, name string) (*idp.UserData, error) { + createCalled = true + return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil + }, + }, + &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) { + assert.Equal(t, "owner-id", userAuth.UserId) + return "acc-1", nil + }, + CreatePATFunc: func(_ context.Context, accountID, initiatorUserID, targetUserID, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { + assert.Equal(t, "acc-1", accountID) + assert.Equal(t, "owner-id", initiatorUserID) + assert.Equal(t, "owner-id", targetUserID) + assert.Equal(t, setupPATTokenName, tokenName) + assert.Equal(t, setupPATDefaultExpireDays, expiresIn) + return &types.PersonalAccessTokenGenerated{PlainToken: "nbp_plain"}, nil + }, + }, + ) + + result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{ + CreatePAT: true, + }) + + require.NoError(t, err) + require.NotNil(t, result) + assert.True(t, createCalled) + assert.Equal(t, "nbp_plain", result.PATPlainToken) +} + +func TestSetupOwner_PATFeatureEnabled_MissingAccountManagerFailsBeforeCreateUser(t *testing.T) { + t.Setenv(SetupPATEnabledEnvKey, "true") + + createCalled := false + rollbackCalled := false + setupManager := NewSetupService( + &setupInstanceManagerMock{ + createOwnerUserFn: func(_ context.Context, email, _, name string) (*idp.UserData, error) { + createCalled = true + return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil + }, + rollbackSetupFn: func(_ context.Context, _ string) error { + rollbackCalled = true + return nil + }, + }, + nil, + ) + + result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{ + CreatePAT: true, + }) + + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "account manager is required") + assert.False(t, createCalled) + assert.False(t, rollbackCalled) +} + +func TestSetupOwner_AccountProvisioningFails_RollsBackSideEffectAccountAndUser(t *testing.T) { + t.Setenv(SetupPATEnabledEnvKey, "true") + + ctrl := gomock.NewController(t) + accountStore := nbstore.NewMockStore(ctrl) + account := &types.Account{Id: "acc-1"} + accountStore.EXPECT().GetAccountIDByUserID(gomock.Any(), nbstore.LockingStrengthNone, "owner-id").Return("acc-1", nil) + accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil) + accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(nil) + + rolledBackFor := "" + rollbackCalls := 0 + setupManager := NewSetupService( + &setupInstanceManagerMock{ + rollbackSetupFn: func(_ context.Context, userID string) error { + rollbackCalls++ + rolledBackFor = userID + return nil + }, + }, + &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) { + assert.Equal(t, "owner-id", userAuth.UserId) + return "", errors.New("metadata update failed") + }, + GetStoreFunc: func() nbstore.Store { + return accountStore + }, + }, + ) + + result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{ + CreatePAT: true, + PATExpireInDays: intPtr(30), + }) + + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "create account for setup user") + assert.Equal(t, "owner-id", rolledBackFor) + assert.Equal(t, 1, rollbackCalls) +} + +func TestSetupOwner_CreatePATFails_RollsBackSetupAccountAndUser(t *testing.T) { + t.Setenv(SetupPATEnabledEnvKey, "true") + + ctrl := gomock.NewController(t) + accountStore := nbstore.NewMockStore(ctrl) + account := &types.Account{Id: "acc-1"} + accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil) + accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(nil) + + rollbackCalls := 0 + setupManager := NewSetupService( + &setupInstanceManagerMock{ + rollbackSetupFn: func(_ context.Context, userID string) error { + rollbackCalls++ + assert.Equal(t, "owner-id", userID) + return nil + }, + }, + &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) { + assert.Equal(t, "owner-id", userAuth.UserId) + return "acc-1", nil + }, + CreatePATFunc: func(_ context.Context, accountID, initiatorUserID, targetUserID, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { + assert.Equal(t, "acc-1", accountID) + assert.Equal(t, "owner-id", initiatorUserID) + assert.Equal(t, "owner-id", targetUserID) + assert.Equal(t, setupPATTokenName, tokenName) + assert.Equal(t, 30, expiresIn) + return nil, status.Errorf(status.Internal, "token store unavailable") + }, + GetStoreFunc: func() nbstore.Store { + return accountStore + }, + }, + ) + + result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{ + CreatePAT: true, + PATExpireInDays: intPtr(30), + }) + + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "create setup PAT") + assert.Equal(t, 1, rollbackCalls) +} + +func TestSetupOwner_CreatePATFails_AccountAlreadyGoneStillRollsBackUser(t *testing.T) { + t.Setenv(SetupPATEnabledEnvKey, "true") + + ctrl := gomock.NewController(t) + accountStore := nbstore.NewMockStore(ctrl) + accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(nil, status.NewAccountNotFoundError("acc-1")) + + rolledBackFor := "" + rollbackCalls := 0 + setupManager := NewSetupService( + &setupInstanceManagerMock{ + rollbackSetupFn: func(_ context.Context, userID string) error { + rollbackCalls++ + rolledBackFor = userID + return nil + }, + }, + &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) { + return "acc-1", nil + }, + CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) { + return nil, errors.New("token failure") + }, + GetStoreFunc: func() nbstore.Store { + return accountStore + }, + }, + ) + + result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{ + CreatePAT: true, + PATExpireInDays: intPtr(30), + }) + + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "create setup PAT") + assert.Equal(t, "owner-id", rolledBackFor) + assert.Equal(t, 1, rollbackCalls) +} + +func TestSetupOwner_CreatePATFails_AccountRollbackFailureStopsBeforeUserRollback(t *testing.T) { + t.Setenv(SetupPATEnabledEnvKey, "true") + + ctrl := gomock.NewController(t) + accountStore := nbstore.NewMockStore(ctrl) + account := &types.Account{Id: "acc-1"} + accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil) + accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(errors.New("delete failed")) + + rollbackCalls := 0 + setupManager := NewSetupService( + &setupInstanceManagerMock{ + rollbackSetupFn: func(_ context.Context, userID string) error { + rollbackCalls++ + return nil + }, + }, + &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) { + return "acc-1", nil + }, + CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) { + return nil, errors.New("token failure") + }, + GetStoreFunc: func() nbstore.Store { + return accountStore + }, + }, + ) + + result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{ + CreatePAT: true, + PATExpireInDays: intPtr(30), + }) + + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "create setup PAT") + assert.Contains(t, err.Error(), "failed to roll back setup resources") + assert.Equal(t, 0, rollbackCalls) +} diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 4e6eb0a33..1b77ea335 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -267,8 +267,8 @@ func Test_SyncProtocol(t *testing.T) { } // expired peers come separately. - if len(networkMap.GetOfflinePeers()) != 1 { - t.Fatal("expecting SyncResponse to have NetworkMap with 1 offline peer") + if len(networkMap.GetOfflinePeers()) != 2 { + t.Fatal("expecting SyncResponse to have NetworkMap with 2 offline peer") } expiredPeerPubKey := "RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=" @@ -391,7 +391,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config return nil, nil, "", cleanup, err } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil, nil) if err != nil { return nil, nil, "", cleanup, err } diff --git a/management/server/management_test.go b/management/server/management_test.go index 3ac28cd4a..f1d49193c 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -256,6 +256,7 @@ func startServer( server.MockIntegratedValidator{}, networkMapController, nil, + nil, ) if err != nil { t.Fatalf("failed creating management server: %v", err) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index ff369355e..08091d4b7 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -63,6 +63,7 @@ type MockAccountManager struct { UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerIPFunc func(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error + UpdatePeerIPv6Func func(ctx context.Context, accountID, userID, peerID string, newIPv6 netip.Addr) error CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool, isSelected bool) (*route.Route, error) GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error @@ -128,8 +129,8 @@ type MockAccountManager struct { GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) AllowSyncFunc func(string, uint64) bool - UpdateAccountPeersFunc func(ctx context.Context, accountID string) - BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) + UpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason) + BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason) RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error GetIdentityProviderFunc func(ctx context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error) @@ -200,15 +201,15 @@ func (am *MockAccountManager) UpdateGroups(ctx context.Context, accountID, userI return status.Errorf(codes.Unimplemented, "method UpdateGroups is not implemented") } -func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { +func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) { if am.UpdateAccountPeersFunc != nil { - am.UpdateAccountPeersFunc(ctx, accountID) + am.UpdateAccountPeersFunc(ctx, accountID, reason) } } -func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { +func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) { if am.BufferUpdateAccountPeersFunc != nil { - am.BufferUpdateAccountPeersFunc(ctx, accountID) + am.BufferUpdateAccountPeersFunc(ctx, accountID, reason) } } @@ -539,6 +540,13 @@ func (am *MockAccountManager) UpdatePeerIP(ctx context.Context, accountID, userI return status.Errorf(codes.Unimplemented, "method UpdatePeerIP is not implemented") } +func (am *MockAccountManager) UpdatePeerIPv6(ctx context.Context, accountID, userID, peerID string, newIPv6 netip.Addr) error { + if am.UpdatePeerIPv6Func != nil { + return am.UpdatePeerIPv6Func(ctx, accountID, userID, peerID, newIPv6) + } + return status.Errorf(codes.Unimplemented, "method UpdatePeerIPv6 is not implemented") +} + // CreateRoute mock implementation of CreateRoute from server.AccountManager interface func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool, isSelected bool) (*route.Route, error) { if am.CreateRouteFunc != nil { diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 3d8c78912..5859bfb0d 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -82,7 +82,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationCreate}) } return newNSGroup.Copy(), nil @@ -133,7 +133,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationUpdate}) } return nil @@ -176,7 +176,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationDelete}) } return nil diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index d10d4464f..b2c8300d6 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -1087,7 +1087,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1105,7 +1105,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go index b6706ca45..c96b60bb2 100644 --- a/management/server/networks/manager.go +++ b/management/server/networks/manager.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" + serverTypes "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -177,7 +178,7 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw event() } - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetwork, Operation: serverTypes.UpdateOperationDelete}) return nil } diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 86f9b6579..5a0e26533 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -162,7 +162,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc event() } - go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) + go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationCreate}) return resource, nil } @@ -270,7 +270,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc } }() - go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) + go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationUpdate}) return resource, nil } @@ -352,7 +352,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net event() } - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationDelete}) return nil } diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go index 82cac424a..c7c3f2ff4 100644 --- a/management/server/networks/routers/manager.go +++ b/management/server/networks/routers/manager.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" + serverTypes "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -119,7 +120,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network)) - go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) + go m.accountManager.UpdateAccountPeers(ctx, router.AccountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationCreate}) return router, nil } @@ -183,7 +184,7 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network)) - go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) + go m.accountManager.UpdateAccountPeers(ctx, router.AccountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationUpdate}) return router, nil } @@ -217,7 +218,7 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo event() - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationDelete}) return nil } diff --git a/management/server/peer.go b/management/server/peer.go index a02e34e0d..8a39fbbb8 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -6,6 +6,7 @@ import ( b64 "encoding/base64" "fmt" "net" + "net/netip" "slices" "strings" "time" @@ -33,8 +34,8 @@ import ( const remoteJobsMinVer = "0.64.0" -// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if -// the current user is not an admin. +// GetPeers returns peers visible to the user within an account. +// Users with "peers:read" see all peers. Otherwise, users see only their own peers, or none if restricted by account settings. func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) { user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) if err != nil { @@ -46,14 +47,8 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return nil, status.NewPermissionValidationError(err) } - accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, nameFilter, ipFilter) - if err != nil { - return nil, err - } - - // @note if the user has permission to read peers it shows all account peers if allowed { - return accountPeers, nil + return am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, nameFilter, ipFilter) } settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) @@ -65,41 +60,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return []*nbpeer.Peer{}, nil } - // @note if it does not have permission read peers then only display it's own peers - peers := make([]*nbpeer.Peer, 0) - peersMap := make(map[string]*nbpeer.Peer) - - for _, peer := range accountPeers { - if user.Id != peer.UserID { - continue - } - peers = append(peers, peer) - peersMap[peer.ID] = peer - } - - return am.getUserAccessiblePeers(ctx, accountID, peersMap, peers) -} - -func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, accountID string, peersMap map[string]*nbpeer.Peer, peers []*nbpeer.Peer) ([]*nbpeer.Peer, error) { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, err - } - - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - return nil, err - } - - // fetch all the peers that have access to the user's peers - for _, peer := range peers { - aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap, account.GetActiveGroupUsers()) - for _, p := range aclPeers { - peersMap[p.ID] = p - } - } - - return maps.Values(peersMap), nil + return am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID) } // MarkPeerConnected marks peer as connected (true) or disconnected (false) @@ -561,6 +522,27 @@ func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID stri return account.Network.Copy(), err } +// peerWillHaveIPv6 checks whether the peer's future group memberships +// (auto-groups + allGroupID) overlap with IPv6EnabledGroups. +func peerWillHaveIPv6(settings *types.Settings, groupsToAdd []string, allGroupID string) bool { + enabledSet := make(map[string]struct{}, len(settings.IPv6EnabledGroups)) + for _, gid := range settings.IPv6EnabledGroups { + enabledSet[gid] = struct{}{} + } + + if allGroupID != "" { + if _, ok := enabledSet[allGroupID]; ok { + return true + } + } + for _, gid := range groupsToAdd { + if _, ok := enabledSet[gid]; ok { + return true + } + } + return false +} + type peerAddAuthConfig struct { AccountID string SetupKeyID string @@ -755,8 +737,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe maxAttempts := 10 for attempt := 1; attempt <= maxAttempts; attempt++ { - var freeIP net.IP - freeIP, err = types.AllocateRandomPeerIP(network.Net) + netPrefix, err := netip.ParsePrefix(network.Net.String()) + if err != nil { + return nil, nil, nil, fmt.Errorf("parse network prefix: %w", err) + } + freeIP, err := types.AllocateRandomPeerIP(netPrefix) if err != nil { return nil, nil, nil, fmt.Errorf("failed to get free IP: %w", err) } @@ -776,6 +761,29 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe newPeer.DNSLabel = freeLabel newPeer.IP = freeIP + if len(settings.IPv6EnabledGroups) > 0 && network.NetV6.IP != nil { + var allGroupID string + if !peer.ProxyMeta.Embedded { + allGroup, err := am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, "All") + if err != nil { + log.WithContext(ctx).Debugf("get All group for IPv6 allocation: %v", err) + } else { + allGroupID = allGroup.ID + } + } + if peerWillHaveIPv6(settings, peerAddConfig.GroupsToAdd, allGroupID) { + v6Prefix, err := netip.ParsePrefix(network.NetV6.String()) + if err != nil { + return nil, nil, nil, fmt.Errorf("parse IPv6 prefix: %w", err) + } + freeIPv6, err := types.AllocateRandomPeerIPv6(v6Prefix) + if err != nil { + return nil, nil, nil, fmt.Errorf("allocate peer IPv6: %w", err) + } + newPeer.IPv6 = freeIPv6 + } + } + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = transaction.AddPeerToAccount(ctx, newPeer) if err != nil { @@ -845,10 +853,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err) } - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err) - } - if newPeer == nil { return nil, nil, nil, fmt.Errorf("new peer is nil") } @@ -858,6 +862,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe if !addedByUser { opEvent.Meta["setup_key_name"] = peerAddConfig.SetupKeyName } + if newPeer.Status != nil && newPeer.Status.RequiresApproval { + opEvent.Meta["pending_approval"] = true + } if !temporary { am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) @@ -871,21 +878,24 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe return p, nmap, pc, err } -func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) { - ip = ip.To4() +func getPeerIPDNSLabel(ip netip.Addr, peerHostName string) (string, error) { + if !ip.Is4() { + return "", fmt.Errorf("DNS label generation requires an IPv4 address, got %s", ip) + } + b := ip.As4() dnsName, err := nbdns.GetParsedDomainLabel(peerHostName) if err != nil { return "", fmt.Errorf("failed to parse peer host name %s: %w", peerHostName, err) } - return fmt.Sprintf("%s-%d-%d", dnsName, ip[2], ip[3]), nil + return fmt.Sprintf("%s-%d-%d", dnsName, b[2], b[3]), nil } // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { var peer *nbpeer.Peer - var updated, versionChanged bool + var updated, versionChanged, ipv6CapabilityChanged bool var err error var postureChecks []*posture.Checks var peerGroupIDs []string @@ -921,7 +931,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy return err } + oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay) updated, versionChanged = peer.UpdateMetaIfNew(sync.Meta) + ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay) if updated { am.metrics.AccountManagerMetrics().CountPeerMetUpdate() log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID) @@ -945,7 +957,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy return nil, nil, nil, 0, err } - if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) { + if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(postureChecks) > 0 || versionChanged)) { err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) if err != nil { return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err) @@ -995,6 +1007,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer var peer *nbpeer.Peer var updateRemotePeers bool var isPeerUpdated bool + var ipv6CapabilityChanged bool var postureChecks []*posture.Checks var peerGroupIDs []string @@ -1034,7 +1047,9 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer return err } + oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay) isPeerUpdated, _ = peer.UpdateMetaIfNew(login.Meta) + ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay) if isPeerUpdated { am.metrics.AccountManagerMetrics().CountPeerMetUpdate() shouldStorePeer = true @@ -1072,7 +1087,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer return nil, nil, nil, err } - if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { + if updateRemotePeers || isStatusChanged || ipv6CapabilityChanged || (isPeerUpdated && len(postureChecks) > 0) { err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) if err != nil { return nil, nil, nil, fmt.Errorf("notify network map controller of peer update: %w", err) @@ -1230,7 +1245,8 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *types.Se return false } -// GetPeer for a given accountID, peerID and userID error if not found. +// GetPeer returns a peer visible to the user within an account. +// Users with "peers:read" permission can access any peer. Otherwise, users can access only their own peer. func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) if err != nil { @@ -1255,47 +1271,17 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, return peer, nil } - return am.checkIfUserOwnsPeer(ctx, accountID, userID, peer) -} - -func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, err - } - - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - return nil, err - } - - // it is also possible that user doesn't own the peer but some of his peers have access to it, - // this is a valid case, show the peer as well. - userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID) - if err != nil { - return nil, err - } - - for _, p := range userPeers { - aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap, account.GetActiveGroupUsers()) - for _, aclPeer := range aclPeers { - if aclPeer.ID == peer.ID { - return peer, nil - } - } - } - return nil, status.Errorf(status.Internal, "user %s has no access to peer %s under account %s", userID, peer.ID, accountID) } // UpdateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. -func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { - _ = am.networkMapController.UpdateAccountPeers(ctx, accountID) +func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) { + _ = am.networkMapController.UpdateAccountPeers(ctx, accountID, reason) } -func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { - _ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID) +func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) { + _ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID, reason) } // UpdateAccountPeer updates a single peer that belongs to an account. @@ -1405,6 +1391,10 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID var peers []*nbpeer.Peer for _, peer := range peersWithExpiry { + if peer.Status.LoginExpired { + continue + } + expired, _ := peer.LoginExpired(settings.PeerLoginExpiration) if expired { peers = append(peers, peer) diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index db392ddda..17df761a1 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -11,6 +11,12 @@ import ( "github.com/netbirdio/netbird/shared/management/http/api" ) +// Peer capability constants mirror the proto enum values. +const ( + PeerCapabilitySourcePrefixes int32 = 1 + PeerCapabilityIPv6Overlay int32 = 2 +) + // Peer represents a machine connected to the network. // The Peer is a WireGuard peer identified by a public key type Peer struct { @@ -21,7 +27,9 @@ type Peer struct { // WireGuard public key Key string // uniqueness index (check migrations) // IP address of the Peer - IP net.IP `gorm:"serializer:json"` // uniqueness index per accountID (check migrations) + IP netip.Addr `gorm:"serializer:json"` // uniqueness index per accountID (check migrations) + // IPv6 overlay address of the Peer, zero value if IPv6 is not enabled for the account. + IPv6 netip.Addr `gorm:"serializer:json"` // Meta is a Peer system meta data Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"` // ProxyMeta is metadata related to proxy peers @@ -115,6 +123,7 @@ type Flags struct { DisableFirewall bool BlockLANAccess bool BlockInbound bool + DisableIPv6 bool LazyConnectionEnabled bool } @@ -138,6 +147,7 @@ type PeerSystemMeta struct { //nolint:revive Environment Environment `gorm:"serializer:json"` Flags Flags `gorm:"serializer:json"` Files []File `gorm:"serializer:json"` + Capabilities []int32 `gorm:"serializer:json"` } func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool { @@ -182,7 +192,8 @@ func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool { p.SystemManufacturer == other.SystemManufacturer && p.Environment.Cloud == other.Environment.Cloud && p.Environment.Platform == other.Environment.Platform && - p.Flags.isEqual(other.Flags) + p.Flags.isEqual(other.Flags) && + capabilitiesEqual(p.Capabilities, other.Capabilities) } func (p PeerSystemMeta) isEmpty() bool { @@ -210,6 +221,37 @@ func (p *Peer) AddedWithSSOLogin() bool { return p.UserID != "" } +// HasCapability reports whether the peer has the given capability. +func (p *Peer) HasCapability(capability int32) bool { + return slices.Contains(p.Meta.Capabilities, capability) +} + +// SupportsIPv6 reports whether the peer supports IPv6 overlay. +func (p *Peer) SupportsIPv6() bool { + return !p.Meta.Flags.DisableIPv6 && p.HasCapability(PeerCapabilityIPv6Overlay) +} + +// SupportsSourcePrefixes reports whether the peer reads SourcePrefixes. +func (p *Peer) SupportsSourcePrefixes() bool { + return p.HasCapability(PeerCapabilitySourcePrefixes) +} + +func capabilitiesEqual(a, b []int32) bool { + if len(a) != len(b) { + return false + } + set := make(map[int32]struct{}, len(a)) + for _, c := range a { + set[c] = struct{}{} + } + for _, c := range b { + if _, ok := set[c]; !ok { + return false + } + } + return true +} + // Copy copies Peer object func (p *Peer) Copy() *Peer { peerStatus := p.Status @@ -221,6 +263,7 @@ func (p *Peer) Copy() *Peer { AccountID: p.AccountID, Key: p.Key, IP: p.IP, + IPv6: p.IPv6, Meta: p.Meta, Name: p.Name, DNSLabel: p.DNSLabel, @@ -323,9 +366,13 @@ func (p *Peer) FQDN(dnsDomain string) string { // EventMeta returns activity event meta related to the peer func (p *Peer) EventMeta(dnsDomain string) map[string]any { - return map[string]any{"name": p.Name, "fqdn": p.FQDN(dnsDomain), "ip": p.IP, "created_at": p.CreatedAt, + meta := map[string]any{"name": p.Name, "fqdn": p.FQDN(dnsDomain), "ip": p.IP, "created_at": p.CreatedAt, "location_city_name": p.Location.CityName, "location_country_code": p.Location.CountryCode, "location_geo_name_id": p.Location.GeoNameID, "location_connection_ip": p.Location.ConnectionIP} + if p.IPv6.IsValid() { + meta["ipv6"] = p.IPv6.String() + } + return meta } // Copy PeerStatus @@ -369,5 +416,6 @@ func (f Flags) isEqual(other Flags) bool { f.DisableFirewall == other.DisableFirewall && f.BlockLANAccess == other.BlockLANAccess && f.BlockInbound == other.BlockInbound && - f.LazyConnectionEnabled == other.LazyConnectionEnabled + f.LazyConnectionEnabled == other.LazyConnectionEnabled && + f.DisableIPv6 == other.DisableIPv6 } diff --git a/management/server/peer/peer_test.go b/management/server/peer/peer_test.go index 1aa3f6ffc..c5b512069 100644 --- a/management/server/peer/peer_test.go +++ b/management/server/peer/peer_test.go @@ -5,6 +5,7 @@ import ( "net/netip" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -141,3 +142,25 @@ func TestFlags_IsEqual(t *testing.T) { }) } } + +func TestPeerCapabilities(t *testing.T) { + tests := []struct { + name string + capabilities []int32 + ipv6 bool + srcPrefixes bool + }{ + {"no capabilities", nil, false, false}, + {"only source prefixes", []int32{PeerCapabilitySourcePrefixes}, false, true}, + {"only ipv6", []int32{PeerCapabilityIPv6Overlay}, true, false}, + {"both", []int32{PeerCapabilitySourcePrefixes, PeerCapabilityIPv6Overlay}, true, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Peer{Meta: PeerSystemMeta{Capabilities: tt.capabilities}} + assert.Equal(t, tt.ipv6, p.SupportsIPv6()) + assert.Equal(t, tt.srcPrefixes, p.SupportsSourcePrefixes()) + }) + } +} diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 6f8d924fd..07acf865f 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -179,11 +179,6 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { testGetNetworkMapGeneral(t) } -func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") - testGetNetworkMapGeneral(t) -} - func testGetNetworkMapGeneral(t *testing.T) { manager, _, err := createManager(t) if err != nil { @@ -564,25 +559,9 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { } assert.NotNil(t, peer) - // the user can see peer2 because peer1 of the user has access to peer2 due to the All group and the default rule 0 all-to-all access - peer, err = manager.GetPeer(context.Background(), accountID, peer2.ID, someUser) - if err != nil { - t.Fatal(err) - return - } - assert.NotNil(t, peer) - - // delete the all-to-all policy so that user's peer1 has no access to peer2 - for _, policy := range account.Policies { - err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser) - if err != nil { - t.Fatal(err) - return - } - } - - // at this point the user can't see the details of peer2 - peer, err = manager.GetPeer(context.Background(), accountID, peer2.ID, someUser) //nolint + // the user can NOT see peer2 because it is not owned by them. + // Regular users only see peers they directly own. + _, err = manager.GetPeer(context.Background(), accountID, peer2.ID, someUser) assert.Error(t, err) // admin users can always access all the peers @@ -775,7 +754,8 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou ID: fmt.Sprintf("peer-%d", i), DNSLabel: fmt.Sprintf("peer-%d", i), Key: peerKey.PublicKey().String(), - IP: net.ParseIP(fmt.Sprintf("100.64.%d.%d", i/256, i%256)), + IP: netip.MustParseAddr(fmt.Sprintf("100.64.%d.%d", i/256, i%256)), + IPv6: netip.MustParseAddr(fmt.Sprintf("fd00::%d", i+1)), Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, UserID: regularUser, } @@ -804,7 +784,15 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou account.Networks = append(account.Networks, network) ips := account.GetTakenIPs() - peerIP, err := types.AllocatePeerIP(account.Network.Net, ips) + peerIP, err := types.AllocatePeerIP(netip.MustParsePrefix(account.Network.Net.String()), ips) + if err != nil { + return nil, nil, "", "", err + } + v6Prefix, err := netip.ParsePrefix(account.Network.NetV6.String()) + if err != nil { + return nil, nil, "", "", err + } + peerIPv6, err := types.AllocateRandomPeerIPv6(v6Prefix) if err != nil { return nil, nil, "", "", err } @@ -815,6 +803,7 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou DNSLabel: fmt.Sprintf("peer-nr-%d", len(account.Peers)+1), Key: peerKey.PublicKey().String(), IP: peerIP, + IPv6: peerIPv6, Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, UserID: regularUser, Meta: nbpeer.PeerSystemMeta{ @@ -996,7 +985,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { start := time.Now() for i := 0; i < b.N; i++ { - manager.UpdateAccountPeers(ctx, account.Id) + manager.UpdateAccountPeers(ctx, account.Id, types.UpdateReason{}) } duration := time.Since(start) @@ -1016,11 +1005,6 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { } } -func TestUpdateAccountPeers_Experimental(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") - testUpdateAccountPeers(t) -} - func TestUpdateAccountPeers(t *testing.T) { testUpdateAccountPeers(t) } @@ -1059,7 +1043,7 @@ func testUpdateAccountPeers(t *testing.T) { peerChannels[peerID] = updateManager.CreateChannel(ctx, peerID) } - manager.UpdateAccountPeers(ctx, account.Id) + manager.UpdateAccountPeers(ctx, account.Id, types.UpdateReason{}) for _, channel := range peerChannels { update := <-channel @@ -1094,7 +1078,8 @@ func TestToSyncResponse(t *testing.T) { }, } peer := &nbpeer.Peer{ - IP: net.ParseIP("192.168.1.1"), + IP: netip.MustParseAddr("192.168.1.1"), + IPv6: netip.MustParseAddr("fd00::1"), SSHEnabled: true, Key: "peer-key", DNSLabel: "peer1", @@ -1105,9 +1090,21 @@ func TestToSyncResponse(t *testing.T) { Signature: "turn-pass", } networkMap := &types.NetworkMap{ - Network: &types.Network{Net: *ipnet, Serial: 1000}, - Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}}, - OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}}, + Network: &types.Network{Net: *ipnet, Serial: 1000}, + Peers: []*nbpeer.Peer{{ + IP: netip.MustParseAddr("192.168.1.2"), + IPv6: netip.MustParseAddr("fd00::2"), + Key: "peer2-key", + DNSLabel: "peer2", + SSHEnabled: true, + SSHKey: "peer2-ssh-key"}}, + OfflinePeers: []*nbpeer.Peer{{ + IP: netip.MustParseAddr("192.168.1.3"), + IPv6: netip.MustParseAddr("fd00::3"), + Key: "peer3-key", + DNSLabel: "peer3", + SSHEnabled: true, + SSHKey: "peer3-ssh-key"}}, Routes: []*nbroute.Route{ { ID: "route1", @@ -1254,6 +1251,7 @@ func TestToSyncResponse(t *testing.T) { assert.Equal(t, int64(53), response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetPort()) // assert network map Firewall assert.Equal(t, 1, len(response.NetworkMap.FirewallRules)) + //nolint:staticcheck // testing backward-compatible field assert.Equal(t, "192.168.1.2", response.NetworkMap.FirewallRules[0].PeerIP) assert.Equal(t, proto.RuleDirection_IN, response.NetworkMap.FirewallRules[0].Direction) assert.Equal(t, proto.RuleAction_ACCEPT, response.NetworkMap.FirewallRules[0].Action) @@ -1316,7 +1314,8 @@ func Test_RegisterPeerByUser(t *testing.T) { ID: xid.New().String(), AccountID: existingAccountID, Key: "newPeerKey", - IP: net.IP{123, 123, 123, 123}, + IP: netip.AddrFrom4([4]byte{123, 123, 123, 123}), + IPv6: netip.MustParseAddr("fd00::7b:7b:7b:7b"), Meta: nbpeer.PeerSystemMeta{ Hostname: "newPeer", GoOS: "linux", @@ -1404,7 +1403,8 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { newPeerTemplate := &nbpeer.Peer{ AccountID: existingAccountID, UserID: "", - IP: net.IP{123, 123, 123, 123}, + IP: netip.AddrFrom4([4]byte{123, 123, 123, 123}), + IPv6: netip.MustParseAddr("fd00::7b:7b:7b:7b"), Meta: nbpeer.PeerSystemMeta{ Hostname: "newPeer", GoOS: "linux", @@ -1565,7 +1565,8 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { AccountID: existingAccountID, Key: "newPeerKey", UserID: "", - IP: net.IP{123, 123, 123, 123}, + IP: netip.AddrFrom4([4]byte{123, 123, 123, 123}), + IPv6: netip.MustParseAddr("fd00::7b:7b:7b:7b"), Meta: nbpeer.PeerSystemMeta{ Hostname: "newPeer", GoOS: "linux", @@ -1600,7 +1601,6 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { } func Test_LoginPeer(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } @@ -1651,7 +1651,8 @@ func Test_LoginPeer(t *testing.T) { newPeerTemplate := &nbpeer.Peer{ AccountID: existingAccountID, UserID: "", - IP: net.IP{123, 123, 123, 123}, + IP: netip.AddrFrom4([4]byte{123, 123, 123, 123}), + IPv6: netip.MustParseAddr("fd00::7b:7b:7b:7b"), Meta: nbpeer.PeerSystemMeta{ Hostname: "newPeer", GoOS: "linux", @@ -1907,7 +1908,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1929,7 +1930,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1994,7 +1995,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2012,7 +2013,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2058,7 +2059,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2076,7 +2077,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2113,7 +2114,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2131,7 +2132,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2153,14 +2154,16 @@ func Test_DeletePeer(t *testing.T) { ID: "peer1", AccountID: accountID, Key: "key1", - IP: net.IP{1, 1, 1, 1}, + IP: netip.AddrFrom4([4]byte{1, 1, 1, 1}), + IPv6: netip.MustParseAddr("fd00::1"), DNSLabel: "peer1.test", }, "peer2": { ID: "peer2", AccountID: accountID, Key: "key2", - IP: net.IP{2, 2, 2, 2}, + IP: netip.AddrFrom4([4]byte{2, 2, 2, 2}), + IPv6: netip.MustParseAddr("fd00::2"), DNSLabel: "peer2.test", }, } @@ -2757,6 +2760,67 @@ func TestProcessPeerAddAuth(t *testing.T) { }) } +func TestPeerWillHaveIPv6(t *testing.T) { + settings := &types.Settings{ + IPv6EnabledGroups: []string{"all-group-id", "group-a"}, + } + + assert.True(t, peerWillHaveIPv6(settings, nil, "all-group-id"), "peer in All group should get IPv6") + assert.True(t, peerWillHaveIPv6(settings, []string{"group-a"}, ""), "peer with matching auto-group should get IPv6") + assert.False(t, peerWillHaveIPv6(settings, []string{"group-b"}, "other-all"), "peer with no matching groups should not get IPv6") + assert.False(t, peerWillHaveIPv6(settings, nil, ""), "embedded peer with no groups should not get IPv6") + + emptySettings := &types.Settings{IPv6EnabledGroups: []string{}} + assert.False(t, peerWillHaveIPv6(emptySettings, []string{"group-a"}, "all-group-id"), "no IPv6 groups means no IPv6") +} + +// TestSyncPeer_IPv6CapabilityChangePropagates ensures that when a peer reports +// a new IPv6 overlay capability via SyncPeer (e.g. after a client upgrade or +// flipping --disable-ipv6) without bumping its WtVersion, other account peers +// receive a fresh network map so their AAAA records for it become unstale. +func TestSyncPeer_IPv6CapabilityChangePropagates(t *testing.T) { + manager, updateManager, _, peer1, peer2, _ := setupNetworkMapTest(t) + + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + updateManager.CloseChannel(context.Background(), peer1.ID) + }) + + // Drain any initial updates from setup. + drain := func() { + for { + select { + case <-updMsg: + case <-time.After(200 * time.Millisecond): + return + } + } + } + drain() + + t.Run("no propagation when capabilities are unchanged", func(t *testing.T) { + _, _, _, _, err := manager.SyncPeer(context.Background(), types.PeerSync{ + WireGuardPubKey: peer2.Key, + Meta: peer2.Meta, + }, peer2.AccountID) + require.NoError(t, err) + peerShouldNotReceiveUpdate(t, updMsg) + }) + + t.Run("propagation when IPv6 capability is added", func(t *testing.T) { + newMeta := peer2.Meta + newMeta.Capabilities = append([]int32{}, peer2.Meta.Capabilities...) + newMeta.Capabilities = append(newMeta.Capabilities, nbpeer.PeerCapabilityIPv6Overlay) + + _, _, _, _, err := manager.SyncPeer(context.Background(), types.PeerSync{ + WireGuardPubKey: peer2.Key, + Meta: newMeta, + }, peer2.AccountID) + require.NoError(t, err) + peerShouldReceiveUpdate(t, updMsg) + }) +} + func TestUpdatePeer_DnsLabelCollisionWithFQDN(t *testing.T) { manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") diff --git a/management/server/policy.go b/management/server/policy.go index 3e84c3d10..40f3908e3 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -5,6 +5,7 @@ import ( _ "embed" "github.com/rs/xid" + "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -46,25 +47,40 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user var isUpdate = policy.ID != "" var updateAccountPeers bool var action = activity.PolicyAdded + var unchanged bool err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - if err = validatePolicy(ctx, transaction, accountID, policy); err != nil { - return err - } - - updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate) + existingPolicy, err := validatePolicy(ctx, transaction, accountID, policy) if err != nil { return err } - saveFunc := transaction.CreatePolicy if isUpdate { - action = activity.PolicyUpdated - saveFunc = transaction.SavePolicy - } + if policy.Equal(existingPolicy) { + logrus.WithContext(ctx).Tracef("policy update skipped because equal to stored one - policy id %s", policy.ID) + unchanged = true + return nil + } - if err = saveFunc(ctx, policy); err != nil { - return err + action = activity.PolicyUpdated + + updateAccountPeers, err = arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy) + if err != nil { + return err + } + + if err = transaction.SavePolicy(ctx, policy); err != nil { + return err + } + } else { + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy) + if err != nil { + return err + } + + if err = transaction.CreatePolicy(ctx, policy); err != nil { + return err + } } return transaction.IncrementNetworkSerial(ctx, accountID) @@ -73,10 +89,18 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return nil, err } + if unchanged { + return policy, nil + } + am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + policyOp := types.UpdateOperationCreate + if isUpdate { + policyOp = types.UpdateOperationUpdate + } + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePolicy, Operation: policyOp}) } return policy, nil @@ -101,7 +125,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po return err } - updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false) + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy) if err != nil { return err } @@ -119,7 +143,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePolicy, Operation: types.UpdateOperationDelete}) } return nil @@ -138,34 +162,37 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) } -// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. -func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) { - if isUpdate { - existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) - if err != nil { - return false, err - } - - if !policy.Enabled && !existingPolicy.Enabled { - return false, nil - } - - for _, rule := range existingPolicy.Rules { - if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { - return true, nil - } - } - - hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups()) - if err != nil { - return false, err - } - - if hasPeers { +// arePolicyChangesAffectPeers checks if a policy (being created or deleted) will affect any associated peers. +func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, policy *types.Policy) (bool, error) { + for _, rule := range policy.Rules { + if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { return true, nil } } + return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups()) +} + +func arePolicyChangesAffectPeersWithExisting(ctx context.Context, transaction store.Store, policy *types.Policy, existingPolicy *types.Policy) (bool, error) { + if !policy.Enabled && !existingPolicy.Enabled { + return false, nil + } + + for _, rule := range existingPolicy.Rules { + if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { + return true, nil + } + } + + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups()) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + for _, rule := range policy.Rules { if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { return true, nil @@ -175,12 +202,15 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups()) } -// validatePolicy validates the policy and its rules. -func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { +// validatePolicy validates the policy and its rules. For updates it returns +// the existing policy loaded from the store so callers can avoid a second read. +func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) (*types.Policy, error) { + var existingPolicy *types.Policy if policy.ID != "" { - existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) + var err error + existingPolicy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) if err != nil { - return err + return nil, err } // TODO: Refactor to support multiple rules per policy @@ -191,7 +221,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri for _, rule := range policy.Rules { if rule.ID != "" && !existingRuleIDs[rule.ID] { - return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID) + return nil, status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID) } } } else { @@ -201,12 +231,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups()) if err != nil { - return err + return nil, err } postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks) if err != nil { - return err + return nil, err } for i, rule := range policy.Rules { @@ -225,7 +255,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks) } - return nil + return existingPolicy, nil } // getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. diff --git a/management/server/policy_test.go b/management/server/policy_test.go index a3f987732..1eae07e79 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -3,7 +3,7 @@ package server import ( "context" "fmt" - "net" + "net/netip" "testing" "time" @@ -20,53 +20,53 @@ func TestAccount_getPeersByPolicy(t *testing.T) { Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", - IP: net.ParseIP("100.65.14.88"), + IP: netip.MustParseAddr("100.65.14.88"), Status: &nbpeer.PeerStatus{}, }, "peerB": { ID: "peerB", - IP: net.ParseIP("100.65.80.39"), + IP: netip.MustParseAddr("100.65.80.39"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{WtVersion: "0.48.0"}, }, "peerC": { ID: "peerC", - IP: net.ParseIP("100.65.254.139"), + IP: netip.MustParseAddr("100.65.254.139"), Status: &nbpeer.PeerStatus{}, }, "peerD": { ID: "peerD", - IP: net.ParseIP("100.65.62.5"), + IP: netip.MustParseAddr("100.65.62.5"), Status: &nbpeer.PeerStatus{}, }, "peerE": { ID: "peerE", - IP: net.ParseIP("100.65.32.206"), + IP: netip.MustParseAddr("100.65.32.206"), Status: &nbpeer.PeerStatus{}, }, "peerF": { ID: "peerF", - IP: net.ParseIP("100.65.250.202"), + IP: netip.MustParseAddr("100.65.250.202"), Status: &nbpeer.PeerStatus{}, }, "peerG": { ID: "peerG", - IP: net.ParseIP("100.65.13.186"), + IP: netip.MustParseAddr("100.65.13.186"), Status: &nbpeer.PeerStatus{}, }, "peerH": { ID: "peerH", - IP: net.ParseIP("100.65.29.55"), + IP: netip.MustParseAddr("100.65.29.55"), Status: &nbpeer.PeerStatus{}, }, "peerI": { ID: "peerI", - IP: net.ParseIP("100.65.31.2"), + IP: netip.MustParseAddr("100.65.31.2"), Status: &nbpeer.PeerStatus{}, }, "peerK": { ID: "peerK", - IP: net.ParseIP("100.32.80.1"), + IP: netip.MustParseAddr("100.32.80.1"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{WtVersion: "0.30.0"}, }, @@ -540,17 +540,17 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", - IP: net.ParseIP("100.65.14.88"), + IP: netip.MustParseAddr("100.65.14.88"), Status: &nbpeer.PeerStatus{}, }, "peerB": { ID: "peerB", - IP: net.ParseIP("100.65.80.39"), + IP: netip.MustParseAddr("100.65.80.39"), Status: &nbpeer.PeerStatus{}, }, "peerC": { ID: "peerC", - IP: net.ParseIP("100.65.254.139"), + IP: netip.MustParseAddr("100.65.254.139"), Status: &nbpeer.PeerStatus{}, }, }, @@ -746,7 +746,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", - IP: net.ParseIP("100.65.14.88"), + IP: netip.MustParseAddr("100.65.14.88"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ GoOS: "linux", @@ -756,7 +756,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, "peerB": { ID: "peerB", - IP: net.ParseIP("100.65.80.39"), + IP: netip.MustParseAddr("100.65.80.39"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ GoOS: "linux", @@ -766,7 +766,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, "peerC": { ID: "peerC", - IP: net.ParseIP("100.65.254.139"), + IP: netip.MustParseAddr("100.65.254.139"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ GoOS: "linux", @@ -776,7 +776,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, "peerD": { ID: "peerD", - IP: net.ParseIP("100.65.62.5"), + IP: netip.MustParseAddr("100.65.62.5"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ GoOS: "linux", @@ -786,7 +786,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, "peerE": { ID: "peerE", - IP: net.ParseIP("100.65.32.206"), + IP: netip.MustParseAddr("100.65.32.206"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ GoOS: "linux", @@ -796,7 +796,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, "peerF": { ID: "peerF", - IP: net.ParseIP("100.65.250.202"), + IP: netip.MustParseAddr("100.65.250.202"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ GoOS: "linux", @@ -806,7 +806,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, "peerG": { ID: "peerG", - IP: net.ParseIP("100.65.13.186"), + IP: netip.MustParseAddr("100.65.13.186"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ GoOS: "linux", @@ -816,7 +816,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, "peerH": { ID: "peerH", - IP: net.ParseIP("100.65.29.55"), + IP: netip.MustParseAddr("100.65.29.55"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ GoOS: "linux", @@ -826,7 +826,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, "peerI": { ID: "peerI", - IP: net.ParseIP("100.65.21.56"), + IP: netip.MustParseAddr("100.65.21.56"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ GoOS: "windows", @@ -1231,7 +1231,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1263,7 +1263,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1294,7 +1294,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1314,7 +1314,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1355,7 +1355,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1373,7 +1373,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } @@ -1393,7 +1393,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) diff --git a/management/server/posture/network.go b/management/server/posture/network.go index f78744143..4b4b3ccaa 100644 --- a/management/server/posture/network.go +++ b/management/server/posture/network.go @@ -17,19 +17,48 @@ type PeerNetworkRangeCheck struct { var _ Check = (*PeerNetworkRangeCheck)(nil) +// prefixContains reports whether outer fully contains inner (equal counts as contained). +// Requires the same address family, that outer is no more specific than inner (its +// netmask is shorter or equal), and that inner's network address falls inside outer. +// This is stricter than netip.Prefix.Contains(Addr) — a peer's /24 NIC will not match a +// configured /32 rule, since the rule covers a single host but the NIC describes a whole +// subnet whose host bits are unknown. +func prefixContains(outer, inner netip.Prefix) bool { + outer = outer.Masked() + inner = inner.Masked() + return outer.Bits() <= inner.Bits() && + outer.Addr().BitLen() == inner.Addr().BitLen() && // same family + outer.Contains(inner.Addr()) +} + +// Check evaluates configured ranges against the peer's local network interface prefixes +// and its public connection IP (as a /32 or /128). A configured range matches when it +// fully contains one of those prefixes, so operators can target both private subnets +// and public CIDRs (e.g. 1.0.0.0/24, 2.2.2.2/32). Including the connection IP is what +// lets a public-range posture check work — peer.Meta.NetworkAddresses only carries +// local NIC addresses. func (p *PeerNetworkRangeCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) { - if len(peer.Meta.NetworkAddresses) == 0 { + peerPrefixes := make([]netip.Prefix, 0, len(peer.Meta.NetworkAddresses)+1) + for _, peerNetAddr := range peer.Meta.NetworkAddresses { + peerPrefixes = append(peerPrefixes, peerNetAddr.NetIP) + } + // Unmap collapses 4-in-6 forms (::ffff:a.b.c.d) so an IPv4 range matches. + if connIP := peer.Location.ConnectionIP; len(connIP) > 0 { + if addr, ok := netip.AddrFromSlice(connIP); ok { + addr = addr.Unmap() + peerPrefixes = append(peerPrefixes, netip.PrefixFrom(addr, addr.BitLen())) + } + } + + if len(peerPrefixes) == 0 { return false, fmt.Errorf("peer's does not contain peer network range addresses") } - maskedPrefixes := make([]netip.Prefix, 0, len(p.Ranges)) - for _, prefix := range p.Ranges { - maskedPrefixes = append(maskedPrefixes, prefix.Masked()) - } - - for _, peerNetAddr := range peer.Meta.NetworkAddresses { - peerMaskedPrefix := peerNetAddr.NetIP.Masked() - if slices.Contains(maskedPrefixes, peerMaskedPrefix) { + for _, peerPrefix := range peerPrefixes { + for _, rangePrefix := range p.Ranges { + if !prefixContains(rangePrefix, peerPrefix) { + continue + } switch p.Action { case CheckActionDeny: return false, nil diff --git a/management/server/posture/network_test.go b/management/server/posture/network_test.go index a841bbe08..4af394c62 100644 --- a/management/server/posture/network_test.go +++ b/management/server/posture/network_test.go @@ -2,6 +2,7 @@ package posture import ( "context" + "net" "net/netip" "testing" @@ -134,6 +135,205 @@ func TestPeerNetworkRangeCheck_Check(t *testing.T) { wantErr: true, isValid: false, }, + { + name: "Peer connection IP matches the denied /32", + check: PeerNetworkRangeCheck{ + Action: CheckActionDeny, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("109.41.115.194/32"), + }, + }, + peer: nbpeer.Peer{ + Meta: nbpeer.PeerSystemMeta{ + NetworkAddresses: []nbpeer.NetworkAddress{ + {NetIP: netip.MustParsePrefix("192.168.0.123/24")}, + }, + }, + Location: nbpeer.Location{ConnectionIP: net.ParseIP("109.41.115.194")}, + }, + wantErr: false, + isValid: false, + }, + { + name: "Peer connection IP does not match the denied /32", + check: PeerNetworkRangeCheck{ + Action: CheckActionDeny, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("109.41.115.194/32"), + }, + }, + peer: nbpeer.Peer{ + Meta: nbpeer.PeerSystemMeta{ + NetworkAddresses: []nbpeer.NetworkAddress{ + {NetIP: netip.MustParsePrefix("192.168.0.123/24")}, + }, + }, + Location: nbpeer.Location{ConnectionIP: net.ParseIP("8.8.8.8")}, + }, + wantErr: false, + isValid: true, + }, + { + name: "Peer connection IP matches the allowed /32 with no NetworkAddresses", + check: PeerNetworkRangeCheck{ + Action: CheckActionAllow, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("109.41.115.194/32"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("109.41.115.194")}, + }, + wantErr: false, + isValid: true, + }, + { + name: "IPv6 connection IP matches the denied /128", + check: PeerNetworkRangeCheck{ + Action: CheckActionDeny, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("2001:db8::1/128"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("2001:db8::1")}, + }, + wantErr: false, + isValid: false, + }, + { + name: "IPv6 connection IP does not match the denied /128", + check: PeerNetworkRangeCheck{ + Action: CheckActionDeny, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("2001:db8::1/128"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("2001:db8::2")}, + }, + wantErr: false, + isValid: true, + }, + { + name: "IPv4-mapped IPv6 connection IP matches IPv4 /32", + check: PeerNetworkRangeCheck{ + Action: CheckActionDeny, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("109.41.115.194/32"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("::ffff:109.41.115.194")}, + }, + wantErr: false, + isValid: false, + }, + { + name: "Connection IP falls inside an allowed /24 range", + check: PeerNetworkRangeCheck{ + Action: CheckActionAllow, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("1.0.0.0/24"), + netip.MustParsePrefix("2.2.2.2/32"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("1.0.0.55")}, + }, + wantErr: false, + isValid: true, + }, + { + name: "Connection IP falls inside an allowed /23 range", + check: PeerNetworkRangeCheck{ + Action: CheckActionAllow, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("3.0.0.0/23"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("3.0.1.200")}, + }, + wantErr: false, + isValid: true, + }, + { + name: "Connection IP outside the allowed /24 range", + check: PeerNetworkRangeCheck{ + Action: CheckActionAllow, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("1.0.0.0/24"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("1.0.1.5")}, + }, + wantErr: false, + isValid: false, + }, + { + name: "Connection IP inside a denied /24 range", + check: PeerNetworkRangeCheck{ + Action: CheckActionDeny, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("1.0.0.0/24"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("1.0.0.7")}, + }, + wantErr: false, + isValid: false, + }, + { + name: "Local NIC /24 does not match a /32 rule even if host bit lines up", + check: PeerNetworkRangeCheck{ + Action: CheckActionAllow, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.5/32"), + }, + }, + peer: nbpeer.Peer{ + Meta: nbpeer.PeerSystemMeta{ + NetworkAddresses: []nbpeer.NetworkAddress{ + {NetIP: netip.MustParsePrefix("192.168.0.5/24")}, + }, + }, + }, + wantErr: false, + isValid: false, + }, + { + name: "Local NIC address inside an allowed /16 range", + check: PeerNetworkRangeCheck{ + Action: CheckActionAllow, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + peer: nbpeer.Peer{ + Meta: nbpeer.PeerSystemMeta{ + NetworkAddresses: []nbpeer.NetworkAddress{ + {NetIP: netip.MustParsePrefix("192.168.5.7/24")}, + }, + }, + }, + wantErr: false, + isValid: true, + }, + { + name: "Empty NetworkAddresses and empty ConnectionIP still errors", + check: PeerNetworkRangeCheck{ + Action: CheckActionDeny, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("109.41.115.194/32"), + }, + }, + peer: nbpeer.Peer{}, + wantErr: true, + isValid: false, + }, } for _, tt := range tests { diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 9562487c0..1e3ce4b8a 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -11,6 +11,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -76,7 +77,11 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + postureOp := types.UpdateOperationCreate + if isUpdate { + postureOp = types.UpdateOperationUpdate + } + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePostureCheck, Operation: postureOp}) } return postureChecks, nil diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 7f0a48dc7..394f0d896 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -244,7 +244,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -273,7 +273,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -292,7 +292,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -395,7 +395,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -438,7 +438,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) diff --git a/management/server/route.go b/management/server/route.go index 2b4f11d05..a9561faf0 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -191,7 +191,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationCreate}) } return newRoute, nil @@ -245,7 +245,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) if oldRouteAffectsPeers || newRouteAffectsPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationUpdate}) } return nil @@ -288,7 +288,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationDelete}) } return nil diff --git a/management/server/route_test.go b/management/server/route_test.go index 91b2cf982..79014790f 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -2,10 +2,7 @@ package server import ( "context" - "fmt" - "net" "net/netip" - "sort" "testing" "time" @@ -1335,14 +1332,24 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou return nil, err } + v6Prefix, err := netip.ParsePrefix(account.Network.NetV6.String()) + if err != nil { + return nil, err + } + ips := account.GetTakenIPs() - peer1IP, err := types.AllocatePeerIP(account.Network.Net, ips) + peer1IP, err := types.AllocatePeerIP(netip.MustParsePrefix(account.Network.Net.String()), ips) + if err != nil { + return nil, err + } + peer1IPv6, err := types.AllocateRandomPeerIPv6(v6Prefix) if err != nil { return nil, err } peer1 := &nbpeer.Peer{ IP: peer1IP, + IPv6: peer1IPv6, ID: peer1ID, Key: peer1Key, Name: "test-host1@netbird.io", @@ -1363,13 +1370,18 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou account.Peers[peer1.ID] = peer1 ips = account.GetTakenIPs() - peer2IP, err := types.AllocatePeerIP(account.Network.Net, ips) + peer2IP, err := types.AllocatePeerIP(netip.MustParsePrefix(account.Network.Net.String()), ips) + if err != nil { + return nil, err + } + peer2IPv6, err := types.AllocateRandomPeerIPv6(v6Prefix) if err != nil { return nil, err } peer2 := &nbpeer.Peer{ IP: peer2IP, + IPv6: peer2IPv6, ID: peer2ID, Key: peer2Key, Name: "test-host2@netbird.io", @@ -1390,13 +1402,18 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou account.Peers[peer2.ID] = peer2 ips = account.GetTakenIPs() - peer3IP, err := types.AllocatePeerIP(account.Network.Net, ips) + peer3IP, err := types.AllocatePeerIP(netip.MustParsePrefix(account.Network.Net.String()), ips) + if err != nil { + return nil, err + } + peer3IPv6, err := types.AllocateRandomPeerIPv6(v6Prefix) if err != nil { return nil, err } peer3 := &nbpeer.Peer{ IP: peer3IP, + IPv6: peer3IPv6, ID: peer3ID, Key: peer3Key, Name: "test-host3@netbird.io", @@ -1417,13 +1434,18 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou account.Peers[peer3.ID] = peer3 ips = account.GetTakenIPs() - peer4IP, err := types.AllocatePeerIP(account.Network.Net, ips) + peer4IP, err := types.AllocatePeerIP(netip.MustParsePrefix(account.Network.Net.String()), ips) + if err != nil { + return nil, err + } + peer4IPv6, err := types.AllocateRandomPeerIPv6(v6Prefix) if err != nil { return nil, err } peer4 := &nbpeer.Peer{ IP: peer4IP, + IPv6: peer4IPv6, ID: peer4ID, Key: peer4Key, Name: "test-host4@netbird.io", @@ -1444,13 +1466,18 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou account.Peers[peer4.ID] = peer4 ips = account.GetTakenIPs() - peer5IP, err := types.AllocatePeerIP(account.Network.Net, ips) + peer5IP, err := types.AllocatePeerIP(netip.MustParsePrefix(account.Network.Net.String()), ips) + if err != nil { + return nil, err + } + peer5IPv6, err := types.AllocateRandomPeerIPv6(v6Prefix) if err != nil { return nil, err } peer5 := &nbpeer.Peer{ IP: peer5IP, + IPv6: peer5IPv6, ID: peer5ID, Key: peer5Key, Name: "test-host5@netbird.io", @@ -1551,7 +1578,8 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", - IP: net.ParseIP("100.65.14.88"), + IP: netip.MustParseAddr("100.65.14.88"), + IPv6: netip.MustParseAddr("fd00::1"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ GoOS: "linux", @@ -1559,18 +1587,21 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }, "peerB": { ID: "peerB", - IP: net.ParseIP(peerBIp), + IP: netip.MustParseAddr(peerBIp), + IPv6: netip.MustParseAddr("fd00::2"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{}, }, "peerC": { ID: "peerC", - IP: net.ParseIP(peerCIp), + IP: netip.MustParseAddr(peerCIp), + IPv6: netip.MustParseAddr("fd00::3"), Status: &nbpeer.PeerStatus{}, }, "peerD": { ID: "peerD", - IP: net.ParseIP("100.65.62.5"), + IP: netip.MustParseAddr("100.65.62.5"), + IPv6: netip.MustParseAddr("fd00::4"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ GoOS: "linux", @@ -1578,7 +1609,8 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }, "peerE": { ID: "peerE", - IP: net.ParseIP("100.65.32.206"), + IP: netip.MustParseAddr("100.65.32.206"), + IPv6: netip.MustParseAddr("fd00::5"), Key: peer1Key, Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ @@ -1587,27 +1619,32 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }, "peerF": { ID: "peerF", - IP: net.ParseIP("100.65.250.202"), + IP: netip.MustParseAddr("100.65.250.202"), + IPv6: netip.MustParseAddr("fd00::6"), Status: &nbpeer.PeerStatus{}, }, "peerG": { ID: "peerG", - IP: net.ParseIP("100.65.13.186"), + IP: netip.MustParseAddr("100.65.13.186"), + IPv6: netip.MustParseAddr("fd00::7"), Status: &nbpeer.PeerStatus{}, }, "peerH": { ID: "peerH", - IP: net.ParseIP(peerHIp), + IP: netip.MustParseAddr(peerHIp), + IPv6: netip.MustParseAddr("fd00::8"), Status: &nbpeer.PeerStatus{}, }, "peerJ": { ID: "peerJ", - IP: net.ParseIP(peerJIp), + IP: netip.MustParseAddr(peerJIp), + IPv6: netip.MustParseAddr("fd00::a"), Status: &nbpeer.PeerStatus{}, }, "peerK": { ID: "peerK", - IP: net.ParseIP(peerKIp), + IP: netip.MustParseAddr(peerKIp), + IPv6: netip.MustParseAddr("fd00::b"), Status: &nbpeer.PeerStatus{}, }, }, @@ -1840,11 +1877,6 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }, } - validatedPeers := make(map[string]struct{}) - for p := range account.Peers { - validatedPeers[p] = struct{}{} - } - t.Run("check applied policies for the route", func(t *testing.T) { route1 := account.Routes["route1"] policies := types.GetAllRoutePoliciesFromGroups(account, route1.AccessControlGroups) @@ -1858,116 +1890,6 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { policies = types.GetAllRoutePoliciesFromGroups(account, route3.AccessControlGroups) assert.Len(t, policies, 0) }) - - t.Run("check peer routes firewall rules", func(t *testing.T) { - routesFirewallRules := account.GetPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers) - assert.Len(t, routesFirewallRules, 4) - - expectedRoutesFirewallRules := []*types.RouteFirewallRule{ - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerCIp), - fmt.Sprintf(types.AllowedIPsFormat, peerHIp), - fmt.Sprintf(types.AllowedIPsFormat, peerBIp), - }, - Action: "accept", - Destination: "192.168.0.0/16", - Protocol: "all", - Port: 80, - RouteID: "route1:peerA", - }, - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerCIp), - fmt.Sprintf(types.AllowedIPsFormat, peerHIp), - fmt.Sprintf(types.AllowedIPsFormat, peerBIp), - }, - Action: "accept", - Destination: "192.168.0.0/16", - Protocol: "all", - Port: 320, - RouteID: "route1:peerA", - }, - } - additionalFirewallRule := []*types.RouteFirewallRule{ - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerJIp), - }, - Action: "accept", - Destination: "192.168.10.0/16", - Protocol: "tcp", - Port: 80, - RouteID: "route4:peerA", - }, - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerKIp), - }, - Action: "accept", - Destination: "192.168.10.0/16", - Protocol: "all", - RouteID: "route4:peerA", - }, - } - - assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(append(expectedRoutesFirewallRules, additionalFirewallRule...))) - - // peerD is also the routing peer for route1, should contain same routes firewall rules as peerA - routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) - assert.Len(t, routesFirewallRules, 2) - for _, rule := range expectedRoutesFirewallRules { - rule.RouteID = "route1:peerD" - } - assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) - - // peerE is a single routing peer for route 2 and route 3 - routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers) - assert.Len(t, routesFirewallRules, 3) - - expectedRoutesFirewallRules = []*types.RouteFirewallRule{ - { - SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"}, - Action: "accept", - Destination: existingNetwork.String(), - Protocol: "tcp", - PortRange: types.RulePortRange{Start: 80, End: 350}, - RouteID: "route2", - }, - { - SourceRanges: []string{"0.0.0.0/0"}, - Action: "accept", - Destination: "192.0.2.0/32", - Protocol: "all", - Domains: domain.List{"example.com"}, - IsDynamic: true, - RouteID: "route3", - }, - { - SourceRanges: []string{"::/0"}, - Action: "accept", - Destination: "192.0.2.0/32", - Protocol: "all", - Domains: domain.List{"example.com"}, - IsDynamic: true, - RouteID: "route3", - }, - } - assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) - - // peerC is part of route1 distribution groups but should not receive the routes firewall rules - routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) - assert.Len(t, routesFirewallRules, 0) - }) - -} - -// orderList is a helper function to sort a list of strings -func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFirewallRule { - for _, rule := range ruleList { - sort.Strings(rule.SourceRanges) - } - return ruleList } func TestRouteAccountPeersUpdate(t *testing.T) { @@ -2070,7 +1992,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } @@ -2107,7 +2029,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2127,7 +2049,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2145,7 +2067,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2185,7 +2107,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2225,7 +2147,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2246,84 +2168,101 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", - IP: net.ParseIP("100.65.14.88"), + IP: netip.MustParseAddr("100.65.14.88"), + IPv6: netip.MustParseAddr("fd00::1"), Key: "peerA", Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ - GoOS: "linux", + GoOS: "linux", + Capabilities: []int32{nbpeer.PeerCapabilityIPv6Overlay}, }, }, "peerB": { ID: "peerB", - IP: net.ParseIP(peerBIp), + IP: netip.MustParseAddr(peerBIp), + IPv6: netip.MustParseAddr("fd00::2"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{}, }, "peerC": { ID: "peerC", - IP: net.ParseIP(peerCIp), + IP: netip.MustParseAddr(peerCIp), + IPv6: netip.MustParseAddr("fd00::3"), Status: &nbpeer.PeerStatus{}, }, "peerD": { ID: "peerD", - IP: net.ParseIP("100.65.62.5"), + IP: netip.MustParseAddr("100.65.62.5"), + IPv6: netip.MustParseAddr("fd00::4"), Key: "peerD", Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ - GoOS: "linux", + GoOS: "linux", + Capabilities: []int32{nbpeer.PeerCapabilityIPv6Overlay}, }, }, "peerE": { ID: "peerE", - IP: net.ParseIP("100.65.32.206"), + IP: netip.MustParseAddr("100.65.32.206"), + IPv6: netip.MustParseAddr("fd00::5"), Key: "peerE", Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ - GoOS: "linux", + GoOS: "linux", + Capabilities: []int32{nbpeer.PeerCapabilityIPv6Overlay}, }, }, "peerF": { ID: "peerF", - IP: net.ParseIP("100.65.250.202"), + IP: netip.MustParseAddr("100.65.250.202"), + IPv6: netip.MustParseAddr("fd00::6"), Status: &nbpeer.PeerStatus{}, }, "peerG": { ID: "peerG", - IP: net.ParseIP("100.65.13.186"), + IP: netip.MustParseAddr("100.65.13.186"), + IPv6: netip.MustParseAddr("fd00::7"), Status: &nbpeer.PeerStatus{}, }, "peerH": { ID: "peerH", - IP: net.ParseIP(peerHIp), + IP: netip.MustParseAddr(peerHIp), + IPv6: netip.MustParseAddr("fd00::8"), Status: &nbpeer.PeerStatus{}, }, "peerJ": { ID: "peerJ", - IP: net.ParseIP(peerJIp), + IP: netip.MustParseAddr(peerJIp), + IPv6: netip.MustParseAddr("fd00::a"), Status: &nbpeer.PeerStatus{}, }, "peerK": { ID: "peerK", - IP: net.ParseIP(peerKIp), + IP: netip.MustParseAddr(peerKIp), + IPv6: netip.MustParseAddr("fd00::b"), Status: &nbpeer.PeerStatus{}, }, "peerL": { ID: "peerL", - IP: net.ParseIP("100.65.19.186"), + IP: netip.MustParseAddr("100.65.19.186"), + IPv6: netip.MustParseAddr("fd00::d"), Key: "peerL", Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ - GoOS: "linux", + GoOS: "linux", + Capabilities: []int32{nbpeer.PeerCapabilityIPv6Overlay}, }, }, "peerM": { ID: "peerM", - IP: net.ParseIP(peerMIp), + IP: netip.MustParseAddr(peerMIp), + IPv6: netip.MustParseAddr("fd00::e"), Status: &nbpeer.PeerStatus{}, }, "peerN": { ID: "peerN", - IP: net.ParseIP("100.65.20.18"), + IP: netip.MustParseAddr("100.65.20.18"), + IPv6: netip.MustParseAddr("fd00::f"), Key: "peerN", Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ @@ -2332,7 +2271,8 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { }, "peerO": { ID: "peerO", - IP: net.ParseIP(peerOIp), + IP: netip.MustParseAddr(peerOIp), + IPv6: netip.MustParseAddr("fd00::10"), Status: &nbpeer.PeerStatus{}, }, }, @@ -2665,11 +2605,6 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { }, } - validatedPeers := make(map[string]struct{}) - for p := range account.Peers { - validatedPeers[p] = struct{}{} - } - t.Run("validate applied policies for different network resources", func(t *testing.T) { // Test case: Resource1 is directly applied to the policy (policyResource1) policies := account.GetPoliciesForNetworkResource("resource1") @@ -2693,127 +2628,4 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { policies = account.GetPoliciesForNetworkResource("resource6") assert.Len(t, policies, 1, "resource6 should have exactly 1 policy applied via access control groups") }) - - t.Run("validate routing peer firewall rules for network resources", func(t *testing.T) { - resourcePoliciesMap := account.GetResourcePoliciesMap() - resourceRoutersMap := account.GetResourceRoutersMap() - _, routes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), "peerA", resourcePoliciesMap, resourceRoutersMap) - firewallRules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerA"], validatedPeers, routes, resourcePoliciesMap) - assert.Len(t, firewallRules, 4) - assert.Len(t, sourcePeers, 5) - - expectedFirewallRules := []*types.RouteFirewallRule{ - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerCIp), - fmt.Sprintf(types.AllowedIPsFormat, peerHIp), - fmt.Sprintf(types.AllowedIPsFormat, peerBIp), - }, - Action: "accept", - Destination: "192.168.0.0/16", - Protocol: "all", - Port: 80, - RouteID: "resource2:peerA", - }, - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerCIp), - fmt.Sprintf(types.AllowedIPsFormat, peerHIp), - fmt.Sprintf(types.AllowedIPsFormat, peerBIp), - }, - Action: "accept", - Destination: "192.168.0.0/16", - Protocol: "all", - Port: 320, - RouteID: "resource2:peerA", - }, - } - - additionalFirewallRules := []*types.RouteFirewallRule{ - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerJIp), - }, - Action: "accept", - Destination: "192.0.2.0/32", - Protocol: "tcp", - Port: 80, - Domains: domain.List{"example.com"}, - IsDynamic: true, - RouteID: "resource4:peerA", - }, - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerKIp), - }, - Action: "accept", - Destination: "192.0.2.0/32", - Protocol: "all", - Domains: domain.List{"example.com"}, - IsDynamic: true, - RouteID: "resource4:peerA", - }, - } - assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(append(expectedFirewallRules, additionalFirewallRules...))) - - // peerD is also the routing peer for resource2 - _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerD", resourcePoliciesMap, resourceRoutersMap) - firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerD"], validatedPeers, routes, resourcePoliciesMap) - assert.Len(t, firewallRules, 2) - for _, rule := range expectedFirewallRules { - rule.RouteID = "resource2:peerD" - } - assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) - assert.Len(t, sourcePeers, 3) - - // peerE is a single routing peer for resource1 and resource3 - // PeerE should only receive rules for resource1 since resource3 has no applied policy - _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerE", resourcePoliciesMap, resourceRoutersMap) - firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerE"], validatedPeers, routes, resourcePoliciesMap) - assert.Len(t, firewallRules, 1) - assert.Len(t, sourcePeers, 2) - - expectedFirewallRules = []*types.RouteFirewallRule{ - { - SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"}, - Action: "accept", - Destination: "10.10.10.0/24", - Protocol: "tcp", - PortRange: types.RulePortRange{Start: 80, End: 350}, - RouteID: "resource1:peerE", - }, - } - assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) - - // peerC is part of distribution groups for resource2 but should not receive the firewall rules - firewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) - assert.Len(t, firewallRules, 0) - - // peerL is the single routing peer for resource5 - _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerL", resourcePoliciesMap, resourceRoutersMap) - assert.Len(t, routes, 1) - firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerL"], validatedPeers, routes, resourcePoliciesMap) - assert.Len(t, firewallRules, 1) - assert.Len(t, sourcePeers, 1) - - expectedFirewallRules = []*types.RouteFirewallRule{ - { - SourceRanges: []string{"100.65.29.67/32"}, - Action: "accept", - Destination: "10.12.12.1/32", - Protocol: "tcp", - Port: 8080, - RouteID: "resource5:peerL", - }, - } - assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) - - _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerM", resourcePoliciesMap, resourceRoutersMap) - assert.Len(t, routes, 1) - assert.Len(t, sourcePeers, 0) - - _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerN", resourcePoliciesMap, resourceRoutersMap) - assert.Len(t, routes, 1) - assert.Len(t, sourcePeers, 2) - }) } diff --git a/management/server/settings/manager.go b/management/server/settings/manager.go index 74af0a3ef..345d857f9 100644 --- a/management/server/settings/manager.go +++ b/management/server/settings/manager.go @@ -5,6 +5,7 @@ package settings import ( "context" "fmt" + "net/netip" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/extra_settings" @@ -22,6 +23,9 @@ type Manager interface { GetSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) GetExtraSettings(ctx context.Context, accountID string) (*types.ExtraSettings, error) UpdateExtraSettings(ctx context.Context, accountID, userID string, extraSettings *types.ExtraSettings) (bool, error) + // GetEffectiveNetworkRanges returns the actual allocated network ranges (v4 and v6). + // This includes auto-allocated ranges even when no custom override was set. + GetEffectiveNetworkRanges(ctx context.Context, accountID string) (v4, v6 netip.Prefix, err error) } // IdpConfig holds IdP-related configuration that is set at runtime @@ -115,3 +119,28 @@ func (m *managerImpl) GetExtraSettings(ctx context.Context, accountID string) (* func (m *managerImpl) UpdateExtraSettings(ctx context.Context, accountID, userID string, extraSettings *types.ExtraSettings) (bool, error) { return m.extraSettingsManager.UpdateExtraSettings(ctx, accountID, userID, extraSettings) } + +// GetEffectiveNetworkRanges returns the actual allocated network ranges from the account's network object. +func (m *managerImpl) GetEffectiveNetworkRanges(ctx context.Context, accountID string) (netip.Prefix, netip.Prefix, error) { + network, err := m.store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return netip.Prefix{}, netip.Prefix{}, fmt.Errorf("get account network: %w", err) + } + + var v4, v6 netip.Prefix + if network.Net.IP != nil { + addr, ok := netip.AddrFromSlice(network.Net.IP) + if ok { + ones, _ := network.Net.Mask.Size() + v4 = netip.PrefixFrom(addr.Unmap(), ones) + } + } + if network.NetV6.IP != nil { + addr, ok := netip.AddrFromSlice(network.NetV6.IP) + if ok { + ones, _ := network.NetV6.Mask.Size() + v6 = netip.PrefixFrom(addr.Unmap(), ones) + } + } + return v4, v6, nil +} diff --git a/management/server/settings/manager_mock.go b/management/server/settings/manager_mock.go index dc2f2ebfe..4bedb2cf7 100644 --- a/management/server/settings/manager_mock.go +++ b/management/server/settings/manager_mock.go @@ -6,6 +6,7 @@ package settings import ( context "context" + netip "net/netip" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -94,3 +95,19 @@ func (mr *MockManagerMockRecorder) UpdateExtraSettings(ctx, accountID, userID, e mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateExtraSettings", reflect.TypeOf((*MockManager)(nil).UpdateExtraSettings), ctx, accountID, userID, extraSettings) } + +// GetEffectiveNetworkRanges mocks base method. +func (m *MockManager) GetEffectiveNetworkRanges(ctx context.Context, accountID string) (netip.Prefix, netip.Prefix, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetEffectiveNetworkRanges", ctx, accountID) + ret0, _ := ret[0].(netip.Prefix) + ret1, _ := ret[1].(netip.Prefix) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetEffectiveNetworkRanges indicates an expected call of GetEffectiveNetworkRanges. +func (mr *MockManagerMockRecorder) GetEffectiveNetworkRanges(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEffectiveNetworkRanges", reflect.TypeOf((*MockManager)(nil).GetEffectiveNetworkRanges), ctx, accountID) +} diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 8189548b7..973101ce3 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "net" + "net/netip" "os" "path/filepath" "runtime" @@ -1196,7 +1197,6 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types account.NameServerGroups[ns.ID] = &ns } account.NameServerGroupsG = nil - account.InitOnce() return &account, nil } @@ -1504,7 +1504,7 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc SELECT id, created_by, created_at, domain, domain_category, is_domain_primary_account, -- Embedded Network - network_identifier, network_net, network_dns, network_serial, + network_identifier, network_net, network_net_v6, network_dns, network_serial, -- Embedded DNSSettings dns_settings_disabled_management_groups, -- Embedded Settings @@ -1513,7 +1513,7 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc settings_regular_users_view_blocked, settings_groups_propagation_enabled, settings_jwt_groups_enabled, settings_jwt_groups_claim_name, settings_jwt_allow_groups, settings_routing_peer_dns_resolution_enabled, settings_dns_domain, settings_network_range, - settings_lazy_connection_enabled, + settings_network_range_v6, settings_ipv6_enabled_groups, settings_lazy_connection_enabled, -- Embedded ExtraSettings settings_extra_peer_approval_enabled, settings_extra_user_approval_required, settings_extra_integrated_validator, settings_extra_integrated_validator_groups @@ -1532,12 +1532,15 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc sRoutingPeerDNSResolutionEnabled sql.NullBool sDNSDomain sql.NullString sNetworkRange sql.NullString + sNetworkRangeV6 sql.NullString + sIPv6EnabledGroups sql.NullString sLazyConnectionEnabled sql.NullBool sExtraPeerApprovalEnabled sql.NullBool sExtraUserApprovalRequired sql.NullBool sExtraIntegratedValidator sql.NullString sExtraIntegratedValidatorGroups sql.NullString networkNet sql.NullString + networkNetV6 sql.NullString dnsSettingsDisabledGroups sql.NullString networkIdentifier sql.NullString networkDns sql.NullString @@ -1546,14 +1549,14 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc ) err := s.pool.QueryRow(ctx, accountQuery, accountID).Scan( &account.Id, &account.CreatedBy, &createdAt, &account.Domain, &account.DomainCategory, &account.IsDomainPrimaryAccount, - &networkIdentifier, &networkNet, &networkDns, &networkSerial, + &networkIdentifier, &networkNet, &networkNetV6, &networkDns, &networkSerial, &dnsSettingsDisabledGroups, &sPeerLoginExpirationEnabled, &sPeerLoginExpiration, &sPeerInactivityExpirationEnabled, &sPeerInactivityExpiration, &sRegularUsersViewBlocked, &sGroupsPropagationEnabled, &sJWTGroupsEnabled, &sJWTGroupsClaimName, &sJWTAllowGroups, &sRoutingPeerDNSResolutionEnabled, &sDNSDomain, &sNetworkRange, - &sLazyConnectionEnabled, + &sNetworkRangeV6, &sIPv6EnabledGroups, &sLazyConnectionEnabled, &sExtraPeerApprovalEnabled, &sExtraUserApprovalRequired, &sExtraIntegratedValidator, &sExtraIntegratedValidatorGroups, ) @@ -1622,6 +1625,15 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc if sNetworkRange.Valid { _ = json.Unmarshal([]byte(sNetworkRange.String), &account.Settings.NetworkRange) } + if networkNetV6.Valid { + _ = json.Unmarshal([]byte(networkNetV6.String), &account.Network.NetV6) + } + if sNetworkRangeV6.Valid { + _ = json.Unmarshal([]byte(sNetworkRangeV6.String), &account.Settings.NetworkRangeV6) + } + if sIPv6EnabledGroups.Valid { + _ = json.Unmarshal([]byte(sIPv6EnabledGroups.String), &account.Settings.IPv6EnabledGroups) + } if sExtraPeerApprovalEnabled.Valid { account.Settings.Extra.PeerApprovalEnabled = sExtraPeerApprovalEnabled.Bool @@ -1635,7 +1647,6 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc if sExtraIntegratedValidatorGroups.Valid { _ = json.Unmarshal([]byte(sExtraIntegratedValidatorGroups.String), &account.Settings.Extra.IntegratedValidatorGroups) } - account.InitOnce() return &account, nil } @@ -1704,12 +1715,12 @@ func (s *SqlStore) getSetupKeys(ctx context.Context, accountID string) ([]types. func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Peer, error) { const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled, - inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, - meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, + inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, + meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, - meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired, - peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, - location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster FROM peers WHERE account_id = $1` + meta_environment, meta_flags, meta_files, meta_capabilities, peer_status_last_seen, peer_status_connected, peer_status_login_expired, + peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, + location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster, ipv6 FROM peers WHERE account_id = $1` rows, err := s.pool.Query(ctx, query, accountID) if err != nil { return nil, err @@ -1723,7 +1734,7 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool peerStatusLastSeen sql.NullTime peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval, proxyEmbedded sql.NullBool - ip, extraDNS, netAddr, env, flags, files, connIP []byte + ip, extraDNS, netAddr, env, flags, files, capabilities, connIP, ipv6 []byte metaHostname, metaGoOS, metaKernel, metaCore, metaPlatform sql.NullString metaOS, metaOSVersion, metaWtVersion, metaUIVersion, metaKernelVersion sql.NullString metaSystemSerialNumber, metaSystemProductName, metaSystemManufacturer sql.NullString @@ -1735,9 +1746,9 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee &loginExpirationEnabled, &inactivityExpirationEnabled, &lastLogin, &createdAt, &ephemeral, &extraDNS, &allowExtraDNSLabels, &metaHostname, &metaGoOS, &metaKernel, &metaCore, &metaPlatform, &metaOS, &metaOSVersion, &metaWtVersion, &metaUIVersion, &metaKernelVersion, &netAddr, - &metaSystemSerialNumber, &metaSystemProductName, &metaSystemManufacturer, &env, &flags, &files, + &metaSystemSerialNumber, &metaSystemProductName, &metaSystemManufacturer, &env, &flags, &files, &capabilities, &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, - &locationCountryCode, &locationCityName, &locationGeoNameID, &proxyEmbedded, &proxyCluster) + &locationCountryCode, &locationCityName, &locationGeoNameID, &proxyEmbedded, &proxyCluster, &ipv6) if err == nil { if lastLogin.Valid { @@ -1830,6 +1841,9 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee if ip != nil { _ = json.Unmarshal(ip, &p.IP) } + if ipv6 != nil { + _ = json.Unmarshal(ipv6, &p.IPv6) + } if extraDNS != nil { _ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels) } @@ -1845,6 +1859,9 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee if files != nil { _ = json.Unmarshal(files, &p.Meta.Files) } + if capabilities != nil { + _ = json.Unmarshal(capabilities, &p.Meta.Capabilities) + } if connIP != nil { _ = json.Unmarshal(connIP, &p.Location.ConnectionIP) } @@ -2588,7 +2605,7 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) return accountID, nil } -func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) { +func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]netip.Addr, error) { tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) @@ -2596,7 +2613,6 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength var ipJSONStrings []string - // Fetch the IP addresses as JSON strings result := tx.Model(&nbpeer.Peer{}). Where("account_id = ?", accountID). Pluck("ip", &ipJSONStrings) @@ -2607,14 +2623,13 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error) } - // Convert the JSON strings to net.IP objects - ips := make([]net.IP, len(ipJSONStrings)) + ips := make([]netip.Addr, len(ipJSONStrings)) for i, ipJSON := range ipJSONStrings { - var ip net.IP + var ip netip.Addr if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil { return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store") } - ips[i] = ip + ips[i] = ip.Unmap() } return ips, nil @@ -3216,7 +3231,7 @@ func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStre query = query.Where("name LIKE ?", "%"+nameFilter+"%") } if ipFilter != "" { - query = query.Where("ip LIKE ?", "%"+ipFilter+"%") + query = query.Where("ip LIKE ? OR ipv6 LIKE ?", "%"+ipFilter+"%", "%"+ipFilter+"%") } if err := query.Find(&peers).Error; err != nil { @@ -3310,7 +3325,7 @@ func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStreng var peers []*nbpeer.Peer result := tx. - Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true). + Where("login_expiration_enabled = ? AND peer_status_login_expired != ? AND user_id IS NOT NULL AND user_id != ''", true, true). Find(&peers, accountIDCondition, accountID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to get peers with expiration from the store: %s", result.Error) @@ -4092,9 +4107,10 @@ func (s *SqlStore) SaveAccountSettings(ctx context.Context, accountID string, se return status.Errorf(status.Internal, "failed to save account settings to store") } - if result.RowsAffected == 0 { - return status.NewAccountNotFoundError(accountID) - } + // MySQL reports RowsAffected=0 for no-op updates where values don't change, + // unlike SQLite/Postgres which report matched rows. Skip the check since the + // caller (UpdateAccountSettings) already verified the account exists via + // GetAccountSettings with LockingStrengthUpdate. return nil } @@ -4519,11 +4535,15 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } + column := "ip" + if ip.To4() == nil { + column = "ipv6" + } jsonValue := fmt.Sprintf(`"%s"`, ip.String()) var peer nbpeer.Peer result := tx. - Take(&peer, "account_id = ? AND ip = ?", accountID, jsonValue) + Take(&peer, fmt.Sprintf("account_id = ? AND %s = ?", column), accountID, jsonValue) if result.Error != nil { // no logging here return nil, status.Errorf(status.Internal, "failed to get peer from store") @@ -4645,6 +4665,27 @@ func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, i return nil } +// UpdateAccountNetworkV6 updates the IPv6 network range for the account. +func (s *SqlStore) UpdateAccountNetworkV6(ctx context.Context, accountID string, ipNet net.IPNet) error { + patch := accountNetworkPatch{ + Network: &types.Network{NetV6: ipNet}, + } + + result := s.db. + Model(&types.Account{}). + Where(idQueryCondition, accountID). + Updates(&patch) + + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to update account network v6: %v", result.Error) + return status.Errorf(status.Internal, "update account network v6") + } + if result.RowsAffected == 0 { + return status.NewAccountNotFoundError(accountID) + } + return nil +} + func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error) { if len(groupIDs) == 0 { return []*nbpeer.Peer{}, nil @@ -5439,13 +5480,35 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error { return nil } -// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy or creates a new entry if it doesn't exist -func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { +// DisconnectProxy marks a proxy as disconnected only if the session ID matches. +// This prevents a slow-to-close old session from overwriting a newer reconnection. +func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error { + now := time.Now() + result := s.db. + Model(&proxy.Proxy{}). + Where("id = ? AND session_id = ?", proxyID, sessionID). + Updates(map[string]any{ + "status": "disconnected", + "disconnected_at": now, + "last_seen": now, + }) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, result.Error) + return status.Errorf(status.Internal, "failed to disconnect proxy") + } + if result.RowsAffected == 0 { + log.WithContext(ctx).Debugf("proxy %s session %s: no row updated (superseded by newer session)", proxyID, sessionID) + } + return nil +} + +// UpdateProxyHeartbeat updates the last_seen timestamp for the proxy's current session. +func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error { now := time.Now() result := s.db. Model(&proxy.Proxy{}). - Where("id = ? AND status = ?", proxyID, "connected"). + Where("id = ? AND session_id = ?", p.ID, p.SessionID). Update("last_seen", now) if result.Error != nil { @@ -5454,17 +5517,11 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd } if result.RowsAffected == 0 { - p := &proxy.Proxy{ - ID: proxyID, - ClusterAddress: clusterAddress, - IPAddress: ipAddress, - LastSeen: now, - ConnectedAt: &now, - Status: "connected", - } - if err := s.db.Save(p).Error; err != nil { - log.WithContext(ctx).Errorf("failed to create proxy on heartbeat: %v", err) - return status.Errorf(status.Internal, "failed to create proxy on heartbeat") + p.LastSeen = now + p.ConnectedAt = &now + p.Status = "connected" + if err := s.db.Create(p).Error; err != nil { + log.WithContext(ctx).Debugf("proxy %s session %s: heartbeat fallback insert skipped: %v", p.ID, p.SessionID, err) } } diff --git a/management/server/store/sql_store_get_account_test.go b/management/server/store/sql_store_get_account_test.go index 69e346ae7..9a9de8cdd 100644 --- a/management/server/store/sql_store_get_account_test.go +++ b/management/server/store/sql_store_get_account_test.go @@ -148,7 +148,8 @@ func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) { AccountID: accountID, Key: "peer-key-1-AAAA", Name: "Peer 1", - IP: net.ParseIP("100.64.0.1"), + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), Meta: nbpeer.PeerSystemMeta{ Hostname: "peer1.example.com", GoOS: "linux", @@ -195,7 +196,8 @@ func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) { AccountID: accountID, Key: "peer-key-2-BBBB", Name: "Peer 2", - IP: net.ParseIP("100.64.0.2"), + IP: netip.MustParseAddr("100.64.0.2"), + IPv6: netip.MustParseAddr("fd00::2"), Meta: nbpeer.PeerSystemMeta{ Hostname: "peer2.example.com", GoOS: "darwin", @@ -232,7 +234,8 @@ func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) { AccountID: accountID, Key: "peer-key-3-CCCC", Name: "Peer 3 (Ephemeral)", - IP: net.ParseIP("100.64.0.3"), + IP: netip.MustParseAddr("100.64.0.3"), + IPv6: netip.MustParseAddr("fd00::3"), Meta: nbpeer.PeerSystemMeta{ Hostname: "peer3.example.com", GoOS: "windows", @@ -710,7 +713,7 @@ func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) { require.True(t, exists, "Peer 1 should exist") assert.Equal(t, "Peer 1", p1.Name, "Peer 1 name mismatch") assert.Equal(t, "peer-key-1-AAAA", p1.Key, "Peer 1 key mismatch") - assert.True(t, p1.IP.Equal(net.ParseIP("100.64.0.1")), "Peer 1 IP mismatch") + assert.Equal(t, netip.MustParseAddr("100.64.0.1"), p1.IP, "Peer 1 IP mismatch") assert.Equal(t, userID1, p1.UserID, "Peer 1 user ID mismatch") assert.True(t, p1.SSHEnabled, "Peer 1 SSH should be enabled") assert.Equal(t, "ssh-rsa AAAAB3NzaC1...", p1.SSHKey, "Peer 1 SSH key mismatch") diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 8ea6c2ae5..2819265c3 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -94,11 +94,12 @@ func runLargeTest(t *testing.T, store Store) { for n := 0; n < numPerAccount; n++ { netIP := randomIPv4() peerID := fmt.Sprintf("%s-peer-%d", account.Id, n) + addr, _ := netip.AddrFromSlice(netIP) peer := &nbpeer.Peer{ ID: peerID, Key: peerID, - IP: netIP, + IP: addr.Unmap(), Name: peerID, DNSLabel: peerID, UserID: "testuser", @@ -235,7 +236,8 @@ func Test_SaveAccount(t *testing.T) { account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", - IP: net.IP{127, 0, 0, 1}, + IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + IPv6: netip.MustParseAddr("fd00::1"), Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, @@ -249,7 +251,8 @@ func Test_SaveAccount(t *testing.T) { account2.SetupKeys[setupKey.Key] = setupKey account2.Peers["testpeer2"] = &nbpeer.Peer{ Key: "peerkey2", - IP: net.IP{127, 0, 0, 2}, + IP: netip.AddrFrom4([4]byte{127, 0, 0, 2}), + IPv6: netip.MustParseAddr("fd00::2"), Meta: nbpeer.PeerSystemMeta{}, Name: "peer name 2", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, @@ -316,7 +319,8 @@ func TestSqlite_DeleteAccount(t *testing.T) { account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", - IP: net.IP{127, 0, 0, 1}, + IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + IPv6: netip.MustParseAddr("fd00::1"), Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, @@ -499,7 +503,8 @@ func TestSqlStore_SavePeer(t *testing.T) { peer := &nbpeer.Peer{ Key: "peerkey", ID: "testpeer", - IP: net.IP{127, 0, 0, 1}, + IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + IPv6: netip.MustParseAddr("fd00::1"), Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"}, Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, @@ -556,7 +561,8 @@ func TestSqlStore_SavePeerStatus(t *testing.T) { account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", ID: "testpeer", - IP: net.IP{127, 0, 0, 1}, + IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + IPv6: netip.MustParseAddr("fd00::1"), Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, @@ -784,7 +790,8 @@ func newAccount(store Store, id int) error { account.SetupKeys[setupKey.Key] = setupKey account.Peers["p"+str] = &nbpeer.Peer{ Key: "peerkey" + str, - IP: net.IP{127, 0, 0, 1}, + IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + IPv6: netip.MustParseAddr("fd00::1"), Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, @@ -823,7 +830,8 @@ func TestPostgresql_SaveAccount(t *testing.T) { account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", - IP: net.IP{127, 0, 0, 1}, + IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + IPv6: netip.MustParseAddr("fd00::1"), Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, @@ -837,7 +845,8 @@ func TestPostgresql_SaveAccount(t *testing.T) { account2.SetupKeys[setupKey.Key] = setupKey account2.Peers["testpeer2"] = &nbpeer.Peer{ Key: "peerkey2", - IP: net.IP{127, 0, 0, 2}, + IP: netip.AddrFrom4([4]byte{127, 0, 0, 2}), + IPv6: netip.MustParseAddr("fd00::2"), Meta: nbpeer.PeerSystemMeta{}, Name: "peer name 2", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, @@ -903,7 +912,8 @@ func TestPostgresql_DeleteAccount(t *testing.T) { account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", - IP: net.IP{127, 0, 0, 1}, + IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + IPv6: netip.MustParseAddr("fd00::1"), Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, @@ -1010,37 +1020,39 @@ func TestSqlite_GetTakenIPs(t *testing.T) { takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthNone, existingAccountID) require.NoError(t, err) - assert.Equal(t, []net.IP{}, takenIPs) + assert.Equal(t, []netip.Addr{}, takenIPs) peer1 := &nbpeer.Peer{ ID: "peer1", AccountID: existingAccountID, Key: "key1", DNSLabel: "peer1", - IP: net.IP{1, 1, 1, 1}, + IP: netip.AddrFrom4([4]byte{1, 1, 1, 1}), + IPv6: netip.MustParseAddr("fd00::1:1:1:1"), } err = store.AddPeerToAccount(context.Background(), peer1) require.NoError(t, err) takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthNone, existingAccountID) require.NoError(t, err) - ip1 := net.IP{1, 1, 1, 1}.To16() - assert.Equal(t, []net.IP{ip1}, takenIPs) + ip1 := netip.AddrFrom4([4]byte{1, 1, 1, 1}) + assert.Equal(t, []netip.Addr{ip1}, takenIPs) peer2 := &nbpeer.Peer{ ID: "peer1second", AccountID: existingAccountID, Key: "key2", DNSLabel: "peer1-1", - IP: net.IP{2, 2, 2, 2}, + IP: netip.AddrFrom4([4]byte{2, 2, 2, 2}), + IPv6: netip.MustParseAddr("fd00::2:2:2:2"), } err = store.AddPeerToAccount(context.Background(), peer2) require.NoError(t, err) takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthNone, existingAccountID) require.NoError(t, err) - ip2 := net.IP{2, 2, 2, 2}.To16() - assert.Equal(t, []net.IP{ip1, ip2}, takenIPs) + ip2 := netip.AddrFrom4([4]byte{2, 2, 2, 2}) + assert.Equal(t, []netip.Addr{ip1, ip2}, takenIPs) } func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { @@ -1060,7 +1072,8 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { AccountID: existingAccountID, Key: "key1", DNSLabel: "peer1", - IP: net.IP{1, 1, 1, 1}, + IP: netip.AddrFrom4([4]byte{1, 1, 1, 1}), + IPv6: netip.MustParseAddr("fd00::1:1:1:1"), } err = store.AddPeerToAccount(context.Background(), peer1) require.NoError(t, err) @@ -1074,7 +1087,8 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { AccountID: existingAccountID, Key: "key2", DNSLabel: "peer1-1", - IP: net.IP{2, 2, 2, 2}, + IP: netip.AddrFrom4([4]byte{2, 2, 2, 2}), + IPv6: netip.MustParseAddr("fd00::2:2:2:2"), } err = store.AddPeerToAccount(context.Background(), peer2) require.NoError(t, err) @@ -1127,7 +1141,8 @@ func Test_AddPeerWithSameIP(t *testing.T) { ID: "peer1", AccountID: existingAccountID, Key: "key1", - IP: net.IP{1, 1, 1, 1}, + IP: netip.AddrFrom4([4]byte{1, 1, 1, 1}), + IPv6: netip.MustParseAddr("fd00::1:1:1:1"), } err = store.AddPeerToAccount(context.Background(), peer1) require.NoError(t, err) @@ -1136,7 +1151,8 @@ func Test_AddPeerWithSameIP(t *testing.T) { ID: "peer1second", AccountID: existingAccountID, Key: "key2", - IP: net.IP{1, 1, 1, 1}, + IP: netip.AddrFrom4([4]byte{1, 1, 1, 1}), + IPv6: netip.MustParseAddr("fd00::2:2:2:2"), } err = store.AddPeerToAccount(context.Background(), peer2) require.Error(t, err) @@ -2640,7 +2656,8 @@ func TestSqlStore_AddPeerToAccount(t *testing.T) { ID: "peer1", AccountID: accountID, Key: "key", - IP: net.IP{1, 1, 1, 1}, + IP: netip.AddrFrom4([4]byte{1, 1, 1, 1}), + IPv6: netip.MustParseAddr("fd00::1:1:1:1"), Meta: nbpeer.PeerSystemMeta{ Hostname: "hostname", GoOS: "linux", @@ -2729,7 +2746,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) { { name: "should retrieve peers for an existing account ID", accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", - expectedCount: 4, + expectedCount: 5, }, { name: "should return no peers for a non-existing account ID", @@ -2751,7 +2768,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) { name: "should filter peers by partial name", accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", nameFilter: "host", - expectedCount: 3, + expectedCount: 4, }, { name: "should filter peers by ip", @@ -2777,14 +2794,16 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) { require.NoError(t, err) tests := []struct { - name string - accountID string - expectedCount int + name string + accountID string + expectedCount int + expectedPeerIDs []string }{ { - name: "should retrieve peers with expiration for an existing account ID", - accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", - expectedCount: 1, + name: "should retrieve only non-expired peers with expiration enabled", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + expectedPeerIDs: []string{"notexpired01"}, }, { name: "should return no peers with expiration for a non-existing account ID", @@ -2803,10 +2822,30 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) { peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, tt.accountID) require.NoError(t, err) require.Len(t, peers, tt.expectedCount) + for i, peer := range peers { + assert.Equal(t, tt.expectedPeerIDs[i], peer.ID) + } }) } } +func TestSqlStore_GetAccountPeersWithExpiration_ExcludesAlreadyExpired(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, accountID) + require.NoError(t, err) + + // Verify the already-expired peer (cg05lnblo1hkg2j514p0) is not returned + for _, peer := range peers { + assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should not be returned") + assert.False(t, peer.Status.LoginExpired, "returned peers should not have LoginExpired set") + } +} + func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) { store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) t.Cleanup(cleanup) @@ -2887,7 +2926,7 @@ func TestSqlStore_GetUserPeers(t *testing.T) { name: "should retrieve peers for another valid account ID and user ID", accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", userID: "edafee4e-63fb-11ec-90d6-0242ac120003", - expectedCount: 2, + expectedCount: 3, }, { name: "should return no peers for existing account ID with empty user ID", @@ -3793,10 +3832,10 @@ func BenchmarkGetAccountPeers(b *testing.B) { } } -func intToIPv4(n uint32) net.IP { - ip := make(net.IP, 4) - binary.BigEndian.PutUint32(ip, n) - return ip +func intToIPv4(n uint32) netip.Addr { + var b [4]byte + binary.BigEndian.PutUint32(b[:], n) + return netip.AddrFrom4(b) } func TestSqlStore_GetPeersByGroupIDs(t *testing.T) { @@ -3923,7 +3962,8 @@ func TestSqlStore_GetUserIDByPeerKey(t *testing.T) { Key: peerKey, AccountID: existingAccountID, UserID: userID, - IP: net.IP{10, 0, 0, 1}, + IP: netip.AddrFrom4([4]byte{10, 0, 0, 1}), + IPv6: netip.MustParseAddr("fd00::a00:1"), DNSLabel: "test-peer-1", } @@ -3960,7 +4000,8 @@ func TestSqlStore_GetUserIDByPeerKey_NoUserID(t *testing.T) { Key: peerKey, AccountID: existingAccountID, UserID: "", - IP: net.IP{10, 0, 0, 1}, + IP: netip.AddrFrom4([4]byte{10, 0, 0, 1}), + IPv6: netip.MustParseAddr("fd00::a00:1"), DNSLabel: "test-peer-1", } @@ -3987,7 +4028,8 @@ func TestSqlStore_ApproveAccountPeers(t *testing.T) { AccountID: accountID, DNSLabel: "peer1.netbird.cloud", Key: "peer1-key", - IP: net.ParseIP("100.64.0.1"), + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), Status: &nbpeer.PeerStatus{ RequiresApproval: true, LastSeen: time.Now().UTC(), @@ -3998,7 +4040,8 @@ func TestSqlStore_ApproveAccountPeers(t *testing.T) { AccountID: accountID, DNSLabel: "peer2.netbird.cloud", Key: "peer2-key", - IP: net.ParseIP("100.64.0.2"), + IP: netip.MustParseAddr("100.64.0.2"), + IPv6: netip.MustParseAddr("fd00::2"), Status: &nbpeer.PeerStatus{ RequiresApproval: true, LastSeen: time.Now().UTC(), @@ -4009,7 +4052,8 @@ func TestSqlStore_ApproveAccountPeers(t *testing.T) { AccountID: accountID, DNSLabel: "peer3.netbird.cloud", Key: "peer3-key", - IP: net.ParseIP("100.64.0.3"), + IP: netip.MustParseAddr("100.64.0.3"), + IPv6: netip.MustParseAddr("fd00::3"), Status: &nbpeer.PeerStatus{ RequiresApproval: false, LastSeen: time.Now().UTC(), diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go index 81c4b33ae..a38b4a8c1 100644 --- a/management/server/store/sqlstore_bench_test.go +++ b/management/server/store/sqlstore_bench_test.go @@ -344,7 +344,8 @@ func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) { ID: fmt.Sprintf("peer-%d", i), AccountID: accountID, Key: fmt.Sprintf("peerkey-%d", i), - IP: net.ParseIP(fmt.Sprintf("100.64.0.%d", i+1)), + IP: netip.MustParseAddr(fmt.Sprintf("100.64.0.%d", i+1)), + IPv6: netip.MustParseAddr(fmt.Sprintf("fd00::%d", i+1)), Name: fmt.Sprintf("peer-name-%d", i), Status: &nbpeer.PeerStatus{Connected: i%2 == 0, LastSeen: time.Now()}, }) diff --git a/management/server/store/store.go b/management/server/store/store.go index 0d8b0678a..db98bc644 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -185,7 +185,7 @@ type Store interface { SaveNameServerGroup(ctx context.Context, nameServerGroup *dns.NameServerGroup) error DeleteNameServerGroup(ctx context.Context, accountID, nameServerGroupID string) error - GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) + GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]netip.Addr, error) IncrementNetworkSerial(ctx context.Context, accountId string) error GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*types.Network, error) @@ -225,6 +225,7 @@ type Store interface { IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) MarkAccountPrimary(ctx context.Context, accountID string) error UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error + UpdateAccountNetworkV6(ctx context.Context, accountID string, ipNet net.IPNet) error GetPolicyRulesByResourceID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) ([]*types.PolicyRule, error) // SetFieldEncrypt sets the field encryptor for encrypting sensitive user data. @@ -284,7 +285,8 @@ type Store interface { DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error SaveProxy(ctx context.Context, proxy *proxy.Proxy) error - UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error + DisconnectProxy(ctx context.Context, proxyID, sessionID string) error + UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index beee13d96..6c2c9bbc3 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -7,6 +7,7 @@ package store import ( context "context" net "net" + netip "net/netip" reflect "reflect" time "time" @@ -178,6 +179,7 @@ func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr int mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr) } + // Close mocks base method. func (m *MockStore) Close(ctx context.Context) error { m.ctrl.T.Helper() @@ -2137,10 +2139,10 @@ func (mr *MockStoreMockRecorder) GetStoreEngine() *gomock.Call { } // GetTakenIPs mocks base method. -func (m *MockStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) { +func (m *MockStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]netip.Addr, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetTakenIPs", ctx, lockStrength, accountId) - ret0, _ := ret[0].([]net.IP) + ret0, _ := ret[0].([]netip.Addr) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -2799,6 +2801,20 @@ func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy) } +// DisconnectProxy mocks base method. +func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID, sessionID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DisconnectProxy indicates an expected call of DisconnectProxy. +func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID, sessionID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID, sessionID) +} + // SaveProxyAccessToken mocks base method. func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error { m.ctrl.T.Helper() @@ -2937,6 +2953,20 @@ func (mr *MockStoreMockRecorder) UpdateAccountNetwork(ctx, accountID, ipNet inte return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountNetwork", reflect.TypeOf((*MockStore)(nil).UpdateAccountNetwork), ctx, accountID, ipNet) } +// UpdateAccountNetworkV6 mocks base method. +func (m *MockStore) UpdateAccountNetworkV6(ctx context.Context, accountID string, ipNet net.IPNet) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateAccountNetworkV6", ctx, accountID, ipNet) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateAccountNetworkV6 indicates an expected call of UpdateAccountNetworkV6. +func (mr *MockStoreMockRecorder) UpdateAccountNetworkV6(ctx, accountID, ipNet interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountNetworkV6", reflect.TypeOf((*MockStore)(nil).UpdateAccountNetworkV6), ctx, accountID, ipNet) +} + // UpdateCustomDomain mocks base method. func (m *MockStore) UpdateCustomDomain(ctx context.Context, accountID string, d *domain.Domain) (*domain.Domain, error) { m.ctrl.T.Helper() @@ -2995,17 +3025,17 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{} } // UpdateProxyHeartbeat mocks base method. -func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { +func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID, clusterAddress, ipAddress) + ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, p) ret0, _ := ret[0].(error) return ret0 } // UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat. -func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, p interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID, clusterAddress, ipAddress) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, p) } // UpdateService mocks base method. diff --git a/management/server/telemetry/accountmanager_metrics.go b/management/server/telemetry/accountmanager_metrics.go index 3b1e078eb..518aae7eb 100644 --- a/management/server/telemetry/accountmanager_metrics.go +++ b/management/server/telemetry/accountmanager_metrics.go @@ -4,6 +4,7 @@ import ( "context" "time" + "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" ) @@ -11,6 +12,7 @@ import ( type AccountManagerMetrics struct { ctx context.Context updateAccountPeersDurationMs metric.Float64Histogram + updateAccountPeersCounter metric.Int64Counter getPeerNetworkMapDurationMs metric.Float64Histogram networkMapObjectCount metric.Int64Histogram peerMetaUpdateCount metric.Int64Counter @@ -48,6 +50,13 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account return nil, err } + updateAccountPeersCounter, err := meter.Int64Counter("management.account.update.account.peers.counter", + metric.WithUnit("1"), + metric.WithDescription("Number of account peers updates triggered, labeled by resource and operation")) + if err != nil { + return nil, err + } + peerMetaUpdateCount, err := meter.Int64Counter("management.account.peer.meta.update.counter", metric.WithUnit("1"), metric.WithDescription("Number of updates with new meta data from the peers")) @@ -59,6 +68,7 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account ctx: ctx, getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs, updateAccountPeersDurationMs: updateAccountPeersDurationMs, + updateAccountPeersCounter: updateAccountPeersCounter, networkMapObjectCount: networkMapObjectCount, peerMetaUpdateCount: peerMetaUpdateCount, }, nil @@ -80,6 +90,16 @@ func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) { metrics.networkMapObjectCount.Record(metrics.ctx, count) } +// CountUpdateAccountPeersTriggered increments the counter for account peers updates with resource and operation labels. +func (metrics *AccountManagerMetrics) CountUpdateAccountPeersTriggered(resource, operation string) { + metrics.updateAccountPeersCounter.Add(metrics.ctx, 1, + metric.WithAttributes( + attribute.String("resource", resource), + attribute.String("operation", operation), + ), + ) +} + // CountPeerMetUpdate counts the number of peer meta updates func (metrics *AccountManagerMetrics) CountPeerMetUpdate() { metrics.peerMetaUpdateCount.Add(metrics.ctx, 1) diff --git a/management/server/telemetry/http_api_metrics.go b/management/server/telemetry/http_api_metrics.go index 28e8457e2..e48e6d64a 100644 --- a/management/server/telemetry/http_api_metrics.go +++ b/management/server/telemetry/http_api_metrics.go @@ -193,20 +193,12 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler { } }) - h.ServeHTTP(w, r.WithContext(ctx)) + // Hold on to req so auth's in-place ctx update is visible after ServeHTTP. + req := r.WithContext(ctx) + h.ServeHTTP(w, req) close(handlerDone) - userAuth, err := nbContext.GetUserAuthFromContext(r.Context()) - if err == nil { - if userAuth.AccountId != "" { - //nolint - ctx = context.WithValue(ctx, nbContext.AccountIDKey, userAuth.AccountId) - } - if userAuth.UserId != "" { - //nolint - ctx = context.WithValue(ctx, nbContext.UserIDKey, userAuth.UserId) - } - } + ctx = req.Context() if w.Status() > 399 { log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status()) diff --git a/management/server/testdata/store_with_expired_peers.sql b/management/server/testdata/store_with_expired_peers.sql index dfcaeee6f..189bd1262 100644 --- a/management/server/testdata/store_with_expired_peers.sql +++ b/management/server/testdata/store_with_expired_peers.sql @@ -31,6 +31,7 @@ INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-3465300 INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','nVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HX=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('notexpired01','bf1c8084-ba50-4ce7-9439-34653001fc3b','oVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HY=','','"100.64.117.98"','activehost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'activehost','activehost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,1,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO installations VALUES(1,''); diff --git a/management/server/types/account.go b/management/server/types/account.go index c448813db..49600163a 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -3,12 +3,10 @@ package types import ( "context" "fmt" - "net" "net/netip" "slices" "strconv" "strings" - "sync" "time" "github.com/hashicorp/go-multierror" @@ -27,7 +25,6 @@ import ( networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" @@ -110,16 +107,9 @@ type Account struct { NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"` Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"` - NetworkMapCache *NetworkMapBuilder `gorm:"-"` - nmapInitOnce *sync.Once `gorm:"-"` - ReverseProxyFreeDomainNonce string } -func (a *Account) InitOnce() { - a.nmapInitOnce = &sync.Once{} -} - // this class is used by gorm only type PrimaryAccountInfo struct { IsDomainPrimaryAccount bool @@ -155,108 +145,6 @@ func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool { o.SignupFormPending == onboarding.SignupFormPending } -// GetRoutesToSync returns the enabled routes for the peer ID and the routes -// from the ACL peers that have distribution groups associated with the peer ID. -// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. -func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route { - routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID) - peerRoutesMembership := make(LookupMap) - for _, r := range append(routes, peerDisabledRoutes...) { - peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} - } - - for _, peer := range aclPeers { - activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID) - groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, peerGroups) - filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) - routes = append(routes, filteredRoutes...) - } - - return routes -} - -// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership -func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route { - var filteredRoutes []*route.Route - for _, r := range routes { - _, found := peerMemberships[string(r.GetHAUniqueID())] - if !found { - filteredRoutes = append(filteredRoutes, r) - } - } - return filteredRoutes -} - -// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map -func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route { - var filteredRoutes []*route.Route - for _, r := range routes { - for _, groupID := range r.Groups { - _, found := groupListMap[groupID] - if found { - filteredRoutes = append(filteredRoutes, r) - break - } - } - } - return filteredRoutes -} - -// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves -// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. -// If the given is not a routing peer, then the lists are empty. -func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { - - peer := a.GetPeer(peerID) - if peer == nil { - log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id) - return enabledRoutes, disabledRoutes - } - - seenRoute := make(map[route.ID]struct{}) - - takeRoute := func(r *route.Route, id string) { - if _, ok := seenRoute[r.ID]; ok { - return - } - seenRoute[r.ID] = struct{}{} - - if r.Enabled { - r.Peer = peer.Key - enabledRoutes = append(enabledRoutes, r) - return - } - disabledRoutes = append(disabledRoutes, r) - } - - for _, r := range a.Routes { - for _, groupID := range r.PeerGroups { - group := a.GetGroup(groupID) - if group == nil { - log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id) - continue - } - for _, id := range group.Peers { - if id != peerID { - continue - } - - newPeerRoute := r.Copy() - newPeerRoute.Peer = id - newPeerRoute.PeerGroups = nil - newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map - takeRoute(newPeerRoute, id) - break - } - } - if r.Peer == peerID { - takeRoute(r.Copy(), peerID) - } - } - - return enabledRoutes, disabledRoutes -} - // GetRoutesByPrefixOrDomains return list of routes by account and route prefix func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route { var routes []*route.Route @@ -276,106 +164,6 @@ func (a *Account) GetGroup(groupID string) *Group { return a.Groups[groupID] } -// GetPeerNetworkMap returns the networkmap for the given peer ID. -func (a *Account) GetPeerNetworkMap( - ctx context.Context, - peerID string, - peersCustomZone nbdns.CustomZone, - accountZones []*zones.Zone, - validatedPeersMap map[string]struct{}, - resourcePolicies map[string][]*Policy, - routers map[string]map[string]*routerTypes.NetworkRouter, - metrics *telemetry.AccountManagerMetrics, - groupIDToUserIDs map[string][]string, -) *NetworkMap { - start := time.Now() - peer := a.Peers[peerID] - if peer == nil { - return &NetworkMap{ - Network: a.Network.Copy(), - } - } - - if _, ok := validatedPeersMap[peerID]; !ok { - return &NetworkMap{ - Network: a.Network.Copy(), - } - } - - peerGroups := a.GetPeerGroups(peerID) - - aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs) - // exclude expired peers - var peersToConnect []*nbpeer.Peer - var expiredPeers []*nbpeer.Peer - for _, p := range aclPeers { - expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration) - if a.Settings.PeerLoginExpirationEnabled && expired { - expiredPeers = append(expiredPeers, p) - continue - } - peersToConnect = append(peersToConnect, p) - } - - routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect, peerGroups) - routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap) - isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers) - var networkResourcesFirewallRules []*RouteFirewallRule - if isRouter { - networkResourcesFirewallRules = a.GetPeerNetworkResourceFirewallRules(ctx, peer, validatedPeersMap, networkResourcesRoutes, resourcePolicies) - } - peersToConnectIncludingRouters := a.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, isRouter, sourcePeers) - - dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) - dnsUpdate := nbdns.Config{ - ServiceEnable: dnsManagementStatus, - } - - if dnsManagementStatus { - var zones []nbdns.CustomZone - - if peersCustomZone.Domain != "" { - records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers) - zones = append(zones, nbdns.CustomZone{ - Domain: peersCustomZone.Domain, - Records: records, - }) - } - - filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups) - zones = append(zones, filteredAccountZones...) - - dnsUpdate.CustomZones = zones - dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) - } - - nm := &NetworkMap{ - Peers: peersToConnectIncludingRouters, - Network: a.Network.Copy(), - Routes: slices.Concat(networkResourcesRoutes, routesUpdate), - DNSConfig: dnsUpdate, - OfflinePeers: expiredPeers, - FirewallRules: firewallRules, - RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules), - AuthorizedUsers: authorizedUsers, - EnableSSH: enableSSH, - } - - if metrics != nil { - objectCount := int64(len(peersToConnectIncludingRouters) + len(expiredPeers) + len(routesUpdate) + len(networkResourcesRoutes) + len(firewallRules) + +len(networkResourcesFirewallRules) + len(routesFirewallRules)) - metrics.CountNetworkMapObjects(objectCount) - metrics.CountGetPeerNetworkMapDuration(time.Since(start)) - - if objectCount > 5000 { - log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+ - "peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d, network resources routes: %d, network resources firewall rules: %d, routes firewall rules: %d", - a.Id, objectCount, len(peersToConnectIncludingRouters), len(expiredPeers), len(routesUpdate), len(firewallRules), len(networkResourcesRoutes), len(networkResourcesFirewallRules), len(routesFirewallRules)) - } - } - - return nm -} - func (a *Account) addNetworksRoutingPeers( networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, @@ -421,39 +209,6 @@ func (a *Account) addNetworksRoutingPeers( return peersToConnect } -func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup { - groupList := account.GetPeerGroups(peerID) - - var peerNSGroups []*nbdns.NameServerGroup - - for _, nsGroup := range account.NameServerGroups { - if !nsGroup.Enabled { - continue - } - for _, gID := range nsGroup.Groups { - _, found := groupList[gID] - if found { - if !peerIsNameserver(account.GetPeer(peerID), nsGroup) { - peerNSGroups = append(peerNSGroups, nsGroup.Copy()) - break - } - } - } - } - - return peerNSGroups -} - -// peerIsNameserver returns true if the peer is a nameserver for a nsGroup -func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool { - for _, ns := range nsGroup.NameServers { - if peer.IP.Equal(ns.IP.AsSlice()) { - return true - } - } - return false -} - func AddPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels LookupMap) { for _, peer := range account.Peers { label, err := GetPeerHostLabel(peer.Name, peerLabels) @@ -514,6 +269,8 @@ func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdn domainSuffix := "." + dnsDomain + ipv6AllowedPeers := a.peerIPv6AllowedSet() + var sb strings.Builder for _, peer := range a.Peers { if peer.DNSLabel == "" { @@ -525,13 +282,31 @@ func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdn sb.WriteString(peer.DNSLabel) sb.WriteString(domainSuffix) + fqdn := sb.String() customZone.Records = append(customZone.Records, nbdns.SimpleRecord{ - Name: sb.String(), + Name: fqdn, Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: defaultTTL, RData: peer.IP.String(), }) + // Only advertise AAAA for peers that have a valid IPv6, whose client supports it, + // and that belong to an IPv6-enabled group. Old clients don't configure v6 on their + // WireGuard interface, so resolving their AAAA causes connections to hang. + // Capability changes (client upgrade/downgrade, --disable-ipv6 toggle) propagate + // to other peers via SyncPeer/LoginPeer regardless of version change, so AAAA + // records refresh when a peer first reports the IPv6 overlay capability. + _, peerAllowed := ipv6AllowedPeers[peer.ID] + hasIPv6 := peer.IPv6.IsValid() && peer.SupportsIPv6() && peerAllowed + if hasIPv6 { + customZone.Records = append(customZone.Records, nbdns.SimpleRecord{ + Name: fqdn, + Type: int(dns.TypeAAAA), + Class: nbdns.DefaultClass, + TTL: defaultTTL, + RData: peer.IPv6.String(), + }) + } sb.Reset() for _, extraLabel := range peer.ExtraDNSLabels { @@ -539,13 +314,23 @@ func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdn sb.WriteString(extraLabel) sb.WriteString(domainSuffix) + extraFqdn := sb.String() customZone.Records = append(customZone.Records, nbdns.SimpleRecord{ - Name: sb.String(), + Name: extraFqdn, Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: defaultTTL, RData: peer.IP.String(), }) + if hasIPv6 { + customZone.Records = append(customZone.Records, nbdns.SimpleRecord{ + Name: extraFqdn, + Type: int(dns.TypeAAAA), + Class: nbdns.DefaultClass, + TTL: defaultTTL, + RData: peer.IPv6.String(), + }) + } sb.Reset() } @@ -800,19 +585,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string { return grps } -func (a *Account) getPeerDNSManagementStatus(peerID string) bool { - peerGroups := a.GetPeerGroups(peerID) - enabled := true - for _, groupID := range a.DNSSettings.DisabledManagementGroups { - _, found := peerGroups[groupID] - if found { - enabled = false - break - } - } - return enabled -} - func (a *Account) GetPeerGroups(peerID string) LookupMap { groupList := make(LookupMap) for groupID, group := range a.Groups { @@ -826,8 +598,43 @@ func (a *Account) GetPeerGroups(peerID string) LookupMap { return groupList } -func (a *Account) GetTakenIPs() []net.IP { - var takenIps []net.IP +// PeerIPv6Allowed reports whether the given peer is in any of the account's IPv6 enabled groups. +// Returns false if IPv6 is disabled or no groups are configured. +func (a *Account) PeerIPv6Allowed(peerID string) bool { + if len(a.Settings.IPv6EnabledGroups) == 0 { + return false + } + + for _, groupID := range a.Settings.IPv6EnabledGroups { + group, ok := a.Groups[groupID] + if !ok { + continue + } + if slices.Contains(group.Peers, peerID) { + return true + } + } + return false +} + +// peerIPv6AllowedSet returns a set of peer IDs that belong to any IPv6-enabled group. +func (a *Account) peerIPv6AllowedSet() map[string]struct{} { + result := make(map[string]struct{}) + for _, groupID := range a.Settings.IPv6EnabledGroups { + group, ok := a.Groups[groupID] + if !ok { + continue + } + for _, peerID := range group.Peers { + result[peerID] = struct{}{} + } + } + return result +} + +// GetTakenIPs returns all peer IP addresses currently allocated in the account. +func (a *Account) GetTakenIPs() []netip.Addr { + takenIps := make([]netip.Addr, 0, len(a.Peers)) for _, existingPeer := range a.Peers { takenIps = append(takenIps, existingPeer.IP) } @@ -941,8 +748,6 @@ func (a *Account) Copy() *Account { NetworkResources: networkResources, Services: services, Onboarding: a.Onboarding, - NetworkMapCache: a.NetworkMapCache, - nmapInitOnce: a.nmapInitOnce, Domains: domains, } } @@ -1186,10 +991,17 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 { rules = append(rules, &fr) - continue + } else { + rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...) } - rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...) + rules = appendIPv6FirewallRule(rules, rulesExists, peer, targetPeer, rule, firewallRuleContext{ + direction: direction, + dirStr: strconv.Itoa(direction), + protocolStr: string(protocol), + actionStr: string(rule.Action), + portsJoined: strings.Join(rule.Ports, ","), + }) } }, func() ([]*nbpeer.Peer, []*FirewallRule) { return peers, rules @@ -1304,32 +1116,7 @@ func (a *Account) GetPostureChecks(postureChecksID string) *posture.Checks { return nil } -// GetPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account. -func (a *Account) GetPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule { - routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes)) - - enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID) - for _, route := range enabledRoutes { - // If no access control groups are specified, accept all traffic. - if len(route.AccessControlGroups) == 0 { - defaultPermit := getDefaultPermit(route) - routesFirewallRules = append(routesFirewallRules, defaultPermit...) - continue - } - - distributionPeers := a.getDistributionGroupsPeers(route) - - for _, accessGroup := range route.AccessControlGroups { - policies := GetAllRoutePoliciesFromGroups(a, []string{accessGroup}) - rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers) - routesFirewallRules = append(routesFirewallRules, rules...) - } - } - - return routesFirewallRules -} - -func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule { +func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}, includeIPv6 bool) []*RouteFirewallRule { var fwRules []*RouteFirewallRule for _, policy := range policies { if !policy.Enabled { @@ -1342,7 +1129,7 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli } rulePeers := a.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap) - rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) + rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN, includeIPv6) fwRules = append(fwRules, rules...) } } @@ -1387,50 +1174,6 @@ func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID return distributionGroupPeers } -func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} { - distPeers := make(map[string]struct{}) - for _, id := range route.Groups { - group := a.Groups[id] - if group == nil { - continue - } - - for _, pID := range group.Peers { - distPeers[pID] = struct{}{} - } - } - return distPeers -} - -func getDefaultPermit(route *route.Route) []*RouteFirewallRule { - var rules []*RouteFirewallRule - - sources := []string{"0.0.0.0/0"} - if route.Network.Addr().Is6() { - sources = []string{"::/0"} - } - rule := RouteFirewallRule{ - SourceRanges: sources, - Action: string(PolicyTrafficActionAccept), - Destination: route.Network.String(), - Protocol: string(PolicyRuleProtocolALL), - Domains: route.Domains, - IsDynamic: route.IsDynamic(), - RouteID: route.ID, - } - - rules = append(rules, &rule) - - // dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally - if route.IsDynamic() { - ruleV6 := rule - ruleV6.SourceRanges = []string{"::/0"} - rules = append(rules, &ruleV6) - } - - return rules -} - // GetAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups // and returns a list of policies that have rules with destinations matching the specified groups. func GetAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy { @@ -1468,7 +1211,7 @@ func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer resourceAppliedPolicies := resourcePolicies[string(route.GetResourceID())] distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups) - rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers) + rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers, peer.SupportsIPv6() && peer.IPv6.IsValid()) for _, rule := range rules { if len(rule.SourceRanges) > 0 { routesFirewallRules = append(routesFirewallRules, rule) @@ -1508,65 +1251,6 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy { return resourcePolicies } -// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers. -func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) { - var isRoutingPeer bool - var routes []*route.Route - allSourcePeers := make(map[string]struct{}, len(a.Peers)) - - for _, resource := range a.NetworkResources { - if !resource.Enabled { - continue - } - - var addSourcePeers bool - - networkRoutingPeers, exists := routers[resource.NetworkID] - if exists { - if router, ok := networkRoutingPeers[peerID]; ok { - isRoutingPeer, addSourcePeers = true, true - routes = append(routes, a.getNetworkResourcesRoutes(resource, peerID, router, resourcePolicies)...) - } - } - - addedResourceRoute := false - for _, policy := range resourcePolicies[resource.ID] { - var peers []string - if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { - peers = []string{policy.Rules[0].SourceResource.ID} - } else { - peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) - } - if addSourcePeers { - for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) { - allSourcePeers[pID] = struct{}{} - } - } else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { - // add routes for the resource if the peer is in the distribution group - for peerId, router := range networkRoutingPeers { - routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...) - } - addedResourceRoute = true - } - if addedResourceRoute { - break - } - } - } - - return isRoutingPeer, routes, allSourcePeers -} - -func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string { - var dest []string - for _, peerID := range inputPeers { - if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) { - dest = append(dest, peerID) - } - } - return dest -} - func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string { peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity for _, groupID := range groups { @@ -1658,22 +1342,6 @@ func (a *Account) GetPoliciesAppliedInNetwork(networkID string) []string { return result } -// getNetworkResourcesRoutes convert the network resources list to routes list. -func (a *Account) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerId string, router *routerTypes.NetworkRouter, resourcePolicies map[string][]*Policy) []*route.Route { - resourceAppliedPolicies := resourcePolicies[resource.ID] - - var routes []*route.Route - // distribute the resource routes only if there is policy applied to it - if len(resourceAppliedPolicies) > 0 { - peer := a.GetPeer(peerId) - if peer != nil { - routes = append(routes, resource.ToRoute(peer, router)) - } - } - - return routes -} - func (a *Account) GetResourceRoutersMap() map[string]map[string]*routerTypes.NetworkRouter { routers := make(map[string]map[string]*routerTypes.NetworkRouter) @@ -1998,24 +1666,32 @@ func peerSupportedFirewallFeatures(peerVer string) supportedFeatures { } // filterZoneRecordsForPeers filters DNS records to only include peers to connect. +// AAAA records are excluded when the requesting peer lacks IPv6 capability. func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect, expiredPeers []*nbpeer.Peer) []nbdns.SimpleRecord { filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records)) - peerIPs := make(map[string]struct{}) + peerIPs := make(map[netip.Addr]struct{}, len(peersToConnect)+len(expiredPeers)+2) + includeIPv6 := peer.SupportsIPv6() && peer.IPv6.IsValid() - // Add peer's own IP to include its own DNS records - peerIPs[peer.IP.String()] = struct{}{} - - for _, peerToConnect := range peersToConnect { - peerIPs[peerToConnect.IP.String()] = struct{}{} + addPeerIPs := func(p *nbpeer.Peer) { + peerIPs[p.IP] = struct{}{} + if includeIPv6 && p.IPv6.IsValid() { + peerIPs[p.IPv6] = struct{}{} + } } - for _, expiredPeer := range expiredPeers { - peerIPs[expiredPeer.IP.String()] = struct{}{} + addPeerIPs(peer) + for _, p := range peersToConnect { + addPeerIPs(p) + } + for _, p := range expiredPeers { + addPeerIPs(p) } for _, record := range customZone.Records { - if _, exists := peerIPs[record.RData]; exists { - filteredRecords = append(filteredRecords, record) + if addr, err := netip.ParseAddr(record.RData); err == nil { + if _, exists := peerIPs[addr.Unmap()]; exists { + filteredRecords = append(filteredRecords, record) + } } } diff --git a/management/server/types/account_components.go b/management/server/types/account_components.go index bd4244546..2b4f7e051 100644 --- a/management/server/types/account_components.go +++ b/management/server/types/account_components.go @@ -115,7 +115,7 @@ func (a *Account) GetPeerNetworkMapComponents( components.Groups = relevantGroups components.Policies = relevantPolicies components.Routes = relevantRoutes - components.AllDNSRecords = filterDNSRecordsByPeers(peersCustomZone.Records, relevantPeers) + components.AllDNSRecords = filterDNSRecordsByPeers(peersCustomZone.Records, relevantPeers, peer.SupportsIPv6() && peer.IPv6.IsValid()) peerGroups := a.GetPeerGroups(peerID) components.AccountZones = filterPeerAppliedZones(ctx, accountZones, peerGroups) @@ -539,15 +539,22 @@ func filterPostureFailedPeers(postureFailedPeers *map[string]map[string]struct{} } } -func filterDNSRecordsByPeers(records []nbdns.SimpleRecord, peers map[string]*nbpeer.Peer) []nbdns.SimpleRecord { +func filterDNSRecordsByPeers(records []nbdns.SimpleRecord, peers map[string]*nbpeer.Peer, includeIPv6 bool) []nbdns.SimpleRecord { if len(records) == 0 || len(peers) == 0 { return nil } - peerIPs := make(map[string]struct{}, len(peers)) + // Include both v4 and v6 addresses so AAAA records (whose RData is an IPv6 + // address) are not filtered out when peers have IPv6 assigned. When the + // requesting peer doesn't have IPv6, omit v6 IPs so AAAA records get dropped. + peerIPs := make(map[string]struct{}, len(peers)*2) for _, peer := range peers { - if peer != nil { - peerIPs[peer.IP.String()] = struct{}{} + if peer == nil { + continue + } + peerIPs[peer.IP.String()] = struct{}{} + if includeIPv6 && peer.IPv6.IsValid() { + peerIPs[peer.IPv6.String()] = struct{}{} } } diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index 00ba29b7f..a1a616882 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -3,9 +3,7 @@ package types import ( "context" "fmt" - "net" "net/netip" - "slices" "testing" "github.com/miekg/dns" @@ -19,7 +17,6 @@ import ( routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/route" ) @@ -451,402 +448,6 @@ func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) { require.Len(t, result, 0) } -const ( - accID = "accountID" - network1ID = "network1ID" - group1ID = "group1" - accNetResourcePeer1ID = "peer1" - accNetResourcePeer2ID = "peer2" - accNetResourceRouter1ID = "router1" - accNetResource1ID = "resource1ID" - accNetResourceRestrictPostureCheckID = "restrictPostureCheck" - accNetResourceRelaxedPostureCheckID = "relaxedPostureCheck" - accNetResourceLockedPostureCheckID = "lockedPostureCheck" - accNetResourceLinuxPostureCheckID = "linuxPostureCheck" -) - -var ( - accNetResourcePeer1IP = net.IP{192, 168, 1, 1} - accNetResourcePeer2IP = net.IP{192, 168, 1, 2} - accNetResourceRouter1IP = net.IP{192, 168, 1, 3} - accNetResourceValidPeers = map[string]struct{}{accNetResourcePeer1ID: {}, accNetResourcePeer2ID: {}} -) - -func getBasicAccountsWithResource() *Account { - return &Account{ - Id: accID, - Peers: map[string]*nbpeer.Peer{ - accNetResourcePeer1ID: { - ID: accNetResourcePeer1ID, - AccountID: accID, - Key: "peer1Key", - IP: accNetResourcePeer1IP, - Meta: nbpeer.PeerSystemMeta{ - GoOS: "linux", - WtVersion: "0.35.1", - KernelVersion: "4.4.0", - }, - }, - accNetResourcePeer2ID: { - ID: accNetResourcePeer2ID, - AccountID: accID, - Key: "peer2Key", - IP: accNetResourcePeer2IP, - Meta: nbpeer.PeerSystemMeta{ - GoOS: "windows", - WtVersion: "0.34.1", - KernelVersion: "4.4.0", - }, - }, - accNetResourceRouter1ID: { - ID: accNetResourceRouter1ID, - AccountID: accID, - Key: "router1Key", - IP: accNetResourceRouter1IP, - Meta: nbpeer.PeerSystemMeta{ - GoOS: "linux", - WtVersion: "0.35.1", - KernelVersion: "4.4.0", - }, - }, - }, - Groups: map[string]*Group{ - group1ID: { - ID: group1ID, - Peers: []string{accNetResourcePeer1ID, accNetResourcePeer2ID}, - }, - }, - Networks: []*networkTypes.Network{ - { - ID: network1ID, - AccountID: accID, - Name: "network1", - }, - }, - NetworkRouters: []*routerTypes.NetworkRouter{ - { - ID: accNetResourceRouter1ID, - NetworkID: network1ID, - AccountID: accID, - Peer: accNetResourceRouter1ID, - PeerGroups: []string{}, - Masquerade: false, - Metric: 100, - Enabled: true, - }, - }, - NetworkResources: []*resourceTypes.NetworkResource{ - { - ID: accNetResource1ID, - AccountID: accID, - NetworkID: network1ID, - Address: "10.10.10.0/24", - Prefix: netip.MustParsePrefix("10.10.10.0/24"), - Type: resourceTypes.NetworkResourceType("subnet"), - Enabled: true, - }, - }, - Policies: []*Policy{ - { - ID: "policy1ID", - AccountID: accID, - Enabled: true, - Rules: []*PolicyRule{ - { - ID: "rule1ID", - Enabled: true, - Sources: []string{group1ID}, - DestinationResource: Resource{ - ID: accNetResource1ID, - Type: "Host", - }, - Protocol: PolicyRuleProtocolTCP, - Ports: []string{"80"}, - Action: PolicyTrafficActionAccept, - }, - }, - SourcePostureChecks: nil, - }, - }, - PostureChecks: []*posture.Checks{ - { - ID: accNetResourceRestrictPostureCheckID, - Name: accNetResourceRestrictPostureCheckID, - Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{ - MinVersion: "0.35.0", - }, - }, - }, - { - ID: accNetResourceRelaxedPostureCheckID, - Name: accNetResourceRelaxedPostureCheckID, - Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{ - MinVersion: "0.0.1", - }, - }, - }, - { - ID: accNetResourceLockedPostureCheckID, - Name: accNetResourceLockedPostureCheckID, - Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{ - MinVersion: "7.7.7", - }, - }, - }, - { - ID: accNetResourceLinuxPostureCheckID, - Name: accNetResourceLinuxPostureCheckID, - Checks: posture.ChecksDefinition{ - OSVersionCheck: &posture.OSVersionCheck{ - Linux: &posture.MinKernelVersionCheck{ - MinKernelVersion: "0.0.0"}, - }, - }, - }, - }, - } -} - -func Test_NetworksNetMapGenWithNoPostureChecks(t *testing.T) { - account := getBasicAccountsWithResource() - - // all peers should match the policy - - // validate for peer1 - isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate for peer2 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate routes for router1 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.True(t, isRouter, "should be router") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 2, "expected source peers don't match") - assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") - assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match") - - // validate rules for router1 - rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) - assert.Len(t, rules, 1, "expected rules count don't match") - assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") - assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") - if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { - t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) - } - if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { - t.Errorf("%s should have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) - } -} - -func Test_NetworksNetMapGenWithPostureChecks(t *testing.T) { - account := getBasicAccountsWithResource() - - // should allow peer1 to match the policy - policy := account.Policies[0] - policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID} - - // validate for peer1 - isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate for peer2 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate routes for router1 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.True(t, isRouter, "should be router") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 1, "expected source peers don't match") - assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") - - // validate rules for router1 - rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) - assert.Len(t, rules, 1, "expected rules count don't match") - assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") - assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") - if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { - t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) - } - if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { - t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) - } -} - -func Test_NetworksNetMapGenWithNoMatchedPostureChecks(t *testing.T) { - account := getBasicAccountsWithResource() - - // should not match any peer - policy := account.Policies[0] - policy.SourcePostureChecks = []string{accNetResourceLockedPostureCheckID} - - // validate for peer1 - isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate for peer2 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate routes for router1 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.True(t, isRouter, "should be router") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate rules for router1 - rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) - assert.Len(t, rules, 0, "expected rules count don't match") -} - -func Test_NetworksNetMapGenWithTwoPoliciesAndPostureChecks(t *testing.T) { - account := getBasicAccountsWithResource() - - // should allow peer1 to match the policy - policy := account.Policies[0] - policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID} - - // should allow peer1 and peer2 to match the policy - newPolicy := &Policy{ - ID: "policy2ID", - AccountID: accID, - Enabled: true, - Rules: []*PolicyRule{ - { - ID: "policy2ID", - Enabled: true, - Sources: []string{group1ID}, - DestinationResource: Resource{ - ID: accNetResource1ID, - Type: "Host", - }, - Protocol: PolicyRuleProtocolTCP, - Ports: []string{"22"}, - Action: PolicyTrafficActionAccept, - }, - }, - SourcePostureChecks: []string{accNetResourceRelaxedPostureCheckID}, - } - - account.Policies = append(account.Policies, newPolicy) - - // validate for peer1 - isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate for peer2 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate routes for router1 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.True(t, isRouter, "should be router") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 2, "expected source peers don't match") - assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") - assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match") - - // validate rules for router1 - rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) - assert.Len(t, rules, 2, "expected rules count don't match") - assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") - assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") - if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { - t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) - } - if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { - t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) - } - - assert.Equal(t, uint16(22), rules[1].Port, "should have port 22") - assert.Equal(t, "tcp", rules[1].Protocol, "should have protocol tcp") - if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer1IP.String()+"/32") { - t.Errorf("%s should have source range of peer1 %s", rules[1].SourceRanges, accNetResourcePeer1IP.String()) - } - if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer2IP.String()+"/32") { - t.Errorf("%s should have source range of peer2 %s", rules[1].SourceRanges, accNetResourcePeer2IP.String()) - } -} - -func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) { - account := getBasicAccountsWithResource() - - // two posture checks should match only the peers that match both checks - policy := account.Policies[0] - policy.SourcePostureChecks = []string{accNetResourceRelaxedPostureCheckID, accNetResourceLinuxPostureCheckID} - - // validate for peer1 - isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate for peer2 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate routes for router1 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.True(t, isRouter, "should be router") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 1, "expected source peers don't match") - assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") - - // validate rules for router1 - rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) - assert.Len(t, rules, 1, "expected rules count don't match") - assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") - assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") - if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { - t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) - } - if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { - t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) - } -} - -func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) { - account := getBasicAccountsWithResource() - - account.Peers["router2Id"] = &nbpeer.Peer{Key: "router2Key", ID: "router2Id", AccountID: accID, IP: net.IP{192, 168, 1, 4}} - account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{ - ID: "router2Id", - NetworkID: network1ID, - AccountID: accID, - Peer: "router2Id", - }) - - // validate routes for router1 - isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.True(t, isRouter, "should be router") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 2, "expected source peers don't match") -} - func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) { tests := []struct { name string @@ -1320,7 +921,11 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { }, peersToConnect: []*nbpeer.Peer{}, expiredPeers: []*nbpeer.Peer{}, - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + peer: &nbpeer.Peer{ + ID: "router", + IP: netip.MustParseAddr("10.0.0.100"), + IPv6: netip.MustParseAddr("fd00::a00:64"), + }, expectedRecords: []nbdns.SimpleRecord{ {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, }, @@ -1347,14 +952,19 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { var peers []*nbpeer.Peer for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { peers = append(peers, &nbpeer.Peer{ - ID: fmt.Sprintf("peer%d", i), - IP: net.ParseIP(fmt.Sprintf("10.0.%d.%d", i/256, i%256)), + ID: fmt.Sprintf("peer%d", i), + IP: netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", i/256, i%256)), + IPv6: netip.MustParseAddr(fmt.Sprintf("fd00::%d", i)), }) } return peers }(), expiredPeers: []*nbpeer.Peer{}, - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + peer: &nbpeer.Peer{ + ID: "router", + IP: netip.MustParseAddr("10.0.0.100"), + IPv6: netip.MustParseAddr("fd00::a00:64"), + }, expectedRecords: func() []nbdns.SimpleRecord { var records []nbdns.SimpleRecord for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { @@ -1385,11 +995,27 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { }, }, peersToConnect: []*nbpeer.Peer{ - {ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}}, - {ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}}, + { + ID: "peer1", + IP: netip.MustParseAddr("10.0.0.1"), + IPv6: netip.MustParseAddr("fd00::a00:1"), + DNSLabel: "peer1", + ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}, + }, + { + ID: "peer2", + IP: netip.MustParseAddr("10.0.0.2"), + IPv6: netip.MustParseAddr("fd00::a00:2"), + DNSLabel: "peer2", + ExtraDNSLabels: []string{"peer2-service"}, + }, }, expiredPeers: []*nbpeer.Peer{}, - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + peer: &nbpeer.Peer{ + ID: "router", + IP: netip.MustParseAddr("10.0.0.100"), + IPv6: netip.MustParseAddr("fd00::a00:64"), + }, expectedRecords: []nbdns.SimpleRecord{ {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, @@ -1411,12 +1037,24 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { }, }, peersToConnect: []*nbpeer.Peer{ - {ID: "peer1", IP: net.ParseIP("10.0.0.1")}, + { + ID: "peer1", + IP: netip.MustParseAddr("10.0.0.1"), + IPv6: netip.MustParseAddr("fd00::a00:1"), + }, }, expiredPeers: []*nbpeer.Peer{ - {ID: "expired-peer", IP: net.ParseIP("10.0.0.99")}, + { + ID: "expired-peer", + IP: netip.MustParseAddr("10.0.0.99"), + IPv6: netip.MustParseAddr("fd00::a00:63"), + }, + }, + peer: &nbpeer.Peer{ + ID: "router", + IP: netip.MustParseAddr("10.0.0.100"), + IPv6: netip.MustParseAddr("fd00::a00:64"), }, - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: []nbdns.SimpleRecord{ {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, {Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"}, diff --git a/management/server/types/firewall_rule.go b/management/server/types/firewall_rule.go index 19222a607..b76a94290 100644 --- a/management/server/types/firewall_rule.go +++ b/management/server/types/firewall_rule.go @@ -48,16 +48,26 @@ func (r *FirewallRule) Equal(other *FirewallRule) bool { } // generateRouteFirewallRules generates a list of firewall rules for a given route. -func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule { +// For static routes, source ranges match the destination family (v4 or v6). +// For dynamic routes (domain-based), separate v4 and v6 rules are generated +// so the routing peer's forwarding chain allows both address families. +func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int, includeIPv6 bool) []*RouteFirewallRule { rulesExists := make(map[string]struct{}) rules := make([]*RouteFirewallRule, 0) - sourceRanges := make([]string, 0, len(groupPeers)) - for _, peer := range groupPeers { - if peer == nil { - continue - } - sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, peer.IP)) + v4Sources, v6Sources := splitPeerSourcesByFamily(groupPeers) + + isV6Route := route.Network.Addr().Is6() + + // Skip v6 destination routes entirely for peers without IPv6 support + if isV6Route && !includeIPv6 { + return rules + } + + // Pick sources matching the destination family + sourceRanges := v4Sources + if isV6Route { + sourceRanges = v6Sources } baseRule := RouteFirewallRule{ @@ -71,18 +81,47 @@ func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule IsDynamic: route.IsDynamic(), } - // generate rule for port range if len(rule.Ports) == 0 { rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...) } else { rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...) } - // TODO: generate IPv6 rules for dynamic routes + // Generate v6 counterpart for dynamic routes and 0.0.0.0/0 exit node routes. + isDefaultV4 := !isV6Route && route.Network.Bits() == 0 + if includeIPv6 && (route.IsDynamic() || isDefaultV4) && len(v6Sources) > 0 { + v6Rule := baseRule + v6Rule.SourceRanges = v6Sources + if isDefaultV4 { + v6Rule.Destination = "::/0" + v6Rule.RouteID = route.ID + "-v6-default" + } + if len(rule.Ports) == 0 { + rules = append(rules, generateRulesWithPortRanges(v6Rule, rule, rulesExists)...) + } else { + rules = append(rules, generateRulesWithPorts(ctx, v6Rule, rule, rulesExists)...) + } + } return rules } +// splitPeerSourcesByFamily separates peer IPs into v4 (/32) and v6 (/128) source ranges. +func splitPeerSourcesByFamily(groupPeers []*nbpeer.Peer) (v4, v6 []string) { + v4 = make([]string, 0, len(groupPeers)) + v6 = make([]string, 0, len(groupPeers)) + for _, peer := range groupPeers { + if peer == nil { + continue + } + v4 = append(v4, fmt.Sprintf(AllowedIPsFormat, peer.IP)) + if peer.IPv6.IsValid() { + v6 = append(v6, fmt.Sprintf(AllowedIPsV6Format, peer.IPv6)) + } + } + return +} + // generateRulesForPeer generates rules for a given peer based on ports and port ranges. func generateRulesWithPortRanges(baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { rules := make([]*RouteFirewallRule, 0) diff --git a/management/server/types/firewall_rule_test.go b/management/server/types/firewall_rule_test.go new file mode 100644 index 000000000..8d97a46bc --- /dev/null +++ b/management/server/types/firewall_rule_test.go @@ -0,0 +1,197 @@ +package types + +import ( + "context" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" +) + +func TestSplitPeerSourcesByFamily(t *testing.T) { + peers := []*nbpeer.Peer{ + { + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), + }, + { + IP: netip.MustParseAddr("100.64.0.2"), + }, + { + IP: netip.MustParseAddr("100.64.0.3"), + IPv6: netip.MustParseAddr("fd00::3"), + }, + nil, + } + + v4, v6 := splitPeerSourcesByFamily(peers) + + assert.Equal(t, []string{"100.64.0.1/32", "100.64.0.2/32", "100.64.0.3/32"}, v4) + assert.Equal(t, []string{"fd00::1/128", "fd00::3/128"}, v6) +} + +func TestGenerateRouteFirewallRules_V4Route(t *testing.T) { + peers := []*nbpeer.Peer{ + { + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), + }, + { + IP: netip.MustParseAddr("100.64.0.2"), + }, + } + + r := &route.Route{ + ID: "route1", + Network: netip.MustParsePrefix("10.0.0.0/24"), + } + rule := &PolicyRule{ + PolicyID: "policy1", + ID: "rule1", + Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolALL, + } + + rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true) + + require.Len(t, rules, 1) + assert.Equal(t, []string{"100.64.0.1/32", "100.64.0.2/32"}, rules[0].SourceRanges, "v4 route should only have v4 sources") + assert.Equal(t, "10.0.0.0/24", rules[0].Destination) +} + +func TestGenerateRouteFirewallRules_V6Route(t *testing.T) { + peers := []*nbpeer.Peer{ + { + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), + }, + { + IP: netip.MustParseAddr("100.64.0.2"), + }, + } + + r := &route.Route{ + ID: "route1", + Network: netip.MustParsePrefix("2001:db8::/32"), + } + rule := &PolicyRule{ + PolicyID: "policy1", + ID: "rule1", + Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolALL, + } + + rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true) + + require.Len(t, rules, 1) + assert.Equal(t, []string{"fd00::1/128"}, rules[0].SourceRanges, "v6 route should only have v6 sources") +} + +func TestGenerateRouteFirewallRules_DynamicRoute_DualStack(t *testing.T) { + peers := []*nbpeer.Peer{ + { + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), + }, + { + IP: netip.MustParseAddr("100.64.0.2"), + }, + } + + r := &route.Route{ + ID: "route1", + NetworkType: route.DomainNetwork, + Domains: domain.List{"example.com"}, + } + rule := &PolicyRule{ + PolicyID: "policy1", + ID: "rule1", + Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolALL, + } + + rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true) + + require.Len(t, rules, 2, "dynamic route should produce both v4 and v6 rules") + assert.Equal(t, []string{"100.64.0.1/32", "100.64.0.2/32"}, rules[0].SourceRanges) + assert.Equal(t, []string{"fd00::1/128"}, rules[1].SourceRanges) + assert.Equal(t, rules[0].Domains, rules[1].Domains) + assert.True(t, rules[0].IsDynamic) + assert.True(t, rules[1].IsDynamic) +} + +func TestGenerateRouteFirewallRules_DynamicRoute_NoV6Peers(t *testing.T) { + peers := []*nbpeer.Peer{ + {IP: netip.MustParseAddr("100.64.0.1")}, + {IP: netip.MustParseAddr("100.64.0.2")}, + } + + r := &route.Route{ + ID: "route1", + NetworkType: route.DomainNetwork, + Domains: domain.List{"example.com"}, + } + rule := &PolicyRule{ + PolicyID: "policy1", + ID: "rule1", + Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolALL, + } + + rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true) + + require.Len(t, rules, 1, "no v6 peers means only v4 rule") + assert.Equal(t, []string{"100.64.0.1/32", "100.64.0.2/32"}, rules[0].SourceRanges) +} + +func TestGenerateRouteFirewallRules_IncludeIPv6False(t *testing.T) { + peers := []*nbpeer.Peer{ + { + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), + }, + { + IP: netip.MustParseAddr("100.64.0.2"), + IPv6: netip.MustParseAddr("fd00::2"), + }, + } + + t.Run("v6 route excluded", func(t *testing.T) { + r := &route.Route{ + ID: "route1", + Network: netip.MustParsePrefix("2001:db8::/32"), + } + rule := &PolicyRule{ + PolicyID: "policy1", + ID: "rule1", + Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolALL, + } + + rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, false) + assert.Empty(t, rules, "v6 route should produce no rules when includeIPv6 is false") + }) + + t.Run("dynamic route only v4", func(t *testing.T) { + r := &route.Route{ + ID: "route1", + NetworkType: route.DomainNetwork, + Domains: domain.List{"example.com"}, + } + rule := &PolicyRule{ + PolicyID: "policy1", + ID: "rule1", + Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolALL, + } + + rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, false) + require.Len(t, rules, 1, "dynamic route with includeIPv6=false should produce only v4 rule") + assert.Equal(t, []string{"100.64.0.1/32", "100.64.0.2/32"}, rules[0].SourceRanges) + }) +} diff --git a/management/server/types/holder.go b/management/server/types/holder.go deleted file mode 100644 index de8ac8110..000000000 --- a/management/server/types/holder.go +++ /dev/null @@ -1,47 +0,0 @@ -package types - -import ( - "context" - "sync" -) - -type Holder struct { - mu sync.RWMutex - accounts map[string]*Account -} - -func NewHolder() *Holder { - return &Holder{ - accounts: make(map[string]*Account), - } -} - -func (h *Holder) GetAccount(id string) *Account { - h.mu.RLock() - defer h.mu.RUnlock() - return h.accounts[id] -} - -func (h *Holder) AddAccount(account *Account) { - h.mu.Lock() - defer h.mu.Unlock() - a := h.accounts[account.Id] - if a != nil && a.Network.CurrentSerial() >= account.Network.CurrentSerial() { - return - } - h.accounts[account.Id] = account -} - -func (h *Holder) LoadOrStoreFunc(ctx context.Context, id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) { - h.mu.Lock() - defer h.mu.Unlock() - if acc, ok := h.accounts[id]; ok { - return acc, nil - } - account, err := accGetter(ctx, id) - if err != nil { - return nil, err - } - h.accounts[id] = account - return account, nil -} diff --git a/management/server/types/identity_provider.go b/management/server/types/identity_provider.go index c4498e4d4..0c1f9509c 100644 --- a/management/server/types/identity_provider.go +++ b/management/server/types/identity_provider.go @@ -39,6 +39,8 @@ const ( IdentityProviderTypeAuthentik IdentityProviderType = "authentik" // IdentityProviderTypeKeycloak is the Keycloak identity provider IdentityProviderTypeKeycloak IdentityProviderType = "keycloak" + // IdentityProviderTypeADFS is the Microsoft AD FS identity provider + IdentityProviderTypeADFS IdentityProviderType = "adfs" ) // IdentityProvider represents an identity provider configuration @@ -112,7 +114,8 @@ func (t IdentityProviderType) IsValid() bool { switch t { case IdentityProviderTypeOIDC, IdentityProviderTypeZitadel, IdentityProviderTypeEntra, IdentityProviderTypeGoogle, IdentityProviderTypeOkta, IdentityProviderTypePocketID, - IdentityProviderTypeMicrosoft, IdentityProviderTypeAuthentik, IdentityProviderTypeKeycloak: + IdentityProviderTypeMicrosoft, IdentityProviderTypeAuthentik, IdentityProviderTypeKeycloak, + IdentityProviderTypeADFS: return true } return false diff --git a/management/server/types/ipv6_endtoend_test.go b/management/server/types/ipv6_endtoend_test.go new file mode 100644 index 000000000..ddd1f649f --- /dev/null +++ b/management/server/types/ipv6_endtoend_test.go @@ -0,0 +1,156 @@ +package types_test + +import ( + "net/netip" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +func TestNetworkMapComponents_IPv6EndToEnd(t *testing.T) { + account := createComponentTestAccount() + + // Make all peers IPv6-capable and assign IPv6 addrs. + v6Caps := []int32{nbpeer.PeerCapabilityIPv6Overlay, nbpeer.PeerCapabilitySourcePrefixes} + account.Peers["peer-src-1"].Meta.Capabilities = v6Caps + account.Peers["peer-src-1"].IPv6 = netip.MustParseAddr("fd00::1") + account.Peers["peer-src-2"].Meta.Capabilities = v6Caps + account.Peers["peer-src-2"].IPv6 = netip.MustParseAddr("fd00::2") + account.Peers["peer-dst-1"].Meta.Capabilities = v6Caps + account.Peers["peer-dst-1"].IPv6 = netip.MustParseAddr("fd00::3") + + // Mark group-src and group-dst as IPv6-enabled. + account.Settings.IPv6EnabledGroups = []string{"group-src", "group-dst"} + + validated := allPeersValidated(account) + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + require.NotNil(t, nm) + + t.Run("v6 AAAA records emitted", func(t *testing.T) { + require.NotEmpty(t, nm.DNSConfig.CustomZones, "expected at least one custom zone") + var hasAAAA bool + var hasA bool + for _, z := range nm.DNSConfig.CustomZones { + for _, r := range z.Records { + if r.Type == int(dns.TypeAAAA) { + hasAAAA = true + } + if r.Type == int(dns.TypeA) { + hasA = true + } + } + } + assert.True(t, hasA, "expected A records") + assert.True(t, hasAAAA, "expected AAAA records for IPv6-enabled peers") + }) + + t.Run("v6 AllowedIPs would be advertised", func(t *testing.T) { + // nm.Peers contains *nbpeer.Peer; IPv6 should be set on those peers + var foundV6 bool + for _, p := range nm.Peers { + if p.IPv6.IsValid() { + foundV6 = true + } + } + assert.True(t, foundV6, "remote peers should have IPv6 set so AllowedIPs gets v6") + }) + + t.Run("v6 firewall rules emitted", func(t *testing.T) { + require.NotEmpty(t, nm.FirewallRules, "expected firewall rules") + var hasV4 bool + var hasV6 bool + for _, r := range nm.FirewallRules { + addr, err := netip.ParseAddr(r.PeerIP) + if err != nil { + continue + } + if addr.Is4() { + hasV4 = true + } + if addr.Is6() { + hasV6 = true + } + } + assert.True(t, hasV4, "expected at least one v4 firewall rule (peer IP)") + assert.True(t, hasV6, "expected at least one v6 firewall rule (peer IPv6)") + }) +} + +// TestNetworkMapComponents_RemotePeerWithoutCapability checks the asymmetric +// case where the target peer is IPv6-capable but a remote peer has an IPv6 +// address assigned in the DB without yet reporting the capability flag. +// In that case the remote peer's v6 still appears in AllowedIPs (gated on +// the target peer's capability) but its AAAA record does not (gated on the +// remote peer's own capability). +func TestNetworkMapComponents_RemotePeerWithoutCapability(t *testing.T) { + account := createComponentTestAccount() + + v6Caps := []int32{nbpeer.PeerCapabilityIPv6Overlay, nbpeer.PeerCapabilitySourcePrefixes} + // Target is fully capable. + account.Peers["peer-src-1"].Meta.Capabilities = v6Caps + account.Peers["peer-src-1"].IPv6 = netip.MustParseAddr("fd00::1") + // Remote peer has v6 assigned but no capability flag yet (e.g. old client). + account.Peers["peer-dst-1"].IPv6 = netip.MustParseAddr("fd00::3") + + account.Settings.IPv6EnabledGroups = []string{"group-src", "group-dst"} + + validated := allPeersValidated(account) + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + require.NotNil(t, nm) + + t.Run("AllowedIPs include remote v6", func(t *testing.T) { + var dst *nbpeer.Peer + for _, p := range nm.Peers { + if p.ID == "peer-dst-1" { + dst = p + } + } + require.NotNil(t, dst) + assert.True(t, dst.IPv6.IsValid(), "remote peer's v6 should still be present so AllowedIPs gets v6/128 (gated on target peer cap)") + }) + + t.Run("no AAAA for non-capable remote peer", func(t *testing.T) { + for _, z := range nm.DNSConfig.CustomZones { + for _, r := range z.Records { + if r.Type == int(dns.TypeAAAA) && r.RData == "fd00::3" { + t.Errorf("AAAA record for non-capable remote peer should NOT be emitted, got %+v", r) + } + } + } + }) +} + +// TestNetworkMapComponents_IPv6Disabled_NoV6Output asserts that a peer that +// does not support IPv6 (e.g. older client without the capability flag) gets +// no v6 firewall rules and no AAAA records, even if other peers have IPv6. +func TestNetworkMapComponents_IPv6Disabled_NoV6Output(t *testing.T) { + account := createComponentTestAccount() + + v6Caps := []int32{nbpeer.PeerCapabilityIPv6Overlay} + account.Peers["peer-src-2"].Meta.Capabilities = v6Caps + account.Peers["peer-src-2"].IPv6 = netip.MustParseAddr("fd00::2") + account.Peers["peer-dst-1"].Meta.Capabilities = v6Caps + account.Peers["peer-dst-1"].IPv6 = netip.MustParseAddr("fd00::3") + // peer-src-1 (target) intentionally has no capability and no IPv6. + + account.Settings.IPv6EnabledGroups = []string{"group-src", "group-dst"} + + validated := allPeersValidated(account) + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + require.NotNil(t, nm) + + t.Run("no v6 firewall rules", func(t *testing.T) { + for _, r := range nm.FirewallRules { + addr, err := netip.ParseAddr(r.PeerIP) + if err != nil { + continue + } + assert.False(t, addr.Is6(), "v6 firewall rules should not be emitted for non-IPv6 peer (got %s)", r.PeerIP) + } + }) +} diff --git a/management/server/types/ipv6_groups_test.go b/management/server/types/ipv6_groups_test.go new file mode 100644 index 000000000..5151e1b1f --- /dev/null +++ b/management/server/types/ipv6_groups_test.go @@ -0,0 +1,234 @@ +package types + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +func TestPeerIPv6Allowed(t *testing.T) { + account := &Account{ + Groups: map[string]*Group{ + "group-all": {ID: "group-all", Name: "All", Peers: []string{"peer1", "peer2", "peer3"}}, + "group-devs": {ID: "group-devs", Name: "Devs", Peers: []string{"peer1", "peer2"}}, + "group-infra": {ID: "group-infra", Name: "Infra", Peers: []string{"peer2", "peer3"}}, + "group-empty": {ID: "group-empty", Name: "Empty", Peers: []string{}}, + }, + Settings: &Settings{}, + } + + tests := []struct { + name string + enabledGroups []string + peerID string + expected bool + }{ + { + name: "empty groups list disables IPv6 for all", + enabledGroups: []string{}, + peerID: "peer1", + expected: false, + }, + { + name: "All group enables IPv6 for everyone", + enabledGroups: []string{"group-all"}, + peerID: "peer1", + expected: true, + }, + { + name: "peer in enabled group gets IPv6", + enabledGroups: []string{"group-devs"}, + peerID: "peer1", + expected: true, + }, + { + name: "peer not in any enabled group denied IPv6", + enabledGroups: []string{"group-devs"}, + peerID: "peer3", + expected: false, + }, + { + name: "peer in multiple groups, one enabled", + enabledGroups: []string{"group-infra"}, + peerID: "peer2", + expected: true, + }, + { + name: "peer in multiple groups, other one enabled", + enabledGroups: []string{"group-devs"}, + peerID: "peer2", + expected: true, + }, + { + name: "multiple enabled groups, peer in one", + enabledGroups: []string{"group-devs", "group-infra"}, + peerID: "peer1", + expected: true, + }, + { + name: "multiple enabled groups, peer in both", + enabledGroups: []string{"group-devs", "group-infra"}, + peerID: "peer2", + expected: true, + }, + { + name: "nonexistent group ID in enabled list", + enabledGroups: []string{"group-deleted"}, + peerID: "peer1", + expected: false, + }, + { + name: "empty group in enabled list", + enabledGroups: []string{"group-empty"}, + peerID: "peer1", + expected: false, + }, + { + name: "unknown peer ID", + enabledGroups: []string{"group-all"}, + peerID: "peer-unknown", + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + account.Settings.IPv6EnabledGroups = tc.enabledGroups + result := account.PeerIPv6Allowed(tc.peerID) + assert.Equal(t, tc.expected, result) + }) + } +} + +func TestIPv6RecalculationOnGroupChange(t *testing.T) { + peerWithV6 := func(id string, v6 string) *nbpeer.Peer { + p := &nbpeer.Peer{ + ID: id, + IP: netip.MustParseAddr("100.64.0.1"), + } + if v6 != "" { + p.IPv6 = netip.MustParseAddr(v6) + } + return p + } + + t.Run("peer loses IPv6 when removed from enabled groups", func(t *testing.T) { + peer := peerWithV6("peer1", "fd00::1") + + account := &Account{ + Peers: map[string]*nbpeer.Peer{"peer1": peer}, + Groups: map[string]*Group{ + "group-a": {ID: "group-a", Peers: []string{"peer1"}}, + "group-b": {ID: "group-b", Peers: []string{}}, + }, + Settings: &Settings{ + IPv6EnabledGroups: []string{"group-a"}, + }, + } + + assert.True(t, account.PeerIPv6Allowed("peer1"), "peer should be allowed before change") + + // Move peer out of enabled group + account.Groups["group-a"].Peers = []string{} + account.Groups["group-b"].Peers = []string{"peer1"} + + assert.False(t, account.PeerIPv6Allowed("peer1"), "peer should be denied after group change") + }) + + t.Run("peer gains IPv6 when added to enabled group", func(t *testing.T) { + peer := peerWithV6("peer1", "") + + account := &Account{ + Peers: map[string]*nbpeer.Peer{"peer1": peer}, + Groups: map[string]*Group{ + "group-a": {ID: "group-a", Peers: []string{}}, + "group-b": {ID: "group-b", Peers: []string{"peer1"}}, + }, + Settings: &Settings{ + IPv6EnabledGroups: []string{"group-a"}, + }, + } + + assert.False(t, account.PeerIPv6Allowed("peer1"), "peer should be denied before change") + + // Add peer to enabled group + account.Groups["group-a"].Peers = []string{"peer1"} + + assert.True(t, account.PeerIPv6Allowed("peer1"), "peer should be allowed after joining enabled group") + }) + + t.Run("peer in two groups, one leaves enabled list", func(t *testing.T) { + peer := peerWithV6("peer1", "fd00::1") + + account := &Account{ + Peers: map[string]*nbpeer.Peer{"peer1": peer}, + Groups: map[string]*Group{ + "group-a": {ID: "group-a", Peers: []string{"peer1"}}, + "group-b": {ID: "group-b", Peers: []string{"peer1"}}, + }, + Settings: &Settings{ + IPv6EnabledGroups: []string{"group-a", "group-b"}, + }, + } + + assert.True(t, account.PeerIPv6Allowed("peer1")) + + // Remove group-a from enabled list, peer still in group-b + account.Settings.IPv6EnabledGroups = []string{"group-b"} + + assert.True(t, account.PeerIPv6Allowed("peer1"), "peer should still be allowed via group-b") + }) + + t.Run("peer in two groups, both leave enabled list", func(t *testing.T) { + peer := peerWithV6("peer1", "fd00::1") + + account := &Account{ + Peers: map[string]*nbpeer.Peer{"peer1": peer}, + Groups: map[string]*Group{ + "group-a": {ID: "group-a", Peers: []string{"peer1"}}, + "group-b": {ID: "group-b", Peers: []string{"peer1"}}, + }, + Settings: &Settings{ + IPv6EnabledGroups: []string{"group-a", "group-b"}, + }, + } + + assert.True(t, account.PeerIPv6Allowed("peer1")) + + // Clear all enabled groups + account.Settings.IPv6EnabledGroups = []string{} + + assert.False(t, account.PeerIPv6Allowed("peer1"), "peer should be denied when no groups enabled") + }) + + t.Run("enabling a group gives only its peers IPv6", func(t *testing.T) { + account := &Account{ + Peers: map[string]*nbpeer.Peer{ + "peer1": peerWithV6("peer1", ""), + "peer2": peerWithV6("peer2", ""), + "peer3": peerWithV6("peer3", ""), + }, + Groups: map[string]*Group{ + "group-devs": {ID: "group-devs", Peers: []string{"peer1", "peer2"}}, + "group-infra": {ID: "group-infra", Peers: []string{"peer2", "peer3"}}, + }, + Settings: &Settings{ + IPv6EnabledGroups: []string{"group-devs"}, + }, + } + + assert.True(t, account.PeerIPv6Allowed("peer1"), "peer1 in devs") + assert.True(t, account.PeerIPv6Allowed("peer2"), "peer2 in devs") + assert.False(t, account.PeerIPv6Allowed("peer3"), "peer3 not in devs") + + // Add infra group + account.Settings.IPv6EnabledGroups = []string{"group-devs", "group-infra"} + + assert.True(t, account.PeerIPv6Allowed("peer1"), "peer1 still in devs") + assert.True(t, account.PeerIPv6Allowed("peer2"), "peer2 in both") + assert.True(t, account.PeerIPv6Allowed("peer3"), "peer3 now in infra") + }) +} diff --git a/management/server/types/network.go b/management/server/types/network.go index 0d13de10f..fe67bfd97 100644 --- a/management/server/types/network.go +++ b/management/server/types/network.go @@ -2,8 +2,11 @@ package types import ( "encoding/binary" + "fmt" "math/rand" "net" + "net/netip" + "slices" "sync" "time" @@ -27,6 +30,12 @@ const ( // AllowedIPsFormat generates Wireguard AllowedIPs format (e.g. 100.64.30.1/32) AllowedIPsFormat = "%s/32" + // AllowedIPsV6Format generates AllowedIPs format for v6 (e.g. fd12:3456:7890::1/128) + AllowedIPsV6Format = "%s/128" + + // IPv6SubnetSize is the prefix length of per-account IPv6 subnets. + // Each account gets a /64 from its unique /48 ULA prefix. + IPv6SubnetSize = 64 ) type NetworkMap struct { @@ -111,7 +120,9 @@ func ipToBytes(ip net.IP) []byte { type Network struct { Identifier string `json:"id"` Net net.IPNet `gorm:"serializer:json"` - Dns string + // NetV6 is the IPv6 ULA subnet for this account's overlay. Empty if not yet allocated. + NetV6 net.IPNet `gorm:"serializer:json"` + Dns string // Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added). // Used to synchronize state to the client apps. Serial uint64 @@ -121,20 +132,45 @@ type Network struct { // NewNetwork creates a new Network initializing it with a Serial=0 // It takes a random /16 subnet from 100.64.0.0/10 (64 different subnets) +// and a random /64 subnet from fd00:4e42::/32 for IPv6. func NewNetwork() *Network { - n := iplib.NewNet4(net.ParseIP("100.64.0.0"), NetSize) sub, _ := n.Subnet(SubnetSize) - s := rand.NewSource(time.Now().Unix()) + s := rand.NewSource(time.Now().UnixNano()) r := rand.New(s) intn := r.Intn(len(sub)) return &Network{ Identifier: xid.New().String(), Net: sub[intn].IPNet, + NetV6: AllocateIPv6Subnet(r), Dns: "", - Serial: 0} + Serial: 0, + } +} + +// AllocateIPv6Subnet generates a random RFC 4193 ULA /64 prefix. +// The format follows RFC 4193 section 3.1: fd + 40-bit Global ID + 16-bit Subnet ID. +// The Global ID and Subnet ID are randomized (simplified from the SHA-1 algorithm +// in section 3.2.2), giving 2^56 possible /64 subnets across all accounts. +func AllocateIPv6Subnet(r *rand.Rand) net.IPNet { + ip := make(net.IP, 16) + ip[0] = 0xfd + // Bytes 1-5: 40-bit random Global ID + ip[1] = byte(r.Intn(256)) + ip[2] = byte(r.Intn(256)) + ip[3] = byte(r.Intn(256)) + ip[4] = byte(r.Intn(256)) + ip[5] = byte(r.Intn(256)) + // Bytes 6-7: 16-bit random Subnet ID + ip[6] = byte(r.Intn(256)) + ip[7] = byte(r.Intn(256)) + + return net.IPNet{ + IP: ip, + Mask: net.CIDRMask(IPv6SubnetSize, 128), + } } // IncSerial increments Serial by 1 reflecting that the network state has been changed @@ -157,19 +193,19 @@ func (n *Network) Copy() *Network { return &Network{ Identifier: n.Identifier, Net: n.Net, + NetV6: n.NetV6, Dns: n.Dns, Serial: n.Serial, } } -// AllocatePeerIP pics an available IP from an net.IPNet. -// This method considers already taken IPs and reuses IPs if there are gaps in takenIps -// E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3 -func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) { - baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask)) - - ones, bits := ipNet.Mask.Size() - hostBits := bits - ones +// AllocatePeerIP picks an available IP from a netip.Prefix. +// This method considers already taken IPs and reuses IPs if there are gaps in takenIps. +// E.g. if prefix=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3. +func AllocatePeerIP(prefix netip.Prefix, takenIps []netip.Addr) (netip.Addr, error) { + b := prefix.Masked().Addr().As4() + baseIP := binary.BigEndian.Uint32(b[:]) + hostBits := 32 - prefix.Bits() totalIPs := uint32(1 << hostBits) taken := make(map[uint32]struct{}, len(takenIps)+1) @@ -177,7 +213,8 @@ func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) { taken[baseIP+totalIPs-1] = struct{}{} // reserve broadcast IP for _, ip := range takenIps { - taken[ipToUint32(ip)] = struct{}{} + ab := ip.As4() + taken[binary.BigEndian.Uint32(ab[:])] = struct{}{} } rng := rand.New(rand.NewSource(time.Now().UnixNano())) @@ -198,15 +235,14 @@ func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) { } } - return nil, status.Errorf(status.PreconditionFailed, "network %s is out of IPs", ipNet.String()) + return netip.Addr{}, status.Errorf(status.PreconditionFailed, "network %s is out of IPs", prefix.String()) } -func AllocateRandomPeerIP(ipNet net.IPNet) (net.IP, error) { - baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask)) - - ones, bits := ipNet.Mask.Size() - hostBits := bits - ones - +// AllocateRandomPeerIP picks a random available IP from a netip.Prefix. +func AllocateRandomPeerIP(prefix netip.Prefix) (netip.Addr, error) { + b := prefix.Masked().Addr().As4() + baseIP := binary.BigEndian.Uint32(b[:]) + hostBits := 32 - prefix.Bits() totalIPs := uint32(1 << hostBits) rng := rand.New(rand.NewSource(time.Now().UnixNano())) @@ -216,18 +252,75 @@ func AllocateRandomPeerIP(ipNet net.IPNet) (net.IP, error) { return uint32ToIP(candidate), nil } -func ipToUint32(ip net.IP) uint32 { - ip = ip.To4() - if len(ip) < 4 { - return 0 +// AllocateRandomPeerIPv6 picks a random host address within the given IPv6 prefix. +// Only the host bits (after the prefix length) are randomized. +func AllocateRandomPeerIPv6(prefix netip.Prefix) (netip.Addr, error) { + ones := prefix.Bits() + if ones == 0 || ones > 126 || !prefix.Addr().Is6() { + return netip.Addr{}, fmt.Errorf("invalid IPv6 subnet: %s", prefix.String()) } - return binary.BigEndian.Uint32(ip) + + ip := prefix.Addr().As16() + + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + + // Determine which byte the host bits start in + firstHostByte := ones / 8 + // If the prefix doesn't end on a byte boundary, handle the partial byte + partialBits := ones % 8 + + if partialBits > 0 { + // Keep the network bits in the partial byte, randomize the rest + hostMask := byte(0xff >> partialBits) + ip[firstHostByte] = (ip[firstHostByte] & ^hostMask) | (byte(rng.Intn(256)) & hostMask) + firstHostByte++ + } + + // Randomize remaining full host bytes + for i := firstHostByte; i < 16; i++ { + ip[i] = byte(rng.Intn(256)) + } + + // Avoid all-zeros and all-ones host parts by checking only host bits. + if isHostAllZeroOrOnes(ip[:], ones) { + ip = prefix.Masked().Addr().As16() + ip[15] |= 0x01 + } + + return netip.AddrFrom16(ip).Unmap(), nil } -func uint32ToIP(n uint32) net.IP { - ip := make(net.IP, 4) - binary.BigEndian.PutUint32(ip, n) - return ip +// isHostAllZeroOrOnes checks whether all host bits (after prefixLen) are zero or all ones. +func isHostAllZeroOrOnes(ip []byte, prefixLen int) bool { + hostStart := prefixLen / 8 + partialBits := prefixLen % 8 + + hostSlice := slices.Clone(ip[hostStart:]) + if partialBits > 0 { + hostSlice[0] &= 0xff >> partialBits + } + + allZero := !slices.ContainsFunc(hostSlice, func(v byte) bool { return v != 0 }) + if allZero { + return true + } + + // Build the all-ones mask for host bits + onesMask := make([]byte, len(hostSlice)) + for i := range onesMask { + onesMask[i] = 0xff + } + if partialBits > 0 { + onesMask[0] = 0xff >> partialBits + } + + return slices.Equal(hostSlice, onesMask) +} + +func uint32ToIP(n uint32) netip.Addr { + var b [4]byte + binary.BigEndian.PutUint32(b[:], n) + return netip.AddrFrom4(b) } // generateIPs generates a list of all possible IPs of the given network excluding IPs specified in the exclusion list diff --git a/management/server/types/network_test.go b/management/server/types/network_test.go index 4c1459ce5..d8a06dbbc 100644 --- a/management/server/types/network_test.go +++ b/management/server/types/network_test.go @@ -1,7 +1,9 @@ package types import ( + "encoding/binary" "net" + "net/netip" "testing" "github.com/stretchr/testify/assert" @@ -17,10 +19,10 @@ func TestNewNetwork(t *testing.T) { } func TestAllocatePeerIP(t *testing.T) { - ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 255, 255, 0}} - var ips []net.IP + prefix := netip.MustParsePrefix("100.64.0.0/24") + var ips []netip.Addr for i := 0; i < 252; i++ { - ip, err := AllocatePeerIP(ipNet, ips) + ip, err := AllocatePeerIP(prefix, ips) if err != nil { t.Fatal(err) } @@ -41,19 +43,19 @@ func TestAllocatePeerIP(t *testing.T) { func TestAllocatePeerIPSmallSubnet(t *testing.T) { // Test /27 network (10.0.0.0/27) - should only have 30 usable IPs (10.0.0.1 to 10.0.0.30) - ipNet := net.IPNet{IP: net.ParseIP("10.0.0.0"), Mask: net.IPMask{255, 255, 255, 224}} - var ips []net.IP + prefix := netip.MustParsePrefix("10.0.0.0/27") + var ips []netip.Addr // Allocate all available IPs in the /27 network for i := 0; i < 30; i++ { - ip, err := AllocatePeerIP(ipNet, ips) + ip, err := AllocatePeerIP(prefix, ips) if err != nil { t.Fatal(err) } // Verify IP is within the correct range - if !ipNet.Contains(ip) { - t.Errorf("allocated IP %s is not within network %s", ip.String(), ipNet.String()) + if !prefix.Contains(ip) { + t.Errorf("allocated IP %s is not within network %s", ip.String(), prefix.String()) } ips = append(ips, ip) @@ -72,7 +74,7 @@ func TestAllocatePeerIPSmallSubnet(t *testing.T) { } // Try to allocate one more IP - should fail as network is full - _, err := AllocatePeerIP(ipNet, ips) + _, err := AllocatePeerIP(prefix, ips) if err == nil { t.Error("expected error when network is full, but got none") } @@ -95,10 +97,11 @@ func TestAllocatePeerIPVariousCIDRs(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - _, ipNet, err := net.ParseCIDR(tc.cidr) + prefix, err := netip.ParsePrefix(tc.cidr) require.NoError(t, err) + prefix = prefix.Masked() - var ips []net.IP + var ips []netip.Addr // For larger networks, test only a subset to avoid long test runs testCount := tc.expectedUsable @@ -108,21 +111,21 @@ func TestAllocatePeerIPVariousCIDRs(t *testing.T) { // Allocate IPs and verify they're within the correct range for i := 0; i < testCount; i++ { - ip, err := AllocatePeerIP(*ipNet, ips) + ip, err := AllocatePeerIP(prefix, ips) require.NoError(t, err, "failed to allocate IP %d", i) // Verify IP is within the correct range - assert.True(t, ipNet.Contains(ip), "allocated IP %s is not within network %s", ip.String(), ipNet.String()) + assert.True(t, prefix.Contains(ip), "allocated IP %s is not within network %s", ip.String(), prefix.String()) // Verify IP is not network or broadcast address - networkIP := ipNet.IP.Mask(ipNet.Mask) - ones, bits := ipNet.Mask.Size() - hostBits := bits - ones - broadcastInt := uint32(ipToUint32(networkIP)) + (1 << hostBits) - 1 - broadcastIP := uint32ToIP(broadcastInt) + networkAddr := prefix.Masked().Addr() + hostBits := 32 - prefix.Bits() + b := networkAddr.As4() + baseIP := binary.BigEndian.Uint32(b[:]) + broadcastIP := uint32ToIP(baseIP + (1 << hostBits) - 1) - assert.False(t, ip.Equal(networkIP), "allocated network address %s", ip.String()) - assert.False(t, ip.Equal(broadcastIP), "allocated broadcast address %s", ip.String()) + assert.NotEqual(t, networkAddr, ip, "allocated network address %s", ip.String()) + assert.NotEqual(t, broadcastIP, ip, "allocated broadcast address %s", ip.String()) ips = append(ips, ip) } @@ -151,3 +154,111 @@ func TestGenerateIPs(t *testing.T) { t.Errorf("expected last ip to be: 100.64.0.253, got %s", ips[len(ips)-1].String()) } } + +func TestNewNetworkHasIPv6(t *testing.T) { + network := NewNetwork() + + assert.NotNil(t, network.NetV6.IP, "v6 subnet should be allocated") + assert.True(t, network.NetV6.IP.To4() == nil, "v6 subnet should be IPv6") + assert.Equal(t, byte(0xfd), network.NetV6.IP[0], "v6 subnet should be ULA (fd prefix)") + + ones, bits := network.NetV6.Mask.Size() + assert.Equal(t, 64, ones, "v6 subnet should be /64") + assert.Equal(t, 128, bits) +} + +func TestAllocateIPv6SubnetUniqueness(t *testing.T) { + seen := make(map[string]struct{}) + for i := 0; i < 100; i++ { + network := NewNetwork() + key := network.NetV6.IP.String() + _, duplicate := seen[key] + assert.False(t, duplicate, "duplicate v6 subnet: %s", key) + seen[key] = struct{}{} + } +} + +func TestAllocateRandomPeerIPv6(t *testing.T) { + prefix := netip.MustParsePrefix("fd12:3456:7890:abcd::/64") + + ip, err := AllocateRandomPeerIPv6(prefix) + require.NoError(t, err) + + assert.True(t, ip.Is6(), "should be IPv6") + assert.True(t, prefix.Contains(ip), "should be within subnet") + // First 8 bytes (network prefix) should match + b := ip.As16() + prefixBytes := prefix.Addr().As16() + assert.Equal(t, prefixBytes[:8], b[:8], "prefix should match") + // Interface ID should not be all zeros + allZero := true + for _, v := range b[8:] { + if v != 0 { + allZero = false + break + } + } + assert.False(t, allZero, "interface ID should not be all zeros") +} + +func TestAllocateRandomPeerIPv6_VariousPrefixes(t *testing.T) { + tests := []struct { + name string + cidr string + prefix int + }{ + {"standard /64", "fd00:1234:5678:abcd::/64", 64}, + {"small /112", "fd00:1234:5678:abcd::/112", 112}, + {"large /48", "fd00:1234::/48", 48}, + {"non-boundary /60", "fd00:1234:5670::/60", 60}, + {"non-boundary /52", "fd00:1230::/52", 52}, + {"minimum /120", "fd00:1234:5678:abcd::100/120", 120}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prefix, err := netip.ParsePrefix(tt.cidr) + require.NoError(t, err) + prefix = prefix.Masked() + + assert.Equal(t, tt.prefix, prefix.Bits()) + + for i := 0; i < 50; i++ { + ip, err := AllocateRandomPeerIPv6(prefix) + require.NoError(t, err) + assert.True(t, prefix.Contains(ip), "IP %s should be within %s", ip, prefix) + } + }) + } +} + +func TestAllocateRandomPeerIPv6_PreservesNetworkBits(t *testing.T) { + // For a /112, bytes 0-13 should be preserved, only bytes 14-15 should vary + prefix := netip.MustParsePrefix("fd00:1234:5678:abcd:ef01:2345:6789:0/112") + + prefixBytes := prefix.Addr().As16() + for i := 0; i < 20; i++ { + ip, err := AllocateRandomPeerIPv6(prefix) + require.NoError(t, err) + // First 14 bytes (112 bits = 14 bytes) must match the network + b := ip.As16() + assert.Equal(t, prefixBytes[:14], b[:14], "network bytes should be preserved for /112") + } +} + +func TestAllocateRandomPeerIPv6_NonByteBoundary(t *testing.T) { + // For a /60, the first 7.5 bytes are network, so byte 7 is partial + prefix := netip.MustParsePrefix("fd00:1234:5678:abc0::/60") + + prefixBytes := prefix.Addr().As16() + for i := 0; i < 50; i++ { + ip, err := AllocateRandomPeerIPv6(prefix) + require.NoError(t, err) + b := ip.As16() + assert.True(t, prefix.Contains(ip), "IP %s should be within %s", ip, prefix) + // First 7 bytes must match exactly + assert.Equal(t, prefixBytes[:7], b[:7], "full network bytes should match for /60") + // Byte 7: top 4 bits (0xc = 1100) must be preserved + assert.Equal(t, prefixBytes[7]&0xf0, b[7]&0xf0, "partial byte network bits should be preserved for /60") + } +} diff --git a/management/server/types/networkmap.go b/management/server/types/networkmap.go deleted file mode 100644 index 68c988a93..000000000 --- a/management/server/types/networkmap.go +++ /dev/null @@ -1,67 +0,0 @@ -package types - -import ( - "context" - - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/modules/zones" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/telemetry" -) - -func (a *Account) initNetworkMapBuilder(validatedPeers map[string]struct{}) { - if a.NetworkMapCache != nil { - return - } - a.nmapInitOnce.Do(func() { - a.NetworkMapCache = NewNetworkMapBuilder(a, validatedPeers) - }) -} - -func (a *Account) InitNetworkMapBuilderIfNeeded(validatedPeers map[string]struct{}) { - a.initNetworkMapBuilder(validatedPeers) -} - -func (a *Account) GetPeerNetworkMapExp( - ctx context.Context, - peerID string, - peersCustomZone nbdns.CustomZone, - accountZones []*zones.Zone, - validatedPeers map[string]struct{}, - metrics *telemetry.AccountManagerMetrics, -) *NetworkMap { - a.initNetworkMapBuilder(validatedPeers) - return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, accountZones, validatedPeers, metrics) -} - -func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error { - if a.NetworkMapCache == nil { - return nil - } - return a.NetworkMapCache.OnPeerAddedIncremental(a, peerId) -} - -func (a *Account) OnPeersAddedUpdNetworkMapCache(peerIds ...string) { - if a.NetworkMapCache == nil { - return - } - a.NetworkMapCache.EnqueuePeersForIncrementalAdd(a, peerIds...) -} - -func (a *Account) OnPeerDeletedUpdNetworkMapCache(peerId string) error { - if a.NetworkMapCache == nil { - return nil - } - return a.NetworkMapCache.OnPeerDeleted(a, peerId) -} - -func (a *Account) UpdatePeerInNetworkMapCache(peer *nbpeer.Peer) { - if a.NetworkMapCache == nil { - return - } - a.NetworkMapCache.UpdatePeer(peer) -} - -func (a *Account) RecalculateNetworkMapCache(validatedPeers map[string]struct{}) { - a.initNetworkMapBuilder(validatedPeers) -} diff --git a/management/server/types/networkmap_comparison_test.go b/management/server/types/networkmap_comparison_test.go deleted file mode 100644 index c5844cca0..000000000 --- a/management/server/types/networkmap_comparison_test.go +++ /dev/null @@ -1,592 +0,0 @@ -package types - -import ( - "context" - "encoding/json" - "fmt" - "net" - "net/netip" - "os" - "path/filepath" - "sort" - "testing" - "time" - - "github.com/stretchr/testify/require" - - nbdns "github.com/netbirdio/netbird/dns" - resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" - routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" - networkTypes "github.com/netbirdio/netbird/management/server/networks/types" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/route" -) - -func TestNetworkMapComponents_CompareWithLegacy(t *testing.T) { - account := createTestAccount() - ctx := context.Background() - - peerID := testingPeerID - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - pid := fmt.Sprintf("peer-%d", i) - if pid == offlinePeerID { - continue - } - validatedPeersMap[pid] = struct{}{} - } - - peersCustomZone := nbdns.CustomZone{} - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - - legacyNetworkMap := account.GetPeerNetworkMap( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - nil, - groupIDToUserIDs, - ) - - components := account.GetPeerNetworkMapComponents( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - groupIDToUserIDs, - ) - - if components == nil { - t.Fatal("GetPeerNetworkMapComponents returned nil") - } - - newNetworkMap := CalculateNetworkMapFromComponents(ctx, components) - - if newNetworkMap == nil { - t.Fatal("CalculateNetworkMapFromComponents returned nil") - } - - compareNetworkMaps(t, legacyNetworkMap, newNetworkMap) -} - -func TestNetworkMapComponents_GoldenFileComparison(t *testing.T) { - account := createTestAccount() - ctx := context.Background() - - peerID := testingPeerID - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - pid := fmt.Sprintf("peer-%d", i) - if pid == offlinePeerID { - continue - } - validatedPeersMap[pid] = struct{}{} - } - - peersCustomZone := nbdns.CustomZone{} - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - - legacyNetworkMap := account.GetPeerNetworkMap( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - nil, - groupIDToUserIDs, - ) - - components := account.GetPeerNetworkMapComponents( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - groupIDToUserIDs, - ) - - require.NotNil(t, components, "GetPeerNetworkMapComponents returned nil") - - newNetworkMap := CalculateNetworkMapFromComponents(ctx, components) - require.NotNil(t, newNetworkMap, "CalculateNetworkMapFromComponents returned nil") - - normalizeAndSortNetworkMap(legacyNetworkMap) - normalizeAndSortNetworkMap(newNetworkMap) - - componentsJSON, err := json.MarshalIndent(components, "", " ") - require.NoError(t, err, "error marshaling components to JSON") - - legacyJSON, err := json.MarshalIndent(legacyNetworkMap, "", " ") - require.NoError(t, err, "error marshaling legacy network map to JSON") - - newJSON, err := json.MarshalIndent(newNetworkMap, "", " ") - require.NoError(t, err, "error marshaling new network map to JSON") - - goldenDir := filepath.Join("testdata", "comparison") - err = os.MkdirAll(goldenDir, 0755) - require.NoError(t, err) - - legacyGoldenPath := filepath.Join(goldenDir, "legacy_networkmap.json") - err = os.WriteFile(legacyGoldenPath, legacyJSON, 0644) - require.NoError(t, err, "error writing legacy golden file") - - newGoldenPath := filepath.Join(goldenDir, "components_networkmap.json") - err = os.WriteFile(newGoldenPath, newJSON, 0644) - require.NoError(t, err, "error writing components golden file") - - componentsPath := filepath.Join(goldenDir, "components.json") - err = os.WriteFile(componentsPath, componentsJSON, 0644) - require.NoError(t, err, "error writing components golden file") - - require.JSONEq(t, string(legacyJSON), string(newJSON), - "NetworkMaps from legacy and components approaches do not match.\n"+ - "Legacy JSON saved to: %s\n"+ - "Components JSON saved to: %s", - legacyGoldenPath, newGoldenPath) - - t.Logf("✅ NetworkMaps are identical") - t.Logf(" Legacy NetworkMap: %s", legacyGoldenPath) - t.Logf(" Components NetworkMap: %s", newGoldenPath) -} - -func normalizeAndSortNetworkMap(nm *NetworkMap) { - if nm == nil { - return - } - - sort.Slice(nm.Peers, func(i, j int) bool { - return nm.Peers[i].ID < nm.Peers[j].ID - }) - - sort.Slice(nm.OfflinePeers, func(i, j int) bool { - return nm.OfflinePeers[i].ID < nm.OfflinePeers[j].ID - }) - - sort.Slice(nm.Routes, func(i, j int) bool { - return string(nm.Routes[i].ID) < string(nm.Routes[j].ID) - }) - - sort.Slice(nm.FirewallRules, func(i, j int) bool { - if nm.FirewallRules[i].PeerIP != nm.FirewallRules[j].PeerIP { - return nm.FirewallRules[i].PeerIP < nm.FirewallRules[j].PeerIP - } - if nm.FirewallRules[i].Direction != nm.FirewallRules[j].Direction { - return nm.FirewallRules[i].Direction < nm.FirewallRules[j].Direction - } - if nm.FirewallRules[i].Protocol != nm.FirewallRules[j].Protocol { - return nm.FirewallRules[i].Protocol < nm.FirewallRules[j].Protocol - } - if nm.FirewallRules[i].Port != nm.FirewallRules[j].Port { - return nm.FirewallRules[i].Port < nm.FirewallRules[j].Port - } - return nm.FirewallRules[i].PolicyID < nm.FirewallRules[j].PolicyID - }) - - for i := range nm.RoutesFirewallRules { - sort.Strings(nm.RoutesFirewallRules[i].SourceRanges) - } - - sort.Slice(nm.RoutesFirewallRules, func(i, j int) bool { - if nm.RoutesFirewallRules[i].Destination != nm.RoutesFirewallRules[j].Destination { - return nm.RoutesFirewallRules[i].Destination < nm.RoutesFirewallRules[j].Destination - } - - minLen := len(nm.RoutesFirewallRules[i].SourceRanges) - if len(nm.RoutesFirewallRules[j].SourceRanges) < minLen { - minLen = len(nm.RoutesFirewallRules[j].SourceRanges) - } - for k := 0; k < minLen; k++ { - if nm.RoutesFirewallRules[i].SourceRanges[k] != nm.RoutesFirewallRules[j].SourceRanges[k] { - return nm.RoutesFirewallRules[i].SourceRanges[k] < nm.RoutesFirewallRules[j].SourceRanges[k] - } - } - if len(nm.RoutesFirewallRules[i].SourceRanges) != len(nm.RoutesFirewallRules[j].SourceRanges) { - return len(nm.RoutesFirewallRules[i].SourceRanges) < len(nm.RoutesFirewallRules[j].SourceRanges) - } - - if string(nm.RoutesFirewallRules[i].RouteID) != string(nm.RoutesFirewallRules[j].RouteID) { - return string(nm.RoutesFirewallRules[i].RouteID) < string(nm.RoutesFirewallRules[j].RouteID) - } - - if nm.RoutesFirewallRules[i].PolicyID != nm.RoutesFirewallRules[j].PolicyID { - return nm.RoutesFirewallRules[i].PolicyID < nm.RoutesFirewallRules[j].PolicyID - } - - if nm.RoutesFirewallRules[i].Port != nm.RoutesFirewallRules[j].Port { - return nm.RoutesFirewallRules[i].Port < nm.RoutesFirewallRules[j].Port - } - - return nm.RoutesFirewallRules[i].Protocol < nm.RoutesFirewallRules[j].Protocol - }) - - if nm.DNSConfig.CustomZones != nil { - for i := range nm.DNSConfig.CustomZones { - sort.Slice(nm.DNSConfig.CustomZones[i].Records, func(a, b int) bool { - return nm.DNSConfig.CustomZones[i].Records[a].Name < nm.DNSConfig.CustomZones[i].Records[b].Name - }) - } - } - - if len(nm.DNSConfig.NameServerGroups) != 0 { - sort.Slice(nm.DNSConfig.NameServerGroups, func(a, b int) bool { - return nm.DNSConfig.NameServerGroups[a].Name < nm.DNSConfig.NameServerGroups[b].Name - }) - } -} - -func compareNetworkMaps(t *testing.T, legacy, current *NetworkMap) { - t.Helper() - - if legacy.Network.Serial != current.Network.Serial { - t.Errorf("Network Serial mismatch: legacy=%d, current=%d", legacy.Network.Serial, current.Network.Serial) - } - - if len(legacy.Peers) != len(current.Peers) { - t.Errorf("Peers count mismatch: legacy=%d, current=%d", len(legacy.Peers), len(current.Peers)) - } - - legacyPeerIDs := make(map[string]bool) - for _, p := range legacy.Peers { - legacyPeerIDs[p.ID] = true - } - - for _, p := range current.Peers { - if !legacyPeerIDs[p.ID] { - t.Errorf("Current NetworkMap contains peer %s not in legacy", p.ID) - } - } - - if len(legacy.OfflinePeers) != len(current.OfflinePeers) { - t.Errorf("OfflinePeers count mismatch: legacy=%d, current=%d", len(legacy.OfflinePeers), len(current.OfflinePeers)) - } - - if len(legacy.FirewallRules) != len(current.FirewallRules) { - t.Logf("FirewallRules count mismatch: legacy=%d, current=%d", len(legacy.FirewallRules), len(current.FirewallRules)) - } - - if len(legacy.Routes) != len(current.Routes) { - t.Logf("Routes count mismatch: legacy=%d, current=%d", len(legacy.Routes), len(current.Routes)) - } - - if len(legacy.RoutesFirewallRules) != len(current.RoutesFirewallRules) { - t.Logf("RoutesFirewallRules count mismatch: legacy=%d, current=%d", len(legacy.RoutesFirewallRules), len(current.RoutesFirewallRules)) - } - - if legacy.DNSConfig.ServiceEnable != current.DNSConfig.ServiceEnable { - t.Errorf("DNSConfig.ServiceEnable mismatch: legacy=%v, current=%v", legacy.DNSConfig.ServiceEnable, current.DNSConfig.ServiceEnable) - } -} - -const ( - numPeers = 100 - devGroupID = "group-dev" - opsGroupID = "group-ops" - allGroupID = "group-all" - routeID = route.ID("route-main") - routeHA1ID = route.ID("route-ha-1") - routeHA2ID = route.ID("route-ha-2") - policyIDDevOps = "policy-dev-ops" - policyIDAll = "policy-all" - policyIDPosture = "policy-posture" - policyIDDrop = "policy-drop" - postureCheckID = "posture-check-ver" - networkResourceID = "res-database" - networkID = "net-database" - networkRouterID = "router-database" - nameserverGroupID = "ns-group-main" - testingPeerID = "peer-60" - expiredPeerID = "peer-98" - offlinePeerID = "peer-99" - routingPeerID = "peer-95" - testAccountID = "account-comparison-test" -) - -func createTestAccount() *Account { - peers := make(map[string]*nbpeer.Peer) - devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{} - - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - ip := net.IP{100, 64, 0, byte(i + 1)} - wtVersion := "0.25.0" - if i%2 == 0 { - wtVersion = "0.40.0" - } - - p := &nbpeer.Peer{ - ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1), - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, - UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"}, - } - - if peerID == expiredPeerID { - p.LoginExpirationEnabled = true - pastTimestamp := time.Now().Add(-2 * time.Hour) - p.LastLogin = &pastTimestamp - } - - peers[peerID] = p - allGroupPeers = append(allGroupPeers, peerID) - if i < numPeers/2 { - devGroupPeers = append(devGroupPeers, peerID) - } else { - opsGroupPeers = append(opsGroupPeers, peerID) - } - } - - groups := map[string]*Group{ - allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers}, - devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers}, - opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers}, - } - - policies := []*Policy{ - { - ID: policyIDAll, Name: "Default-Allow", Enabled: true, - Rules: []*PolicyRule{{ - ID: policyIDAll, Name: "Allow All", Enabled: true, Action: PolicyTrafficActionAccept, - Protocol: PolicyRuleProtocolALL, Bidirectional: true, - Sources: []string{allGroupID}, Destinations: []string{allGroupID}, - }}, - }, - { - ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true, - Rules: []*PolicyRule{{ - ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: PolicyTrafficActionAccept, - Protocol: PolicyRuleProtocolTCP, Bidirectional: false, - PortRanges: []RulePortRange{{Start: 8080, End: 8090}}, - Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, - }}, - }, - { - ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true, - Rules: []*PolicyRule{{ - ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: PolicyTrafficActionDrop, - Protocol: PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true, - Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, - }}, - }, - { - ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true, - SourcePostureChecks: []string{postureCheckID}, - Rules: []*PolicyRule{{ - ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: PolicyTrafficActionAccept, - Protocol: PolicyRuleProtocolALL, Bidirectional: true, - Sources: []string{opsGroupID}, DestinationResource: Resource{ID: networkResourceID}, - }}, - }, - } - - routes := map[route.ID]*route.Route{ - routeID: { - ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"), - Peer: peers["peer-75"].Key, - PeerID: "peer-75", - Description: "Route to internal resource", Enabled: true, - PeerGroups: []string{devGroupID, opsGroupID}, - Groups: []string{devGroupID, opsGroupID}, - AccessControlGroups: []string{devGroupID}, - }, - routeHA1ID: { - ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"), - Peer: peers["peer-80"].Key, - PeerID: "peer-80", - Description: "HA Route 1", Enabled: true, Metric: 1000, - PeerGroups: []string{allGroupID}, - Groups: []string{allGroupID}, - AccessControlGroups: []string{allGroupID}, - }, - routeHA2ID: { - ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"), - Peer: peers["peer-90"].Key, - PeerID: "peer-90", - Description: "HA Route 2", Enabled: true, Metric: 900, - PeerGroups: []string{devGroupID, opsGroupID}, - Groups: []string{devGroupID, opsGroupID}, - AccessControlGroups: []string{allGroupID}, - }, - } - - account := &Account{ - Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes, - Network: &Network{ - Identifier: "net-comparison-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1, - }, - DNSSettings: DNSSettings{DisabledManagementGroups: []string{opsGroupID}}, - NameServerGroups: map[string]*nbdns.NameServerGroup{ - nameserverGroupID: { - ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID}, - NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}}, - }, - }, - PostureChecks: []*posture.Checks{ - {ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"}, - }}, - }, - NetworkResources: []*resourceTypes.NetworkResource{ - {ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"}, - }, - Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}}, - NetworkRouters: []*routerTypes.NetworkRouter{ - {ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID}, - }, - Settings: &Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour}, - } - - for _, p := range account.Policies { - p.AccountID = account.Id - } - for _, r := range account.Routes { - r.AccountID = account.Id - } - - return account -} - -func BenchmarkLegacyNetworkMap(b *testing.B) { - account := createTestAccount() - ctx := context.Background() - peerID := testingPeerID - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - pid := fmt.Sprintf("peer-%d", i) - if pid != offlinePeerID { - validatedPeersMap[pid] = struct{}{} - } - } - - peersCustomZone := nbdns.CustomZone{} - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = account.GetPeerNetworkMap( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - nil, - groupIDToUserIDs, - ) - } -} - -func BenchmarkComponentsNetworkMap(b *testing.B) { - account := createTestAccount() - ctx := context.Background() - peerID := testingPeerID - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - pid := fmt.Sprintf("peer-%d", i) - if pid != offlinePeerID { - validatedPeersMap[pid] = struct{}{} - } - } - - peersCustomZone := nbdns.CustomZone{} - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - components := account.GetPeerNetworkMapComponents( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - groupIDToUserIDs, - ) - _ = CalculateNetworkMapFromComponents(ctx, components) - } -} - -func BenchmarkComponentsCreation(b *testing.B) { - account := createTestAccount() - ctx := context.Background() - peerID := testingPeerID - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - pid := fmt.Sprintf("peer-%d", i) - if pid != offlinePeerID { - validatedPeersMap[pid] = struct{}{} - } - } - - peersCustomZone := nbdns.CustomZone{} - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = account.GetPeerNetworkMapComponents( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - groupIDToUserIDs, - ) - } -} - -func BenchmarkCalculationFromComponents(b *testing.B) { - account := createTestAccount() - ctx := context.Background() - peerID := testingPeerID - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - pid := fmt.Sprintf("peer-%d", i) - if pid != offlinePeerID { - validatedPeersMap[pid] = struct{}{} - } - } - - peersCustomZone := nbdns.CustomZone{} - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - - components := account.GetPeerNetworkMapComponents( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - groupIDToUserIDs, - ) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = CalculateNetworkMapFromComponents(ctx, components) - } -} diff --git a/management/server/types/networkmap_components.go b/management/server/types/networkmap_components.go index 23d84a994..3a7e20ec5 100644 --- a/management/server/types/networkmap_components.go +++ b/management/server/types/networkmap_components.go @@ -3,7 +3,6 @@ package types import ( "context" "maps" - "net" "net/netip" "slices" "strconv" @@ -19,8 +18,6 @@ import ( "github.com/netbirdio/netbird/shared/management/domain" ) -const EnvNewNetworkMapCompacted = "NB_NETWORK_MAP_COMPACTED" - type NetworkMapComponents struct { PeerID string @@ -116,13 +113,17 @@ func (c *NetworkMapComponents) Calculate(ctx context.Context) *NetworkMap { peersToConnect, expiredPeers := c.filterPeersByLoginExpiration(aclPeers) - routesUpdate := c.getRoutesToSync(targetPeerID, peersToConnect, peerGroups) - routesFirewallRules := c.getPeerRoutesFirewallRules(ctx, targetPeerID) + includeIPv6 := false + if p := c.Peers[targetPeerID]; p != nil { + includeIPv6 = p.SupportsIPv6() && p.IPv6.IsValid() + } + routesUpdate := filterAndExpandRoutes(c.getRoutesToSync(targetPeerID, peersToConnect, peerGroups), includeIPv6) + routesFirewallRules := c.getPeerRoutesFirewallRules(ctx, targetPeerID, includeIPv6) isRouter, networkResourcesRoutes, sourcePeers := c.getNetworkResourcesRoutesToSync(targetPeerID) var networkResourcesFirewallRules []*RouteFirewallRule if isRouter { - networkResourcesFirewallRules = c.getPeerNetworkResourceFirewallRules(ctx, targetPeerID, networkResourcesRoutes) + networkResourcesFirewallRules = c.getPeerNetworkResourceFirewallRules(ctx, targetPeerID, networkResourcesRoutes, includeIPv6) } peersToConnectIncludingRouters := c.addNetworksRoutingPeers( @@ -158,7 +159,7 @@ func (c *NetworkMapComponents) Calculate(ctx context.Context) *NetworkMap { return &NetworkMap{ Peers: peersToConnectIncludingRouters, Network: c.Network.Copy(), - Routes: append(networkResourcesRoutes, routesUpdate...), + Routes: append(filterAndExpandRoutes(networkResourcesRoutes, includeIPv6), routesUpdate...), DNSConfig: dnsUpdate, OfflinePeers: expiredPeers, FirewallRules: firewallRules, @@ -298,7 +299,7 @@ func (c *NetworkMapComponents) connResourcesGenerator(targetPeer *nbpeer.Peer) ( peersExists[peer.ID] = struct{}{} } - peerIP := net.IP(peer.IP).String() + peerIP := peer.IP.String() fr := FirewallRule{ PolicyID: rule.ID, @@ -317,10 +318,17 @@ func (c *NetworkMapComponents) connResourcesGenerator(targetPeer *nbpeer.Peer) ( if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 { rules = append(rules, &fr) - continue + } else { + rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...) } - rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...) + rules = appendIPv6FirewallRule(rules, rulesExists, peer, targetPeer, rule, firewallRuleContext{ + direction: direction, + dirStr: dirStr, + protocolStr: protocolStr, + actionStr: actionStr, + portsJoined: portsJoined, + }) } }, func() ([]*nbpeer.Peer, []*FirewallRule) { return peers, rules @@ -456,6 +464,29 @@ func (c *NetworkMapComponents) peerIsNameserver(peerIPStr string, nsGroup *nbdns return false } +// filterAndExpandRoutes drops v6 routes for non-capable peers and duplicates +// the default v4 route (0.0.0.0/0) as ::/0 for v6-capable peers. +// TODO: the "-v6" suffix on IDs could collide with user-supplied route IDs. +func filterAndExpandRoutes(routes []*route.Route, includeIPv6 bool) []*route.Route { + filtered := make([]*route.Route, 0, len(routes)) + for _, r := range routes { + if !includeIPv6 && r.Network.Addr().Is6() { + continue + } + filtered = append(filtered, r) + + if includeIPv6 && r.Network.Bits() == 0 && r.Network.Addr().Is4() { + v6 := r.Copy() + v6.ID = r.ID + "-v6-default" + v6.NetID = r.NetID + "-v6" + v6.Network = netip.MustParsePrefix("::/0") + v6.NetworkType = route.IPv6Network + filtered = append(filtered, v6) + } + } + return filtered +} + func (c *NetworkMapComponents) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route { routes, peerDisabledRoutes := c.getRoutingPeerRoutes(peerID) peerRoutesMembership := make(LookupMap) @@ -552,13 +583,13 @@ func (c *NetworkMapComponents) filterRoutesFromPeersOfSameHAGroup(routes []*rout return filteredRoutes } -func (c *NetworkMapComponents) getPeerRoutesFirewallRules(ctx context.Context, peerID string) []*RouteFirewallRule { +func (c *NetworkMapComponents) getPeerRoutesFirewallRules(ctx context.Context, peerID string, includeIPv6 bool) []*RouteFirewallRule { routesFirewallRules := make([]*RouteFirewallRule, 0) enabledRoutes, _ := c.getRoutingPeerRoutes(peerID) for _, r := range enabledRoutes { if len(r.AccessControlGroups) == 0 { - defaultPermit := c.getDefaultPermit(r) + defaultPermit := c.getDefaultPermit(r, includeIPv6) routesFirewallRules = append(routesFirewallRules, defaultPermit...) continue } @@ -567,7 +598,7 @@ func (c *NetworkMapComponents) getPeerRoutesFirewallRules(ctx context.Context, p for _, accessGroup := range r.AccessControlGroups { policies := c.getAllRoutePoliciesFromGroups([]string{accessGroup}) - rules := c.getRouteFirewallRules(ctx, peerID, policies, r, distributionPeers) + rules := c.getRouteFirewallRules(ctx, peerID, policies, r, distributionPeers, includeIPv6) routesFirewallRules = append(routesFirewallRules, rules...) } } @@ -575,8 +606,10 @@ func (c *NetworkMapComponents) getPeerRoutesFirewallRules(ctx context.Context, p return routesFirewallRules } -func (c *NetworkMapComponents) getDefaultPermit(r *route.Route) []*RouteFirewallRule { - var rules []*RouteFirewallRule +func (c *NetworkMapComponents) getDefaultPermit(r *route.Route, includeIPv6 bool) []*RouteFirewallRule { + if r.Network.Addr().Is6() && !includeIPv6 { + return nil + } sources := []string{"0.0.0.0/0"} if r.Network.Addr().Is6() { @@ -593,9 +626,9 @@ func (c *NetworkMapComponents) getDefaultPermit(r *route.Route) []*RouteFirewall RouteID: r.ID, } - rules = append(rules, &rule) + rules := []*RouteFirewallRule{&rule} - if r.IsDynamic() { + if includeIPv6 && r.IsDynamic() { ruleV6 := rule ruleV6.SourceRanges = []string{"::/0"} rules = append(rules, &ruleV6) @@ -634,7 +667,7 @@ func (c *NetworkMapComponents) getAllRoutePoliciesFromGroups(accessControlGroups return routePolicies } -func (c *NetworkMapComponents) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, distributionPeers map[string]struct{}) []*RouteFirewallRule { +func (c *NetworkMapComponents) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, distributionPeers map[string]struct{}, includeIPv6 bool) []*RouteFirewallRule { var fwRules []*RouteFirewallRule for _, policy := range policies { if !policy.Enabled { @@ -647,7 +680,7 @@ func (c *NetworkMapComponents) getRouteFirewallRules(ctx context.Context, peerID } rulePeers := c.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers) - rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) + rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN, includeIPv6) fwRules = append(fwRules, rules...) } } @@ -712,33 +745,49 @@ func (c *NetworkMapComponents) getNetworkResourcesRoutesToSync(peerID string) (b } } - addedResourceRoute := false - for _, policy := range c.ResourcePoliciesMap[resource.ID] { - var peers []string - if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { - peers = []string{policy.Rules[0].SourceResource.ID} - } else { - peers = c.getUniquePeerIDsFromGroupsIDs(policy.SourceGroups()) - } - if addSourcePeers { - for _, pID := range c.getPostureValidPeers(peers, policy.SourcePostureChecks) { - allSourcePeers[pID] = struct{}{} - } - } else if slices.Contains(peers, peerID) && c.ValidatePostureChecksOnPeer(peerID, policy.SourcePostureChecks) { - for peerId, router := range networkRoutingPeers { - routes = append(routes, c.getNetworkResourcesRoutes(resource, peerId, router)...) - } - addedResourceRoute = true - } - if addedResourceRoute { - break - } - } + newRoutes := c.processResourcePolicies(peerID, resource, networkRoutingPeers, addSourcePeers, allSourcePeers) + routes = append(routes, newRoutes...) } return isRoutingPeer, routes, allSourcePeers } +func (c *NetworkMapComponents) processResourcePolicies( + peerID string, + resource *resourceTypes.NetworkResource, + networkRoutingPeers map[string]*routerTypes.NetworkRouter, + addSourcePeers bool, + allSourcePeers map[string]struct{}, +) []*route.Route { + var routes []*route.Route + + for _, policy := range c.ResourcePoliciesMap[resource.ID] { + peers := c.getResourcePolicyPeers(policy) + if addSourcePeers { + for _, pID := range c.getPostureValidPeers(peers, policy.SourcePostureChecks) { + allSourcePeers[pID] = struct{}{} + } + continue + } + + if slices.Contains(peers, peerID) && c.ValidatePostureChecksOnPeer(peerID, policy.SourcePostureChecks) { + for peerId, router := range networkRoutingPeers { + routes = append(routes, c.getNetworkResourcesRoutes(resource, peerId, router)...) + } + break + } + } + + return routes +} + +func (c *NetworkMapComponents) getResourcePolicyPeers(policy *Policy) []string { + if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { + return []string{policy.Rules[0].SourceResource.ID} + } + return c.getUniquePeerIDsFromGroupsIDs(policy.SourceGroups()) +} + func (c *NetworkMapComponents) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerID string, router *routerTypes.NetworkRouter) []*route.Route { resourceAppliedPolicies := c.ResourcePoliciesMap[resource.ID] @@ -798,7 +847,7 @@ func (c *NetworkMapComponents) getPostureValidPeers(inputPeers []string, posture return dest } -func (c *NetworkMapComponents) getPeerNetworkResourceFirewallRules(ctx context.Context, peerID string, routes []*route.Route) []*RouteFirewallRule { +func (c *NetworkMapComponents) getPeerNetworkResourceFirewallRules(ctx context.Context, peerID string, routes []*route.Route, includeIPv6 bool) []*RouteFirewallRule { routesFirewallRules := make([]*RouteFirewallRule, 0) peerInfo := c.GetPeerInfo(peerID) @@ -815,7 +864,7 @@ func (c *NetworkMapComponents) getPeerNetworkResourceFirewallRules(ctx context.C resourcePolicies := c.ResourcePoliciesMap[resourceID] distributionPeers := c.getPoliciesSourcePeers(resourcePolicies) - rules := c.getRouteFirewallRules(ctx, peerID, resourcePolicies, r, distributionPeers) + rules := c.getRouteFirewallRules(ctx, peerID, resourcePolicies, r, distributionPeers, includeIPv6) for _, rule := range rules { if len(rule.SourceRanges) > 0 { routesFirewallRules = append(routesFirewallRules, rule) @@ -899,3 +948,36 @@ func (c *NetworkMapComponents) addNetworksRoutingPeers( return peersToConnect } + +type firewallRuleContext struct { + direction int + dirStr string + protocolStr string + actionStr string + portsJoined string +} + +func appendIPv6FirewallRule(rules []*FirewallRule, rulesExists map[string]struct{}, peer, targetPeer *nbpeer.Peer, rule *PolicyRule, rc firewallRuleContext) []*FirewallRule { + if !peer.IPv6.IsValid() || !targetPeer.SupportsIPv6() || !targetPeer.IPv6.IsValid() { + return rules + } + + v6IP := peer.IPv6.String() + v6RuleID := rule.ID + v6IP + rc.dirStr + rc.protocolStr + rc.actionStr + rc.portsJoined + if _, ok := rulesExists[v6RuleID]; ok { + return rules + } + rulesExists[v6RuleID] = struct{}{} + + v6fr := FirewallRule{ + PolicyID: rule.ID, + PeerIP: v6IP, + Direction: rc.direction, + Action: rc.actionStr, + Protocol: rc.protocolStr, + } + if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 { + return append(rules, &v6fr) + } + return append(rules, expandPortsAndRanges(v6fr, rule, targetPeer)...) +} diff --git a/management/server/types/networkmap_components_correctness_test.go b/management/server/types/networkmap_components_correctness_test.go index 5cd41ff10..bcfb6fdf9 100644 --- a/management/server/types/networkmap_components_correctness_test.go +++ b/management/server/types/networkmap_components_correctness_test.go @@ -42,7 +42,7 @@ func buildScalableTestAccount(numPeers, numGroups int, withDefaultPolicy bool) ( for i := range numPeers { peerID := fmt.Sprintf("peer-%d", i) - ip := net.IP{100, byte(64 + i/65536), byte((i / 256) % 256), byte(i % 256)} + ip := netip.AddrFrom4([4]byte{100, byte(64 + i/65536), byte((i / 256) % 256), byte(i % 256)}) wtVersion := "0.25.0" if i%2 == 0 { wtVersion = "0.40.0" @@ -1083,7 +1083,7 @@ func TestComponents_PeerIsNameserverExcludedFromNSGroup(t *testing.T) { nsIP := account.Peers["peer-0"].IP account.NameServerGroups["ns-self"] = &nbdns.NameServerGroup{ ID: "ns-self", Name: "Self NS", Enabled: true, Groups: []string{"group-all"}, - NameServers: []nbdns.NameServer{{IP: netip.AddrFrom4([4]byte{nsIP[0], nsIP[1], nsIP[2], nsIP[3]}), NSType: nbdns.UDPNameServerType, Port: 53}}, + NameServers: []nbdns.NameServer{{IP: nsIP, NSType: nbdns.UDPNameServerType, Port: 53}}, } nm := componentsNetworkMap(account, "peer-0", validatedPeers) diff --git a/management/server/types/networkmap_components_test.go b/management/server/types/networkmap_components_test.go new file mode 100644 index 000000000..1a99b4511 --- /dev/null +++ b/management/server/types/networkmap_components_test.go @@ -0,0 +1,787 @@ +package types_test + +import ( + "context" + "net" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" +) + +func networkMapFromComponents(t *testing.T, account *types.Account, peerID string, validatedPeers map[string]struct{}) *types.NetworkMap { + t.Helper() + return account.GetPeerNetworkMapFromComponents( + context.Background(), + peerID, + account.GetPeersCustomZone(context.Background(), "netbird.io"), + nil, + validatedPeers, + account.GetResourcePoliciesMap(), + account.GetResourceRoutersMap(), + nil, + account.GetActiveGroupUsers(), + ) +} + +func allPeersValidated(account *types.Account, excludePeerIDs ...string) map[string]struct{} { + excludeSet := make(map[string]struct{}, len(excludePeerIDs)) + for _, id := range excludePeerIDs { + excludeSet[id] = struct{}{} + } + validated := make(map[string]struct{}, len(account.Peers)) + for id := range account.Peers { + if _, excluded := excludeSet[id]; !excluded { + validated[id] = struct{}{} + } + } + return validated +} + +func peerIDs(peers []*nbpeer.Peer) []string { + ids := make([]string, len(peers)) + for i, p := range peers { + ids[i] = p.ID + } + return ids +} + +func TestNetworkMapComponents_RegularPeerConnectivity(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + assert.NotNil(t, nm) + assert.Contains(t, peerIDs(nm.Peers), "peer-dst-1", "should see peer from destination group via bidirectional policy") + assert.Contains(t, peerIDs(nm.Peers), "peer-router-1", "should see router peer via resource policy") + assert.NotContains(t, peerIDs(nm.Peers), "peer-src-1", "should not see itself") + assert.Empty(t, nm.OfflinePeers, "no expired peers expected") +} + +func TestNetworkMapComponents_IntraGroupConnectivity(t *testing.T) { + account := createComponentTestAccount() + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-intra-src", Name: "Intra-source connectivity", Enabled: true, AccountID: account.Id, + Rules: []*types.PolicyRule{{ + ID: "rule-intra-src", Name: "src <-> src", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Bidirectional: true, + Sources: []string{"group-src"}, Destinations: []string{"group-src"}, + }}, + }) + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + assert.Contains(t, peerIDs(nm.Peers), "peer-src-2", "should see peer from same group with intra-group policy") +} + +func TestNetworkMapComponents_FirewallRules(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + require.NotEmpty(t, nm.FirewallRules, "firewall rules should be generated") + + var hasAcceptAll bool + for _, rule := range nm.FirewallRules { + if rule.Protocol == string(types.PolicyRuleProtocolALL) && rule.Action == string(types.PolicyTrafficActionAccept) { + hasAcceptAll = true + } + } + assert.True(t, hasAcceptAll, "should have an accept-all firewall rule from the base policy") +} + +func TestNetworkMapComponents_LoginExpiration(t *testing.T) { + account := createComponentTestAccount() + account.Settings.PeerLoginExpirationEnabled = true + account.Settings.PeerLoginExpiration = 1 * time.Hour + + expiredTime := time.Now().Add(-2 * time.Hour) + account.Peers["peer-dst-1"].LoginExpirationEnabled = true + account.Peers["peer-dst-1"].LastLogin = &expiredTime + + validated := allPeersValidated(account) + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + assert.Contains(t, peerIDs(nm.OfflinePeers), "peer-dst-1", "expired peer should be in OfflinePeers") + assert.NotContains(t, peerIDs(nm.Peers), "peer-dst-1", "expired peer should NOT be in active Peers") +} + +func TestNetworkMapComponents_InvalidatedPeerExcluded(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account, "peer-dst-1") + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + assert.NotContains(t, peerIDs(nm.Peers), "peer-dst-1", "non-validated peer should be excluded") + assert.NotContains(t, peerIDs(nm.OfflinePeers), "peer-dst-1", "non-validated peer should not be in offline peers either") +} + +func TestNetworkMapComponents_NonValidatedTargetPeer(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account, "peer-src-1") + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + assert.Empty(t, nm.Peers, "non-validated target peer should get empty network map") + assert.Empty(t, nm.FirewallRules) +} + +func TestNetworkMapComponents_NetworkResourceRoutes_SourcePeer(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var hasResourceRoute bool + for _, r := range nm.Routes { + if r.Network.String() == "10.200.0.1/32" { + hasResourceRoute = true + break + } + } + assert.True(t, hasResourceRoute, "source peer should receive route to network resource via router") + assert.Contains(t, peerIDs(nm.Peers), "peer-router-1", "source peer should see the routing peer") +} + +func TestNetworkMapComponents_NetworkResourceRoutes_RouterPeer(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-router-1", validated) + + var hasResourceRoute bool + for _, r := range nm.Routes { + if r.Network.String() == "10.200.0.1/32" { + hasResourceRoute = true + break + } + } + assert.True(t, hasResourceRoute, "router peer should receive network resource route") + assert.NotEmpty(t, nm.RoutesFirewallRules, "router peer should have route firewall rules for the resource") +} + +func TestNetworkMapComponents_NetworkResourceRoutes_UnrelatedPeer(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-dst-1", validated) + + for _, r := range nm.Routes { + assert.NotEqual(t, "10.200.0.1/32", r.Network.String(), "unrelated peer should not receive network resource route") + } +} + +func TestNetworkMapComponents_NetworkResource_WithPostureCheck(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.PostureChecks = []*posture.Checks{ + {ID: "pc-version", Name: "Version check", Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"}, + }}, + } + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-posture-resource", Name: "Posture resource access", Enabled: true, AccountID: account.Id, + SourcePostureChecks: []string{"pc-version"}, + Rules: []*types.PolicyRule{{ + ID: "rule-posture-resource", Name: "Posture -> Resource", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Sources: []string{"group-src"}, + DestinationResource: types.Resource{ID: "resource-guarded"}, + }}, + }) + + account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{ + ID: "resource-guarded", NetworkID: "net-guarded", AccountID: account.Id, Enabled: true, + Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.1.1/32"), Address: "10.200.1.1/32", + }) + account.Networks = append(account.Networks, &networkTypes.Network{ + ID: "net-guarded", Name: "Guarded Net", AccountID: account.Id, + }) + account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{ + ID: "router-guarded", NetworkID: "net-guarded", Peer: "peer-router-1", Enabled: true, AccountID: account.Id, + }) + + t.Run("peer passes posture check", func(t *testing.T) { + account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0" + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var hasGuardedRoute bool + for _, r := range nm.Routes { + if r.Network.String() == "10.200.1.1/32" { + hasGuardedRoute = true + } + } + assert.True(t, hasGuardedRoute, "peer passing posture check should get guarded resource route") + }) + + t.Run("peer fails posture check", func(t *testing.T) { + account.Peers["peer-src-1"].Meta.WtVersion = "0.20.0" + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + for _, r := range nm.Routes { + assert.NotEqual(t, "10.200.1.1/32", r.Network.String(), "peer failing posture check should NOT get guarded resource route") + } + }) +} + +func TestNetworkMapComponents_NetworkResource_MultiplePostureChecks(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.PostureChecks = []*posture.Checks{ + {ID: "pc-version", Name: "Version", Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"}, + }}, + {ID: "pc-os", Name: "OS check", Checks: posture.ChecksDefinition{ + OSVersionCheck: &posture.OSVersionCheck{Linux: &posture.MinKernelVersionCheck{MinKernelVersion: "5.0"}}, + }}, + } + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-multi-posture", Name: "Multi posture", Enabled: true, AccountID: account.Id, + SourcePostureChecks: []string{"pc-version", "pc-os"}, + Rules: []*types.PolicyRule{{ + ID: "rule-multi-posture", Name: "Multi posture rule", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Sources: []string{"group-src"}, + DestinationResource: types.Resource{ID: "resource-strict"}, + }}, + }) + + account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{ + ID: "resource-strict", NetworkID: "net-strict", AccountID: account.Id, Enabled: true, + Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.2.1/32"), Address: "10.200.2.1/32", + }) + account.Networks = append(account.Networks, &networkTypes.Network{ + ID: "net-strict", Name: "Strict Net", AccountID: account.Id, + }) + account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{ + ID: "router-strict", NetworkID: "net-strict", Peer: "peer-router-1", Enabled: true, AccountID: account.Id, + }) + + t.Run("passes both posture checks", func(t *testing.T) { + account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0" + account.Peers["peer-src-1"].Meta.GoOS = "linux" + account.Peers["peer-src-1"].Meta.KernelVersion = "6.1.0" + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var found bool + for _, r := range nm.Routes { + if r.Network.String() == "10.200.2.1/32" { + found = true + } + } + assert.True(t, found, "peer passing both checks should get resource route") + }) + + t.Run("fails version posture check", func(t *testing.T) { + account.Peers["peer-src-1"].Meta.WtVersion = "0.20.0" + account.Peers["peer-src-1"].Meta.KernelVersion = "6.1.0" + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + for _, r := range nm.Routes { + assert.NotEqual(t, "10.200.2.1/32", r.Network.String(), "peer failing version check should NOT get resource route") + } + }) + + t.Run("fails OS posture check", func(t *testing.T) { + account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0" + account.Peers["peer-src-1"].Meta.KernelVersion = "4.0.0" + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + for _, r := range nm.Routes { + assert.NotEqual(t, "10.200.2.1/32", r.Network.String(), "peer failing OS check should NOT get resource route") + } + }) +} + +func TestNetworkMapComponents_RouterPeerFirewallRules(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-router-1", validated) + + var resourceFWRules []*types.RouteFirewallRule + for _, rule := range nm.RoutesFirewallRules { + if rule.Destination == "10.200.0.1/32" { + resourceFWRules = append(resourceFWRules, rule) + } + } + assert.NotEmpty(t, resourceFWRules, "router should have firewall rules for the network resource") + + var hasSourcePeerIP bool + for _, rule := range resourceFWRules { + for _, sr := range rule.SourceRanges { + if sr == account.Peers["peer-src-1"].IP.String()+"/32" || sr == account.Peers["peer-src-2"].IP.String()+"/32" { + hasSourcePeerIP = true + } + } + } + assert.True(t, hasSourcePeerIP, "resource firewall rules should include source peer IPs") +} + +func TestNetworkMapComponents_DNSManagement(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + t.Run("peer in DNS-enabled group", func(t *testing.T) { + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + assert.True(t, nm.DNSConfig.ServiceEnable, "peer in non-disabled group should have DNS enabled") + }) + + t.Run("peer in DNS-disabled group", func(t *testing.T) { + nm := networkMapFromComponents(t, account, "peer-dst-1", validated) + assert.False(t, nm.DNSConfig.ServiceEnable, "peer in DNS-disabled group should have DNS disabled") + }) +} + +func TestNetworkMapComponents_NameServerGroups(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + assert.True(t, nm.DNSConfig.ServiceEnable) + + var hasNSGroup bool + for _, ns := range nm.DNSConfig.NameServerGroups { + if ns.ID == "ns-main" { + hasNSGroup = true + } + } + assert.True(t, hasNSGroup, "peer in NS group should receive nameserver configuration") +} + +func TestNetworkMapComponents_RoutesWithHADeduplication(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.Routes["route-ha-1"] = &route.Route{ + ID: "route-ha-1", Network: netip.MustParsePrefix("172.16.0.0/16"), + Peer: account.Peers["peer-dst-1"].Key, PeerID: "peer-dst-1", + Enabled: true, Metric: 100, AccountID: account.Id, + Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-dst"}, + } + account.Routes["route-ha-2"] = &route.Route{ + ID: "route-ha-2", Network: netip.MustParsePrefix("172.16.0.0/16"), + Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1", + Enabled: true, Metric: 200, AccountID: account.Id, + Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-src"}, + } + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + haCount := 0 + for _, r := range nm.Routes { + if r.Network.String() == "172.16.0.0/16" { + haCount++ + } + } + assert.Equal(t, 1, haCount, "peer should only receive one route from HA group (not both, since it's a member of one)") +} + +func TestNetworkMapComponents_RoutesFirewallRulesForAccessControl(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.Routes["route-acl"] = &route.Route{ + ID: "route-acl", Network: netip.MustParsePrefix("192.168.100.0/24"), + Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1", + Enabled: true, Metric: 100, AccountID: account.Id, + Groups: []string{"group-dst"}, + PeerGroups: []string{"group-src"}, + AccessControlGroups: []string{"group-dst"}, + } + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var hasFWRule bool + for _, rule := range nm.RoutesFirewallRules { + if rule.Destination == "192.168.100.0/24" { + hasFWRule = true + } + } + assert.True(t, hasFWRule, "routing peer should have firewall rules for route with access control groups") +} + +func TestNetworkMapComponents_RoutesDefaultPermit(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.Routes["route-open"] = &route.Route{ + ID: "route-open", Network: netip.MustParsePrefix("10.99.0.0/16"), + Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1", + Enabled: true, Metric: 100, AccountID: account.Id, + Groups: []string{"group-src"}, + PeerGroups: []string{"group-src"}, + } + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var hasFWRule bool + for _, rule := range nm.RoutesFirewallRules { + if rule.Destination == "10.99.0.0/16" { + hasFWRule = true + } + } + assert.True(t, hasFWRule, "route without access control groups should have default permit firewall rules") +} + +func TestNetworkMapComponents_SSHAuthorizedUsers(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.Peers["peer-dst-1"].SSHEnabled = true + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-ssh", Name: "SSH Access", Enabled: true, AccountID: account.Id, + Rules: []*types.PolicyRule{{ + ID: "rule-ssh", Name: "SSH to dst", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Bidirectional: true, + Sources: []string{"group-src"}, Destinations: []string{"group-dst"}, + }}, + }) + + nm := networkMapFromComponents(t, account, "peer-dst-1", validated) + assert.True(t, nm.EnableSSH, "SSH-enabled peer with matching policy should have EnableSSH") +} + +func TestNetworkMapComponents_DisabledPolicyIgnored(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + for _, p := range account.Policies { + p.Enabled = false + } + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + assert.Empty(t, nm.Peers, "with all policies disabled, peer should see no other peers") + assert.Empty(t, nm.FirewallRules) +} + +func TestNetworkMapComponents_DisabledRouteIgnored(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + for _, r := range account.Routes { + r.Enabled = false + } + for _, r := range account.NetworkResources { + r.Enabled = false + } + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + assert.Empty(t, nm.Routes, "disabled routes should not appear in network map") +} + +func TestNetworkMapComponents_DisabledNetworkResourceIgnored(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + for _, r := range account.NetworkResources { + r.Enabled = false + } + + nm := networkMapFromComponents(t, account, "peer-router-1", validated) + + for _, r := range nm.Routes { + assert.NotEqual(t, "10.200.0.1/32", r.Network.String(), "disabled resource should not generate routes") + } +} + +func TestNetworkMapComponents_BidirectionalPolicy(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nmSrc := networkMapFromComponents(t, account, "peer-src-1", validated) + nmDst := networkMapFromComponents(t, account, "peer-dst-1", validated) + + assert.Contains(t, peerIDs(nmSrc.Peers), "peer-dst-1", "src should see dst via bidirectional policy") + assert.Contains(t, peerIDs(nmDst.Peers), "peer-src-1", "dst should see src via bidirectional policy") +} + +func TestNetworkMapComponents_DropPolicy(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-drop", Name: "Drop traffic", Enabled: true, AccountID: account.Id, + Rules: []*types.PolicyRule{{ + ID: "rule-drop", Name: "Drop src->dst", Enabled: true, + Action: types.PolicyTrafficActionDrop, Protocol: types.PolicyRuleProtocolTCP, + Ports: []string{"5432"}, + Sources: []string{"group-src"}, Destinations: []string{"group-dst"}, + }}, + }) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var hasDropRule bool + for _, rule := range nm.FirewallRules { + if rule.Action == string(types.PolicyTrafficActionDrop) && rule.Port == "5432" { + hasDropRule = true + } + } + assert.True(t, hasDropRule, "drop policy should generate drop firewall rule") +} + +func TestNetworkMapComponents_PortRangePolicy(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.Peers["peer-src-1"].Meta.WtVersion = "0.50.0" + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-range", Name: "Port range", Enabled: true, AccountID: account.Id, + Rules: []*types.PolicyRule{{ + ID: "rule-range", Name: "Range rule", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP, + PortRanges: []types.RulePortRange{{Start: 8080, End: 8090}}, + Sources: []string{"group-src"}, Destinations: []string{"group-dst"}, + }}, + }) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var hasRangeRule bool + for _, rule := range nm.FirewallRules { + if rule.PortRange.Start == 8080 && rule.PortRange.End == 8090 { + hasRangeRule = true + } + } + assert.True(t, hasRangeRule, "port range policy should generate corresponding firewall rule") +} + +func TestNetworkMapComponents_MultipleNetworkResources(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{ + ID: "resource-2", NetworkID: "net-1", AccountID: account.Id, Enabled: true, + Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.0.2/32"), Address: "10.200.0.2/32", + }) + account.Groups["group-res2"] = &types.Group{ID: "group-res2", Name: "Resource 2 Group", Peers: []string{"peer-src-1", "peer-src-2"}, + Resources: []types.Resource{{ID: "resource-2"}}, + } + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-res2", Name: "Resource 2 Policy", Enabled: true, AccountID: account.Id, + Rules: []*types.PolicyRule{{ + ID: "rule-res2", Name: "Access Resource 2", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Sources: []string{"group-src"}, + DestinationResource: types.Resource{ID: "resource-2"}, + }}, + }) + + nm := networkMapFromComponents(t, account, "peer-router-1", validated) + + resourceRouteCount := 0 + for _, r := range nm.Routes { + if r.Network.String() == "10.200.0.1/32" || r.Network.String() == "10.200.0.2/32" { + resourceRouteCount++ + } + } + assert.Equal(t, 2, resourceRouteCount, "router should have routes for both network resources") +} + +func TestNetworkMapComponents_DomainNetworkResource(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{ + ID: "resource-domain", NetworkID: "net-1", AccountID: account.Id, Enabled: true, + Type: resourceTypes.Domain, Domain: "api.example.com", Address: "api.example.com", + }) + account.Groups["group-res-domain"] = &types.Group{ + ID: "group-res-domain", Name: "Domain Resource Group", + Resources: []types.Resource{{ID: "resource-domain"}}, + } + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-domain", Name: "Domain resource policy", Enabled: true, AccountID: account.Id, + Rules: []*types.PolicyRule{{ + ID: "rule-domain", Name: "Access domain resource", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Sources: []string{"group-src"}, + DestinationResource: types.Resource{ID: "resource-domain"}, + }}, + }) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var hasDomainRoute bool + for _, r := range nm.Routes { + if r.NetworkType == route.DomainNetwork && len(r.Domains) > 0 && r.Domains[0].SafeString() == "api.example.com" { + hasDomainRoute = true + } + } + assert.True(t, hasDomainRoute, "source peer should receive domain route for domain network resource") +} + +func TestNetworkMapComponents_NetworkEmpty(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "nonexistent-peer", validated) + + assert.NotNil(t, nm) + assert.Empty(t, nm.Peers) + assert.Empty(t, nm.FirewallRules) + assert.NotNil(t, nm.Network) +} + +func TestNetworkMapComponents_RouterExcludesOtherNetworkRoutes(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{ + ID: "resource-other", NetworkID: "net-other", AccountID: account.Id, Enabled: true, + Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.99.1/32"), Address: "10.200.99.1/32", + }) + account.Networks = append(account.Networks, &networkTypes.Network{ + ID: "net-other", Name: "Other Net", AccountID: account.Id, + }) + account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{ + ID: "router-other", NetworkID: "net-other", Peer: "peer-dst-1", Enabled: true, AccountID: account.Id, + }) + account.Groups["group-res-other"] = &types.Group{ID: "group-res-other", Name: "Other resource group", + Resources: []types.Resource{{ID: "resource-other"}}, + } + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-other-resource", Name: "Other resource policy", Enabled: true, AccountID: account.Id, + Rules: []*types.PolicyRule{{ + ID: "rule-other", Name: "Other resource access", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Sources: []string{"group-src"}, + DestinationResource: types.Resource{ID: "resource-other"}, + }}, + }) + + nm := networkMapFromComponents(t, account, "peer-router-1", validated) + + for _, r := range nm.Routes { + assert.NotEqual(t, "10.200.99.1/32", r.Network.String(), "router-1 should NOT get routes for other network's resources") + } +} + +func createComponentTestAccount() *types.Account { + peers := map[string]*nbpeer.Peer{ + "peer-src-1": { + ID: "peer-src-1", IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), Key: "key-src-1", DNSLabel: "src1", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"}, + }, + "peer-src-2": { + ID: "peer-src-2", IP: netip.AddrFrom4([4]byte{100, 64, 0, 2}), Key: "key-src-2", DNSLabel: "src2", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"}, + }, + "peer-dst-1": { + ID: "peer-dst-1", IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}), Key: "key-dst-1", DNSLabel: "dst1", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-2", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"}, + }, + "peer-router-1": { + ID: "peer-router-1", IP: netip.AddrFrom4([4]byte{100, 64, 0, 10}), Key: "key-router-1", DNSLabel: "router1", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"}, + }, + } + + groups := map[string]*types.Group{ + "group-src": {ID: "group-src", Name: "Sources", Peers: []string{"peer-src-1", "peer-src-2"}}, + "group-dst": {ID: "group-dst", Name: "Destinations", Peers: []string{"peer-dst-1"}}, + "group-all": {ID: "group-all", Name: "All", Peers: []string{"peer-src-1", "peer-src-2", "peer-dst-1", "peer-router-1"}}, + "group-res": { + ID: "group-res", Name: "Resource Group", + Resources: []types.Resource{{ID: "resource-1"}}, + }, + } + + policies := []*types.Policy{ + { + ID: "policy-base", Name: "Base connectivity", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: "rule-base", Name: "Allow src <-> dst", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Bidirectional: true, + Sources: []string{"group-src"}, Destinations: []string{"group-dst"}, + }}, + }, + { + ID: "policy-resource", Name: "Network resource access", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: "rule-resource", Name: "Source -> Resource", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Sources: []string{"group-src"}, + DestinationResource: types.Resource{ID: "resource-1"}, + }}, + }, + } + + routes := map[route.ID]*route.Route{ + "route-main": { + ID: "route-main", Network: netip.MustParsePrefix("192.168.10.0/24"), + Peer: peers["peer-dst-1"].Key, PeerID: "peer-dst-1", + Enabled: true, Metric: 100, + Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-dst"}, + }, + } + + users := map[string]*types.User{ + "user-1": {Id: "user-1", Role: types.UserRoleAdmin, IsServiceUser: false, AutoGroups: []string{"group-all"}}, + "user-2": {Id: "user-2", Role: types.UserRoleUser, IsServiceUser: false, AutoGroups: []string{"group-all"}}, + } + + account := &types.Account{ + Id: "account-components-test", Peers: peers, Groups: groups, Policies: policies, Routes: routes, + Users: users, + Network: &types.Network{ + Identifier: "net-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1, + }, + DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{"group-dst"}}, + NameServerGroups: map[string]*nbdns.NameServerGroup{ + "ns-main": { + ID: "ns-main", Name: "Main NS", Enabled: true, Groups: []string{"group-src"}, + NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}}, + }, + }, + PostureChecks: []*posture.Checks{}, + NetworkResources: []*resourceTypes.NetworkResource{ + { + ID: "resource-1", NetworkID: "net-1", AccountID: "account-components-test", Enabled: true, + Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.0.1/32"), Address: "10.200.0.1/32", + }, + }, + Networks: []*networkTypes.Network{ + {ID: "net-1", Name: "Resource Net", AccountID: "account-components-test"}, + }, + NetworkRouters: []*routerTypes.NetworkRouter{ + {ID: "router-1", NetworkID: "net-1", Peer: "peer-router-1", Enabled: true, AccountID: "account-components-test"}, + }, + Settings: &types.Settings{PeerLoginExpirationEnabled: false, PeerLoginExpiration: 24 * time.Hour}, + } + + for _, p := range account.Policies { + p.AccountID = account.Id + } + for _, r := range account.Routes { + r.AccountID = account.Id + } + + return account +} diff --git a/management/server/types/networkmap_golden_test.go b/management/server/types/networkmap_golden_test.go deleted file mode 100644 index 53261f22d..000000000 --- a/management/server/types/networkmap_golden_test.go +++ /dev/null @@ -1,967 +0,0 @@ -package types_test - -import ( - "context" - "encoding/json" - "fmt" - "net" - "net/netip" - "os" - "path/filepath" - "slices" - "sort" - "testing" - "time" - - "github.com/stretchr/testify/require" - - "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/modules/zones" - resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" - routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" - networkTypes "github.com/netbirdio/netbird/management/server/networks/types" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/route" -) - -const ( - numPeers = 100 - devGroupID = "group-dev" - opsGroupID = "group-ops" - allGroupID = "group-all" - sshUsersGroupID = "group-ssh-users" - routeID = route.ID("route-main") - routeHA1ID = route.ID("route-ha-1") - routeHA2ID = route.ID("route-ha-2") - policyIDDevOps = "policy-dev-ops" - policyIDAll = "policy-all" - policyIDPosture = "policy-posture" - policyIDDrop = "policy-drop" - policyIDSSH = "policy-ssh" - postureCheckID = "posture-check-ver" - networkResourceID = "res-database" - networkID = "net-database" - networkRouterID = "router-database" - nameserverGroupID = "ns-group-main" - testingPeerID = "peer-60" // A peer from the "dev" group, should receive the most detailed map. - expiredPeerID = "peer-98" // This peer will be online but with an expired session. - offlinePeerID = "peer-99" // This peer will be completely offline. - routingPeerID = "peer-95" // This peer is used for routing, it has a route to the network. - testAccountID = "account-golden-test" - userAdminID = "user-admin" - userDevID = "user-dev" - userOpsID = "user-ops" -) - -func TestGetPeerNetworkMap_Golden(t *testing.T) { - account := createTestAccountWithEntities() - - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - if peerID == offlinePeerID { - continue - } - validatedPeersMap[peerID] = struct{}{} - } - - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - - legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) - normalizeAndSortNetworkMap(legacyNetworkMap) - legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ") - require.NoError(t, err, "error marshaling legacy network map to JSON") - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - normalizeAndSortNetworkMap(newNetworkMap) - newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ") - require.NoError(t, err, "error marshaling new network map to JSON") - - if string(legacyJSON) != string(newJSON) { - legacyFilePath := filepath.Join("testdata", "networkmap_golden.json") - newFilePath := filepath.Join("testdata", "networkmap_golden_new.json") - - err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755) - require.NoError(t, err) - - err = os.WriteFile(legacyFilePath, legacyJSON, 0644) - require.NoError(t, err) - t.Logf("Saved legacy network map to %s", legacyFilePath) - - err = os.WriteFile(newFilePath, newJSON, 0644) - require.NoError(t, err) - t.Logf("Saved new network map to %s", newFilePath) - - require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps from legacy and new builder do not match") - } -} - -func BenchmarkGetPeerNetworkMap(b *testing.B) { - account := createTestAccountWithEntities() - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - var peerIDs []string - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - validatedPeersMap[peerID] = struct{}{} - peerIDs = append(peerIDs, peerID) - } - - b.ResetTimer() - b.Run("old builder", func(b *testing.B) { - for range b.N { - for _, peerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) - } - } - }) - b.ResetTimer() - b.Run("new builder", func(b *testing.B) { - for range b.N { - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - for _, peerID := range peerIDs { - _ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - } - } - }) -} - -func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) { - account := createTestAccountWithEntities() - - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - if peerID == offlinePeerID { - continue - } - validatedPeersMap[peerID] = struct{}{} - } - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - - newPeerID := "peer-new-101" - newPeerIP := net.IP{100, 64, 1, 1} - newPeer := &nbpeer.Peer{ - ID: newPeerID, - IP: newPeerIP, - Key: fmt.Sprintf("key-%s", newPeerID), - DNSLabel: "peernew101", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, - UserID: "user-admin", - Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, - LastLogin: func() *time.Time { t := time.Now(); return &t }(), - } - - account.Peers[newPeerID] = newPeer - - if devGroup, exists := account.Groups[devGroupID]; exists { - devGroup.Peers = append(devGroup.Peers, newPeerID) - } - - if allGroup, exists := account.Groups[allGroupID]; exists { - allGroup.Peers = append(allGroup.Peers, newPeerID) - } - - validatedPeersMap[newPeerID] = struct{}{} - - if account.Network != nil { - account.Network.Serial++ - } - - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - - legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) - normalizeAndSortNetworkMap(legacyNetworkMap) - legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ") - require.NoError(t, err, "error marshaling legacy network map to JSON") - - err = builder.OnPeerAddedIncremental(account, newPeerID) - require.NoError(t, err, "error adding peer to cache") - - newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - normalizeAndSortNetworkMap(newNetworkMap) - newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ") - require.NoError(t, err, "error marshaling new network map to JSON") - - if string(legacyJSON) != string(newJSON) { - legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_new_peer.json") - newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded.json") - - err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755) - require.NoError(t, err) - - err = os.WriteFile(legacyFilePath, legacyJSON, 0644) - require.NoError(t, err) - t.Logf("Saved legacy network map to %s", legacyFilePath) - - err = os.WriteFile(newFilePath, newJSON, 0644) - require.NoError(t, err) - t.Logf("Saved new network map to %s", newFilePath) - - require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with new peer from legacy and new builder do not match") - } -} - -func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) { - account := createTestAccountWithEntities() - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - var peerIDs []string - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - validatedPeersMap[peerID] = struct{}{} - peerIDs = append(peerIDs, peerID) - } - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - newPeerID := "peer-new-101" - newPeer := &nbpeer.Peer{ - ID: newPeerID, - IP: net.IP{100, 64, 1, 1}, - Key: fmt.Sprintf("key-%s", newPeerID), - DNSLabel: "peernew101", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, - UserID: "user-admin", - Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, - } - - account.Peers[newPeerID] = newPeer - account.Groups[devGroupID].Peers = append(account.Groups[devGroupID].Peers, newPeerID) - account.Groups[allGroupID].Peers = append(account.Groups[allGroupID].Peers, newPeerID) - validatedPeersMap[newPeerID] = struct{}{} - - b.ResetTimer() - b.Run("old builder after add", func(b *testing.B) { - for i := 0; i < b.N; i++ { - for _, testingPeerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) - } - } - }) - - b.ResetTimer() - b.Run("new builder after add", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = builder.OnPeerAddedIncremental(account, newPeerID) - for _, testingPeerID := range peerIDs { - _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - } - } - }) -} - -func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) { - account := createTestAccountWithEntities() - - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - if peerID == offlinePeerID { - continue - } - validatedPeersMap[peerID] = struct{}{} - } - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - - newRouterID := "peer-new-router-102" - newRouterIP := net.IP{100, 64, 1, 2} - newRouter := &nbpeer.Peer{ - ID: newRouterID, - IP: newRouterIP, - Key: fmt.Sprintf("key-%s", newRouterID), - DNSLabel: "newrouter102", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, - UserID: "user-admin", - Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, - LastLogin: func() *time.Time { t := time.Now(); return &t }(), - } - - account.Peers[newRouterID] = newRouter - - if opsGroup, exists := account.Groups[opsGroupID]; exists { - opsGroup.Peers = append(opsGroup.Peers, newRouterID) - } - - if allGroup, exists := account.Groups[allGroupID]; exists { - allGroup.Peers = append(allGroup.Peers, newRouterID) - } - - newRoute := &route.Route{ - ID: route.ID("route-new-router"), - Network: netip.MustParsePrefix("172.16.0.0/24"), - Peer: newRouter.Key, - PeerID: newRouterID, - Description: "Route from new router", - Enabled: true, - PeerGroups: []string{opsGroupID}, - Groups: []string{devGroupID, opsGroupID}, - AccessControlGroups: []string{devGroupID}, - AccountID: account.Id, - } - account.Routes[newRoute.ID] = newRoute - - validatedPeersMap[newRouterID] = struct{}{} - - if account.Network != nil { - account.Network.Serial++ - } - - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - - legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) - normalizeAndSortNetworkMap(legacyNetworkMap) - legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ") - require.NoError(t, err, "error marshaling legacy network map to JSON") - - err = builder.OnPeerAddedIncremental(account, newRouterID) - require.NoError(t, err, "error adding router to cache") - - newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - normalizeAndSortNetworkMap(newNetworkMap) - newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ") - require.NoError(t, err, "error marshaling new network map to JSON") - - if string(legacyJSON) != string(newJSON) { - legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_new_router.json") - newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json") - - err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755) - require.NoError(t, err) - - err = os.WriteFile(legacyFilePath, legacyJSON, 0644) - require.NoError(t, err) - t.Logf("Saved legacy network map to %s", legacyFilePath) - - err = os.WriteFile(newFilePath, newJSON, 0644) - require.NoError(t, err) - t.Logf("Saved new network map to %s", newFilePath) - - require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with new router from legacy and new builder do not match") - } -} - -func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) { - account := createTestAccountWithEntities() - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - var peerIDs []string - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - validatedPeersMap[peerID] = struct{}{} - peerIDs = append(peerIDs, peerID) - } - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - newRouterID := "peer-new-router-102" - newRouterIP := net.IP{100, 64, 1, 2} - newRouter := &nbpeer.Peer{ - ID: newRouterID, - IP: newRouterIP, - Key: fmt.Sprintf("key-%s", newRouterID), - DNSLabel: "newrouter102", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, - UserID: "user-admin", - Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, - LastLogin: func() *time.Time { t := time.Now(); return &t }(), - } - - account.Peers[newRouterID] = newRouter - - if opsGroup, exists := account.Groups[opsGroupID]; exists { - opsGroup.Peers = append(opsGroup.Peers, newRouterID) - } - if allGroup, exists := account.Groups[allGroupID]; exists { - allGroup.Peers = append(allGroup.Peers, newRouterID) - } - - newRoute := &route.Route{ - ID: route.ID("route-new-router"), - Network: netip.MustParsePrefix("172.16.0.0/24"), - Peer: newRouter.Key, - PeerID: newRouterID, - Description: "Route from new router", - Enabled: true, - PeerGroups: []string{opsGroupID}, - Groups: []string{devGroupID, opsGroupID}, - AccessControlGroups: []string{devGroupID}, - AccountID: account.Id, - } - account.Routes[newRoute.ID] = newRoute - - validatedPeersMap[newRouterID] = struct{}{} - - b.ResetTimer() - b.Run("old builder after add", func(b *testing.B) { - for i := 0; i < b.N; i++ { - for _, testingPeerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) - } - } - }) - - b.ResetTimer() - b.Run("new builder after add", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = builder.OnPeerAddedIncremental(account, newRouterID) - for _, testingPeerID := range peerIDs { - _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - } - } - }) -} - -func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) { - account := createTestAccountWithEntities() - - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - if peerID == offlinePeerID { - continue - } - validatedPeersMap[peerID] = struct{}{} - } - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - - deletedPeerID := "peer-25" - - delete(account.Peers, deletedPeerID) - - if devGroup, exists := account.Groups[devGroupID]; exists { - devGroup.Peers = slices.DeleteFunc(devGroup.Peers, func(id string) bool { - return id == deletedPeerID - }) - } - - if allGroup, exists := account.Groups[allGroupID]; exists { - allGroup.Peers = slices.DeleteFunc(allGroup.Peers, func(id string) bool { - return id == deletedPeerID - }) - } - - delete(validatedPeersMap, deletedPeerID) - - if account.Network != nil { - account.Network.Serial++ - } - - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - - legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) - normalizeAndSortNetworkMap(legacyNetworkMap) - legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ") - require.NoError(t, err, "error marshaling legacy network map to JSON") - - err = builder.OnPeerDeleted(account, deletedPeerID) - require.NoError(t, err, "error deleting peer from cache") - - newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - normalizeAndSortNetworkMap(newNetworkMap) - newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ") - require.NoError(t, err, "error marshaling new network map to JSON") - - if string(legacyJSON) != string(newJSON) { - legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_peer.json") - newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeerdeleted.json") - - err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755) - require.NoError(t, err) - - err = os.WriteFile(legacyFilePath, legacyJSON, 0644) - require.NoError(t, err) - t.Logf("Saved legacy network map to %s", legacyFilePath) - - err = os.WriteFile(newFilePath, newJSON, 0644) - require.NoError(t, err) - t.Logf("Saved new network map to %s", newFilePath) - - require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with deleted peer from legacy and new builder do not match") - } -} - -func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) { - account := createTestAccountWithEntities() - - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - if peerID == offlinePeerID { - continue - } - validatedPeersMap[peerID] = struct{}{} - } - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - - deletedRouterID := "peer-75" - - var affectedRoute *route.Route - for _, r := range account.Routes { - if r.PeerID == deletedRouterID { - affectedRoute = r - break - } - } - require.NotNil(t, affectedRoute, "Router peer should have a route") - - for _, group := range account.Groups { - group.Peers = slices.DeleteFunc(group.Peers, func(id string) bool { - return id == deletedRouterID - }) - } - - for routeID, r := range account.Routes { - if r.Peer == account.Peers[deletedRouterID].Key || r.PeerID == deletedRouterID { - delete(account.Routes, routeID) - } - } - delete(account.Peers, deletedRouterID) - delete(validatedPeersMap, deletedRouterID) - - if account.Network != nil { - account.Network.Serial++ - } - - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - - legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) - normalizeAndSortNetworkMap(legacyNetworkMap) - legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ") - require.NoError(t, err, "error marshaling legacy network map to JSON") - - err = builder.OnPeerDeleted(account, deletedRouterID) - require.NoError(t, err, "error deleting routing peer from cache") - - newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - normalizeAndSortNetworkMap(newNetworkMap) - newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ") - require.NoError(t, err, "error marshaling new network map to JSON") - - if string(legacyJSON) != string(newJSON) { - legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_router_peer.json") - newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_deleted_router.json") - - err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755) - require.NoError(t, err) - - err = os.WriteFile(legacyFilePath, legacyJSON, 0644) - require.NoError(t, err) - t.Logf("Saved legacy network map to %s", legacyFilePath) - - err = os.WriteFile(newFilePath, newJSON, 0644) - require.NoError(t, err) - t.Logf("Saved new network map to %s", newFilePath) - - require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with deleted router from legacy and new builder do not match") - } -} - -func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) { - account := createTestAccountWithEntities() - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - var peerIDs []string - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - validatedPeersMap[peerID] = struct{}{} - peerIDs = append(peerIDs, peerID) - } - - deletedPeerID := "peer-25" - - delete(account.Peers, deletedPeerID) - account.Groups[devGroupID].Peers = slices.DeleteFunc(account.Groups[devGroupID].Peers, func(id string) bool { - return id == deletedPeerID - }) - account.Groups[allGroupID].Peers = slices.DeleteFunc(account.Groups[allGroupID].Peers, func(id string) bool { - return id == deletedPeerID - }) - delete(validatedPeersMap, deletedPeerID) - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - - b.ResetTimer() - b.Run("old builder after delete", func(b *testing.B) { - for i := 0; i < b.N; i++ { - for _, testingPeerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) - } - } - }) - - b.ResetTimer() - b.Run("new builder after delete", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = builder.OnPeerDeleted(account, deletedPeerID) - for _, testingPeerID := range peerIDs { - _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - } - } - }) -} - -func normalizeAndSortNetworkMap(networkMap *types.NetworkMap) { - for _, peer := range networkMap.Peers { - if peer.Status != nil { - peer.Status.LastSeen = time.Time{} - } - peer.LastLogin = &time.Time{} - } - for _, peer := range networkMap.OfflinePeers { - if peer.Status != nil { - peer.Status.LastSeen = time.Time{} - } - peer.LastLogin = &time.Time{} - } - - sort.Slice(networkMap.Peers, func(i, j int) bool { return networkMap.Peers[i].ID < networkMap.Peers[j].ID }) - sort.Slice(networkMap.OfflinePeers, func(i, j int) bool { return networkMap.OfflinePeers[i].ID < networkMap.OfflinePeers[j].ID }) - sort.Slice(networkMap.Routes, func(i, j int) bool { return networkMap.Routes[i].ID < networkMap.Routes[j].ID }) - - sort.Slice(networkMap.FirewallRules, func(i, j int) bool { - r1, r2 := networkMap.FirewallRules[i], networkMap.FirewallRules[j] - if r1.PeerIP != r2.PeerIP { - return r1.PeerIP < r2.PeerIP - } - if r1.Protocol != r2.Protocol { - return r1.Protocol < r2.Protocol - } - if r1.Direction != r2.Direction { - return r1.Direction < r2.Direction - } - if r1.Action != r2.Action { - return r1.Action < r2.Action - } - return r1.Port < r2.Port - }) - - sort.Slice(networkMap.RoutesFirewallRules, func(i, j int) bool { - r1, r2 := networkMap.RoutesFirewallRules[i], networkMap.RoutesFirewallRules[j] - if r1.RouteID != r2.RouteID { - return r1.RouteID < r2.RouteID - } - if r1.Action != r2.Action { - return r1.Action < r2.Action - } - if r1.Destination != r2.Destination { - return r1.Destination < r2.Destination - } - if len(r1.SourceRanges) > 0 && len(r2.SourceRanges) > 0 { - if r1.SourceRanges[0] != r2.SourceRanges[0] { - return r1.SourceRanges[0] < r2.SourceRanges[0] - } - } - return r1.Port < r2.Port - }) - - for _, ranges := range networkMap.RoutesFirewallRules { - sort.Slice(ranges.SourceRanges, func(i, j int) bool { - return ranges.SourceRanges[i] < ranges.SourceRanges[j] - }) - } -} - -type networkMapJSON struct { - Peers []*nbpeer.Peer `json:"Peers"` - Network *types.Network `json:"Network"` - Routes []*route.Route `json:"Routes"` - DNSConfig dns.Config `json:"DNSConfig"` - OfflinePeers []*nbpeer.Peer `json:"OfflinePeers"` - FirewallRules []*types.FirewallRule `json:"FirewallRules"` - RoutesFirewallRules []*types.RouteFirewallRule `json:"RoutesFirewallRules"` - ForwardingRules []*types.ForwardingRule `json:"ForwardingRules"` - AuthorizedUsers map[string][]string `json:"AuthorizedUsers,omitempty"` - EnableSSH bool `json:"EnableSSH"` -} - -func toNetworkMapJSON(nm *types.NetworkMap) *networkMapJSON { - result := &networkMapJSON{ - Peers: nm.Peers, - Network: nm.Network, - Routes: nm.Routes, - DNSConfig: nm.DNSConfig, - OfflinePeers: nm.OfflinePeers, - FirewallRules: nm.FirewallRules, - RoutesFirewallRules: nm.RoutesFirewallRules, - ForwardingRules: nm.ForwardingRules, - EnableSSH: nm.EnableSSH, - } - - if len(nm.AuthorizedUsers) > 0 { - result.AuthorizedUsers = make(map[string][]string) - localUsers := make([]string, 0, len(nm.AuthorizedUsers)) - for localUser := range nm.AuthorizedUsers { - localUsers = append(localUsers, localUser) - } - sort.Strings(localUsers) - - for _, localUser := range localUsers { - userIDs := nm.AuthorizedUsers[localUser] - sortedUserIDs := make([]string, 0, len(userIDs)) - for userID := range userIDs { - sortedUserIDs = append(sortedUserIDs, userID) - } - sort.Strings(sortedUserIDs) - result.AuthorizedUsers[localUser] = sortedUserIDs - } - } - - return result -} - -func createTestAccountWithEntities() *types.Account { - peers := make(map[string]*nbpeer.Peer) - devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{} - - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - ip := net.IP{100, 64, 0, byte(i + 1)} - wtVersion := "0.25.0" - if i%2 == 0 { - wtVersion = "0.40.0" - } - - p := &nbpeer.Peer{ - ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1), - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, - UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"}, - } - - if peerID == expiredPeerID { - p.LoginExpirationEnabled = true - pastTimestamp := time.Now().Add(-2 * time.Hour) - p.LastLogin = &pastTimestamp - } - - peers[peerID] = p - allGroupPeers = append(allGroupPeers, peerID) - if i < numPeers/2 { - devGroupPeers = append(devGroupPeers, peerID) - } else { - opsGroupPeers = append(opsGroupPeers, peerID) - } - - } - - groups := map[string]*types.Group{ - allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers}, - devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers}, - opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers}, - sshUsersGroupID: {ID: sshUsersGroupID, Name: "SSH Users", Peers: []string{}}, - } - - policies := []*types.Policy{ - { - ID: policyIDAll, Name: "Default-Allow", Enabled: true, - Rules: []*types.PolicyRule{{ - ID: policyIDAll, Name: "Allow All", Enabled: true, Action: types.PolicyTrafficActionAccept, - Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, - Sources: []string{allGroupID}, Destinations: []string{allGroupID}, - }}, - }, - { - ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true, - Rules: []*types.PolicyRule{{ - ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: types.PolicyTrafficActionAccept, - Protocol: types.PolicyRuleProtocolTCP, Bidirectional: false, - PortRanges: []types.RulePortRange{{Start: 8080, End: 8090}}, - Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, - }}, - }, - { - ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true, - Rules: []*types.PolicyRule{{ - ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: types.PolicyTrafficActionDrop, - Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true, - Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, - }}, - }, - { - ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true, - SourcePostureChecks: []string{postureCheckID}, - Rules: []*types.PolicyRule{{ - ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: types.PolicyTrafficActionAccept, - Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, - Sources: []string{opsGroupID}, DestinationResource: types.Resource{ID: networkResourceID}, - }}, - }, - { - ID: policyIDSSH, Name: "SSH Access Policy", Enabled: true, - Rules: []*types.PolicyRule{{ - ID: policyIDSSH, Name: "Allow SSH to Ops", Enabled: true, Action: types.PolicyTrafficActionAccept, - Protocol: types.PolicyRuleProtocolNetbirdSSH, Bidirectional: false, - Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, - AuthorizedGroups: map[string][]string{sshUsersGroupID: {"root", "admin"}}, - }}, - }, - } - - routes := map[route.ID]*route.Route{ - routeID: { - ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"), - Peer: peers["peer-75"].Key, - PeerID: "peer-75", - Description: "Route to internal resource", Enabled: true, - PeerGroups: []string{devGroupID, opsGroupID}, - Groups: []string{devGroupID, opsGroupID}, - AccessControlGroups: []string{devGroupID}, - }, - routeHA1ID: { - ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"), - Peer: peers["peer-80"].Key, - PeerID: "peer-80", - Description: "HA Route 1", Enabled: true, Metric: 1000, - PeerGroups: []string{allGroupID}, - Groups: []string{allGroupID}, - AccessControlGroups: []string{allGroupID}, - }, - routeHA2ID: { - ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"), - Peer: peers["peer-90"].Key, - PeerID: "peer-90", - Description: "HA Route 2", Enabled: true, Metric: 900, - PeerGroups: []string{devGroupID, opsGroupID}, - Groups: []string{devGroupID, opsGroupID}, - AccessControlGroups: []string{allGroupID}, - }, - } - - users := map[string]*types.User{ - userAdminID: {Id: userAdminID, Role: types.UserRoleAdmin, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{allGroupID}}, - userDevID: {Id: userDevID, Role: types.UserRoleUser, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{sshUsersGroupID, devGroupID}}, - userOpsID: {Id: userOpsID, Role: types.UserRoleUser, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{sshUsersGroupID, opsGroupID}}, - } - - account := &types.Account{ - Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes, - Users: users, - Network: &types.Network{ - Identifier: "net-golden-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1, - }, - DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{opsGroupID}}, - NameServerGroups: map[string]*dns.NameServerGroup{ - nameserverGroupID: { - ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID}, - NameServers: []dns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: dns.UDPNameServerType, Port: 53}}, - }, - }, - PostureChecks: []*posture.Checks{ - {ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"}, - }}, - }, - NetworkResources: []*resourceTypes.NetworkResource{ - {ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"}, - }, - Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}}, - NetworkRouters: []*routerTypes.NetworkRouter{ - {ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID}, - }, - Settings: &types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour}, - } - - for _, p := range account.Policies { - p.AccountID = account.Id - } - for _, r := range account.Routes { - r.AccountID = account.Id - } - - return account -} - -func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter_Batched(t *testing.T) { - account := createTestAccountWithEntities() - - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - if peerID == offlinePeerID { - continue - } - validatedPeersMap[peerID] = struct{}{} - } - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - - newRouterID := "peer-new-router-102" - newRouterIP := net.IP{100, 64, 1, 2} - newRouter := &nbpeer.Peer{ - ID: newRouterID, - IP: newRouterIP, - Key: fmt.Sprintf("key-%s", newRouterID), - DNSLabel: "newrouter102", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, - UserID: "user-admin", - Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, - LastLogin: func() *time.Time { t := time.Now(); return &t }(), - } - - account.Peers[newRouterID] = newRouter - - if opsGroup, exists := account.Groups[opsGroupID]; exists { - opsGroup.Peers = append(opsGroup.Peers, newRouterID) - } - if allGroup, exists := account.Groups[allGroupID]; exists { - allGroup.Peers = append(allGroup.Peers, newRouterID) - } - - newRoute := &route.Route{ - ID: route.ID("route-new-router"), - Network: netip.MustParsePrefix("172.16.0.0/24"), - Peer: newRouter.Key, - PeerID: newRouterID, - Description: "Route from new router", - Enabled: true, - PeerGroups: []string{opsGroupID}, - Groups: []string{devGroupID, opsGroupID}, - AccessControlGroups: []string{devGroupID}, - AccountID: account.Id, - } - account.Routes[newRoute.ID] = newRoute - - validatedPeersMap[newRouterID] = struct{}{} - - if account.Network != nil { - account.Network.Serial++ - } - - builder.EnqueuePeersForIncrementalAdd(account, newRouterID) - - time.Sleep(100 * time.Millisecond) - - networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - - normalizeAndSortNetworkMap(networkMap) - - jsonData, err := json.MarshalIndent(networkMap, "", " ") - require.NoError(t, err, "error marshaling network map to JSON") - - goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json") - - t.Log("Update golden file with OnPeerAdded router...") - err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) - require.NoError(t, err) - err = os.WriteFile(goldenFilePath, jsonData, 0644) - require.NoError(t, err) - - expectedJSON, err := os.ReadFile(goldenFilePath) - require.NoError(t, err, "error reading golden file") - - require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded router does not match golden file") -} diff --git a/management/server/types/networkmapbuilder.go b/management/server/types/networkmapbuilder.go deleted file mode 100644 index 6448b8403..000000000 --- a/management/server/types/networkmapbuilder.go +++ /dev/null @@ -1,2317 +0,0 @@ -package types - -import ( - "context" - "fmt" - "slices" - "strconv" - "strings" - "sync" - "time" - - log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" - - "github.com/netbirdio/netbird/client/ssh/auth" - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/modules/zones" - resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" - routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/route" -) - -const ( - allPeers = "0.0.0.0" - allWildcard = "0.0.0.0/0" - v6AllWildcard = "::/0" - fw = "fw:" - rfw = "route-fw:" - - szAddPeerBatch = 10 - maxPeerAddRetries = 20 -) - -type NetworkMapCache struct { - globalRoutes map[route.ID]*route.Route - globalRules map[string]*FirewallRule //ruleId - globalRouteRules map[string]*RouteFirewallRule //ruleId - globalPeers map[string]*nbpeer.Peer - - groupToPeers map[string][]string - peerToGroups map[string][]string - policyToRules map[string][]*PolicyRule //policyId - groupToPolicies map[string][]*Policy - groupToRoutes map[string][]*route.Route - peerToRoutes map[string][]*route.Route - - peerACLs map[string]*PeerACLView - peerRoutes map[string]*PeerRoutesView - peerDNS map[string]*nbdns.Config - peerSSH map[string]*PeerSSHView - - groupIDToUserIDs map[string][]string - allowedUserIDs map[string]struct{} - - resourceRouters map[string]map[string]*routerTypes.NetworkRouter - resourcePolicies map[string][]*Policy - - globalResources map[string]*resourceTypes.NetworkResource // resourceId - - acgToRoutes map[string]map[route.ID]*RouteOwnerInfo // routeID -> owner info - noACGRoutes map[route.ID]*RouteOwnerInfo - - mu sync.RWMutex -} - -type RouteOwnerInfo struct { - PeerID string - RouteID route.ID -} - -type PeerACLView struct { - ConnectedPeerIDs []string - FirewallRuleIDs []string -} - -type PeerRoutesView struct { - OwnRouteIDs []route.ID - NetworkResourceIDs []route.ID - InheritedRouteIDs []route.ID - RouteFirewallRuleIDs []string -} - -type PeerSSHView struct { - EnableSSH bool - AuthorizedUsers map[string]map[string]struct{} -} - -type NetworkMapBuilder struct { - account *Account - cache *NetworkMapCache - validatedPeers map[string]struct{} - - apb addPeerBatch -} - -type addPeerBatch struct { - mu sync.Mutex - sg *sync.Cond - ids []string - la *Account - retryCount map[string]int -} - -func NewNetworkMapBuilder(account *Account, validatedPeers map[string]struct{}) *NetworkMapBuilder { - builder := &NetworkMapBuilder{ - cache: &NetworkMapCache{ - globalRoutes: make(map[route.ID]*route.Route), - globalRules: make(map[string]*FirewallRule), - globalRouteRules: make(map[string]*RouteFirewallRule), - globalPeers: make(map[string]*nbpeer.Peer), - groupToPeers: make(map[string][]string), - peerToGroups: make(map[string][]string), - policyToRules: make(map[string][]*PolicyRule), - groupToPolicies: make(map[string][]*Policy), - groupToRoutes: make(map[string][]*route.Route), - peerToRoutes: make(map[string][]*route.Route), - peerACLs: make(map[string]*PeerACLView), - peerRoutes: make(map[string]*PeerRoutesView), - peerDNS: make(map[string]*nbdns.Config), - peerSSH: make(map[string]*PeerSSHView), - groupIDToUserIDs: make(map[string][]string), - allowedUserIDs: make(map[string]struct{}), - globalResources: make(map[string]*resourceTypes.NetworkResource), - acgToRoutes: make(map[string]map[route.ID]*RouteOwnerInfo), - noACGRoutes: make(map[route.ID]*RouteOwnerInfo), - }, - validatedPeers: make(map[string]struct{}), - } - builder.apb.sg = sync.NewCond(&builder.apb.mu) - builder.apb.ids = make([]string, 0, szAddPeerBatch) - builder.apb.la = account - builder.apb.retryCount = make(map[string]int) - - maps.Copy(builder.validatedPeers, validatedPeers) - - builder.initialBuild(account) - - go builder.incAddPeerLoop() - return builder -} - -func (b *NetworkMapBuilder) initialBuild(account *Account) { - b.cache.mu.Lock() - defer b.cache.mu.Unlock() - - b.account = account - - start := time.Now() - - b.buildGlobalIndexes(account) - - resourceRouters := account.GetResourceRoutersMap() - resourcePolicies := account.GetResourcePoliciesMap() - b.cache.resourceRouters = resourceRouters - b.cache.resourcePolicies = resourcePolicies - - for peerID := range account.Peers { - b.buildPeerACLView(account, peerID) - b.buildPeerRoutesView(account, peerID) - b.buildPeerDNSView(account, peerID) - } - - log.Debugf("NetworkMapBuilder: Initial build completed in %v for account %s", time.Since(start), account.Id) -} - -func (b *NetworkMapBuilder) buildGlobalIndexes(account *Account) { - clear(b.cache.globalPeers) - clear(b.cache.groupToPeers) - clear(b.cache.peerToGroups) - clear(b.cache.policyToRules) - clear(b.cache.groupToPolicies) - clear(b.cache.globalRoutes) - clear(b.cache.globalRules) - clear(b.cache.globalRouteRules) - clear(b.cache.globalResources) - clear(b.cache.groupToRoutes) - clear(b.cache.peerToRoutes) - clear(b.cache.acgToRoutes) - clear(b.cache.noACGRoutes) - clear(b.cache.groupIDToUserIDs) - clear(b.cache.allowedUserIDs) - clear(b.cache.peerSSH) - - maps.Copy(b.cache.globalPeers, account.Peers) - - b.cache.groupIDToUserIDs = account.GetActiveGroupUsers() - b.cache.allowedUserIDs = b.buildAllowedUserIDs(account) - - for groupID, group := range account.Groups { - peersCopy := make([]string, len(group.Peers)) - copy(peersCopy, group.Peers) - b.cache.groupToPeers[groupID] = peersCopy - - for _, peerID := range group.Peers { - b.cache.peerToGroups[peerID] = append(b.cache.peerToGroups[peerID], groupID) - } - } - - for _, policy := range account.Policies { - if !policy.Enabled { - continue - } - - b.cache.policyToRules[policy.ID] = policy.Rules - - affectedGroups := make(map[string]struct{}) - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } - - for _, groupID := range rule.Sources { - affectedGroups[groupID] = struct{}{} - } - for _, groupID := range rule.Destinations { - affectedGroups[groupID] = struct{}{} - } - if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { - groupId := rule.SourceResource.ID - affectedGroups[groupId] = struct{}{} - b.cache.peerToGroups[rule.SourceResource.ID] = append(b.cache.peerToGroups[rule.SourceResource.ID], groupId) - } - if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { - groupId := rule.DestinationResource.ID - affectedGroups[groupId] = struct{}{} - b.cache.peerToGroups[rule.DestinationResource.ID] = append(b.cache.peerToGroups[rule.DestinationResource.ID], groupId) - } - } - - for groupID := range affectedGroups { - b.cache.groupToPolicies[groupID] = append(b.cache.groupToPolicies[groupID], policy) - } - } - - for _, resource := range account.NetworkResources { - if !resource.Enabled { - continue - } - b.cache.globalResources[resource.ID] = resource - } - - for _, r := range account.Routes { - if !r.Enabled { - continue - } - for _, groupID := range r.PeerGroups { - b.cache.groupToRoutes[groupID] = append(b.cache.groupToRoutes[groupID], r) - } - if r.Peer != "" { - if peer, ok := b.cache.globalPeers[r.Peer]; ok { - b.cache.peerToRoutes[peer.ID] = append(b.cache.peerToRoutes[peer.ID], r) - } - } - } -} - -func (b *NetworkMapBuilder) buildPeerACLView(account *Account, peerID string) { - peer := account.GetPeer(peerID) - if peer == nil { - return - } - - allPotentialPeers, firewallRules, authorizedUsers, sshEnabled := b.getPeerConnectionResources(account, peer, b.validatedPeers) - - isRouter, networkResourcesRoutes, sourcePeers := b.getNetworkResourcesForPeer(account, peer) - - var emptyExpiredPeers []*nbpeer.Peer - finalAllPeers := b.addNetworksRoutingPeers( - networkResourcesRoutes, - peer, - allPotentialPeers, - emptyExpiredPeers, - isRouter, - sourcePeers, - ) - - view := &PeerACLView{ - ConnectedPeerIDs: make([]string, 0, len(finalAllPeers)), - FirewallRuleIDs: make([]string, 0, len(firewallRules)), - } - - for _, p := range finalAllPeers { - view.ConnectedPeerIDs = append(view.ConnectedPeerIDs, p.ID) - } - - for _, rule := range firewallRules { - ruleID := b.generateFirewallRuleID(rule) - view.FirewallRuleIDs = append(view.FirewallRuleIDs, ruleID) - b.cache.globalRules[ruleID] = rule - } - - b.cache.peerACLs[peerID] = view - b.cache.peerSSH[peerID] = &PeerSSHView{ - EnableSSH: sshEnabled, - AuthorizedUsers: authorizedUsers, - } -} - -func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *nbpeer.Peer, - validatedPeersMap map[string]struct{}, -) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) { - peerID := peer.ID - ctx := context.Background() - - peerGroups := b.cache.peerToGroups[peerID] - peerGroupsMap := make(map[string]struct{}, len(peerGroups)) - for _, groupID := range peerGroups { - peerGroupsMap[groupID] = struct{}{} - } - - rulesExists := make(map[string]struct{}) - peersExists := make(map[string]struct{}) - fwRules := make([]*FirewallRule, 0) - peers := make([]*nbpeer.Peer, 0) - - authorizedUsers := make(map[string]map[string]struct{}) - sshEnabled := false - - for _, group := range peerGroups { - policies := b.cache.groupToPolicies[group] - for _, policy := range policies { - if isValid := account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID); !isValid { - continue - } - rules := b.cache.policyToRules[policy.ID] - for _, rule := range rules { - var sourcePeers, destinationPeers []*nbpeer.Peer - var peerInSources, peerInDestinations bool - - if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { - peerInSources = rule.SourceResource.ID == peerID - } else { - peerInSources = b.isPeerInGroupscached(rule.Sources, peerGroupsMap) - } - - if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { - peerInDestinations = rule.DestinationResource.ID == peerID - } else { - peerInDestinations = b.isPeerInGroupscached(rule.Destinations, peerGroupsMap) - } - - if !peerInSources && !peerInDestinations { - continue - } - - if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { - peer := account.GetPeer(rule.SourceResource.ID) - if peer != nil { - sourcePeers = []*nbpeer.Peer{peer} - } - } else { - sourcePeers = b.getPeersFromGroupscached(account, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) - } - - if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { - peer := account.GetPeer(rule.DestinationResource.ID) - if peer != nil { - destinationPeers = []*nbpeer.Peer{peer} - } - } else { - destinationPeers = b.getPeersFromGroupscached(account, rule.Destinations, peerID, nil, validatedPeersMap) - } - - if rule.Bidirectional { - if peerInSources { - b.generateResourcescached( - rule, destinationPeers, FirewallRuleDirectionIN, - peer, &peers, &fwRules, peersExists, rulesExists, - ) - } - if peerInDestinations { - b.generateResourcescached( - rule, sourcePeers, FirewallRuleDirectionOUT, - peer, &peers, &fwRules, peersExists, rulesExists, - ) - } - } - - if peerInSources { - b.generateResourcescached( - rule, destinationPeers, FirewallRuleDirectionOUT, - peer, &peers, &fwRules, peersExists, rulesExists, - ) - } - - if peerInDestinations { - b.generateResourcescached( - rule, sourcePeers, FirewallRuleDirectionIN, - peer, &peers, &fwRules, peersExists, rulesExists, - ) - - if rule.Protocol == PolicyRuleProtocolNetbirdSSH { - sshEnabled = true - switch { - case len(rule.AuthorizedGroups) > 0: - for groupID, localUsers := range rule.AuthorizedGroups { - userIDs, ok := b.cache.groupIDToUserIDs[groupID] - if !ok { - continue - } - - if len(localUsers) == 0 { - localUsers = []string{auth.Wildcard} - } - - for _, localUser := range localUsers { - if authorizedUsers[localUser] == nil { - authorizedUsers[localUser] = make(map[string]struct{}) - } - for _, userID := range userIDs { - authorizedUsers[localUser][userID] = struct{}{} - } - } - } - case rule.AuthorizedUser != "": - if authorizedUsers[auth.Wildcard] == nil { - authorizedUsers[auth.Wildcard] = make(map[string]struct{}) - } - authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{} - default: - authorizedUsers[auth.Wildcard] = maps.Clone(b.cache.allowedUserIDs) - } - } else if policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled { - sshEnabled = true - authorizedUsers[auth.Wildcard] = maps.Clone(b.cache.allowedUserIDs) - } - } - } - } - } - - return peers, fwRules, authorizedUsers, sshEnabled -} - -func (b *NetworkMapBuilder) isPeerInGroupscached(groupIDs []string, peerGroupsMap map[string]struct{}) bool { - for _, groupID := range groupIDs { - if _, exists := peerGroupsMap[groupID]; exists { - return true - } - } - return false -} - -func (b *NetworkMapBuilder) getPeersFromGroupscached(account *Account, groupIDs []string, - excludePeerID string, postureChecksIDs []string, validatedPeersMap map[string]struct{}, -) []*nbpeer.Peer { - ctx := context.Background() - uniquePeers := make(map[string]*nbpeer.Peer) - - for _, groupID := range groupIDs { - peerIDs := b.cache.groupToPeers[groupID] - for _, peerID := range peerIDs { - if peerID == excludePeerID { - continue - } - - if _, ok := validatedPeersMap[peerID]; !ok { - continue - } - - peer := b.cache.globalPeers[peerID] - if peer == nil { - continue - } - - if len(postureChecksIDs) > 0 { - if !account.validatePostureChecksOnPeer(ctx, postureChecksIDs, peerID) { - continue - } - } - - uniquePeers[peerID] = peer - } - } - - result := make([]*nbpeer.Peer, 0, len(uniquePeers)) - for _, peer := range uniquePeers { - result = append(result, peer) - } - - return result -} - -func (b *NetworkMapBuilder) generateResourcescached( - rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int, targetPeer *nbpeer.Peer, - peers *[]*nbpeer.Peer, rules *[]*FirewallRule, peersExists map[string]struct{}, rulesExists map[string]struct{}, -) { - for _, peer := range groupPeers { - if peer == nil { - continue - } - if _, ok := peersExists[peer.ID]; !ok { - *peers = append(*peers, peer) - peersExists[peer.ID] = struct{}{} - } - - fr := FirewallRule{ - PolicyID: rule.ID, - PeerIP: peer.IP.String(), - Direction: direction, - Action: string(rule.Action), - Protocol: firewallRuleProtocol(rule.Protocol), - } - - var s strings.Builder - s.WriteString(rule.ID) - s.WriteString(fr.PeerIP) - s.WriteString(strconv.Itoa(direction)) - s.WriteString(fr.Protocol) - s.WriteString(fr.Action) - s.WriteString(strings.Join(rule.Ports, ",")) - - ruleID := s.String() - - if _, ok := rulesExists[ruleID]; ok { - continue - } - rulesExists[ruleID] = struct{}{} - - if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 { - *rules = append(*rules, &fr) - continue - } - - *rules = append(*rules, expandPortsAndRanges(fr, rule, targetPeer)...) - } -} - -func (b *NetworkMapBuilder) getNetworkResourcesForPeer(account *Account, peer *nbpeer.Peer) (bool, []*route.Route, map[string]struct{}) { - ctx := context.Background() - peerID := peer.ID - - var isRoutingPeer bool - var routes []*route.Route - allSourcePeers := make(map[string]struct{}) - - peerGroups := b.cache.peerToGroups[peerID] - peerGroupsMap := make(map[string]struct{}, len(peerGroups)) - for _, groupID := range peerGroups { - peerGroupsMap[groupID] = struct{}{} - } - - for _, resource := range b.cache.globalResources { - - networkRoutingPeers := b.cache.resourceRouters[resource.NetworkID] - resourcePolicies := b.cache.resourcePolicies[resource.ID] - if len(resourcePolicies) == 0 { - continue - } - - isRouterForThisResource := false - - if networkRoutingPeers != nil { - if router, ok := networkRoutingPeers[peerID]; ok && router.Enabled { - isRoutingPeer = true - isRouterForThisResource = true - if rt := b.createNetworkResourceRoutes(resource, peerID, router, resourcePolicies); rt != nil { - routes = append(routes, rt) - } - } - } - - hasAccessAsClient := false - if !isRouterForThisResource { - for _, policy := range resourcePolicies { - if b.isPeerInGroupscached(policy.SourceGroups(), peerGroupsMap) { - if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { - hasAccessAsClient = true - break - } - } - } - } - - if hasAccessAsClient && networkRoutingPeers != nil { - for routerPeerID, router := range networkRoutingPeers { - if router.Enabled { - if rt := b.createNetworkResourceRoutes(resource, routerPeerID, router, resourcePolicies); rt != nil { - routes = append(routes, rt) - } - } - } - } - - if isRouterForThisResource { - for _, policy := range resourcePolicies { - var peersWithAccess []*nbpeer.Peer - if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { - peersWithAccess = []*nbpeer.Peer{peer} - } else { - peersWithAccess = b.getPeersFromGroupscached(account, policy.SourceGroups(), "", policy.SourcePostureChecks, b.validatedPeers) - } - for _, p := range peersWithAccess { - allSourcePeers[p.ID] = struct{}{} - } - } - } - } - - return isRoutingPeer, routes, allSourcePeers -} - -func (b *NetworkMapBuilder) createNetworkResourceRoutes( - resource *resourceTypes.NetworkResource, routerPeerID string, - router *routerTypes.NetworkRouter, resourcePolicies []*Policy, -) *route.Route { - if len(resourcePolicies) > 0 { - peer := b.cache.globalPeers[routerPeerID] - if peer != nil { - return resource.ToRoute(peer, router) - } - } - return nil -} - -func (b *NetworkMapBuilder) addNetworksRoutingPeers( - networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, peersToConnect []*nbpeer.Peer, - expiredPeers []*nbpeer.Peer, isRouter bool, sourcePeers map[string]struct{}, -) []*nbpeer.Peer { - - networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes)) - for _, r := range networkResourcesRoutes { - networkRoutesPeers[r.PeerID] = struct{}{} - } - - delete(sourcePeers, peer.ID) - delete(networkRoutesPeers, peer.ID) - - for _, existingPeer := range peersToConnect { - delete(sourcePeers, existingPeer.ID) - delete(networkRoutesPeers, existingPeer.ID) - } - for _, expPeer := range expiredPeers { - delete(sourcePeers, expPeer.ID) - delete(networkRoutesPeers, expPeer.ID) - } - - missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers)) - if isRouter { - for p := range sourcePeers { - missingPeers[p] = struct{}{} - } - } - for p := range networkRoutesPeers { - missingPeers[p] = struct{}{} - } - - for p := range missingPeers { - if missingPeer := b.cache.globalPeers[p]; missingPeer != nil { - peersToConnect = append(peersToConnect, missingPeer) - } - } - - return peersToConnect -} - -func (b *NetworkMapBuilder) buildPeerRoutesView(account *Account, peerID string) { - ctx := context.Background() - peer := account.GetPeer(peerID) - if peer == nil { - return - } - resourcePolicies := b.cache.resourcePolicies - - view := &PeerRoutesView{ - OwnRouteIDs: make([]route.ID, 0), - NetworkResourceIDs: make([]route.ID, 0), - RouteFirewallRuleIDs: make([]string, 0), - } - - enabledRoutes, disabledRoutes := b.getRoutingPeerRoutes(peerID) - for _, rt := range enabledRoutes { - if rt.PeerID != "" && rt.PeerID != peerID { - if b.cache.globalPeers[rt.PeerID] == nil { - continue - } - } - - view.OwnRouteIDs = append(view.OwnRouteIDs, rt.ID) - b.cache.globalRoutes[rt.ID] = rt - } - - aclView := b.cache.peerACLs[peerID] - if aclView != nil { - peerRoutesMembership := make(LookupMap) - for _, r := range append(enabledRoutes, disabledRoutes...) { - peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} - } - - peerGroups := b.cache.peerToGroups[peerID] - peerGroupsMap := make(LookupMap) - for _, groupID := range peerGroups { - peerGroupsMap[groupID] = struct{}{} - } - - for _, aclPeerID := range aclView.ConnectedPeerIDs { - if aclPeerID == peerID { - continue - } - activeRoutes, _ := b.getRoutingPeerRoutes(aclPeerID) - groupFilteredRoutes := account.filterRoutesByGroups(activeRoutes, peerGroupsMap) - haFilteredRoutes := account.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) - - for _, inheritedRoute := range haFilteredRoutes { - view.InheritedRouteIDs = append(view.InheritedRouteIDs, inheritedRoute.ID) - b.cache.globalRoutes[inheritedRoute.ID] = inheritedRoute - } - } - } - - _, networkResourcesRoutes, _ := b.getNetworkResourcesForPeer(account, peer) - - for _, rt := range networkResourcesRoutes { - view.NetworkResourceIDs = append(view.NetworkResourceIDs, rt.ID) - b.cache.globalRoutes[rt.ID] = rt - } - - allRoutes := slices.Concat(enabledRoutes, networkResourcesRoutes) - b.updateACGIndexForPeer(peerID, allRoutes) - - routeFirewallRules := b.getPeerRoutesFirewallRules(account, peerID, b.validatedPeers) - for _, rule := range routeFirewallRules { - ruleID := b.generateRouteFirewallRuleID(rule) - view.RouteFirewallRuleIDs = append(view.RouteFirewallRuleIDs, ruleID) - b.cache.globalRouteRules[ruleID] = rule - } - - if len(networkResourcesRoutes) > 0 { - networkResourceFirewallRules := account.GetPeerNetworkResourceFirewallRules(ctx, peer, b.validatedPeers, networkResourcesRoutes, resourcePolicies) - for _, rule := range networkResourceFirewallRules { - ruleID := b.generateRouteFirewallRuleID(rule) - view.RouteFirewallRuleIDs = append(view.RouteFirewallRuleIDs, ruleID) - b.cache.globalRouteRules[ruleID] = rule - } - } - - b.cache.peerRoutes[peerID] = view -} - -func (b *NetworkMapBuilder) updateACGIndexForPeer(peerID string, routes []*route.Route) { - for acg, routeMap := range b.cache.acgToRoutes { - for routeID, info := range routeMap { - if info.PeerID == peerID { - delete(routeMap, routeID) - } - } - if len(routeMap) == 0 { - delete(b.cache.acgToRoutes, acg) - } - } - - for routeID, info := range b.cache.noACGRoutes { - if info.PeerID == peerID { - delete(b.cache.noACGRoutes, routeID) - } - } - - for _, rt := range routes { - if !rt.Enabled { - continue - } - - if len(rt.AccessControlGroups) == 0 { - b.cache.noACGRoutes[rt.ID] = &RouteOwnerInfo{ - PeerID: peerID, - RouteID: rt.ID, - } - } else { - for _, acg := range rt.AccessControlGroups { - if b.cache.acgToRoutes[acg] == nil { - b.cache.acgToRoutes[acg] = make(map[route.ID]*RouteOwnerInfo) - } - - b.cache.acgToRoutes[acg][rt.ID] = &RouteOwnerInfo{ - PeerID: peerID, - RouteID: rt.ID, - } - } - } - } -} - -func (b *NetworkMapBuilder) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { - peer := b.cache.globalPeers[peerID] - if peer == nil { - return enabledRoutes, disabledRoutes - } - - seenRoute := make(map[route.ID]struct{}) - - takeRoute := func(r *route.Route, id string) { - if _, ok := seenRoute[r.ID]; ok { - return - } - seenRoute[r.ID] = struct{}{} - - if r.Enabled { - // maybe here is some mess - here we store peer key (see comment below) - r.Peer = peer.Key - enabledRoutes = append(enabledRoutes, r) - return - } - disabledRoutes = append(disabledRoutes, r) - } - - peerGroups := b.cache.peerToGroups[peerID] - for _, groupID := range peerGroups { - groupRoutes := b.cache.groupToRoutes[groupID] - for _, r := range groupRoutes { - newPeerRoute := r.Copy() - // and here we store peer ID - this logic is taken from original account.getRoutingPeerRoutes - newPeerRoute.Peer = peerID - newPeerRoute.PeerGroups = nil - newPeerRoute.ID = route.ID(string(r.ID) + ":" + peerID) - takeRoute(newPeerRoute, peerID) - } - } - for _, r := range b.cache.peerToRoutes[peerID] { - takeRoute(r.Copy(), peerID) - } - return enabledRoutes, disabledRoutes -} - -func (b *NetworkMapBuilder) getPeerRoutesFirewallRules(account *Account, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule { - routesFirewallRules := make([]*RouteFirewallRule, 0) - - enabledRoutes, _ := b.getRoutingPeerRoutes(peerID) - for _, route := range enabledRoutes { - if len(route.AccessControlGroups) == 0 { - defaultPermit := getDefaultPermit(route) - routesFirewallRules = append(routesFirewallRules, defaultPermit...) - continue - } - - distributionPeers := b.getDistributionGroupsPeers(route) - - for _, accessGroup := range route.AccessControlGroups { - policies := b.getAllRoutePoliciesFromGroups([]string{accessGroup}) - - rules := b.getRouteFirewallRules(peerID, policies, route, validatedPeersMap, distributionPeers, account) - routesFirewallRules = append(routesFirewallRules, rules...) - } - } - - return routesFirewallRules -} - -func (b *NetworkMapBuilder) getDistributionGroupsPeers(route *route.Route) map[string]struct{} { - distPeers := make(map[string]struct{}) - for _, id := range route.Groups { - groupPeers := b.cache.groupToPeers[id] - if groupPeers == nil { - continue - } - - for _, pID := range groupPeers { - distPeers[pID] = struct{}{} - } - } - return distPeers -} - -func (b *NetworkMapBuilder) getAllRoutePoliciesFromGroups(accessControlGroups []string) []*Policy { - routePolicies := make(map[string]*Policy) - - for _, groupID := range accessControlGroups { - candidatePolicies := b.cache.groupToPolicies[groupID] - - for _, policy := range candidatePolicies { - if _, found := routePolicies[policy.ID]; found { - continue - } - policyRules := b.cache.policyToRules[policy.ID] - for _, rule := range policyRules { - if slices.Contains(rule.Destinations, groupID) { - routePolicies[policy.ID] = policy - break - } - } - } - } - - return maps.Values(routePolicies) -} - -func (b *NetworkMapBuilder) getRouteFirewallRules( - peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, - distributionPeers map[string]struct{}, account *Account, -) []*RouteFirewallRule { - ctx := context.Background() - var fwRules []*RouteFirewallRule - for _, policy := range policies { - if !policy.Enabled { - continue - } - - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } - - rulePeers := b.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap, account) - - rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) - fwRules = append(fwRules, rules...) - } - } - return fwRules -} - -func (b *NetworkMapBuilder) getRulePeers( - rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, - validatedPeersMap map[string]struct{}, account *Account, -) []*nbpeer.Peer { - distPeersWithPolicy := make(map[string]struct{}) - - for _, id := range rule.Sources { - groupPeers := b.cache.groupToPeers[id] - if groupPeers == nil { - continue - } - - for _, pID := range groupPeers { - if pID == peerID { - continue - } - _, distPeer := distributionPeers[pID] - _, valid := validatedPeersMap[pID] - - if distPeer && valid && account.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) { - distPeersWithPolicy[pID] = struct{}{} - } - } - } - - if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { - _, distPeer := distributionPeers[rule.SourceResource.ID] - _, valid := validatedPeersMap[rule.SourceResource.ID] - if distPeer && valid && account.validatePostureChecksOnPeer(context.Background(), postureChecks, rule.SourceResource.ID) { - distPeersWithPolicy[rule.SourceResource.ID] = struct{}{} - } - } - - distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) - for pID := range distPeersWithPolicy { - peer := b.cache.globalPeers[pID] - if peer == nil { - continue - } - distributionGroupPeers = append(distributionGroupPeers, peer) - } - return distributionGroupPeers -} - -func (b *NetworkMapBuilder) buildPeerDNSView(account *Account, peerID string) { - peerGroups := b.cache.peerToGroups[peerID] - checkGroups := make(map[string]struct{}, len(peerGroups)) - for _, groupID := range peerGroups { - checkGroups[groupID] = struct{}{} - } - - dnsManagementStatus := b.getPeerDNSManagementStatus(account, checkGroups) - dnsConfig := &nbdns.Config{ - ServiceEnable: dnsManagementStatus, - } - - if dnsManagementStatus { - dnsConfig.NameServerGroups = b.getPeerNSGroups(account, peerID, checkGroups) - } - - b.cache.peerDNS[peerID] = dnsConfig -} - -func (b *NetworkMapBuilder) getPeerDNSManagementStatus(account *Account, checkGroups map[string]struct{}) bool { - - enabled := true - for _, groupID := range account.DNSSettings.DisabledManagementGroups { - _, found := checkGroups[groupID] - if found { - enabled = false - break - } - } - return enabled -} - -func (b *NetworkMapBuilder) getPeerNSGroups(account *Account, peerID string, checkGroups map[string]struct{}) []*nbdns.NameServerGroup { - var peerNSGroups []*nbdns.NameServerGroup - - for _, nsGroup := range account.NameServerGroups { - if !nsGroup.Enabled { - continue - } - for _, gID := range nsGroup.Groups { - _, found := checkGroups[gID] - if found { - peer := b.cache.globalPeers[peerID] - if !peerIsNameserver(peer, nsGroup) { - peerNSGroups = append(peerNSGroups, nsGroup.Copy()) - break - } - } - } - } - - return peerNSGroups -} - -func (b *NetworkMapBuilder) buildAllowedUserIDs(account *Account) map[string]struct{} { - users := make(map[string]struct{}) - for _, nbUser := range account.Users { - if !nbUser.IsBlocked() && !nbUser.IsServiceUser { - users[nbUser.Id] = struct{}{} - } - } - return users -} - -func firewallRuleProtocol(protocol PolicyRuleProtocolType) string { - if protocol == PolicyRuleProtocolNetbirdSSH { - return string(PolicyRuleProtocolTCP) - } - return string(protocol) -} - -// lock should be held -func (b *NetworkMapBuilder) updateAccountLocked(account *Account) *Account { - if account.Network.CurrentSerial() > b.account.Network.CurrentSerial() { - b.account = account - } - return b.account -} - -func (b *NetworkMapBuilder) GetPeerNetworkMap( - ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone, accountZones []*zones.Zone, - validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics, -) *NetworkMap { - start := time.Now() - - b.cache.mu.RLock() - defer b.cache.mu.RUnlock() - - account := b.account - - peer := account.GetPeer(peerID) - if peer == nil { - return &NetworkMap{Network: account.Network.Copy()} - } - - aclView := b.cache.peerACLs[peerID] - routesView := b.cache.peerRoutes[peerID] - dnsConfig := b.cache.peerDNS[peerID] - sshView := b.cache.peerSSH[peerID] - - if aclView == nil || routesView == nil || dnsConfig == nil { - return &NetworkMap{Network: account.Network.Copy()} - } - - nm := b.assembleNetworkMap(ctx, account, peer, aclView, routesView, dnsConfig, sshView, peersCustomZone, accountZones, validatedPeers) - - if metrics != nil { - objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules)) - metrics.CountNetworkMapObjects(objectCount) - metrics.CountGetPeerNetworkMapDuration(time.Since(start)) - - if objectCount > 5000 { - log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects from cache", - account.Id, objectCount) - } - } - - return nm -} - -func (b *NetworkMapBuilder) assembleNetworkMap( - ctx context.Context, account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView, - dnsConfig *nbdns.Config, sshView *PeerSSHView, peersCustomZone nbdns.CustomZone, accountZones []*zones.Zone, validatedPeers map[string]struct{}, -) *NetworkMap { - - var peersToConnect []*nbpeer.Peer - var expiredPeers []*nbpeer.Peer - - for _, peerID := range aclView.ConnectedPeerIDs { - if _, ok := validatedPeers[peerID]; !ok { - continue - } - - peer := b.cache.globalPeers[peerID] - if peer == nil { - continue - } - - expired, _ := peer.LoginExpired(account.Settings.PeerLoginExpiration) - if account.Settings.PeerLoginExpirationEnabled && expired { - expiredPeers = append(expiredPeers, peer) - } else { - peersToConnect = append(peersToConnect, peer) - } - } - - var routes []*route.Route - allRouteIDs := slices.Concat(routesView.OwnRouteIDs, routesView.NetworkResourceIDs, routesView.InheritedRouteIDs) - - for _, routeID := range allRouteIDs { - if route := b.cache.globalRoutes[routeID]; route != nil { - routes = append(routes, route) - } - } - - var firewallRules []*FirewallRule - for _, ruleID := range aclView.FirewallRuleIDs { - if rule := b.cache.globalRules[ruleID]; rule != nil { - firewallRules = append(firewallRules, rule) - } else { - log.Debugf("NetworkMapBuilder: peer %s assembling network map has no fwrule %s in globalRules", peer.ID, ruleID) - } - } - - var routesFirewallRules []*RouteFirewallRule - for _, ruleID := range routesView.RouteFirewallRuleIDs { - if rule := b.cache.globalRouteRules[ruleID]; rule != nil { - routesFirewallRules = append(routesFirewallRules, rule) - } - } - - finalDNSConfig := *dnsConfig - if finalDNSConfig.ServiceEnable { - var zones []nbdns.CustomZone - - peerGroupsSlice := b.cache.peerToGroups[peer.ID] - peerGroups := make(LookupMap, len(peerGroupsSlice)) - for _, groupID := range peerGroupsSlice { - peerGroups[groupID] = struct{}{} - } - - if peersCustomZone.Domain != "" { - records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect, expiredPeers) - zones = append(zones, nbdns.CustomZone{ - Domain: peersCustomZone.Domain, - Records: records, - }) - } - - filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups) - zones = append(zones, filteredAccountZones...) - - finalDNSConfig.CustomZones = zones - } - - nm := &NetworkMap{ - Peers: peersToConnect, - Network: account.Network.Copy(), - Routes: routes, - DNSConfig: finalDNSConfig, - OfflinePeers: expiredPeers, - FirewallRules: firewallRules, - RoutesFirewallRules: routesFirewallRules, - } - - if sshView != nil { - nm.EnableSSH = sshView.EnableSSH - nm.AuthorizedUsers = sshView.AuthorizedUsers - } - - return nm -} - -func (b *NetworkMapBuilder) generateFirewallRuleID(rule *FirewallRule) string { - var s strings.Builder - s.WriteString(fw) - s.WriteString(rule.PolicyID) - s.WriteRune(':') - s.WriteString(rule.PeerIP) - s.WriteRune(':') - s.WriteString(strconv.Itoa(rule.Direction)) - s.WriteRune(':') - s.WriteString(rule.Protocol) - s.WriteRune(':') - s.WriteString(rule.Action) - s.WriteRune(':') - s.WriteString(rule.Port) - s.WriteRune(':') - s.WriteString(strconv.Itoa(int(rule.PortRange.Start))) - s.WriteRune('-') - s.WriteString(strconv.Itoa(int(rule.PortRange.End))) - return s.String() -} - -func (b *NetworkMapBuilder) generateRouteFirewallRuleID(rule *RouteFirewallRule) string { - var s strings.Builder - s.WriteString(rfw) - s.WriteString(string(rule.RouteID)) - s.WriteRune(':') - s.WriteString(rule.Destination) - s.WriteRune(':') - s.WriteString(rule.Action) - s.WriteRune(':') - s.WriteString(strings.Join(rule.SourceRanges, ",")) - s.WriteRune(':') - s.WriteString(rule.Protocol) - s.WriteRune(':') - s.WriteString(strconv.Itoa(int(rule.Port))) - return s.String() -} - -func (b *NetworkMapBuilder) isPeerInGroups(groupIDs []string, peerGroups []string) bool { - for _, groupID := range groupIDs { - if slices.Contains(peerGroups, groupID) { - return true - } - } - return false -} - -func (b *NetworkMapBuilder) isPeerRouter(account *Account, peerID string) bool { - for _, r := range account.Routes { - if !r.Enabled { - continue - } - - if r.PeerID == peerID { - return true - } - - if peer := b.cache.globalPeers[peerID]; peer != nil { - if r.Peer == peer.Key && r.PeerID == "" { - return true - } - } - } - - routers := account.GetResourceRoutersMap() - for _, networkRouters := range routers { - if router, exists := networkRouters[peerID]; exists && router.Enabled { - return true - } - } - - return false -} - -func (b *NetworkMapBuilder) incAddPeerLoop() { - for { - b.apb.mu.Lock() - if len(b.apb.ids) == 0 { - b.apb.sg.Wait() - } - b.addPeersIncrementally() - b.apb.mu.Unlock() - } -} - -// lock on b.apb level should be held -func (b *NetworkMapBuilder) addPeersIncrementally() { - peers := slices.Clone(b.apb.ids) - clear(b.apb.ids) - b.apb.ids = b.apb.ids[:0] - latestAcc := b.apb.la - b.apb.mu.Unlock() - - tt := time.Now() - b.cache.mu.Lock() - defer b.cache.mu.Unlock() - - account := b.updateAccountLocked(latestAcc) - - log.Debugf("NetworkMapBuilder: Starting incremental add of %d peers", len(peers)) - - allUpdates := make(map[string]*PeerUpdateDelta) - - for _, peerID := range peers { - peer := account.GetPeer(peerID) - if peer == nil { - b.apb.mu.Lock() - retries := b.apb.retryCount[peerID] - b.apb.mu.Unlock() - - if retries >= maxPeerAddRetries { - log.Errorf("NetworkMapBuilder: peer %s not found in account %s after %d retries, giving up", peerID, account.Id, retries) - b.apb.mu.Lock() - delete(b.apb.retryCount, peerID) - b.apb.mu.Unlock() - continue - } - - log.Warnf("NetworkMapBuilder: peer %s not found in account %s, retry %d/%d", peerID, account.Id, retries+1, maxPeerAddRetries) - b.apb.mu.Lock() - b.apb.retryCount[peerID] = retries + 1 - b.apb.mu.Unlock() - b.enqueuePeersForIncrementalAdd(latestAcc, peerID) - continue - } - - b.apb.mu.Lock() - delete(b.apb.retryCount, peerID) - b.apb.mu.Unlock() - - b.validatedPeers[peerID] = struct{}{} - b.cache.globalPeers[peerID] = peer - - peerGroups := b.updateIndexesForNewPeer(account, peerID) - b.buildPeerACLView(account, peerID) - b.buildPeerRoutesView(account, peerID) - b.buildPeerDNSView(account, peerID) - - peerDeltas := b.collectDeltasForNewPeer(account, peerID, peerGroups) - for affectedPeerID, delta := range peerDeltas { - if existing, ok := allUpdates[affectedPeerID]; ok { - existing.mergeFrom(delta) - continue - } - allUpdates[affectedPeerID] = delta - } - } - - for affectedPeerID, delta := range allUpdates { - b.applyDeltaToPeer(account, affectedPeerID, delta) - } - - log.Debugf("NetworkMapBuilder: Added %d peers to cache, affected %d peers, took %s", len(peers), len(allUpdates), time.Since(tt)) - - b.apb.mu.Lock() - if len(b.apb.ids) > 0 { - b.apb.sg.Signal() - } -} - -func (b *NetworkMapBuilder) enqueuePeersForIncrementalAdd(acc *Account, peerIDs ...string) { - b.apb.mu.Lock() - b.apb.ids = append(b.apb.ids, peerIDs...) - if b.apb.la != nil && acc.Network.CurrentSerial() > b.apb.la.Network.CurrentSerial() { - b.apb.la = acc - } - b.apb.sg.Signal() - b.apb.mu.Unlock() -} - -func (b *NetworkMapBuilder) EnqueuePeersForIncrementalAdd(acc *Account, peerIDs ...string) { - b.enqueuePeersForIncrementalAdd(acc, peerIDs...) -} - -type ViewDelta struct { - AddedPeerIDs []string - RemovedPeerIDs []string - AddedRuleIDs []string - RemovedRuleIDs []string -} - -func (b *NetworkMapBuilder) OnPeerAddedIncremental(acc *Account, peerID string) error { - tt := time.Now() - peer := acc.GetPeer(peerID) - if peer == nil { - return fmt.Errorf("NetworkMapBuilder: peer %s not found in account", peerID) - } - - b.cache.mu.Lock() - defer b.cache.mu.Unlock() - - account := b.updateAccountLocked(acc) - - log.Debugf("NetworkMapBuilder: Adding peer %s (IP: %s) to cache", peerID, peer.IP.String()) - - b.validatedPeers[peerID] = struct{}{} - - b.cache.globalPeers[peerID] = peer - - peerGroups := b.updateIndexesForNewPeer(account, peerID) - - b.buildPeerACLView(account, peerID) - b.buildPeerRoutesView(account, peerID) - b.buildPeerDNSView(account, peerID) - - log.Debugf("NetworkMapBuilder: Adding peer %s to cache, views took %s", peerID, time.Since(tt)) - - b.incrementalUpdateAffectedPeers(account, peerID, peerGroups) - - log.Debugf("NetworkMapBuilder: Added peer %s to cache, took %s", peerID, time.Since(tt)) - - return nil -} - -func (b *NetworkMapBuilder) updateIndexesForNewPeer(account *Account, peerID string) []string { - peerGroups := make([]string, 0) - - for groupID, group := range account.Groups { - if slices.Contains(group.Peers, peerID) { - if !slices.Contains(b.cache.groupToPeers[groupID], peerID) { - b.cache.groupToPeers[groupID] = append(b.cache.groupToPeers[groupID], peerID) - } - peerGroups = append(peerGroups, groupID) - } - } - - b.cache.peerToGroups[peerID] = peerGroups - - for _, r := range account.Routes { - if !r.Enabled || b.cache.globalRoutes[r.ID] != nil { - continue - } - for _, groupID := range r.PeerGroups { - if !slices.Contains(b.cache.groupToRoutes[groupID], r) { - b.cache.groupToRoutes[groupID] = append(b.cache.groupToRoutes[groupID], r) - } - } - if r.Peer != "" { - if peer, ok := b.cache.globalPeers[r.Peer]; ok { - if !slices.Contains(b.cache.peerToRoutes[peer.ID], r) { - b.cache.peerToRoutes[peer.ID] = append(b.cache.peerToRoutes[peer.ID], r) - } - } - } - b.cache.globalRoutes[r.ID] = r - } - - return peerGroups -} - -func (b *NetworkMapBuilder) incrementalUpdateAffectedPeers(account *Account, newPeerID string, peerGroups []string) { - updates := b.collectDeltasForNewPeer(account, newPeerID, peerGroups) - for affectedPeerID, delta := range updates { - b.applyDeltaToPeer(account, affectedPeerID, delta) - } -} - -func (b *NetworkMapBuilder) collectDeltasForNewPeer(account *Account, newPeerID string, peerGroups []string) map[string]*PeerUpdateDelta { - updates := b.calculateIncrementalUpdates(account, newPeerID, peerGroups) - - if b.isPeerRouter(account, newPeerID) { - affectedByRoutes := b.findPeersAffectedByNewRouter(account, newPeerID, peerGroups) - for affectedPeerID := range affectedByRoutes { - if affectedPeerID == newPeerID { - continue - } - if _, exists := updates[affectedPeerID]; !exists { - updates[affectedPeerID] = &PeerUpdateDelta{ - PeerID: affectedPeerID, - RebuildRoutesView: true, - } - } else { - updates[affectedPeerID].RebuildRoutesView = true - } - } - } - - return updates -} - -func (b *NetworkMapBuilder) findPeersAffectedByNewRouter(account *Account, newRouterID string, routerGroups []string) map[string]struct{} { - affected := make(map[string]struct{}) - enabledRoutes, _ := b.getRoutingPeerRoutes(newRouterID) - - for _, route := range enabledRoutes { - for _, distGroupID := range route.Groups { - if peers := b.cache.groupToPeers[distGroupID]; peers != nil { - for _, peerID := range peers { - if peerID != newRouterID { - affected[peerID] = struct{}{} - } - } - } - } - - for _, peerGroupID := range route.PeerGroups { - if peers := b.cache.groupToPeers[peerGroupID]; peers != nil { - for _, peerID := range peers { - if peerID != newRouterID { - affected[peerID] = struct{}{} - } - } - } - } - } - - for _, route := range account.Routes { - if !route.Enabled { - continue - } - - routerInPeerGroups := false - for _, peerGroupID := range route.PeerGroups { - if slices.Contains(routerGroups, peerGroupID) { - routerInPeerGroups = true - break - } - } - - if routerInPeerGroups { - for _, distGroupID := range route.Groups { - if peers := b.cache.groupToPeers[distGroupID]; peers != nil { - for _, peerID := range peers { - affected[peerID] = struct{}{} - } - } - } - } - } - - return affected -} - -func (b *NetworkMapBuilder) calculateIncrementalUpdates(account *Account, newPeerID string, peerGroups []string) map[string]*PeerUpdateDelta { - updates := make(map[string]*PeerUpdateDelta) - ctx := context.Background() - - groupAllLn := 0 - if allGroup, err := account.GetGroupAll(); err == nil { - groupAllLn = len(allGroup.Peers) - 1 - } - - newPeer := b.cache.globalPeers[newPeerID] - if newPeer == nil { - return updates - } - - for _, policy := range account.Policies { - if !policy.Enabled { - continue - } - - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } - var peerInSources, peerInDestinations bool - - if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID == newPeerID { - peerInSources = true - } else { - peerInSources = b.isPeerInGroups(rule.Sources, peerGroups) - } - - if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID == newPeerID { - peerInDestinations = true - } else { - peerInDestinations = b.isPeerInGroups(rule.Destinations, peerGroups) - } - - if peerInSources { - if len(rule.Destinations) > 0 { - b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) - } - if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { - b.addUpdateForDirectPeerResource(updates, rule.DestinationResource.ID, newPeerID, rule, FirewallRuleDirectionIN) - } - } - - if peerInDestinations { - if len(rule.Sources) > 0 { - b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) - } - if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { - b.addUpdateForDirectPeerResource(updates, rule.SourceResource.ID, newPeerID, rule, FirewallRuleDirectionOUT) - } - } - - if rule.Bidirectional { - if peerInSources { - if len(rule.Destinations) > 0 { - b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) - } - if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { - b.addUpdateForDirectPeerResource(updates, rule.DestinationResource.ID, newPeerID, rule, FirewallRuleDirectionOUT) - } - } - if peerInDestinations { - if len(rule.Sources) > 0 { - b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) - } - if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { - b.addUpdateForDirectPeerResource(updates, rule.SourceResource.ID, newPeerID, rule, FirewallRuleDirectionIN) - } - } - } - } - } - - b.calculateRouteFirewallUpdates(newPeerID, newPeer, peerGroups, updates) - - b.calculateNetworkResourceFirewallUpdates(ctx, account, newPeerID, newPeer, peerGroups, updates) - - b.calculateNewRouterNetworkResourceUpdates(ctx, account, newPeerID, updates) - - return updates -} - -func (b *NetworkMapBuilder) calculateNewRouterNetworkResourceUpdates( - ctx context.Context, account *Account, newPeerID string, - updates map[string]*PeerUpdateDelta, -) { - resourceRouters := b.cache.resourceRouters - - for networkID, routers := range resourceRouters { - router, isRouter := routers[newPeerID] - if !isRouter || !router.Enabled { - continue - } - - for _, resource := range b.cache.globalResources { - if resource.NetworkID != networkID { - continue - } - - policies := b.cache.resourcePolicies[resource.ID] - if len(policies) == 0 { - continue - } - - peersWithAccess := make(map[string]struct{}) - - for _, policy := range policies { - if !policy.Enabled { - continue - } - - sourceGroups := policy.SourceGroups() - for _, sourceGroup := range sourceGroups { - groupPeers := b.cache.groupToPeers[sourceGroup] - for _, peerID := range groupPeers { - if peerID == newPeerID { - continue - } - - if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { - peersWithAccess[peerID] = struct{}{} - } - } - } - } - - for peerID := range peersWithAccess { - delta := updates[peerID] - if delta == nil { - delta = &PeerUpdateDelta{ - PeerID: peerID, - } - updates[peerID] = delta - } - - if !slices.Contains(delta.AddConnectedPeers, newPeerID) { - delta.AddConnectedPeers = append(delta.AddConnectedPeers, newPeerID) - } - - delta.RebuildRoutesView = true - } - } - } -} - -func (b *NetworkMapBuilder) calculateRouteFirewallUpdates( - newPeerID string, newPeer *nbpeer.Peer, - peerGroups []string, updates map[string]*PeerUpdateDelta, -) { - processedPeerRoutes := make(map[string]map[route.ID]struct{}) - - for routeID, info := range b.cache.noACGRoutes { - if info.PeerID == newPeerID { - continue - } - - b.addRouteFirewallUpdate(updates, info.PeerID, string(routeID), newPeer.IP.String()) - - if processedPeerRoutes[info.PeerID] == nil { - processedPeerRoutes[info.PeerID] = make(map[route.ID]struct{}) - } - processedPeerRoutes[info.PeerID][routeID] = struct{}{} - } - - for _, acg := range peerGroups { - routeInfos := b.cache.acgToRoutes[acg] - if routeInfos == nil { - continue - } - - for routeID, info := range routeInfos { - if info.PeerID == newPeerID { - continue - } - - if processedRoutes, exists := processedPeerRoutes[info.PeerID]; exists { - if _, processed := processedRoutes[routeID]; processed { - continue - } - } - - b.addRouteFirewallUpdate(updates, info.PeerID, string(routeID), newPeer.IP.String()) - - if processedPeerRoutes[info.PeerID] == nil { - processedPeerRoutes[info.PeerID] = make(map[route.ID]struct{}) - } - processedPeerRoutes[info.PeerID][routeID] = struct{}{} - } - } -} - -func (b *NetworkMapBuilder) addRouteFirewallUpdate( - updates map[string]*PeerUpdateDelta, peerID string, - routeID string, sourceIP string, -) { - delta := updates[peerID] - if delta == nil { - delta = &PeerUpdateDelta{ - PeerID: peerID, - UpdateRouteFirewallRules: make([]*RouteFirewallRuleUpdate, 0), - } - updates[peerID] = delta - } - - for _, existing := range delta.UpdateRouteFirewallRules { - if existing.RuleID == routeID && existing.AddSourceIP == sourceIP { - return - } - } - - delta.UpdateRouteFirewallRules = append(delta.UpdateRouteFirewallRules, &RouteFirewallRuleUpdate{ - RuleID: routeID, - AddSourceIP: sourceIP, - }) -} - -func (b *NetworkMapBuilder) calculateNetworkResourceFirewallUpdates( - ctx context.Context, account *Account, newPeerID string, - newPeer *nbpeer.Peer, peerGroups []string, updates map[string]*PeerUpdateDelta, -) { - for _, resource := range b.cache.globalResources { - resourcePolicies := b.cache.resourcePolicies - resourceRouters := b.cache.resourceRouters - - policies := resourcePolicies[resource.ID] - peerHasAccess := false - - for _, policy := range policies { - if !policy.Enabled { - continue - } - - sourceGroups := policy.SourceGroups() - for _, sourceGroup := range sourceGroups { - if slices.Contains(peerGroups, sourceGroup) { - if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, newPeerID) { - peerHasAccess = true - break - } - } - } - - if peerHasAccess { - break - } - } - - if !peerHasAccess { - continue - } - - networkRouters := resourceRouters[resource.NetworkID] - for routerPeerID, router := range networkRouters { - if !router.Enabled || routerPeerID == newPeerID { - continue - } - - delta := updates[routerPeerID] - if delta == nil { - delta = &PeerUpdateDelta{ - PeerID: routerPeerID, - } - updates[routerPeerID] = delta - } - - if !slices.Contains(delta.AddConnectedPeers, newPeerID) { - delta.AddConnectedPeers = append(delta.AddConnectedPeers, newPeerID) - } - - delta.RebuildRoutesView = true - } - } -} - -type PeerUpdateDelta struct { - PeerID string - AddConnectedPeers []string - AddFirewallRules []*FirewallRuleDelta - AddRoutes []route.ID - UpdateRouteFirewallRules []*RouteFirewallRuleUpdate - UpdateDNS bool - RebuildRoutesView bool -} - -func (d *PeerUpdateDelta) mergeFrom(other *PeerUpdateDelta) { - for _, peerID := range other.AddConnectedPeers { - if !slices.Contains(d.AddConnectedPeers, peerID) { - d.AddConnectedPeers = append(d.AddConnectedPeers, peerID) - } - } - - existingRuleIDs := make(map[string]struct{}, len(d.AddFirewallRules)) - for _, rule := range d.AddFirewallRules { - existingRuleIDs[rule.RuleID] = struct{}{} - } - for _, rule := range other.AddFirewallRules { - if _, exists := existingRuleIDs[rule.RuleID]; !exists { - d.AddFirewallRules = append(d.AddFirewallRules, rule) - existingRuleIDs[rule.RuleID] = struct{}{} - } - } - - for _, routeID := range other.AddRoutes { - if !slices.Contains(d.AddRoutes, routeID) { - d.AddRoutes = append(d.AddRoutes, routeID) - } - } - - existingRouteUpdates := make(map[string]map[string]struct{}) - for _, update := range d.UpdateRouteFirewallRules { - if existingRouteUpdates[update.RuleID] == nil { - existingRouteUpdates[update.RuleID] = make(map[string]struct{}) - } - existingRouteUpdates[update.RuleID][update.AddSourceIP] = struct{}{} - } - for _, update := range other.UpdateRouteFirewallRules { - if existingRouteUpdates[update.RuleID] == nil { - existingRouteUpdates[update.RuleID] = make(map[string]struct{}) - } - if _, exists := existingRouteUpdates[update.RuleID][update.AddSourceIP]; !exists { - d.UpdateRouteFirewallRules = append(d.UpdateRouteFirewallRules, update) - existingRouteUpdates[update.RuleID][update.AddSourceIP] = struct{}{} - } - } - - if other.UpdateDNS { - d.UpdateDNS = true - } - if other.RebuildRoutesView { - d.RebuildRoutesView = true - } -} - -type FirewallRuleDelta struct { - Rule *FirewallRule - RuleID string - Direction int -} - -type RouteFirewallRuleUpdate struct { - RuleID string - AddSourceIP string -} - -func (b *NetworkMapBuilder) addUpdateForPeersInGroups( - updates map[string]*PeerUpdateDelta, groupIDs []string, newPeerID string, - rule *PolicyRule, direction int, allGroupLn int, -) { - for _, groupID := range groupIDs { - peers := b.cache.groupToPeers[groupID] - cnt := 0 - for _, peerID := range peers { - if peerID == newPeerID { - continue - } - if _, ok := b.validatedPeers[peerID]; !ok { - continue - } - cnt++ - } - all := false - if allGroupLn > 0 && cnt == allGroupLn { - all = true - } - newPeer := b.cache.globalPeers[newPeerID] - fr := &FirewallRule{ - PolicyID: rule.ID, - PeerIP: newPeer.IP.String(), - Direction: direction, - Action: string(rule.Action), - Protocol: firewallRuleProtocol(rule.Protocol), - } - for _, peerID := range peers { - if peerID == newPeerID { - continue - } - if _, ok := b.validatedPeers[peerID]; !ok { - continue - } - targetPeer := b.cache.globalPeers[peerID] - if targetPeer == nil { - continue - } - - peerIPForRule := fr.PeerIP - if all { - peerIPForRule = allPeers - } - - b.addOrUpdateFirewallRuleInDelta(updates, peerID, newPeerID, rule, direction, fr, peerIPForRule, targetPeer) - } - } -} - -func (b *NetworkMapBuilder) addUpdateForDirectPeerResource( - updates map[string]*PeerUpdateDelta, targetPeerID string, newPeerID string, - rule *PolicyRule, direction int, -) { - if targetPeerID == newPeerID { - return - } - - if _, ok := b.validatedPeers[targetPeerID]; !ok { - return - } - - newPeer := b.cache.globalPeers[newPeerID] - if newPeer == nil { - return - } - - targetPeer := b.cache.globalPeers[targetPeerID] - if targetPeer == nil { - return - } - - fr := &FirewallRule{ - PolicyID: rule.ID, - PeerIP: newPeer.IP.String(), - Direction: direction, - Action: string(rule.Action), - Protocol: firewallRuleProtocol(rule.Protocol), - } - - b.addOrUpdateFirewallRuleInDelta(updates, targetPeerID, newPeerID, rule, direction, fr, fr.PeerIP, targetPeer) -} - -func (b *NetworkMapBuilder) addOrUpdateFirewallRuleInDelta( - updates map[string]*PeerUpdateDelta, targetPeerID string, newPeerID string, - rule *PolicyRule, direction int, baseRule *FirewallRule, peerIP string, targetPeer *nbpeer.Peer, -) { - delta := updates[targetPeerID] - if delta == nil { - delta = &PeerUpdateDelta{ - PeerID: targetPeerID, - AddConnectedPeers: []string{newPeerID}, - AddFirewallRules: make([]*FirewallRuleDelta, 0), - } - updates[targetPeerID] = delta - } else if !slices.Contains(delta.AddConnectedPeers, newPeerID) { - delta.AddConnectedPeers = append(delta.AddConnectedPeers, newPeerID) - } - - baseRule.PeerIP = peerIP - - if len(rule.Ports) > 0 || len(rule.PortRanges) > 0 { - expandedRules := expandPortsAndRanges(*baseRule, rule, targetPeer) - for _, expandedRule := range expandedRules { - ruleID := b.generateFirewallRuleID(expandedRule) - delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ - Rule: expandedRule, - RuleID: ruleID, - Direction: direction, - }) - } - } else { - ruleID := b.generateFirewallRuleID(baseRule) - delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ - Rule: baseRule, - RuleID: ruleID, - Direction: direction, - }) - } -} - -func (b *NetworkMapBuilder) applyDeltaToPeer(account *Account, peerID string, delta *PeerUpdateDelta) { - if len(delta.AddConnectedPeers) > 0 || len(delta.AddFirewallRules) > 0 { - if aclView := b.cache.peerACLs[peerID]; aclView != nil { - for _, connectedPeerID := range delta.AddConnectedPeers { - if !slices.Contains(aclView.ConnectedPeerIDs, connectedPeerID) { - aclView.ConnectedPeerIDs = append(aclView.ConnectedPeerIDs, connectedPeerID) - } - } - - for _, ruleDelta := range delta.AddFirewallRules { - b.cache.globalRules[ruleDelta.RuleID] = ruleDelta.Rule - - if !slices.Contains(aclView.FirewallRuleIDs, ruleDelta.RuleID) { - aclView.FirewallRuleIDs = append(aclView.FirewallRuleIDs, ruleDelta.RuleID) - } - } - } - } - - if delta.RebuildRoutesView { - b.buildPeerRoutesView(account, peerID) - } else if len(delta.UpdateRouteFirewallRules) > 0 { - if routesView := b.cache.peerRoutes[peerID]; routesView != nil { - b.updateRouteFirewallRules(routesView, delta.UpdateRouteFirewallRules) - } - } - - if delta.UpdateDNS { - b.buildPeerDNSView(account, peerID) - } -} - -func (b *NetworkMapBuilder) updateRouteFirewallRules(routesView *PeerRoutesView, updates []*RouteFirewallRuleUpdate) { - for _, update := range updates { - for _, ruleID := range routesView.RouteFirewallRuleIDs { - rule := b.cache.globalRouteRules[ruleID] - if rule == nil { - continue - } - - if string(rule.RouteID) == update.RuleID { - if hasWildcard := slices.Contains(rule.SourceRanges, allWildcard) || slices.Contains(rule.SourceRanges, v6AllWildcard); hasWildcard { - break - } - - sourceIP := update.AddSourceIP - - if strings.Contains(sourceIP, ":") { - sourceIP += "/128" // IPv6 - } else { - sourceIP += "/32" // IPv4 - } - - if !slices.Contains(rule.SourceRanges, sourceIP) { - rule.SourceRanges = append(rule.SourceRanges, sourceIP) - } - break - } - } - } -} - -func (b *NetworkMapBuilder) OnPeerDeleted(acc *Account, peerID string) error { - b.cache.mu.Lock() - defer b.cache.mu.Unlock() - - account := b.updateAccountLocked(acc) - - deletedPeer := b.cache.globalPeers[peerID] - if deletedPeer == nil { - return fmt.Errorf("peer %s not found in cache", peerID) - } - - deletedPeerKey := deletedPeer.Key - peerGroups := b.cache.peerToGroups[peerID] - peerIP := deletedPeer.IP.String() - - log.Debugf("NetworkMapBuilder: Deleting peer %s (IP: %s) from cache", peerID, peerIP) - - delete(b.validatedPeers, peerID) - - routesToDelete := []route.ID{} - - for routeID, r := range account.Routes { - if r.Peer != deletedPeerKey && r.PeerID != peerID { - continue - } - if len(r.PeerGroups) == 0 { - routesToDelete = append(routesToDelete, routeID) - continue - } - newPeerAssigned := false - for _, groupID := range r.PeerGroups { - candidatePeerIDs := b.cache.groupToPeers[groupID] - for _, candidatePeerID := range candidatePeerIDs { - if candidatePeerID == peerID { - continue - } - if candidatePeer := b.cache.globalPeers[candidatePeerID]; candidatePeer != nil { - r.Peer = candidatePeer.Key - r.PeerID = candidatePeerID - newPeerAssigned = true - break - } - } - if newPeerAssigned { - break - } - } - - if !newPeerAssigned { - routesToDelete = append(routesToDelete, routeID) - } - } - - for _, routeID := range routesToDelete { - delete(account.Routes, routeID) - } - - delete(b.cache.peerACLs, peerID) - delete(b.cache.peerRoutes, peerID) - delete(b.cache.peerDNS, peerID) - delete(b.cache.peerSSH, peerID) - - delete(b.cache.globalPeers, peerID) - - for acg, routeMap := range b.cache.acgToRoutes { - for routeID, info := range routeMap { - if info.PeerID == peerID { - delete(routeMap, routeID) - } - } - if len(routeMap) == 0 { - delete(b.cache.acgToRoutes, acg) - } - } - - for _, groupID := range peerGroups { - if peers := b.cache.groupToPeers[groupID]; peers != nil { - b.cache.groupToPeers[groupID] = slices.DeleteFunc(peers, func(id string) bool { - return id == peerID - }) - } - } - delete(b.cache.peerToGroups, peerID) - - affectedPeers := make(map[string]struct{}) - - for _, r := range account.Routes { - for _, groupID := range r.Groups { - if peers := b.cache.groupToPeers[groupID]; peers != nil { - for _, p := range peers { - affectedPeers[p] = struct{}{} - } - } - } - - for _, groupID := range r.PeerGroups { - if peers := b.cache.groupToPeers[groupID]; peers != nil { - for _, p := range peers { - affectedPeers[p] = struct{}{} - } - } - } - } - - for affectedPeerID := range affectedPeers { - if affectedPeerID == peerID { - continue - } - b.buildPeerRoutesView(account, affectedPeerID) - } - - peersToRebuildACL := make(map[string]struct{}) - peerDeletionUpdates := b.findPeersAffectedByDeletedPeerACL(peerID, peerIP, peerGroups, peersToRebuildACL) - for affectedPeerID, updates := range peerDeletionUpdates { - b.applyDeletionUpdates(affectedPeerID, updates) - } - - for affectedPeerID := range peersToRebuildACL { - b.buildPeerACLView(account, affectedPeerID) - } - - b.cleanupUnusedRules() - - log.Debugf("NetworkMapBuilder: Deleted peer %s, affected %d other peers", peerID, len(affectedPeers)) - - return nil -} - -func (b *NetworkMapBuilder) findPeersAffectedByDeletedPeerACL( - deletedPeerID string, - peerIP string, - peerGroups []string, - peersToRebuildACL map[string]struct{}, -) map[string]*PeerDeletionUpdate { - - affected := make(map[string]*PeerDeletionUpdate) - - for peerID, aclView := range b.cache.peerACLs { - if peerID == deletedPeerID { - continue - } - - if slices.Contains(aclView.ConnectedPeerIDs, deletedPeerID) { - peersToRebuildACL[peerID] = struct{}{} - if affected[peerID] == nil { - affected[peerID] = &PeerDeletionUpdate{ - RemovePeerID: deletedPeerID, - PeerIP: peerIP, - } - } - } - } - - affectedRouteOwners := make(map[string]struct{}) - - for _, groupID := range peerGroups { - if routeMap, ok := b.cache.acgToRoutes[groupID]; ok { - for _, info := range routeMap { - if info.PeerID != deletedPeerID { - affectedRouteOwners[info.PeerID] = struct{}{} - } - } - } - } - - for _, info := range b.cache.noACGRoutes { - if info.PeerID != deletedPeerID { - affectedRouteOwners[info.PeerID] = struct{}{} - } - } - - for ownerPeerID := range affectedRouteOwners { - if affected[ownerPeerID] == nil { - affected[ownerPeerID] = &PeerDeletionUpdate{ - RemovePeerID: deletedPeerID, - PeerIP: peerIP, - RemoveFromSourceRanges: true, - } - } else { - affected[ownerPeerID].RemoveFromSourceRanges = true - } - } - - return affected -} - -type PeerDeletionUpdate struct { - RemovePeerID string - RemoveFirewallRuleIDs []string - RemoveRouteIDs []route.ID - RemoveFromSourceRanges bool - PeerIP string -} - -func (b *NetworkMapBuilder) applyDeletionUpdates(peerID string, updates *PeerDeletionUpdate) { - if routesView := b.cache.peerRoutes[peerID]; routesView != nil { - if len(updates.RemoveRouteIDs) > 0 { - routesView.NetworkResourceIDs = slices.DeleteFunc(routesView.NetworkResourceIDs, func(routeID route.ID) bool { - return slices.Contains(updates.RemoveRouteIDs, routeID) - }) - } - - if updates.RemoveFromSourceRanges { - b.removeIPFromRouteFirewallRules(routesView, updates.PeerIP) - } - } -} - -func (b *NetworkMapBuilder) removeIPFromRouteFirewallRules(routesView *PeerRoutesView, peerIP string) { - sourceIPv4 := peerIP + "/32" - sourceIPv6 := peerIP + "/128" - - rulesToRemove := []string{} - - for _, ruleID := range routesView.RouteFirewallRuleIDs { - if rule := b.cache.globalRouteRules[ruleID]; rule != nil { - rule.SourceRanges = slices.DeleteFunc(rule.SourceRanges, func(source string) bool { - return source == sourceIPv4 || source == sourceIPv6 || source == peerIP - }) - - if len(rule.SourceRanges) == 0 { - rulesToRemove = append(rulesToRemove, ruleID) - } - } - } - - if len(rulesToRemove) > 0 { - routesView.RouteFirewallRuleIDs = slices.DeleteFunc(routesView.RouteFirewallRuleIDs, func(ruleID string) bool { - return slices.Contains(rulesToRemove, ruleID) - }) - } -} - -func (b *NetworkMapBuilder) cleanupUnusedRules() { - usedFirewallRules := make(map[string]struct{}) - usedRouteRules := make(map[string]struct{}) - usedRoutes := make(map[route.ID]struct{}) - - for _, aclView := range b.cache.peerACLs { - for _, ruleID := range aclView.FirewallRuleIDs { - usedFirewallRules[ruleID] = struct{}{} - } - } - - for _, routesView := range b.cache.peerRoutes { - for _, ruleID := range routesView.RouteFirewallRuleIDs { - usedRouteRules[ruleID] = struct{}{} - } - - for _, routeID := range routesView.OwnRouteIDs { - usedRoutes[routeID] = struct{}{} - } - for _, routeID := range routesView.NetworkResourceIDs { - usedRoutes[routeID] = struct{}{} - } - } - - for ruleID := range b.cache.globalRules { - if _, used := usedFirewallRules[ruleID]; !used { - delete(b.cache.globalRules, ruleID) - } - } - - for ruleID := range b.cache.globalRouteRules { - if _, used := usedRouteRules[ruleID]; !used { - delete(b.cache.globalRouteRules, ruleID) - } - } - - for routeID := range b.cache.globalRoutes { - if _, used := usedRoutes[routeID]; !used { - delete(b.cache.globalRoutes, routeID) - } - } -} - -func (b *NetworkMapBuilder) UpdatePeer(peer *nbpeer.Peer) { - b.cache.mu.Lock() - defer b.cache.mu.Unlock() - peerStored, ok := b.cache.globalPeers[peer.ID] - if !ok { - return - } - *peerStored = *peer -} diff --git a/management/server/types/policy.go b/management/server/types/policy.go index d4e1a8816..d410aec8d 100644 --- a/management/server/types/policy.go +++ b/management/server/types/policy.go @@ -93,6 +93,44 @@ func (p *Policy) Copy() *Policy { return c } +func (p *Policy) Equal(other *Policy) bool { + if p == nil || other == nil { + return p == other + } + + if p.ID != other.ID || + p.AccountID != other.AccountID || + p.Name != other.Name || + p.Description != other.Description || + p.Enabled != other.Enabled { + return false + } + + if !stringSlicesEqualUnordered(p.SourcePostureChecks, other.SourcePostureChecks) { + return false + } + + if len(p.Rules) != len(other.Rules) { + return false + } + + otherRules := make(map[string]*PolicyRule, len(other.Rules)) + for _, r := range other.Rules { + otherRules[r.ID] = r + } + for _, r := range p.Rules { + otherRule, ok := otherRules[r.ID] + if !ok { + return false + } + if !r.Equal(otherRule) { + return false + } + } + + return true +} + // EventMeta returns activity event meta related to this policy func (p *Policy) EventMeta() map[string]any { return map[string]any{"name": p.Name} diff --git a/management/server/types/policy_test.go b/management/server/types/policy_test.go new file mode 100644 index 000000000..b1d7aabc2 --- /dev/null +++ b/management/server/types/policy_test.go @@ -0,0 +1,193 @@ +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPolicyEqual_SameRulesDifferentOrder(t *testing.T) { + a := &Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "test", + Enabled: true, + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1", Ports: []string{"80"}}, + {ID: "r2", PolicyID: "pol1", Ports: []string{"443"}}, + }, + } + b := &Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "test", + Enabled: true, + Rules: []*PolicyRule{ + {ID: "r2", PolicyID: "pol1", Ports: []string{"443"}}, + {ID: "r1", PolicyID: "pol1", Ports: []string{"80"}}, + }, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyEqual_DifferentRules(t *testing.T) { + a := &Policy{ + ID: "pol1", + Enabled: true, + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1", Ports: []string{"80"}}, + }, + } + b := &Policy{ + ID: "pol1", + Enabled: true, + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1", Ports: []string{"443"}}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyEqual_DifferentRuleCount(t *testing.T) { + a := &Policy{ + ID: "pol1", + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1"}, + }, + } + b := &Policy{ + ID: "pol1", + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1"}, + {ID: "r2", PolicyID: "pol1"}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyEqual_PostureChecksDifferentOrder(t *testing.T) { + a := &Policy{ + ID: "pol1", + SourcePostureChecks: []string{"pc3", "pc1", "pc2"}, + } + b := &Policy{ + ID: "pol1", + SourcePostureChecks: []string{"pc1", "pc2", "pc3"}, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyEqual_DifferentPostureChecks(t *testing.T) { + a := &Policy{ + ID: "pol1", + SourcePostureChecks: []string{"pc1", "pc2"}, + } + b := &Policy{ + ID: "pol1", + SourcePostureChecks: []string{"pc1", "pc3"}, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyEqual_DifferentScalarFields(t *testing.T) { + base := Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "test", + Description: "desc", + Enabled: true, + } + + other := base + other.Name = "changed" + assert.False(t, base.Equal(&other)) + + other = base + other.Enabled = false + assert.False(t, base.Equal(&other)) + + other = base + other.Description = "changed" + assert.False(t, base.Equal(&other)) +} + +func TestPolicyEqual_NilCases(t *testing.T) { + var a *Policy + var b *Policy + assert.True(t, a.Equal(b)) + + a = &Policy{ID: "pol1"} + assert.False(t, a.Equal(nil)) +} + +func TestPolicyEqual_RulesMismatchByID(t *testing.T) { + a := &Policy{ + ID: "pol1", + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1"}, + }, + } + b := &Policy{ + ID: "pol1", + Rules: []*PolicyRule{ + {ID: "r2", PolicyID: "pol1"}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyEqual_FullScenario(t *testing.T) { + a := &Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "Web Access", + Description: "Allow web access", + Enabled: true, + SourcePostureChecks: []string{"pc2", "pc1"}, + Rules: []*PolicyRule{ + { + ID: "r1", + PolicyID: "pol1", + Name: "HTTP", + Enabled: true, + Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolTCP, + Bidirectional: true, + Sources: []string{"g2", "g1"}, + Destinations: []string{"g4", "g3"}, + Ports: []string{"443", "80", "8080"}, + PortRanges: []RulePortRange{ + {Start: 8000, End: 9000}, + {Start: 80, End: 80}, + }, + }, + }, + } + b := &Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "Web Access", + Description: "Allow web access", + Enabled: true, + SourcePostureChecks: []string{"pc1", "pc2"}, + Rules: []*PolicyRule{ + { + ID: "r1", + PolicyID: "pol1", + Name: "HTTP", + Enabled: true, + Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolTCP, + Bidirectional: true, + Sources: []string{"g1", "g2"}, + Destinations: []string{"g3", "g4"}, + Ports: []string{"80", "8080", "443"}, + PortRanges: []RulePortRange{ + {Start: 80, End: 80}, + {Start: 8000, End: 9000}, + }, + }, + }, + } + assert.True(t, a.Equal(b)) +} diff --git a/management/server/types/policyrule.go b/management/server/types/policyrule.go index bb75dd555..52c494a6a 100644 --- a/management/server/types/policyrule.go +++ b/management/server/types/policyrule.go @@ -1,6 +1,8 @@ package types import ( + "slices" + "github.com/netbirdio/netbird/shared/management/proto" ) @@ -118,3 +120,106 @@ func (pm *PolicyRule) Copy() *PolicyRule { } return rule } + +func (pm *PolicyRule) Equal(other *PolicyRule) bool { + if pm == nil || other == nil { + return pm == other + } + + if pm.ID != other.ID || + pm.PolicyID != other.PolicyID || + pm.Name != other.Name || + pm.Description != other.Description || + pm.Enabled != other.Enabled || + pm.Action != other.Action || + pm.Bidirectional != other.Bidirectional || + pm.Protocol != other.Protocol || + pm.SourceResource != other.SourceResource || + pm.DestinationResource != other.DestinationResource || + pm.AuthorizedUser != other.AuthorizedUser { + return false + } + + if !stringSlicesEqualUnordered(pm.Sources, other.Sources) { + return false + } + if !stringSlicesEqualUnordered(pm.Destinations, other.Destinations) { + return false + } + if !stringSlicesEqualUnordered(pm.Ports, other.Ports) { + return false + } + if !portRangeSlicesEqualUnordered(pm.PortRanges, other.PortRanges) { + return false + } + if !authorizedGroupsEqual(pm.AuthorizedGroups, other.AuthorizedGroups) { + return false + } + + return true +} + +func stringSlicesEqualUnordered(a, b []string) bool { + if len(a) != len(b) { + return false + } + if len(a) == 0 { + return true + } + sorted1 := make([]string, len(a)) + sorted2 := make([]string, len(b)) + copy(sorted1, a) + copy(sorted2, b) + slices.Sort(sorted1) + slices.Sort(sorted2) + return slices.Equal(sorted1, sorted2) +} + +func portRangeSlicesEqualUnordered(a, b []RulePortRange) bool { + if len(a) != len(b) { + return false + } + if len(a) == 0 { + return true + } + cmp := func(x, y RulePortRange) int { + if x.Start != y.Start { + if x.Start < y.Start { + return -1 + } + return 1 + } + if x.End != y.End { + if x.End < y.End { + return -1 + } + return 1 + } + return 0 + } + sorted1 := make([]RulePortRange, len(a)) + sorted2 := make([]RulePortRange, len(b)) + copy(sorted1, a) + copy(sorted2, b) + slices.SortFunc(sorted1, cmp) + slices.SortFunc(sorted2, cmp) + return slices.EqualFunc(sorted1, sorted2, func(x, y RulePortRange) bool { + return x.Start == y.Start && x.End == y.End + }) +} + +func authorizedGroupsEqual(a, b map[string][]string) bool { + if len(a) != len(b) { + return false + } + for k, va := range a { + vb, ok := b[k] + if !ok { + return false + } + if !stringSlicesEqualUnordered(va, vb) { + return false + } + } + return true +} diff --git a/management/server/types/policyrule_test.go b/management/server/types/policyrule_test.go new file mode 100644 index 000000000..816e72abb --- /dev/null +++ b/management/server/types/policyrule_test.go @@ -0,0 +1,194 @@ +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPolicyRuleEqual_SamePortsDifferentOrder(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{"443", "80", "22"}, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{"22", "443", "80"}, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentPorts(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{"443", "80"}, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{"443", "22"}, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_SourcesDestinationsDifferentOrder(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Sources: []string{"g1", "g2", "g3"}, + Destinations: []string{"g4", "g5"}, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Sources: []string{"g3", "g1", "g2"}, + Destinations: []string{"g5", "g4"}, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentSources(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Sources: []string{"g1", "g2"}, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Sources: []string{"g1", "g3"}, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_PortRangesDifferentOrder(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + PortRanges: []RulePortRange{ + {Start: 8000, End: 9000}, + {Start: 80, End: 80}, + }, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + PortRanges: []RulePortRange{ + {Start: 80, End: 80}, + {Start: 8000, End: 9000}, + }, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentPortRanges(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + PortRanges: []RulePortRange{ + {Start: 80, End: 80}, + }, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + PortRanges: []RulePortRange{ + {Start: 80, End: 443}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_AuthorizedGroupsDifferentValueOrder(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + AuthorizedGroups: map[string][]string{ + "g1": {"u1", "u2", "u3"}, + }, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + AuthorizedGroups: map[string][]string{ + "g1": {"u3", "u1", "u2"}, + }, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentAuthorizedGroups(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + AuthorizedGroups: map[string][]string{ + "g1": {"u1"}, + }, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + AuthorizedGroups: map[string][]string{ + "g2": {"u1"}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentScalarFields(t *testing.T) { + base := PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Name: "test", + Description: "desc", + Enabled: true, + Action: PolicyTrafficActionAccept, + Bidirectional: true, + Protocol: PolicyRuleProtocolTCP, + } + + other := base + other.Name = "changed" + assert.False(t, base.Equal(&other)) + + other = base + other.Enabled = false + assert.False(t, base.Equal(&other)) + + other = base + other.Action = PolicyTrafficActionDrop + assert.False(t, base.Equal(&other)) + + other = base + other.Protocol = PolicyRuleProtocolUDP + assert.False(t, base.Equal(&other)) +} + +func TestPolicyRuleEqual_NilCases(t *testing.T) { + var a *PolicyRule + var b *PolicyRule + assert.True(t, a.Equal(b)) + + a = &PolicyRule{ID: "rule1"} + assert.False(t, a.Equal(nil)) +} + +func TestPolicyRuleEqual_EmptySlices(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{}, + Sources: nil, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: nil, + Sources: []string{}, + } + assert.True(t, a.Equal(b)) +} + diff --git a/management/server/types/settings.go b/management/server/types/settings.go index 4ea79ec72..264a018d4 100644 --- a/management/server/types/settings.go +++ b/management/server/types/settings.go @@ -46,6 +46,8 @@ type Settings struct { // NetworkRange is the custom network range for that account NetworkRange netip.Prefix `gorm:"serializer:json"` + // NetworkRangeV6 is the custom IPv6 network range for that account + NetworkRangeV6 netip.Prefix `gorm:"serializer:json"` // PeerExposeEnabled enables or disables peer-initiated service expose PeerExposeEnabled bool @@ -65,6 +67,12 @@ type Settings struct { // when false, updates require user interaction from the UI AutoUpdateAlways bool `gorm:"default:false"` + // IPv6EnabledGroups is the list of group IDs whose peers receive IPv6 overlay addresses. + // Peers not in any of these groups will not be allocated an IPv6 address. + // Empty list means IPv6 is disabled for the account. + // For new accounts this defaults to the All group. + IPv6EnabledGroups []string `gorm:"serializer:json"` + // EmbeddedIdpEnabled indicates if the embedded identity provider is enabled. // This is a runtime-only field, not stored in the database. EmbeddedIdpEnabled bool `gorm:"-"` @@ -94,8 +102,10 @@ func (s *Settings) Copy() *Settings { LazyConnectionEnabled: s.LazyConnectionEnabled, DNSDomain: s.DNSDomain, NetworkRange: s.NetworkRange, + NetworkRangeV6: s.NetworkRangeV6, AutoUpdateVersion: s.AutoUpdateVersion, AutoUpdateAlways: s.AutoUpdateAlways, + IPv6EnabledGroups: slices.Clone(s.IPv6EnabledGroups), EmbeddedIdpEnabled: s.EmbeddedIdpEnabled, LocalAuthDisabled: s.LocalAuthDisabled, } diff --git a/management/server/types/update_reason.go b/management/server/types/update_reason.go new file mode 100644 index 000000000..9d752da9a --- /dev/null +++ b/management/server/types/update_reason.go @@ -0,0 +1,37 @@ +package types + +// UpdateReason describes why an account peers update was triggered. +type UpdateReason struct { + Resource UpdateResource + Operation UpdateOperation +} + +// UpdateResource represents the kind of resource that triggered an account peers update. +type UpdateResource string + +const ( + UpdateResourceAccountSettings UpdateResource = "account_settings" + UpdateResourceDNSSettings UpdateResource = "dns_settings" + UpdateResourceGroup UpdateResource = "group" + UpdateResourceNameServerGroup UpdateResource = "nameserver_group" + UpdateResourcePolicy UpdateResource = "policy" + UpdateResourcePostureCheck UpdateResource = "posture_check" + UpdateResourceRoute UpdateResource = "route" + UpdateResourceUser UpdateResource = "user" + UpdateResourcePeer UpdateResource = "peer" + UpdateResourceNetwork UpdateResource = "network" + UpdateResourceNetworkResource UpdateResource = "network_resource" + UpdateResourceNetworkRouter UpdateResource = "network_router" + UpdateResourceService UpdateResource = "service" + UpdateResourceZone UpdateResource = "zone" + UpdateResourceZoneRecord UpdateResource = "zone_record" +) + +// UpdateOperation represents the kind of change that triggered the update. +type UpdateOperation string + +const ( + UpdateOperationCreate UpdateOperation = "create" + UpdateOperationUpdate UpdateOperation = "update" + UpdateOperationDelete UpdateOperation = "delete" +) diff --git a/management/server/types/user_invite_test.go b/management/server/types/user_invite_test.go index 09dae3800..c77fb89e2 100644 --- a/management/server/types/user_invite_test.go +++ b/management/server/types/user_invite_test.go @@ -144,8 +144,11 @@ func TestValidateInviteToken_ModifiedToken(t *testing.T) { _, plainToken, err := GenerateInviteToken() require.NoError(t, err) - // Modify one character in the secret part - modifiedToken := plainToken[:5] + "X" + plainToken[6:] + replacement := "X" + if plainToken[5] == 'X' { + replacement = "Y" + } + modifiedToken := plainToken[:5] + replacement + plainToken[6:] err = ValidateInviteToken(modifiedToken) require.Error(t, err) } diff --git a/management/server/user.go b/management/server/user.go index c1f984f2f..892d982e7 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "slices" "strings" "time" "unicode" @@ -15,6 +16,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/idp/dex" + "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -395,8 +397,8 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string return nil, status.Errorf(status.InvalidArgument, "token name can't be empty") } - if expiresIn < 1 || expiresIn > 365 { - return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365") + if expiresIn < account.PATMinExpireDays || expiresIn > account.PATMaxExpireDays { + return nil, status.Errorf(status.InvalidArgument, "expiration has to be between %d and %d", account.PATMinExpireDays, account.PATMaxExpireDays) } allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Pats, operations.Create) @@ -675,7 +677,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, if err = am.Store.IncrementNetworkSerial(ctx, accountID); err != nil { return nil, fmt.Errorf("failed to increment network serial: %w", err) } - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceUser, Operation: types.UpdateOperationUpdate}) } return updatedUsersInfo, globalErr @@ -824,6 +826,11 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact } } } + + allGroupChanges := slices.Concat(removedGroups, addedGroups) + if err := am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, allGroupChanges); err != nil { + return false, nil, nil, nil, fmt.Errorf("reconcile IPv6 for group changes: %w", err) + } } updateAccountPeers := len(userPeers) > 0 diff --git a/management/server/user_test.go b/management/server/user_test.go index 8fdfbd633..c77ea53d1 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1586,7 +1586,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1609,7 +1609,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) diff --git a/proxy/cmd/proxy/cmd/debug.go b/proxy/cmd/proxy/cmd/debug.go index 59f7a6b65..49afc7638 100644 --- a/proxy/cmd/proxy/cmd/debug.go +++ b/proxy/cmd/proxy/cmd/debug.go @@ -2,7 +2,13 @@ package cmd import ( "fmt" + "os" + "os/signal" + "path/filepath" "strconv" + "strings" + "syscall" + "time" "github.com/spf13/cobra" @@ -57,7 +63,11 @@ var debugSyncCmd = &cobra.Command{ SilenceUsage: true, } -var pingTimeout string +var ( + pingTimeout time.Duration + pingIPv4 bool + pingIPv6 bool +) var debugPingCmd = &cobra.Command{ Use: "ping [port]", @@ -99,6 +109,27 @@ var debugStopCmd = &cobra.Command{ SilenceUsage: true, } +var debugCaptureCmd = &cobra.Command{ + Use: "capture [filter expression]", + Short: "Capture packets on a client's WireGuard interface", + Long: `Captures decrypted packets flowing through a client's WireGuard interface. + +Default output is human-readable text. Use --pcap or --output for pcap binary. +Filter arguments after the account ID use BPF-like syntax. + +Examples: + netbird-proxy debug capture + netbird-proxy debug capture --duration 1m host 10.0.0.1 + netbird-proxy debug capture host 10.0.0.1 and tcp port 443 + netbird-proxy debug capture not port 22 + netbird-proxy debug capture -o capture.pcap + netbird-proxy debug capture --pcap | tcpdump -r - -n + netbird-proxy debug capture --pcap | tshark -r -`, + Args: cobra.MinimumNArgs(1), + RunE: runDebugCapture, + SilenceUsage: true, +} + func init() { debugCmd.PersistentFlags().StringVar(&debugAddr, "addr", envStringOrDefault("NB_PROXY_DEBUG_ADDRESS", "localhost:8444"), "Debug endpoint address") debugCmd.PersistentFlags().BoolVar(&jsonOutput, "json", false, "Output JSON instead of pretty format") @@ -108,7 +139,16 @@ func init() { debugStatusCmd.Flags().StringVar(&statusFilterByStatus, "filter-by-status", "", "Filter by status (idle|connecting|connected)") debugStatusCmd.Flags().StringVar(&statusFilterByConnectionType, "filter-by-connection-type", "", "Filter by connection type (P2P|Relayed)") - debugPingCmd.Flags().StringVar(&pingTimeout, "timeout", "", "Ping timeout (e.g., 10s)") + debugPingCmd.Flags().DurationVar(&pingTimeout, "timeout", 0, "Ping timeout (e.g., 10s)") + debugPingCmd.Flags().BoolVarP(&pingIPv4, "ipv4", "4", false, "Force IPv4") + debugPingCmd.Flags().BoolVarP(&pingIPv6, "ipv6", "6", false, "Force IPv6") + debugPingCmd.MarkFlagsMutuallyExclusive("ipv4", "ipv6") + + debugCaptureCmd.Flags().DurationP("duration", "d", 0, "Capture duration (0 = server default)") + debugCaptureCmd.Flags().Bool("pcap", false, "Force pcap binary output (default when --output is set)") + debugCaptureCmd.Flags().BoolP("verbose", "v", false, "Show seq/ack, TTL, window, total length (text mode)") + debugCaptureCmd.Flags().Bool("ascii", false, "Print payload as ASCII after each packet (text mode)") + debugCaptureCmd.Flags().StringP("output", "o", "", "Write pcap to file instead of stdout") debugCmd.AddCommand(debugHealthCmd) debugCmd.AddCommand(debugClientsCmd) @@ -119,6 +159,7 @@ func init() { debugCmd.AddCommand(debugLogCmd) debugCmd.AddCommand(debugStartCmd) debugCmd.AddCommand(debugStopCmd) + debugCmd.AddCommand(debugCaptureCmd) rootCmd.AddCommand(debugCmd) } @@ -157,7 +198,14 @@ func runDebugPing(cmd *cobra.Command, args []string) error { } port = p } - return getDebugClient(cmd).PingTCP(cmd.Context(), args[0], args[1], port, pingTimeout) + var ipVersion string + switch { + case pingIPv4: + ipVersion = "4" + case pingIPv6: + ipVersion = "6" + } + return getDebugClient(cmd).PingTCP(cmd.Context(), args[0], args[1], port, pingTimeout, ipVersion) } func runDebugLogLevel(cmd *cobra.Command, args []string) error { @@ -171,3 +219,84 @@ func runDebugStart(cmd *cobra.Command, args []string) error { func runDebugStop(cmd *cobra.Command, args []string) error { return getDebugClient(cmd).StopClient(cmd.Context(), args[0]) } + +func runDebugCapture(cmd *cobra.Command, args []string) error { + duration, _ := cmd.Flags().GetDuration("duration") + forcePcap, _ := cmd.Flags().GetBool("pcap") + verbose, _ := cmd.Flags().GetBool("verbose") + ascii, _ := cmd.Flags().GetBool("ascii") + outPath, _ := cmd.Flags().GetString("output") + + // Default to text. Use pcap when --pcap is set or --output is given. + wantText := !forcePcap && outPath == "" + + var filterExpr string + if len(args) > 1 { + filterExpr = strings.Join(args[1:], " ") + } + + ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + out, cleanup, err := captureOutputWriter(cmd, outPath) + if err != nil { + return err + } + defer cleanup() + + if wantText { + cmd.PrintErrln("Capturing packets... Press Ctrl+C to stop.") + } else { + cmd.PrintErrln("Capturing packets (pcap)... Press Ctrl+C to stop.") + } + + var durationStr string + if duration > 0 { + durationStr = duration.String() + } + + err = getDebugClient(cmd).Capture(ctx, debug.CaptureOptions{ + AccountID: args[0], + Duration: durationStr, + FilterExpr: filterExpr, + Text: wantText, + Verbose: verbose, + ASCII: ascii, + Output: out, + }) + if err != nil { + return err + } + + cmd.PrintErrln("\nCapture finished.") + return nil +} + +// captureOutputWriter returns the writer and cleanup function for capture output. +func captureOutputWriter(cmd *cobra.Command, outPath string) (out *os.File, cleanup func(), err error) { + if outPath != "" { + f, err := os.CreateTemp(filepath.Dir(outPath), filepath.Base(outPath)+".*.tmp") + if err != nil { + return nil, nil, fmt.Errorf("create output file: %w", err) + } + tmpPath := f.Name() + return f, func() { + if err := f.Close(); err != nil { + cmd.PrintErrf("close output file: %v\n", err) + } + if fi, err := os.Stat(tmpPath); err == nil && fi.Size() > 0 { + if err := os.Rename(tmpPath, outPath); err != nil { + cmd.PrintErrf("rename output file: %v\n", err) + } else { + cmd.PrintErrf("Wrote %s\n", outPath) + } + } else { + os.Remove(tmpPath) + } + }, nil + } + + return os.Stdout, func() { + // no cleanup needed for stdout + }, nil +} diff --git a/proxy/internal/auth/middleware.go b/proxy/internal/auth/middleware.go index 055e4510f..3b383f8b4 100644 --- a/proxy/internal/auth/middleware.go +++ b/proxy/internal/auth/middleware.go @@ -433,6 +433,7 @@ func setSessionCookie(w http.ResponseWriter, token string, expiration time.Durat http.SetCookie(w, &http.Cookie{ Name: auth.SessionCookieName, Value: token, + Path: "/", HttpOnly: true, Secure: true, SameSite: http.SameSiteLaxMode, diff --git a/proxy/internal/auth/middleware_test.go b/proxy/internal/auth/middleware_test.go index 16d09800c..2c93d7912 100644 --- a/proxy/internal/auth/middleware_test.go +++ b/proxy/internal/auth/middleware_test.go @@ -391,6 +391,15 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) { assert.Equal(t, http.SameSiteLaxMode, sessionCookie.SameSite) } +func TestSetSessionCookieHasRootPath(t *testing.T) { + w := httptest.NewRecorder() + setSessionCookie(w, "test-token", time.Hour) + + cookies := w.Result().Cookies() + require.Len(t, cookies, 1) + assert.Equal(t, "/", cookies[0].Path, "session cookie must be scoped to root so it applies to all paths") +} + func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) { mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) diff --git a/proxy/internal/debug/client.go b/proxy/internal/debug/client.go index 01b0bc8e6..09c25afb2 100644 --- a/proxy/internal/debug/client.go +++ b/proxy/internal/debug/client.go @@ -6,10 +6,12 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "net/url" "strings" "time" + ) // StatusFilters contains filter options for status queries. @@ -230,12 +232,16 @@ func (c *Client) ClientSyncResponse(ctx context.Context, accountID string) error } // PingTCP performs a TCP ping through a client. -func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout string) error { +// ipVersion may be "4", "6", or "" for automatic. +func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout time.Duration, ipVersion string) error { params := url.Values{} params.Set("host", host) params.Set("port", fmt.Sprintf("%d", port)) - if timeout != "" { - params.Set("timeout", timeout) + if timeout > 0 { + params.Set("timeout", timeout.String()) + } + if ipVersion != "" { + params.Set("ip_version", ipVersion) } path := fmt.Sprintf("/debug/clients/%s/pingtcp?%s", url.PathEscape(accountID), params.Encode()) @@ -244,11 +250,17 @@ func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, func (c *Client) printPingResult(data map[string]any) { success, _ := data["success"].(bool) + host := net.JoinHostPort(fmt.Sprint(data["host"]), fmt.Sprint(data["port"])) if success { - _, _ = fmt.Fprintf(c.out, "Success: %v:%v\n", data["host"], data["port"]) + remote, _ := data["remote"].(string) + if remote != "" && remote != host { + _, _ = fmt.Fprintf(c.out, "Success: %s (via %s)\n", host, remote) + } else { + _, _ = fmt.Fprintf(c.out, "Success: %s\n", host) + } _, _ = fmt.Fprintf(c.out, "Latency: %v\n", data["latency"]) } else { - _, _ = fmt.Fprintf(c.out, "Failed: %v:%v\n", data["host"], data["port"]) + _, _ = fmt.Fprintf(c.out, "Failed: %s\n", host) c.printError(data) } } @@ -310,6 +322,76 @@ func (c *Client) printError(data map[string]any) { } } +// CaptureOptions configures a capture request. +type CaptureOptions struct { + AccountID string + Duration string + FilterExpr string + Text bool + Verbose bool + ASCII bool + Output io.Writer +} + +// Capture streams a packet capture from the debug endpoint. The response body +// (pcap or text) is written directly to opts.Output until the server closes the +// connection or the context is cancelled. +func (c *Client) Capture(ctx context.Context, opts CaptureOptions) error { + if opts.AccountID == "" { + return fmt.Errorf("account ID is required") + } + if opts.Output == nil { + return fmt.Errorf("output writer is required") + } + + params := url.Values{} + if opts.Duration != "" { + params.Set("duration", opts.Duration) + } + if opts.FilterExpr != "" { + params.Set("filter", opts.FilterExpr) + } + if opts.Text { + params.Set("format", "text") + } + if opts.Verbose { + params.Set("verbose", "true") + } + if opts.ASCII { + params.Set("ascii", "true") + } + + path := fmt.Sprintf("/debug/clients/%s/capture", url.PathEscape(opts.AccountID)) + if len(params) > 0 { + path += "?" + params.Encode() + } + + fullURL := c.baseURL + path + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) + if err != nil { + return fmt.Errorf("create request: %w", err) + } + + // Use a separate client without timeout since captures stream for their full duration. + httpClient := &http.Client{} + resp, err := httpClient.Do(req) + if err != nil { + return fmt.Errorf("request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("server error (%d): %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + _, err = io.Copy(opts.Output, resp.Body) + if err != nil && ctx.Err() != nil { + return nil + } + return err +} + func (c *Client) fetchAndPrint(ctx context.Context, path string, printer func(map[string]any)) error { data, raw, err := c.fetch(ctx, path) if err != nil { diff --git a/proxy/internal/debug/handler.go b/proxy/internal/debug/handler.go index c507cfad9..23ca4adbb 100644 --- a/proxy/internal/debug/handler.go +++ b/proxy/internal/debug/handler.go @@ -9,6 +9,7 @@ import ( "fmt" "html/template" "maps" + "net" "net/http" "slices" "strconv" @@ -174,6 +175,8 @@ func (h *Handler) handleClientRoutes(w http.ResponseWriter, r *http.Request, pat h.handleClientStart(w, r, accountID) case "stop": h.handleClientStop(w, r, accountID) + case "capture": + h.handleCapture(w, r, accountID) default: return false } @@ -525,13 +528,18 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI } } + network := "tcp" + if v := r.URL.Query().Get("ip_version"); v == "4" || v == "6" { + network += v + } + ctx, cancel := context.WithTimeout(r.Context(), timeout) defer cancel() - address := fmt.Sprintf("%s:%d", host, port) + address := net.JoinHostPort(host, strconv.Itoa(port)) start := time.Now() - conn, err := client.Dial(ctx, "tcp", address) + conn, err := client.Dial(ctx, network, address) if err != nil { h.writeJSON(w, map[string]interface{}{ "success": false, @@ -541,18 +549,22 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI }) return } + + remote := conn.RemoteAddr().String() if err := conn.Close(); err != nil { h.logger.Debugf("close tcp ping connection: %v", err) } latency := time.Since(start) - h.writeJSON(w, map[string]interface{}{ + resp := map[string]interface{}{ "success": true, "host": host, "port": port, + "remote": remote, "latency_ms": latency.Milliseconds(), "latency": formatDuration(latency), - }) + } + h.writeJSON(w, resp) } func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { @@ -632,6 +644,81 @@ func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accou }) } +const maxCaptureDuration = 30 * time.Minute + +// handleCapture streams a pcap or text packet capture for the given client. +// +// Query params: +// +// duration: capture duration (0 or absent = max, capped at 30m) +// format: "text" for human-readable output (default: pcap) +// filter: BPF-like filter expression (e.g. "host 10.0.0.1 and tcp port 443") +func (h *Handler) handleCapture(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { + client, ok := h.provider.GetClient(accountID) + if !ok { + http.Error(w, "client not found", http.StatusNotFound) + return + } + + duration := maxCaptureDuration + if durationStr := r.URL.Query().Get("duration"); durationStr != "" { + d, err := time.ParseDuration(durationStr) + if err != nil { + http.Error(w, "invalid duration: "+err.Error(), http.StatusBadRequest) + return + } + if d < 0 { + http.Error(w, "duration must not be negative", http.StatusBadRequest) + return + } + if d > 0 { + duration = min(d, maxCaptureDuration) + } + } + + filter := r.URL.Query().Get("filter") + wantText := r.URL.Query().Get("format") == "text" + verbose := r.URL.Query().Get("verbose") == "true" + ascii := r.URL.Query().Get("ascii") == "true" + + opts := nbembed.CaptureOptions{Filter: filter, Verbose: verbose, ASCII: ascii} + if wantText { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + opts.TextOutput = w + } else { + w.Header().Set("Content-Type", "application/vnd.tcpdump.pcap") + w.Header().Set("Content-Disposition", + fmt.Sprintf("attachment; filename=capture-%s.pcap", accountID)) + opts.Output = w + } + + cs, err := client.StartCapture(opts) + if err != nil { + http.Error(w, "start capture: "+err.Error(), http.StatusServiceUnavailable) + return + } + defer cs.Stop() + + // Flush headers after setup succeeds so errors above can still set status codes. + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + + timer := time.NewTimer(duration) + defer timer.Stop() + + select { + case <-r.Context().Done(): + case <-timer.C: + } + + cs.Stop() + + stats := cs.Stats() + h.logger.Infof("capture for %s finished: %d packets, %d bytes, %d dropped", + accountID, stats.Packets, stats.Bytes, stats.Dropped) +} + func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON bool) { if !wantJSON { http.Redirect(w, r, "/debug", http.StatusSeeOther) diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index 4b1ecf922..99bbdad0c 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -201,15 +201,15 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string, // testProxyManager is a mock implementation of proxy.Manager for testing. type testProxyManager struct{} -func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *nbproxy.Capabilities) error { +func (m *testProxyManager) Connect(_ context.Context, proxyID, sessionID, _, _ string, _ *nbproxy.Capabilities) (*nbproxy.Proxy, error) { + return &nbproxy.Proxy{ID: proxyID, SessionID: sessionID, Status: "connected"}, nil +} + +func (m *testProxyManager) Disconnect(_ context.Context, _, _ string) error { return nil } -func (m *testProxyManager) Disconnect(_ context.Context, _ string) error { - return nil -} - -func (m *testProxyManager) Heartbeat(_ context.Context, _, _, _ string) error { +func (m *testProxyManager) Heartbeat(_ context.Context, _ *nbproxy.Proxy) error { return nil } @@ -364,14 +364,16 @@ func TestIntegration_ProxyConnection_HappyPath(t *testing.T) { }) require.NoError(t, err) - // Receive all mappings from the snapshot - server sends each mapping individually mappingsByID := make(map[string]*proto.ProxyMapping) - for i := 0; i < 2; i++ { + for { msg, err := stream.Recv() require.NoError(t, err) for _, m := range msg.GetMapping() { mappingsByID[m.GetId()] = m } + if msg.GetInitialSyncComplete() { + break + } } // Should receive 2 mappings total @@ -411,12 +413,14 @@ func TestIntegration_ProxyConnection_SendsClusterAddress(t *testing.T) { }) require.NoError(t, err) - // Receive all mappings - server sends each mapping individually mappings := make([]*proto.ProxyMapping, 0) - for i := 0; i < 2; i++ { + for { msg, err := stream.Recv() require.NoError(t, err) mappings = append(mappings, msg.GetMapping()...) + if msg.GetInitialSyncComplete() { + break + } } // Should receive the 2 mappings matching the cluster @@ -440,13 +444,15 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T) clusterAddress := "test.proxy.io" proxyID := "test-proxy-reconnect" - // Helper to receive all mappings from a stream - receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient, count int) []*proto.ProxyMapping { + receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping { var mappings []*proto.ProxyMapping - for i := 0; i < count; i++ { + for { msg, err := stream.Recv() require.NoError(t, err) mappings = append(mappings, msg.GetMapping()...) + if msg.GetInitialSyncComplete() { + break + } } return mappings } @@ -460,7 +466,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T) }) require.NoError(t, err) - firstMappings := receiveMappings(stream1, 2) + firstMappings := receiveMappings(stream1) cancel1() time.Sleep(100 * time.Millisecond) @@ -476,7 +482,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T) }) require.NoError(t, err) - secondMappings := receiveMappings(stream2, 2) + secondMappings := receiveMappings(stream2) // Should receive the same mappings assert.Equal(t, len(firstMappings), len(secondMappings), @@ -542,12 +548,14 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T } } - // Helper to receive and apply all mappings receiveAndApply := func(stream proto.ProxyService_GetMappingUpdateClient) { - for i := 0; i < 2; i++ { + for { msg, err := stream.Recv() require.NoError(t, err) applyMappings(msg.GetMapping()) + if msg.GetInitialSyncComplete() { + break + } } } @@ -636,12 +644,14 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T) }) require.NoError(t, err) - // Receive all mappings - server sends each mapping individually count := 0 - for i := 0; i < 2; i++ { + for { msg, err := stream.Recv() require.NoError(t, err) count += len(msg.GetMapping()) + if msg.GetInitialSyncComplete() { + break + } } mu.Lock() @@ -656,3 +666,78 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T) assert.Equal(t, 2, count, "Proxy %s should receive 2 mappings", proxyID) } } + +// TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState verifies that +// when a proxy reconnects before the old stream's cleanup runs, the new +// connection is NOT removed by the stale defer. +func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T) { + setup := setupIntegrationTest(t) + defer setup.cleanup() + + clusterAddress := "test.proxy.io" + proxyID := "test-proxy-race" + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx1, cancel1 := context.WithCancel(context.Background()) + stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{ + ProxyId: proxyID, + Version: "test-v1", + Address: clusterAddress, + }) + require.NoError(t, err) + + for { + msg, err := stream1.Recv() + require.NoError(t, err) + if msg.GetInitialSyncComplete() { + break + } + } + + require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID, + "proxy should be registered after first connection") + + ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel2() + + stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{ + ProxyId: proxyID, + Version: "test-v1", + Address: clusterAddress, + }) + require.NoError(t, err) + + for { + msg, err := stream2.Recv() + require.NoError(t, err) + if msg.GetInitialSyncComplete() { + break + } + } + + cancel1() + + time.Sleep(200 * time.Millisecond) + + assert.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID, + "proxy should still be registered after old connection cleanup — old defer must not remove new connection") + + setup.proxyService.SendServiceUpdate(&proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, + Id: "rp-1", + AccountId: "test-account-1", + Domain: "app1.test.proxy.io", + }}, + }) + + msg, err := stream2.Recv() + require.NoError(t, err, "new stream should still receive updates") + require.NotEmpty(t, msg.GetMapping(), "update should contain the mapping") + assert.Equal(t, "rp-1", msg.GetMapping()[0].GetId()) +} diff --git a/proxy/server.go b/proxy/server.go index fbd0d058e..6980e1df1 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -943,6 +943,8 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr operation := func() error { s.Logger.Debug("connecting to management mapping stream") + initialSyncDone = false + if s.healthChecker != nil { s.healthChecker.SetManagementConnected(false) } @@ -1000,6 +1002,11 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr return ctx.Err() } + var snapshotIDs map[types.ServiceID]struct{} + if !*initialSyncDone { + snapshotIDs = make(map[types.ServiceID]struct{}) + } + for { // Check for context completion to gracefully shutdown. select { @@ -1020,17 +1027,45 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr s.processMappings(ctx, msg.GetMapping()) s.Logger.Debug("Processing mapping update completed") - if !*initialSyncDone && msg.GetInitialSyncComplete() { - if s.healthChecker != nil { - s.healthChecker.SetInitialSyncComplete() + if !*initialSyncDone { + for _, m := range msg.GetMapping() { + snapshotIDs[types.ServiceID(m.GetId())] = struct{}{} + } + if msg.GetInitialSyncComplete() { + s.reconcileSnapshot(ctx, snapshotIDs) + snapshotIDs = nil + if s.healthChecker != nil { + s.healthChecker.SetInitialSyncComplete() + } + *initialSyncDone = true + s.Logger.Info("Initial mapping sync complete") } - *initialSyncDone = true - s.Logger.Info("Initial mapping sync complete") } } } } +// reconcileSnapshot removes local mappings that are absent from the snapshot. +// This ensures services deleted while the proxy was disconnected get cleaned up. +func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.ServiceID]struct{}) { + s.portMu.RLock() + var stale []*proto.ProxyMapping + for svcID, mapping := range s.lastMappings { + if _, ok := snapshotIDs[svcID]; !ok { + stale = append(stale, mapping) + } + } + s.portMu.RUnlock() + + for _, mapping := range stale { + s.Logger.WithFields(log.Fields{ + "service_id": mapping.GetId(), + "domain": mapping.GetDomain(), + }).Info("Removing stale mapping absent from snapshot") + s.removeMapping(ctx, mapping) + } +} + func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) { for _, mapping := range mappings { s.Logger.WithFields(log.Fields{ diff --git a/proxy/snapshot_reconcile_test.go b/proxy/snapshot_reconcile_test.go new file mode 100644 index 000000000..042d8df77 --- /dev/null +++ b/proxy/snapshot_reconcile_test.go @@ -0,0 +1,227 @@ +package proxy + +import ( + "context" + "io" + "testing" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/health" + "github.com/netbirdio/netbird/proxy/internal/types" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// collectStaleIDs mirrors the stale-detection logic in reconcileSnapshot +// so we can verify it without triggering removeMapping (which requires full +// server wiring). This keeps the test focused on the detection algorithm. +func collectStaleIDs(lastMappings map[types.ServiceID]*proto.ProxyMapping, snapshotIDs map[types.ServiceID]struct{}) []types.ServiceID { + var stale []types.ServiceID + for svcID := range lastMappings { + if _, ok := snapshotIDs[svcID]; !ok { + stale = append(stale, svcID) + } + } + return stale +} + +// TestStaleDetection_PartialOverlap verifies that only services absent from +// the snapshot are flagged as stale. +func TestStaleDetection_PartialOverlap(t *testing.T) { + local := map[types.ServiceID]*proto.ProxyMapping{ + "svc-1": {Id: "svc-1"}, + "svc-2": {Id: "svc-2"}, + "svc-stale-a": {Id: "svc-stale-a"}, + "svc-stale-b": {Id: "svc-stale-b"}, + } + snapshot := map[types.ServiceID]struct{}{ + "svc-1": {}, + "svc-2": {}, + "svc-3": {}, // new service, not in local + } + + stale := collectStaleIDs(local, snapshot) + assert.Len(t, stale, 2) + staleSet := make(map[types.ServiceID]struct{}) + for _, id := range stale { + staleSet[id] = struct{}{} + } + assert.Contains(t, staleSet, types.ServiceID("svc-stale-a")) + assert.Contains(t, staleSet, types.ServiceID("svc-stale-b")) +} + +// TestStaleDetection_AllStale verifies an empty snapshot flags everything. +func TestStaleDetection_AllStale(t *testing.T) { + local := map[types.ServiceID]*proto.ProxyMapping{ + "svc-1": {Id: "svc-1"}, + "svc-2": {Id: "svc-2"}, + } + stale := collectStaleIDs(local, map[types.ServiceID]struct{}{}) + assert.Len(t, stale, 2) +} + +// TestStaleDetection_NoneStale verifies full overlap produces no stale entries. +func TestStaleDetection_NoneStale(t *testing.T) { + local := map[types.ServiceID]*proto.ProxyMapping{ + "svc-1": {Id: "svc-1"}, + "svc-2": {Id: "svc-2"}, + } + snapshot := map[types.ServiceID]struct{}{ + "svc-1": {}, + "svc-2": {}, + } + stale := collectStaleIDs(local, snapshot) + assert.Empty(t, stale) +} + +// TestStaleDetection_EmptyLocal verifies no stale entries when local is empty. +func TestStaleDetection_EmptyLocal(t *testing.T) { + stale := collectStaleIDs( + map[types.ServiceID]*proto.ProxyMapping{}, + map[types.ServiceID]struct{}{"svc-1": {}}, + ) + assert.Empty(t, stale) +} + +// TestReconcileSnapshot_NoStale verifies reconciliation is a no-op when all +// local mappings are present in the snapshot (removeMapping is never called). +func TestReconcileSnapshot_NoStale(t *testing.T) { + s := &Server{ + Logger: log.StandardLogger(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1"} + s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2"} + + snapshotIDs := map[types.ServiceID]struct{}{ + "svc-1": {}, + "svc-2": {}, + } + // This should not panic — no stale entries means removeMapping is never called. + s.reconcileSnapshot(context.Background(), snapshotIDs) + + assert.Len(t, s.lastMappings, 2, "no mappings should be removed when all are in snapshot") +} + +// TestReconcileSnapshot_EmptyLocal verifies reconciliation is a no-op with +// no local mappings. +func TestReconcileSnapshot_EmptyLocal(t *testing.T) { + s := &Server{ + Logger: log.StandardLogger(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + s.reconcileSnapshot(context.Background(), map[types.ServiceID]struct{}{"svc-1": {}}) + assert.Empty(t, s.lastMappings) +} + +// --- handleMappingStream tests for batched snapshot ID accumulation --- + +// TestHandleMappingStream_BatchedSnapshotSyncComplete verifies that sync is +// marked done only after the final InitialSyncComplete message, even when +// the snapshot arrives in multiple batches. +func TestHandleMappingStream_BatchedSnapshotSyncComplete(t *testing.T) { + checker := health.NewChecker(nil, nil) + s := &Server{ + Logger: log.StandardLogger(), + healthChecker: checker, + routerReady: closedChan(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + + stream := &mockMappingStream{ + messages: []*proto.GetMappingUpdateResponse{ + {}, // batch 1: no sync-complete + {}, // batch 2: no sync-complete + {InitialSyncComplete: true}, // batch 3: sync done + }, + } + + syncDone := false + err := s.handleMappingStream(context.Background(), stream, &syncDone) + assert.NoError(t, err) + assert.True(t, syncDone, "sync should be marked done after final batch") +} + +// TestHandleMappingStream_PostSyncDoesNotReconcile verifies that messages +// arriving after InitialSyncComplete do not trigger a second reconciliation. +func TestHandleMappingStream_PostSyncDoesNotReconcile(t *testing.T) { + s := &Server{ + Logger: log.StandardLogger(), + routerReady: closedChan(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + + // Simulate state left over from a previous sync. + s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1", AccountId: "acct-1"} + s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2", AccountId: "acct-1"} + + stream := &mockMappingStream{ + messages: []*proto.GetMappingUpdateResponse{ + {}, // post-sync empty message — must not reconcile + }, + } + + syncDone := true // sync already completed in a previous stream + err := s.handleMappingStream(context.Background(), stream, &syncDone) + require.NoError(t, err) + + assert.Len(t, s.lastMappings, 2, + "post-sync messages must not trigger reconciliation — all entries should survive") +} + +// TestHandleMappingStream_ImmediateEOF_NoReconciliation verifies that if the +// stream closes before sync completes, no reconciliation occurs. +func TestHandleMappingStream_ImmediateEOF_NoReconciliation(t *testing.T) { + s := &Server{ + Logger: log.StandardLogger(), + routerReady: closedChan(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + + s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"} + + stream := &mockMappingStream{} // no messages → immediate EOF + + syncDone := false + err := s.handleMappingStream(context.Background(), stream, &syncDone) + assert.NoError(t, err) + assert.False(t, syncDone, "sync should not be marked done on immediate EOF") + + _, hasStale := s.lastMappings["svc-stale"] + assert.True(t, hasStale, "stale mapping should remain when sync never completed") +} + +// mockErrRecvStream returns an error on the second Recv to verify +// handleMappingStream returns without completing sync. +type mockErrRecvStream struct { + mockMappingStream + calls int +} + +func (m *mockErrRecvStream) Recv() (*proto.GetMappingUpdateResponse, error) { + m.calls++ + if m.calls == 1 { + return &proto.GetMappingUpdateResponse{}, nil + } + return nil, io.ErrUnexpectedEOF +} + +func TestHandleMappingStream_ErrorMidSync_NoReconciliation(t *testing.T) { + s := &Server{ + Logger: log.StandardLogger(), + routerReady: closedChan(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + + s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"} + + syncDone := false + err := s.handleMappingStream(context.Background(), &mockErrRecvStream{}, &syncDone) + assert.Error(t, err) + assert.False(t, syncDone) + + _, hasStale := s.lastMappings["svc-stale"] + assert.True(t, hasStale, "stale mapping should remain when sync was interrupted by error") +} diff --git a/relay/test/benchmark_test.go b/relay/test/benchmark_test.go index 4dfea6da1..6b1131f1e 100644 --- a/relay/test/benchmark_test.go +++ b/relay/test/benchmark_test.go @@ -337,7 +337,7 @@ func runTurnDataTransfer(t *testing.T, testData []byte) time.Duration { func getTurnClient(t *testing.T, address string, conn net.Conn) (*turn.Client, error) { t.Helper() // Dial TURN Server - addrStr := fmt.Sprintf("%s:%d", address, 443) + addrStr := net.JoinHostPort(address, "443") fac := logging.NewDefaultLoggerFactory() //fac.DefaultLogLevel = logging.LogLevelTrace diff --git a/relay/testec2/turn_allocator.go b/relay/testec2/turn_allocator.go index fd86208df..440f6222a 100644 --- a/relay/testec2/turn_allocator.go +++ b/relay/testec2/turn_allocator.go @@ -52,7 +52,7 @@ func AllocateTurnClient(serverAddr string) *TurnConn { func getTurnClient(address string, conn net.Conn) (*turn.Client, error) { // Dial TURN Server - addrStr := fmt.Sprintf("%s:%d", address, 443) + addrStr := net.JoinHostPort(address, "443") fac := logging.NewDefaultLoggerFactory() //fac.DefaultLogLevel = logging.LogLevelTrace diff --git a/route/route.go b/route/route.go index c724e7c7d..97b9721f6 100644 --- a/route/route.go +++ b/route/route.go @@ -20,6 +20,9 @@ const ( MaxMetric = 9999 // MaxNetIDChar Max Network Identifier MaxNetIDChar = 40 + + // V6ExitSuffix is appended to a v4 exit node NetID to form its v6 counterpart. + V6ExitSuffix = "-v6" ) const ( @@ -215,3 +218,61 @@ func ParseNetwork(networkString string) (NetworkType, netip.Prefix, error) { return IPv4Network, masked, nil } + +var ( + v4Default = netip.PrefixFrom(netip.IPv4Unspecified(), 0) + v6Default = netip.PrefixFrom(netip.IPv6Unspecified(), 0) +) + +// IsV4DefaultRoute reports whether p is the IPv4 default route (0.0.0.0/0). +func IsV4DefaultRoute(p netip.Prefix) bool { return p == v4Default } + +// IsV6DefaultRoute reports whether p is the IPv6 default route (::/0). +func IsV6DefaultRoute(p netip.Prefix) bool { return p == v6Default } + +// ExpandV6ExitPairs appends the paired "-v6" exit node NetID for any v4 exit +// node (0.0.0.0/0) in ids that has a matching v6 counterpart (::/0) in routesMap. +// It modifies and returns the input slice. +func ExpandV6ExitPairs(ids []NetID, routesMap map[NetID][]*Route) []NetID { + for _, id := range ids { + rt, ok := routesMap[id] + if !ok || len(rt) == 0 || !IsV4DefaultRoute(rt[0].Network) { + continue + } + v6ID := NetID(string(id) + V6ExitSuffix) + if v6Rt, ok := routesMap[v6ID]; ok && len(v6Rt) > 0 && IsV6DefaultRoute(v6Rt[0].Network) { + if !slices.Contains(ids, v6ID) { + ids = append(ids, v6ID) + } + } + } + return ids +} + +// V6ExitMergeSet scans routesMap and returns the set of v6 exit node NetIDs +// that should be hidden from the UI because they are paired with a v4 exit node. +// A v6 ID is paired when it has suffix "-v6", its route is ::/0, and the base +// name (without "-v6") exists with route 0.0.0.0/0. +func V6ExitMergeSet(routesMap map[NetID][]*Route) map[NetID]struct{} { + merged := make(map[NetID]struct{}) + for id, rt := range routesMap { + if len(rt) == 0 { + continue + } + name := string(id) + if !IsV6DefaultRoute(rt[0].Network) || !strings.HasSuffix(name, V6ExitSuffix) { + continue + } + baseName := NetID(strings.TrimSuffix(name, V6ExitSuffix)) + if baseRt, ok := routesMap[baseName]; ok && len(baseRt) > 0 && IsV4DefaultRoute(baseRt[0].Network) { + merged[id] = struct{}{} + } + } + return merged +} + +// HasV6ExitPair reports whether id has a paired v6 exit node in the merge set. +func HasV6ExitPair(id NetID, v6Merged map[NetID]struct{}) bool { + _, ok := v6Merged[NetID(string(id)+"-v6")] + return ok +} diff --git a/route/route_test.go b/route/route_test.go new file mode 100644 index 000000000..dab707ed3 --- /dev/null +++ b/route/route_test.go @@ -0,0 +1,108 @@ +package route + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExpandV6ExitPairs(t *testing.T) { + v4ExitRoute := &Route{Network: netip.MustParsePrefix("0.0.0.0/0")} + v6ExitRoute := &Route{Network: netip.MustParsePrefix("::/0")} + regularRoute := &Route{Network: netip.MustParsePrefix("10.0.0.0/8")} + + tests := []struct { + name string + ids []NetID + routesMap map[NetID][]*Route + expected []NetID + }{ + { + name: "v4 exit node with matching v6 pair", + ids: []NetID{"exit-node"}, + routesMap: map[NetID][]*Route{ + "exit-node": {v4ExitRoute}, + "exit-node-v6": {v6ExitRoute}, + }, + expected: []NetID{"exit-node", "exit-node-v6"}, + }, + { + name: "v4 exit node without v6 pair", + ids: []NetID{"exit-node"}, + routesMap: map[NetID][]*Route{ + "exit-node": {v4ExitRoute}, + }, + expected: []NetID{"exit-node"}, + }, + { + name: "regular route is not expanded", + ids: []NetID{"office"}, + routesMap: map[NetID][]*Route{ + "office": {regularRoute}, + "office-v6": {v6ExitRoute}, + }, + expected: []NetID{"office"}, + }, + { + name: "v6 already included is not duplicated", + ids: []NetID{"exit-node", "exit-node-v6"}, + routesMap: map[NetID][]*Route{ + "exit-node": {v4ExitRoute}, + "exit-node-v6": {v6ExitRoute}, + }, + expected: []NetID{"exit-node", "exit-node-v6"}, + }, + { + name: "multiple exit nodes expanded independently", + ids: []NetID{"exit-a", "exit-b"}, + routesMap: map[NetID][]*Route{ + "exit-a": {v4ExitRoute}, + "exit-a-v6": {v6ExitRoute}, + "exit-b": {v4ExitRoute}, + "exit-b-v6": {v6ExitRoute}, + }, + expected: []NetID{"exit-a", "exit-b", "exit-a-v6", "exit-b-v6"}, + }, + { + name: "v6 suffix but not exit node network", + ids: []NetID{"office"}, + routesMap: map[NetID][]*Route{ + "office": {regularRoute}, + "office-v6": {regularRoute}, + }, + expected: []NetID{"office"}, + }, + { + name: "user-chosen name for exit node with v6 pair", + ids: []NetID{"my-exit"}, + routesMap: map[NetID][]*Route{ + "my-exit": {v4ExitRoute}, + "my-exit-v6": {v6ExitRoute}, + }, + expected: []NetID{"my-exit", "my-exit-v6"}, + }, + { + name: "real-world management-generated IDs", + ids: []NetID{"0.0.0.0/0"}, + routesMap: map[NetID][]*Route{ + "0.0.0.0/0": {v4ExitRoute}, + "0.0.0.0/0-v6": {v6ExitRoute}, + }, + expected: []NetID{"0.0.0.0/0", "0.0.0.0/0-v6"}, + }, + { + name: "empty input", + ids: []NetID{}, + routesMap: map[NetID][]*Route{}, + expected: []NetID{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExpandV6ExitPairs(tt.ids, tt.routesMap) + assert.ElementsMatch(t, tt.expected, result) + }) + } +} diff --git a/shared/auth/jwt/extractor.go b/shared/auth/jwt/extractor.go index 5806d1f4d..d113f53d5 100644 --- a/shared/auth/jwt/extractor.go +++ b/shared/auth/jwt/extractor.go @@ -146,7 +146,11 @@ func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string userJWTGroups := make([]string, 0) if claim, ok := claims[claimName]; ok { - if claimGroups, ok := claim.([]interface{}); ok { + switch claimGroups := claim.(type) { + case string: + // Some IdPs emit a single group claim as a string instead of an array. + userJWTGroups = append(userJWTGroups, claimGroups) + case []any: for _, g := range claimGroups { if group, ok := g.(string); ok { userJWTGroups = append(userJWTGroups, group) @@ -154,9 +158,11 @@ func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string log.Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g) } } + default: + log.Debugf("JWT claim %q is not a string or string array (type: %T): %v", claimName, claim, claim) } } else { - log.Debugf("JWT claim %q is not a string array", claimName) + log.Debugf("JWT claim %q is missing", claimName) } return userJWTGroups diff --git a/shared/auth/jwt/extractor_test.go b/shared/auth/jwt/extractor_test.go index 45529770d..4f8fe0007 100644 --- a/shared/auth/jwt/extractor_test.go +++ b/shared/auth/jwt/extractor_test.go @@ -249,6 +249,15 @@ func TestClaimsExtractor_ToGroups(t *testing.T) { groupClaimName: "groups", expectedGroups: []string{}, }, + { + name: "extracts single group string from claim", + claims: jwt.MapClaims{ + "sub": "user-123", + "groups": "admin", + }, + groupClaimName: "groups", + expectedGroups: []string{"admin"}, + }, { name: "handles custom claim name", claims: jwt.MapClaims{ diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index d9a1a7d65..a8e8172dc 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -138,7 +138,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { if err != nil { t.Fatal(err) } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController, nil) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController, nil, nil) if err != nil { t.Fatal(err) } diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index a01e51abc..58895b7c2 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -30,6 +30,8 @@ import ( const ConnectTimeout = 10 * time.Second +const healthCheckTimeout = 5 * time.Second + const ( // EnvMaxRecvMsgSize overrides the default gRPC max receive message size (4 MB) // for the management client connection. Value is in bytes. @@ -244,29 +246,23 @@ func (c *GrpcClient) handleJobStream( for { jobReq, err := c.receiveJobRequest(ctx, stream, serverPubKey) if err != nil { + if ctx.Err() != nil { + log.Debugf("job stream context has been canceled, this usually indicates shutdown") + return nil + } if s, ok := gstatus.FromError(err); ok { switch s.Code() { case codes.PermissionDenied: c.notifyDisconnected(err) return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer - case codes.Canceled: - log.Debugf("management connection context has been canceled, this usually indicates shutdown") - return err case codes.Unimplemented: log.Warn("Job feature is not supported by the current management server version. " + "Please update the management service to use this feature.") return nil - default: - c.notifyDisconnected(err) - log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err) - return err } - } else { - // non-gRPC error - c.notifyDisconnected(err) - log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err) - return err } + log.Warnf("job stream disconnected, will retry silently. Reason: %v", err) + return err } if jobReq == nil || len(jobReq.ID) == 0 { @@ -381,22 +377,15 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes. err = c.receiveUpdatesEvents(stream, serverPubKey, msgHandler) if err != nil { c.notifyDisconnected(err) - if s, ok := gstatus.FromError(err); ok { - switch s.Code() { - case codes.PermissionDenied: - return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer - case codes.Canceled: - log.Debugf("management connection context has been canceled, this usually indicates shutdown") - return nil - default: - log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err) - return err - } - } else { - // non-gRPC error - log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err) - return err + if ctx.Err() != nil { + log.Debugf("management connection context has been canceled, this usually indicates shutdown") + return nil } + if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied { + return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer + } + log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err) + return err } return nil @@ -532,7 +521,7 @@ func (c *GrpcClient) IsHealthy() bool { case connectivity.Ready: } - ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second) + ctx, cancel := context.WithTimeout(c.ctx, healthCheckTimeout) defer cancel() _, err := c.realClient.GetServerKey(ctx, &proto.Empty{}) @@ -948,8 +937,22 @@ func infoToMetaData(info *system.Info) *proto.PeerSystemMeta { DisableFirewall: info.DisableFirewall, BlockLANAccess: info.BlockLANAccess, BlockInbound: info.BlockInbound, + DisableIPv6: info.DisableIPv6, LazyConnectionEnabled: info.LazyConnectionEnabled, }, + + Capabilities: peerCapabilities(*info), } } + +// peerCapabilities returns the capabilities this client supports. +func peerCapabilities(info system.Info) []proto.PeerCapability { + caps := []proto.PeerCapability{ + proto.PeerCapability_PeerCapabilitySourcePrefixes, + } + if !info.DisableIPv6 { + caps = append(caps, proto.PeerCapability_PeerCapabilityIPv6Overlay) + } + return caps +} diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 0b855db67..8e6ee54cc 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -341,7 +341,11 @@ components: description: Allows to define a custom network range for the account in CIDR format type: string format: cidr - example: 100.64.0.0/16 + network_range_v6: + description: Allows to define a custom IPv6 network range for the account in CIDR format. + type: string + format: cidr + example: fd00:1234:5678::/64 peer_expose_enabled: description: Enables or disables peer expose. If enabled, peers can expose local services through the reverse proxy using the CLI. type: boolean @@ -377,6 +381,12 @@ components: type: boolean readOnly: true example: false + ipv6_enabled_groups: + description: List of group IDs whose peers receive IPv6 overlay addresses. Peers not in any of these groups will not be allocated an IPv6 address. New accounts default to the All group. + type: array + items: + type: string + example: ["ch8i4ug6lnn4g9hqv7m0"] required: - peer_login_expiration_enabled - peer_login_expiration @@ -776,6 +786,11 @@ components: type: string format: ipv4 example: 100.64.0.15 + ipv6: + description: Peer's IPv6 overlay address. Omitted if IPv6 is not enabled for the account. + type: string + format: ipv6 + example: "fd00:4e42:ab12::1" required: - name - ssh_enabled @@ -795,6 +810,11 @@ components: description: Peer's IP address type: string example: 10.64.0.1 + ipv6: + description: Peer's IPv6 overlay address + type: string + format: ipv6 + example: "fd00:4e42:ab12::1" connection_ip: description: Peer's public connection IP address type: string @@ -1013,6 +1033,10 @@ components: description: Peer's IP address type: string example: 10.64.0.1 + ipv6: + description: Peer's IPv6 overlay address + type: string + example: "fd00:4e42:ab12::1" dns_label: description: Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud type: string @@ -1687,15 +1711,18 @@ components: - locations - action PeerNetworkRangeCheck: - description: Posture check for allow or deny access based on peer local network addresses + description: | + Posture check for allow or deny access based on the peer's IP addresses. A range matches when it + contains any of the peer's local network interface IPs or its public connection (NAT egress) IP, + so ranges may target private subnets, public CIDRs, or single hosts via a /32 or /128. type: object properties: ranges: - description: List of peer network ranges in CIDR notation + description: List of network ranges in CIDR notation, matched against the peer's local interface IPs and its public connection IP type: array items: type: string - example: [ "192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56" ] + example: [ "192.168.1.0/24", "10.0.0.0/8", "1.0.0.0/24", "2.2.2.2/32", "2001:db8:1234:1a00::/56" ] action: description: Action to take upon policy match type: string @@ -2917,6 +2944,7 @@ components: - okta - pocketid - microsoft + - adfs example: oidc IdentityProvider: type: object @@ -3425,6 +3453,17 @@ components: description: Display name for the admin user (defaults to email if not provided) type: string example: Admin User + create_pat: + description: If true and the server has setup-time PAT issuance enabled (NB_SETUP_PAT_ENABLED=true), create a Personal Access Token for the new owner user and return it in the response. Ignored when the server feature is disabled. + type: boolean + example: true + pat_expire_in: + description: Expiration of the Personal Access Token in days. Applies only when create_pat is true and the server feature is enabled. Defaults to 1 day when omitted. + type: integer + minimum: 1 + maximum: 365 + default: 1 + example: 30 required: - email - password @@ -3441,6 +3480,12 @@ components: description: Email address of the created user type: string example: admin@example.com + personal_access_token: + description: Plain text Personal Access Token created during setup. Present only when create_pat was requested and the NB_SETUP_PAT_ENABLED feature was enabled on the server. + type: string + format: password + readOnly: true + example: nbp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx required: - user_id - email @@ -4979,7 +5024,10 @@ paths: /api/setup: post: summary: Setup Instance - description: Creates the initial admin user for the instance. This endpoint does not require authentication but only works when setup is required (no accounts exist and embedded IDP is enabled). + description: | + Creates the initial admin user for the instance. This endpoint does not require authentication but only works when setup is required (no accounts exist and embedded IDP is enabled). + + When the management server is started with `NB_SETUP_PAT_ENABLED=true` and the request includes `create_pat: true`, the endpoint also provisions the NetBird account for the new owner user and returns the plain text Personal Access Token in `personal_access_token`. The optional `pat_expire_in` value applies only when `create_pat` is true and defaults to 1 day when omitted. If a post-user step fails, setup-created resources are rolled back when safe; if account cleanup fails, the owner user is left in place to avoid leaving an account without its admin user. tags: [ Instance ] security: [ ] requestBody: @@ -4992,6 +5040,12 @@ paths: responses: '200': description: Setup completed successfully + headers: + Cache-Control: + description: Always set to no-store because the response may contain a one-time plain text Personal Access Token. + schema: + type: string + example: no-store content: application/json: schema: diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 0317b8183..f8ea07be7 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -518,6 +518,7 @@ const ( IdentityProviderTypeOkta IdentityProviderType = "okta" IdentityProviderTypePocketid IdentityProviderType = "pocketid" IdentityProviderTypeZitadel IdentityProviderType = "zitadel" + IdentityProviderTypeAdfs IdentityProviderType = "adfs" ) // Valid indicates whether the value is a known member of the IdentityProviderType enum. @@ -537,6 +538,8 @@ func (e IdentityProviderType) Valid() bool { return true case IdentityProviderTypeZitadel: return true + case IdentityProviderTypeAdfs: + return true default: return false } @@ -1378,6 +1381,9 @@ type AccessiblePeer struct { // Ip Peer's IP address Ip string `json:"ip"` + // Ipv6 Peer's IPv6 overlay address + Ipv6 *string `json:"ipv6,omitempty"` + // LastSeen Last time peer connected to Netbird's management service LastSeen time.Time `json:"last_seen"` @@ -1462,6 +1468,9 @@ type AccountSettings struct { // GroupsPropagationEnabled Allows propagate the new user auto groups to peers that belongs to the user GroupsPropagationEnabled *bool `json:"groups_propagation_enabled,omitempty"` + // Ipv6EnabledGroups List of group IDs whose peers receive IPv6 overlay addresses. Peers not in any of these groups will not be allocated an IPv6 address. New accounts default to the All group. + Ipv6EnabledGroups *[]string `json:"ipv6_enabled_groups,omitempty"` + // JwtAllowGroups List of groups to which users are allowed access JwtAllowGroups *[]string `json:"jwt_allow_groups,omitempty"` @@ -1480,6 +1489,9 @@ type AccountSettings struct { // NetworkRange Allows to define a custom network range for the account in CIDR format NetworkRange *string `json:"network_range,omitempty"` + // NetworkRangeV6 Allows to define a custom IPv6 network range for the account in CIDR format. + NetworkRangeV6 *string `json:"network_range_v6,omitempty"` + // PeerExposeEnabled Enables or disables peer expose. If enabled, peers can expose local services through the reverse proxy using the CLI. PeerExposeEnabled bool `json:"peer_expose_enabled"` @@ -1623,7 +1635,7 @@ type Checks struct { // OsVersionCheck Posture check for the version of operating system OsVersionCheck *OSVersionCheck `json:"os_version_check,omitempty"` - // PeerNetworkRangeCheck Posture check for allow or deny access based on peer local network addresses + // PeerNetworkRangeCheck Posture check for allow or deny access based on the peer's IP addresses. A range matches when it contains any of the peer's local network interface IPs or its public connection (NAT egress) IP, so ranges may target private subnets, public CIDRs, or single hosts via a /32 or /128. PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:"peer_network_range_check,omitempty"` // ProcessCheck Posture Check for binaries exist and are running in the peer’s system @@ -3138,6 +3150,9 @@ type Peer struct { // Ip Peer's IP address Ip string `json:"ip"` + // Ipv6 Peer's IPv6 overlay address + Ipv6 *string `json:"ipv6,omitempty"` + // KernelVersion Peer's operating system kernel version KernelVersion string `json:"kernel_version"` @@ -3229,6 +3244,9 @@ type PeerBatch struct { // Ip Peer's IP address Ip string `json:"ip"` + // Ipv6 Peer's IPv6 overlay address + Ipv6 *string `json:"ipv6,omitempty"` + // KernelVersion Peer's operating system kernel version KernelVersion string `json:"kernel_version"` @@ -3309,12 +3327,12 @@ type PeerMinimum struct { Name string `json:"name"` } -// PeerNetworkRangeCheck Posture check for allow or deny access based on peer local network addresses +// PeerNetworkRangeCheck Posture check for allow or deny access based on the peer's IP addresses. A range matches when it contains any of the peer's local network interface IPs or its public connection (NAT egress) IP, so ranges may target private subnets, public CIDRs, or single hosts via a /32 or /128. type PeerNetworkRangeCheck struct { // Action Action to take upon policy match Action PeerNetworkRangeCheckAction `json:"action"` - // Ranges List of peer network ranges in CIDR notation + // Ranges List of network ranges in CIDR notation, matched against the peer's local interface IPs and its public connection IP Ranges []string `json:"ranges"` } @@ -3328,7 +3346,10 @@ type PeerRequest struct { InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"` // Ip Peer's IP address - Ip *string `json:"ip,omitempty"` + Ip *string `json:"ip,omitempty"` + + // Ipv6 Peer's IPv6 overlay address. Omitted if IPv6 is not enabled for the account. + Ipv6 *string `json:"ipv6,omitempty"` LoginExpirationEnabled bool `json:"login_expiration_enabled"` Name string `json:"name"` SshEnabled bool `json:"ssh_enabled"` @@ -4294,6 +4315,9 @@ type SetupKeyRequest struct { // SetupRequest Request to set up the initial admin user type SetupRequest struct { + // CreatePat If true and the server has setup-time PAT issuance enabled (NB_SETUP_PAT_ENABLED=true), create a Personal Access Token for the new owner user and return it in the response. Ignored when the server feature is disabled. + CreatePat *bool `json:"create_pat,omitempty"` + // Email Email address for the admin user Email string `json:"email"` @@ -4302,6 +4326,9 @@ type SetupRequest struct { // Password Password for the admin user (minimum 8 characters) Password string `json:"password"` + + // PatExpireIn Expiration of the Personal Access Token in days. Applies only when create_pat is true and the server feature is enabled. Defaults to 1 day when omitted. + PatExpireIn *int `json:"pat_expire_in,omitempty"` } // SetupResponse Response after successful instance setup @@ -4309,6 +4336,9 @@ type SetupResponse struct { // Email Email address of the created user Email string `json:"email"` + // PersonalAccessToken Plain text Personal Access Token created during setup. Present only when create_pat was requested and the NB_SETUP_PAT_ENABLED feature was enabled on the server. + PersonalAccessToken *string `json:"personal_access_token,omitempty"` + // UserId The ID of the created user UserId string `json:"user_id"` } diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 604f9c793..13f4fbc8d 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -71,6 +71,59 @@ func (JobStatus) EnumDescriptor() ([]byte, []int) { return file_management_proto_rawDescGZIP(), []int{0} } +// PeerCapability represents a feature the client binary supports. +// Reported in PeerSystemMeta.capabilities on every login/sync. +type PeerCapability int32 + +const ( + PeerCapability_PeerCapabilityUnknown PeerCapability = 0 + // Client reads SourcePrefixes instead of the deprecated PeerIP string. + PeerCapability_PeerCapabilitySourcePrefixes PeerCapability = 1 + // Client handles IPv6 overlay addresses and firewall rules. + PeerCapability_PeerCapabilityIPv6Overlay PeerCapability = 2 +) + +// Enum value maps for PeerCapability. +var ( + PeerCapability_name = map[int32]string{ + 0: "PeerCapabilityUnknown", + 1: "PeerCapabilitySourcePrefixes", + 2: "PeerCapabilityIPv6Overlay", + } + PeerCapability_value = map[string]int32{ + "PeerCapabilityUnknown": 0, + "PeerCapabilitySourcePrefixes": 1, + "PeerCapabilityIPv6Overlay": 2, + } +) + +func (x PeerCapability) Enum() *PeerCapability { + p := new(PeerCapability) + *p = x + return p +} + +func (x PeerCapability) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (PeerCapability) Descriptor() protoreflect.EnumDescriptor { + return file_management_proto_enumTypes[1].Descriptor() +} + +func (PeerCapability) Type() protoreflect.EnumType { + return &file_management_proto_enumTypes[1] +} + +func (x PeerCapability) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use PeerCapability.Descriptor instead. +func (PeerCapability) EnumDescriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{1} +} + type RuleProtocol int32 const ( @@ -113,11 +166,11 @@ func (x RuleProtocol) String() string { } func (RuleProtocol) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[1].Descriptor() + return file_management_proto_enumTypes[2].Descriptor() } func (RuleProtocol) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[1] + return &file_management_proto_enumTypes[2] } func (x RuleProtocol) Number() protoreflect.EnumNumber { @@ -126,7 +179,7 @@ func (x RuleProtocol) Number() protoreflect.EnumNumber { // Deprecated: Use RuleProtocol.Descriptor instead. func (RuleProtocol) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{1} + return file_management_proto_rawDescGZIP(), []int{2} } type RuleDirection int32 @@ -159,11 +212,11 @@ func (x RuleDirection) String() string { } func (RuleDirection) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[2].Descriptor() + return file_management_proto_enumTypes[3].Descriptor() } func (RuleDirection) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[2] + return &file_management_proto_enumTypes[3] } func (x RuleDirection) Number() protoreflect.EnumNumber { @@ -172,7 +225,7 @@ func (x RuleDirection) Number() protoreflect.EnumNumber { // Deprecated: Use RuleDirection.Descriptor instead. func (RuleDirection) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{2} + return file_management_proto_rawDescGZIP(), []int{3} } type RuleAction int32 @@ -205,11 +258,11 @@ func (x RuleAction) String() string { } func (RuleAction) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[3].Descriptor() + return file_management_proto_enumTypes[4].Descriptor() } func (RuleAction) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[3] + return &file_management_proto_enumTypes[4] } func (x RuleAction) Number() protoreflect.EnumNumber { @@ -218,7 +271,7 @@ func (x RuleAction) Number() protoreflect.EnumNumber { // Deprecated: Use RuleAction.Descriptor instead. func (RuleAction) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{3} + return file_management_proto_rawDescGZIP(), []int{4} } type ExposeProtocol int32 @@ -260,11 +313,11 @@ func (x ExposeProtocol) String() string { } func (ExposeProtocol) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[4].Descriptor() + return file_management_proto_enumTypes[5].Descriptor() } func (ExposeProtocol) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[4] + return &file_management_proto_enumTypes[5] } func (x ExposeProtocol) Number() protoreflect.EnumNumber { @@ -273,7 +326,7 @@ func (x ExposeProtocol) Number() protoreflect.EnumNumber { // Deprecated: Use ExposeProtocol.Descriptor instead. func (ExposeProtocol) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{4} + return file_management_proto_rawDescGZIP(), []int{5} } type HostConfig_Protocol int32 @@ -315,11 +368,11 @@ func (x HostConfig_Protocol) String() string { } func (HostConfig_Protocol) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[5].Descriptor() + return file_management_proto_enumTypes[6].Descriptor() } func (HostConfig_Protocol) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[5] + return &file_management_proto_enumTypes[6] } func (x HostConfig_Protocol) Number() protoreflect.EnumNumber { @@ -358,11 +411,11 @@ func (x DeviceAuthorizationFlowProvider) String() string { } func (DeviceAuthorizationFlowProvider) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[6].Descriptor() + return file_management_proto_enumTypes[7].Descriptor() } func (DeviceAuthorizationFlowProvider) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[6] + return &file_management_proto_enumTypes[7] } func (x DeviceAuthorizationFlowProvider) Number() protoreflect.EnumNumber { @@ -1201,6 +1254,7 @@ type Flags struct { EnableSSHLocalPortForwarding bool `protobuf:"varint,13,opt,name=enableSSHLocalPortForwarding,proto3" json:"enableSSHLocalPortForwarding,omitempty"` EnableSSHRemotePortForwarding bool `protobuf:"varint,14,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"` DisableSSHAuth bool `protobuf:"varint,15,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"` + DisableIPv6 bool `protobuf:"varint,16,opt,name=disableIPv6,proto3" json:"disableIPv6,omitempty"` } func (x *Flags) Reset() { @@ -1340,6 +1394,13 @@ func (x *Flags) GetDisableSSHAuth() bool { return false } +func (x *Flags) GetDisableIPv6() bool { + if x != nil { + return x.DisableIPv6 + } + return false +} + // PeerSystemMeta is machine meta data like OS and version. type PeerSystemMeta struct { state protoimpl.MessageState @@ -1363,6 +1424,7 @@ type PeerSystemMeta struct { Environment *Environment `protobuf:"bytes,15,opt,name=environment,proto3" json:"environment,omitempty"` Files []*File `protobuf:"bytes,16,rep,name=files,proto3" json:"files,omitempty"` Flags *Flags `protobuf:"bytes,17,opt,name=flags,proto3" json:"flags,omitempty"` + Capabilities []PeerCapability `protobuf:"varint,18,rep,packed,name=capabilities,proto3,enum=management.PeerCapability" json:"capabilities,omitempty"` } func (x *PeerSystemMeta) Reset() { @@ -1516,6 +1578,13 @@ func (x *PeerSystemMeta) GetFlags() *Flags { return nil } +func (x *PeerSystemMeta) GetCapabilities() []PeerCapability { + if x != nil { + return x.Capabilities + } + return nil +} + type LoginResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -2163,6 +2232,8 @@ type PeerConfig struct { Mtu int32 `protobuf:"varint,7,opt,name=mtu,proto3" json:"mtu,omitempty"` // Auto-update config AutoUpdate *AutoUpdateSettings `protobuf:"bytes,8,opt,name=autoUpdate,proto3" json:"autoUpdate,omitempty"` + // IPv6 overlay address as compact bytes: 16 bytes IP + 1 byte prefix length. + AddressV6 []byte `protobuf:"bytes,9,opt,name=address_v6,json=addressV6,proto3" json:"address_v6,omitempty"` } func (x *PeerConfig) Reset() { @@ -2253,6 +2324,13 @@ func (x *PeerConfig) GetAutoUpdate() *AutoUpdateSettings { return nil } +func (x *PeerConfig) GetAddressV6() []byte { + if x != nil { + return x.AddressV6 + } + return nil +} + type AutoUpdateSettings struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -3562,6 +3640,9 @@ type FirewallRule struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields + // Use sourcePrefixes instead. + // + // Deprecated: Do not use. PeerIP string `protobuf:"bytes,1,opt,name=PeerIP,proto3" json:"PeerIP,omitempty"` Direction RuleDirection `protobuf:"varint,2,opt,name=Direction,proto3,enum=management.RuleDirection" json:"Direction,omitempty"` Action RuleAction `protobuf:"varint,3,opt,name=Action,proto3,enum=management.RuleAction" json:"Action,omitempty"` @@ -3570,6 +3651,11 @@ type FirewallRule struct { PortInfo *PortInfo `protobuf:"bytes,6,opt,name=PortInfo,proto3" json:"PortInfo,omitempty"` // PolicyID is the ID of the policy that this rule belongs to PolicyID []byte `protobuf:"bytes,7,opt,name=PolicyID,proto3" json:"PolicyID,omitempty"` + // CustomProtocol is a custom protocol ID when Protocol is CUSTOM. + CustomProtocol uint32 `protobuf:"varint,8,opt,name=customProtocol,proto3" json:"customProtocol,omitempty"` + // Compact source IP prefixes for this rule, supersedes PeerIP. + // Each entry is 5 bytes (v4) or 17 bytes (v6): [IP bytes][1 byte prefix_len]. + SourcePrefixes [][]byte `protobuf:"bytes,9,rep,name=sourcePrefixes,proto3" json:"sourcePrefixes,omitempty"` } func (x *FirewallRule) Reset() { @@ -3604,6 +3690,7 @@ func (*FirewallRule) Descriptor() ([]byte, []int) { return file_management_proto_rawDescGZIP(), []int{41} } +// Deprecated: Do not use. func (x *FirewallRule) GetPeerIP() string { if x != nil { return x.PeerIP @@ -3653,6 +3740,20 @@ func (x *FirewallRule) GetPolicyID() []byte { return nil } +func (x *FirewallRule) GetCustomProtocol() uint32 { + if x != nil { + return x.CustomProtocol + } + return 0 +} + +func (x *FirewallRule) GetSourcePrefixes() [][]byte { + if x != nil { + return x.SourcePrefixes + } + return nil +} + type NetworkAddress struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -4542,7 +4643,7 @@ var file_management_proto_rawDesc = []byte{ 0x01, 0x28, 0x08, 0x52, 0x05, 0x65, 0x78, 0x69, 0x73, 0x74, 0x12, 0x2a, 0x0a, 0x10, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x49, 0x73, 0x52, 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x49, 0x73, 0x52, - 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x22, 0xbf, 0x05, 0x0a, 0x05, 0x46, 0x6c, 0x61, 0x67, 0x73, + 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x22, 0xe1, 0x05, 0x0a, 0x05, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, @@ -4586,551 +4687,571 @@ var file_management_proto_rawDesc = []byte{ 0x74, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x12, 0x26, 0x0a, 0x0e, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, - 0x65, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x22, 0xf2, 0x04, 0x0a, 0x0e, 0x50, 0x65, 0x65, - 0x72, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1a, 0x0a, 0x08, 0x68, - 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x68, - 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x67, 0x6f, 0x4f, 0x53, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x67, 0x6f, 0x4f, 0x53, 0x12, 0x16, 0x0a, 0x06, 0x6b, - 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6b, 0x65, 0x72, - 0x6e, 0x65, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, 0x72, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x04, 0x63, 0x6f, 0x72, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x6c, 0x61, 0x74, 0x66, - 0x6f, 0x72, 0x6d, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x6c, 0x61, 0x74, 0x66, - 0x6f, 0x72, 0x6d, 0x12, 0x0e, 0x0a, 0x02, 0x4f, 0x53, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x02, 0x4f, 0x53, 0x12, 0x26, 0x0a, 0x0e, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x56, 0x65, - 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x6e, 0x65, 0x74, - 0x62, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x75, - 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, - 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x24, 0x0a, 0x0d, 0x6b, 0x65, 0x72, - 0x6e, 0x65, 0x6c, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0d, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, - 0x1c, 0x0a, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x0a, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x46, 0x0a, - 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, - 0x73, 0x18, 0x0b, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, - 0x65, 0x73, 0x73, 0x52, 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, - 0x65, 0x73, 0x73, 0x65, 0x73, 0x12, 0x28, 0x0a, 0x0f, 0x73, 0x79, 0x73, 0x53, 0x65, 0x72, 0x69, - 0x61, 0x6c, 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, - 0x73, 0x79, 0x73, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x12, - 0x26, 0x0a, 0x0e, 0x73, 0x79, 0x73, 0x50, 0x72, 0x6f, 0x64, 0x75, 0x63, 0x74, 0x4e, 0x61, 0x6d, - 0x65, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x73, 0x79, 0x73, 0x50, 0x72, 0x6f, 0x64, - 0x75, 0x63, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x28, 0x0a, 0x0f, 0x73, 0x79, 0x73, 0x4d, 0x61, - 0x6e, 0x75, 0x66, 0x61, 0x63, 0x74, 0x75, 0x72, 0x65, 0x72, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0f, 0x73, 0x79, 0x73, 0x4d, 0x61, 0x6e, 0x75, 0x66, 0x61, 0x63, 0x74, 0x75, 0x72, 0x65, - 0x72, 0x12, 0x39, 0x0a, 0x0b, 0x65, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, 0x74, - 0x18, 0x0f, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x52, - 0x0b, 0x65, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x26, 0x0a, 0x05, - 0x66, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x10, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x6c, 0x65, 0x52, 0x05, 0x66, - 0x69, 0x6c, 0x65, 0x73, 0x12, 0x27, 0x0a, 0x05, 0x66, 0x6c, 0x61, 0x67, 0x73, 0x18, 0x11, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x52, 0x05, 0x66, 0x6c, 0x61, 0x67, 0x73, 0x22, 0xb4, 0x01, - 0x0a, 0x0d, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, - 0x3f, 0x0a, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x52, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, - 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2a, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, - 0x6b, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x52, 0x06, 0x43, 0x68, - 0x65, 0x63, 0x6b, 0x73, 0x22, 0x79, 0x0a, 0x11, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, - 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x38, 0x0a, 0x09, 0x65, - 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, - 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, - 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x65, 0x78, 0x70, 0x69, - 0x72, 0x65, 0x73, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, - 0x07, 0x0a, 0x05, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0xff, 0x01, 0x0a, 0x0d, 0x4e, 0x65, 0x74, - 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2c, 0x0a, 0x05, 0x73, 0x74, - 0x75, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x52, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x12, 0x35, 0x0a, 0x05, 0x74, 0x75, 0x72, 0x6e, - 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, - 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x74, 0x75, 0x72, 0x6e, 0x73, 0x12, - 0x2e, 0x0a, 0x06, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, - 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x06, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x12, - 0x2d, 0x0a, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6c, 0x61, - 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x12, 0x2a, - 0x0a, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x22, 0x98, 0x01, 0x0a, 0x0a, 0x48, - 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x69, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x69, 0x12, 0x3b, 0x0a, 0x08, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1f, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0x3b, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x00, 0x12, 0x07, 0x0a, - 0x03, 0x54, 0x43, 0x50, 0x10, 0x01, 0x12, 0x08, 0x0a, 0x04, 0x48, 0x54, 0x54, 0x50, 0x10, 0x02, - 0x12, 0x09, 0x0a, 0x05, 0x48, 0x54, 0x54, 0x50, 0x53, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x44, - 0x54, 0x4c, 0x53, 0x10, 0x04, 0x22, 0x6d, 0x0a, 0x0b, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x03, - 0x28, 0x09, 0x52, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x12, 0x22, 0x0a, 0x0c, 0x74, 0x6f, 0x6b, 0x65, - 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, - 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x26, 0x0a, 0x0e, - 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, - 0x74, 0x75, 0x72, 0x65, 0x22, 0xad, 0x02, 0x0a, 0x0a, 0x46, 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x03, 0x75, 0x72, 0x6c, 0x12, 0x22, 0x0a, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, - 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, 0x6b, - 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, 0x6b, - 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, - 0x65, 0x12, 0x35, 0x0a, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x08, - 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, - 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, - 0x65, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x73, 0x18, 0x06, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x73, 0x12, 0x2e, - 0x0a, 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, - 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x65, 0x78, 0x69, 0x74, - 0x4e, 0x6f, 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x24, - 0x0a, 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, - 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, - 0x74, 0x69, 0x6f, 0x6e, 0x22, 0xa3, 0x01, 0x0a, 0x09, 0x4a, 0x57, 0x54, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x12, 0x16, 0x0a, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x75, - 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x75, - 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x6b, 0x65, 0x79, 0x73, 0x4c, 0x6f, - 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x6b, 0x65, - 0x79, 0x73, 0x4c, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x6d, 0x61, - 0x78, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, - 0x0b, 0x6d, 0x61, 0x78, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67, 0x65, 0x12, 0x1c, 0x0a, 0x09, - 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, - 0x09, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x73, 0x22, 0x7d, 0x0a, 0x13, 0x50, 0x72, - 0x6f, 0x74, 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x68, - 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, - 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x1a, 0x0a, - 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0xd3, 0x02, 0x0a, 0x0a, 0x50, 0x65, - 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, - 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, - 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x03, 0x64, 0x6e, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, - 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, - 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x48, 0x0a, - 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, - 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, - 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, - 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x34, 0x0a, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, - 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, - 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, 0x6f, 0x6e, 0x6e, - 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x10, 0x0a, - 0x03, 0x6d, 0x74, 0x75, 0x18, 0x07, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x6d, 0x74, 0x75, 0x12, - 0x3e, 0x0a, 0x0a, 0x61, 0x75, 0x74, 0x6f, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, 0x08, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x41, 0x75, 0x74, 0x6f, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x74, 0x74, 0x69, - 0x6e, 0x67, 0x73, 0x52, 0x0a, 0x61, 0x75, 0x74, 0x6f, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x22, - 0x52, 0x0a, 0x12, 0x41, 0x75, 0x74, 0x6f, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x74, - 0x74, 0x69, 0x6e, 0x67, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, - 0x22, 0x0a, 0x0c, 0x61, 0x6c, 0x77, 0x61, 0x79, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0c, 0x61, 0x6c, 0x77, 0x61, 0x79, 0x73, 0x55, 0x70, 0x64, - 0x61, 0x74, 0x65, 0x22, 0xe8, 0x05, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, - 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, - 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, - 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, - 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, - 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, - 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, - 0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, - 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x4e, - 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x12, 0x40, 0x0a, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, - 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x69, 0x73, 0x61, + 0x62, 0x6c, 0x65, 0x49, 0x50, 0x76, 0x36, 0x18, 0x10, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, 0x64, + 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x49, 0x50, 0x76, 0x36, 0x22, 0xb2, 0x05, 0x0a, 0x0e, 0x50, + 0x65, 0x65, 0x72, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1a, 0x0a, + 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x67, 0x6f, 0x4f, + 0x53, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x67, 0x6f, 0x4f, 0x53, 0x12, 0x16, 0x0a, + 0x06, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6b, + 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, 0x72, 0x65, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x04, 0x63, 0x6f, 0x72, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x6c, 0x61, + 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x6c, 0x61, + 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x12, 0x0e, 0x0a, 0x02, 0x4f, 0x53, 0x18, 0x06, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x02, 0x4f, 0x53, 0x12, 0x26, 0x0a, 0x0e, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, + 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x6e, + 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1c, 0x0a, + 0x09, 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x09, 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x24, 0x0a, 0x0d, 0x6b, + 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x09, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x0d, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, + 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x0a, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, + 0x46, 0x0a, 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, + 0x73, 0x65, 0x73, 0x18, 0x0b, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, + 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, + 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x12, 0x28, 0x0a, 0x0f, 0x73, 0x79, 0x73, 0x53, 0x65, + 0x72, 0x69, 0x61, 0x6c, 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0f, 0x73, 0x79, 0x73, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x4e, 0x75, 0x6d, 0x62, 0x65, + 0x72, 0x12, 0x26, 0x0a, 0x0e, 0x73, 0x79, 0x73, 0x50, 0x72, 0x6f, 0x64, 0x75, 0x63, 0x74, 0x4e, + 0x61, 0x6d, 0x65, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x73, 0x79, 0x73, 0x50, 0x72, + 0x6f, 0x64, 0x75, 0x63, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x28, 0x0a, 0x0f, 0x73, 0x79, 0x73, + 0x4d, 0x61, 0x6e, 0x75, 0x66, 0x61, 0x63, 0x74, 0x75, 0x72, 0x65, 0x72, 0x18, 0x0e, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x0f, 0x73, 0x79, 0x73, 0x4d, 0x61, 0x6e, 0x75, 0x66, 0x61, 0x63, 0x74, 0x75, + 0x72, 0x65, 0x72, 0x12, 0x39, 0x0a, 0x0b, 0x65, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, 0x65, + 0x6e, 0x74, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, + 0x74, 0x52, 0x0b, 0x65, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x26, + 0x0a, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x10, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x10, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x6c, 0x65, 0x52, + 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x12, 0x27, 0x0a, 0x05, 0x66, 0x6c, 0x61, 0x67, 0x73, 0x18, + 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x52, 0x05, 0x66, 0x6c, 0x61, 0x67, 0x73, 0x12, + 0x3e, 0x0a, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x18, + 0x12, 0x20, 0x03, 0x28, 0x0e, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, + 0x79, 0x52, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x22, + 0xb4, 0x01, 0x0a, 0x0d, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x3f, 0x0a, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x52, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, + 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2a, 0x0a, 0x06, 0x43, 0x68, + 0x65, 0x63, 0x6b, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x52, 0x06, + 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x22, 0x79, 0x0a, 0x11, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x6b, + 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x38, 0x0a, + 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x65, 0x78, + 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, + 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, + 0x6e, 0x22, 0x07, 0x0a, 0x05, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0xff, 0x01, 0x0a, 0x0d, 0x4e, + 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2c, 0x0a, 0x05, + 0x73, 0x74, 0x75, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x52, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x12, 0x35, 0x0a, 0x05, 0x74, 0x75, + 0x72, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x65, 0x63, 0x74, 0x65, 0x64, + 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x74, 0x75, 0x72, 0x6e, + 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, + 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x06, 0x73, 0x69, 0x67, 0x6e, 0x61, + 0x6c, 0x12, 0x2d, 0x0a, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, + 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, + 0x12, 0x2a, 0x0a, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6c, 0x6f, 0x77, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x22, 0x98, 0x01, 0x0a, + 0x0a, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, + 0x72, 0x69, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x69, 0x12, 0x3b, 0x0a, + 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, + 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, + 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0x3b, 0x0a, 0x08, 0x50, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x00, 0x12, + 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x01, 0x12, 0x08, 0x0a, 0x04, 0x48, 0x54, 0x54, 0x50, + 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x48, 0x54, 0x54, 0x50, 0x53, 0x10, 0x03, 0x12, 0x08, 0x0a, + 0x04, 0x44, 0x54, 0x4c, 0x53, 0x10, 0x04, 0x22, 0x6d, 0x0a, 0x0b, 0x52, 0x65, 0x6c, 0x61, 0x79, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x18, 0x01, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x12, 0x22, 0x0a, 0x0c, 0x74, 0x6f, + 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x26, + 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, + 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0xad, 0x02, 0x0a, 0x0a, 0x46, 0x6c, 0x6f, 0x77, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x6c, 0x12, 0x22, 0x0a, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, + 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x74, + 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x74, + 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, + 0x75, 0x72, 0x65, 0x12, 0x35, 0x0a, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x52, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, + 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, + 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x73, + 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x73, + 0x12, 0x2e, 0x0a, 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, + 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x65, 0x78, + 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, + 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, + 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0xa3, 0x01, 0x0a, 0x09, 0x4a, 0x57, 0x54, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x12, 0x16, 0x0a, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, + 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x6b, 0x65, 0x79, 0x73, + 0x4c, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, + 0x6b, 0x65, 0x79, 0x73, 0x4c, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, + 0x6d, 0x61, 0x78, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x03, 0x52, 0x0b, 0x6d, 0x61, 0x78, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67, 0x65, 0x12, 0x1c, + 0x0a, 0x09, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x09, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x73, 0x22, 0x7d, 0x0a, 0x13, + 0x50, 0x72, 0x6f, 0x74, 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, + 0x0a, 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, + 0x73, 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, + 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0xf2, 0x02, 0x0a, 0x0a, + 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, + 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, + 0x72, 0x65, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x03, 0x64, 0x6e, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, + 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, + 0x48, 0x0a, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, + 0x73, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, + 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, + 0x67, 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, + 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x34, 0x0a, 0x15, 0x4c, 0x61, 0x7a, + 0x79, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, + 0x65, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, 0x6f, + 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, + 0x10, 0x0a, 0x03, 0x6d, 0x74, 0x75, 0x18, 0x07, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x6d, 0x74, + 0x75, 0x12, 0x3e, 0x0a, 0x0a, 0x61, 0x75, 0x74, 0x6f, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, + 0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x6f, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x74, + 0x74, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x0a, 0x61, 0x75, 0x74, 0x6f, 0x55, 0x70, 0x64, 0x61, 0x74, + 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x5f, 0x76, 0x36, 0x18, + 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x56, 0x36, + 0x22, 0x52, 0x0a, 0x12, 0x41, 0x75, 0x74, 0x6f, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, + 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, + 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, + 0x12, 0x22, 0x0a, 0x0c, 0x61, 0x6c, 0x77, 0x61, 0x79, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0c, 0x61, 0x6c, 0x77, 0x61, 0x79, 0x73, 0x55, 0x70, + 0x64, 0x61, 0x74, 0x65, 0x22, 0xe8, 0x05, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, + 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, + 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, + 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, - 0x65, 0x65, 0x72, 0x73, 0x12, 0x3e, 0x0a, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, - 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, - 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, - 0x75, 0x6c, 0x65, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, - 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, - 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, - 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, - 0x52, 0x75, 0x6c, 0x65, 0x52, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, - 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x6f, 0x75, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, + 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, + 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, + 0x70, 0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, + 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, + 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x12, 0x40, 0x0a, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, + 0x65, 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, + 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, + 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x3e, 0x0a, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, + 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, + 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, + 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, + 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, - 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, 0x72, - 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, - 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x44, 0x0a, 0x0f, 0x66, 0x6f, 0x72, - 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0c, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0f, - 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, - 0x2d, 0x0a, 0x07, 0x73, 0x73, 0x68, 0x41, 0x75, 0x74, 0x68, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x13, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, - 0x48, 0x41, 0x75, 0x74, 0x68, 0x52, 0x07, 0x73, 0x73, 0x68, 0x41, 0x75, 0x74, 0x68, 0x22, 0x82, - 0x02, 0x0a, 0x07, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x12, 0x20, 0x0a, 0x0b, 0x55, 0x73, - 0x65, 0x72, 0x49, 0x44, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0b, 0x55, 0x73, 0x65, 0x72, 0x49, 0x44, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x12, 0x28, 0x0a, 0x0f, - 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x55, 0x73, 0x65, 0x72, 0x73, 0x18, - 0x02, 0x20, 0x03, 0x28, 0x0c, 0x52, 0x0f, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, - 0x64, 0x55, 0x73, 0x65, 0x72, 0x73, 0x12, 0x4a, 0x0a, 0x0d, 0x6d, 0x61, 0x63, 0x68, 0x69, 0x6e, - 0x65, 0x5f, 0x75, 0x73, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x25, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x41, 0x75, - 0x74, 0x68, 0x2e, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, 0x72, 0x73, 0x45, - 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0c, 0x6d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, - 0x72, 0x73, 0x1a, 0x5f, 0x0a, 0x11, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, - 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x34, 0x0a, 0x05, 0x76, 0x61, 0x6c, - 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, - 0x72, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, - 0x02, 0x38, 0x01, 0x22, 0x2e, 0x0a, 0x12, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, - 0x65, 0x72, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x69, 0x6e, 0x64, - 0x65, 0x78, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x07, 0x69, 0x6e, 0x64, 0x65, - 0x78, 0x65, 0x73, 0x22, 0xbb, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, - 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, - 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, - 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, - 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, - 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, - 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, - 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x22, 0x0a, - 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x05, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, - 0x6e, 0x22, 0x7e, 0x0a, 0x09, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, - 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, - 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x33, 0x0a, 0x09, - 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4a, 0x57, 0x54, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, - 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, - 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, - 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, - 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x52, - 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, - 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, - 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, - 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, - 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, - 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, - 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, - 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x42, - 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x22, 0xbc, 0x03, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, - 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, - 0x44, 0x12, 0x26, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, - 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0c, 0x43, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, - 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, - 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, - 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, - 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, - 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, - 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, - 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74, - 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, - 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, - 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, - 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, - 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, - 0x52, 0x4c, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, - 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67, - 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, - 0x67, 0x22, 0x93, 0x02, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, - 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, - 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, - 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, - 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, - 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, - 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, - 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, - 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, - 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, - 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, - 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, - 0x6c, 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, - 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xde, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, - 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, - 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, - 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, - 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, - 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, - 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, - 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x12, 0x28, - 0x0a, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x03, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, - 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xb8, 0x01, 0x0a, 0x0a, 0x43, 0x75, 0x73, - 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, - 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, - 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, - 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, - 0x72, 0x64, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x44, - 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x2a, 0x0a, 0x10, 0x4e, 0x6f, 0x6e, 0x41, 0x75, - 0x74, 0x68, 0x6f, 0x72, 0x69, 0x74, 0x61, 0x74, 0x69, 0x76, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x10, 0x4e, 0x6f, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x74, 0x61, 0x74, - 0x69, 0x76, 0x65, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, - 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, - 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, - 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, - 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, - 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, - 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, - 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, - 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, - 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, - 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, - 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, - 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, - 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, - 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, - 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, - 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, - 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, - 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, - 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, - 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, - 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, - 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, - 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, - 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, - 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, - 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x04, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, - 0x6f, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, - 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, - 0x79, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, - 0x79, 0x49, 0x44, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, - 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, - 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, - 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, - 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, - 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, - 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, - 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, - 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, - 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, - 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, - 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, - 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, - 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, - 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, - 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, - 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, - 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, - 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, - 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, - 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, - 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, - 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, - 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x73, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, - 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, - 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, - 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, - 0x44, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, - 0x22, 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, - 0x75, 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, - 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, - 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, - 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, - 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, - 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, - 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, - 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, - 0x64, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x8b, 0x02, 0x0a, 0x14, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, - 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, - 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x70, 0x6f, - 0x72, 0x74, 0x12, 0x36, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, - 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x69, - 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, - 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, - 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x12, 0x1f, 0x0a, 0x0b, 0x75, 0x73, 0x65, 0x72, - 0x5f, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x75, - 0x73, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x12, 0x1f, 0x0a, 0x0b, 0x6e, 0x61, 0x6d, 0x65, 0x5f, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, - 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x6e, 0x61, 0x6d, 0x65, 0x50, 0x72, 0x65, 0x66, - 0x69, 0x78, 0x12, 0x1f, 0x0a, 0x0b, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x5f, 0x70, 0x6f, 0x72, - 0x74, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x50, - 0x6f, 0x72, 0x74, 0x22, 0xa1, 0x01, 0x0a, 0x15, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x53, 0x65, - 0x72, 0x76, 0x69, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x21, 0x0a, - 0x0c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x0b, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, - 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x75, 0x72, 0x6c, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x55, 0x72, - 0x6c, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x2c, 0x0a, 0x12, 0x70, 0x6f, 0x72, - 0x74, 0x5f, 0x61, 0x75, 0x74, 0x6f, 0x5f, 0x61, 0x73, 0x73, 0x69, 0x67, 0x6e, 0x65, 0x64, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x70, 0x6f, 0x72, 0x74, 0x41, 0x75, 0x74, 0x6f, 0x41, - 0x73, 0x73, 0x69, 0x67, 0x6e, 0x65, 0x64, 0x22, 0x2c, 0x0a, 0x12, 0x52, 0x65, 0x6e, 0x65, 0x77, - 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, - 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x22, 0x15, 0x0a, 0x13, 0x52, 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, - 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x2b, 0x0a, 0x11, - 0x53, 0x74, 0x6f, 0x70, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x22, 0x14, 0x0a, 0x12, 0x53, 0x74, 0x6f, - 0x70, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a, - 0x3a, 0x0a, 0x09, 0x4a, 0x6f, 0x62, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x12, 0x0a, 0x0e, - 0x75, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x10, 0x00, - 0x12, 0x0d, 0x0a, 0x09, 0x73, 0x75, 0x63, 0x63, 0x65, 0x65, 0x64, 0x65, 0x64, 0x10, 0x01, 0x12, - 0x0a, 0x0a, 0x06, 0x66, 0x61, 0x69, 0x6c, 0x65, 0x64, 0x10, 0x02, 0x2a, 0x4c, 0x0a, 0x0c, 0x52, - 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, - 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, - 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, - 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, 0x0a, - 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, 0x6c, - 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, - 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, - 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, - 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x2a, - 0x63, 0x0a, 0x0e, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, - 0x6c, 0x12, 0x0f, 0x0a, 0x0b, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x48, 0x54, 0x54, 0x50, - 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x48, 0x54, 0x54, - 0x50, 0x53, 0x10, 0x01, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x54, - 0x43, 0x50, 0x10, 0x02, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x55, - 0x44, 0x50, 0x10, 0x03, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x54, - 0x4c, 0x53, 0x10, 0x04, 0x32, 0xfd, 0x06, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, - 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, - 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, - 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, - 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, - 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, + 0x18, 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, + 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, + 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x6f, + 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, + 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, + 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, + 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x44, 0x0a, 0x0f, 0x66, 0x6f, + 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0c, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x52, + 0x0f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, + 0x12, 0x2d, 0x0a, 0x07, 0x73, 0x73, 0x68, 0x41, 0x75, 0x74, 0x68, 0x18, 0x0d, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x13, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, + 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x52, 0x07, 0x73, 0x73, 0x68, 0x41, 0x75, 0x74, 0x68, 0x22, + 0x82, 0x02, 0x0a, 0x07, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x12, 0x20, 0x0a, 0x0b, 0x55, + 0x73, 0x65, 0x72, 0x49, 0x44, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0b, 0x55, 0x73, 0x65, 0x72, 0x49, 0x44, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x12, 0x28, 0x0a, + 0x0f, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x55, 0x73, 0x65, 0x72, 0x73, + 0x18, 0x02, 0x20, 0x03, 0x28, 0x0c, 0x52, 0x0f, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, + 0x65, 0x64, 0x55, 0x73, 0x65, 0x72, 0x73, 0x12, 0x4a, 0x0a, 0x0d, 0x6d, 0x61, 0x63, 0x68, 0x69, + 0x6e, 0x65, 0x5f, 0x75, 0x73, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x25, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x41, + 0x75, 0x74, 0x68, 0x2e, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, 0x72, 0x73, + 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0c, 0x6d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, + 0x65, 0x72, 0x73, 0x1a, 0x5f, 0x0a, 0x11, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, + 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x34, 0x0a, 0x05, 0x76, 0x61, + 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, + 0x65, 0x72, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, + 0x3a, 0x02, 0x38, 0x01, 0x22, 0x2e, 0x0a, 0x12, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, + 0x73, 0x65, 0x72, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x69, 0x6e, + 0x64, 0x65, 0x78, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x07, 0x69, 0x6e, 0x64, + 0x65, 0x78, 0x65, 0x73, 0x22, 0xbb, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, + 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, + 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, + 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, + 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, + 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, + 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, + 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x22, + 0x0a, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, + 0x6f, 0x6e, 0x22, 0x7e, 0x0a, 0x09, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, + 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, + 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x33, 0x0a, + 0x09, 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4a, 0x57, + 0x54, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, + 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, - 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, - 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, - 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, - 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, - 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, - 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, - 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, - 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, - 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, - 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, - 0x74, 0x79, 0x22, 0x00, 0x12, 0x47, 0x0a, 0x03, 0x4a, 0x6f, 0x62, 0x12, 0x1c, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, - 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, - 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x12, 0x4c, 0x0a, - 0x0c, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x12, 0x1c, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, - 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, - 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0b, 0x52, - 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, - 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x4a, 0x0a, 0x0a, 0x53, 0x74, 0x6f, 0x70, - 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, + 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, + 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, + 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, + 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, + 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, + 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, + 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, + 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x22, 0xbc, 0x03, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, + 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, + 0x49, 0x44, 0x12, 0x26, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, + 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0c, 0x43, 0x6c, + 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, + 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, + 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, + 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, + 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, + 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, + 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, + 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, + 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, + 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, + 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, + 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, + 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, + 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, + 0x55, 0x52, 0x4c, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, + 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, + 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, + 0x67, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, + 0x61, 0x67, 0x22, 0x93, 0x02, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, + 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, + 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, + 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, + 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, + 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, + 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, + 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, + 0x74, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, + 0x70, 0x6c, 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, + 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xde, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, + 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, + 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, + 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, + 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, + 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, + 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, + 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, + 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x12, + 0x28, 0x0a, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, + 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xb8, 0x01, 0x0a, 0x0a, 0x43, 0x75, + 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, + 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, + 0x6f, 0x72, 0x64, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x2a, 0x0a, 0x10, 0x4e, 0x6f, 0x6e, 0x41, + 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x74, 0x61, 0x74, 0x69, 0x76, 0x65, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x10, 0x4e, 0x6f, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x74, 0x61, + 0x74, 0x69, 0x76, 0x65, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, + 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, + 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, + 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, + 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, + 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, + 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, + 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, + 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, + 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, + 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, + 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, + 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, + 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, + 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, + 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, + 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, + 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xfb, 0x02, 0x0a, 0x0c, 0x46, + 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x1a, 0x0a, 0x06, 0x50, + 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x02, 0x18, 0x01, 0x52, + 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, + 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, + 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, + 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, + 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, + 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, + 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, + 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, + 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, + 0x12, 0x26, 0x0a, 0x0e, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, + 0x65, 0x73, 0x18, 0x09, 0x20, 0x03, 0x28, 0x0c, 0x52, 0x0e, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, + 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x65, 0x73, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, + 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, + 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, + 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, + 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, + 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, + 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, + 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, + 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, + 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, + 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, + 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, + 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, + 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, + 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, + 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, + 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, + 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, + 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, + 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, + 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, + 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, + 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, + 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, + 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, + 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, + 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, + 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, + 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, + 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, + 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, + 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, + 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, + 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, + 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, + 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, + 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, + 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x8b, 0x02, 0x0a, 0x14, 0x45, + 0x78, 0x70, 0x6f, 0x73, 0x65, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x0d, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x36, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x50, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, + 0x10, 0x0a, 0x03, 0x70, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x70, 0x69, + 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x12, 0x1f, 0x0a, + 0x0b, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x05, 0x20, 0x03, + 0x28, 0x09, 0x52, 0x0a, 0x75, 0x73, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x16, + 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, + 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1f, 0x0a, 0x0b, 0x6e, 0x61, 0x6d, 0x65, 0x5f, 0x70, + 0x72, 0x65, 0x66, 0x69, 0x78, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x6e, 0x61, 0x6d, + 0x65, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x12, 0x1f, 0x0a, 0x0b, 0x6c, 0x69, 0x73, 0x74, 0x65, + 0x6e, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x6c, 0x69, + 0x73, 0x74, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa1, 0x01, 0x0a, 0x15, 0x45, 0x78, 0x70, + 0x6f, 0x73, 0x65, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x21, 0x0a, 0x0c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x6e, 0x61, + 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, + 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, + 0x5f, 0x75, 0x72, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x73, 0x65, 0x72, 0x76, + 0x69, 0x63, 0x65, 0x55, 0x72, 0x6c, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x2c, + 0x0a, 0x12, 0x70, 0x6f, 0x72, 0x74, 0x5f, 0x61, 0x75, 0x74, 0x6f, 0x5f, 0x61, 0x73, 0x73, 0x69, + 0x67, 0x6e, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x70, 0x6f, 0x72, 0x74, + 0x41, 0x75, 0x74, 0x6f, 0x41, 0x73, 0x73, 0x69, 0x67, 0x6e, 0x65, 0x64, 0x22, 0x2c, 0x0a, 0x12, + 0x52, 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x22, 0x15, 0x0a, 0x13, 0x52, 0x65, + 0x6e, 0x65, 0x77, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x2b, 0x0a, 0x11, 0x53, 0x74, 0x6f, 0x70, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x22, 0x14, + 0x0a, 0x12, 0x53, 0x74, 0x6f, 0x70, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x2a, 0x3a, 0x0a, 0x09, 0x4a, 0x6f, 0x62, 0x53, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x12, 0x12, 0x0a, 0x0e, 0x75, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x5f, 0x73, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x10, 0x00, 0x12, 0x0d, 0x0a, 0x09, 0x73, 0x75, 0x63, 0x63, 0x65, 0x65, 0x64, + 0x65, 0x64, 0x10, 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x66, 0x61, 0x69, 0x6c, 0x65, 0x64, 0x10, 0x02, + 0x2a, 0x6c, 0x0a, 0x0e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, + 0x74, 0x79, 0x12, 0x19, 0x0a, 0x15, 0x50, 0x65, 0x65, 0x72, 0x43, 0x61, 0x70, 0x61, 0x62, 0x69, + 0x6c, 0x69, 0x74, 0x79, 0x55, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x10, 0x00, 0x12, 0x20, 0x0a, + 0x1c, 0x50, 0x65, 0x65, 0x72, 0x43, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x79, 0x53, + 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x65, 0x73, 0x10, 0x01, 0x12, + 0x1d, 0x0a, 0x19, 0x50, 0x65, 0x65, 0x72, 0x43, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, + 0x79, 0x49, 0x50, 0x76, 0x36, 0x4f, 0x76, 0x65, 0x72, 0x6c, 0x61, 0x79, 0x10, 0x02, 0x2a, 0x4c, + 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, + 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, + 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, + 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, + 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, + 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, + 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, + 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, + 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, + 0x10, 0x01, 0x2a, 0x63, 0x0a, 0x0e, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x50, 0x72, 0x6f, 0x74, + 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0f, 0x0a, 0x0b, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x48, + 0x54, 0x54, 0x50, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, + 0x48, 0x54, 0x54, 0x50, 0x53, 0x10, 0x01, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, + 0x45, 0x5f, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, + 0x45, 0x5f, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, + 0x45, 0x5f, 0x54, 0x4c, 0x53, 0x10, 0x04, 0x32, 0xfd, 0x06, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, + 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, + 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, + 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, + 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, + 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, + 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, + 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, + 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, + 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, + 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, + 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, + 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, + 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, + 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, + 0x67, 0x6f, 0x75, 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x47, 0x0a, 0x03, 0x4a, 0x6f, 0x62, 0x12, 0x1c, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, + 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, + 0x12, 0x4c, 0x0a, 0x0c, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, + 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, + 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x4b, + 0x0a, 0x0b, 0x52, 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x12, 0x1c, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, + 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x4a, 0x0a, 0x0a, 0x53, + 0x74, 0x6f, 0x70, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -5145,166 +5266,168 @@ func file_management_proto_rawDescGZIP() []byte { return file_management_proto_rawDescData } -var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 7) +var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 8) var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 55) var file_management_proto_goTypes = []interface{}{ (JobStatus)(0), // 0: management.JobStatus - (RuleProtocol)(0), // 1: management.RuleProtocol - (RuleDirection)(0), // 2: management.RuleDirection - (RuleAction)(0), // 3: management.RuleAction - (ExposeProtocol)(0), // 4: management.ExposeProtocol - (HostConfig_Protocol)(0), // 5: management.HostConfig.Protocol - (DeviceAuthorizationFlowProvider)(0), // 6: management.DeviceAuthorizationFlow.provider - (*EncryptedMessage)(nil), // 7: management.EncryptedMessage - (*JobRequest)(nil), // 8: management.JobRequest - (*JobResponse)(nil), // 9: management.JobResponse - (*BundleParameters)(nil), // 10: management.BundleParameters - (*BundleResult)(nil), // 11: management.BundleResult - (*SyncRequest)(nil), // 12: management.SyncRequest - (*SyncResponse)(nil), // 13: management.SyncResponse - (*SyncMetaRequest)(nil), // 14: management.SyncMetaRequest - (*LoginRequest)(nil), // 15: management.LoginRequest - (*PeerKeys)(nil), // 16: management.PeerKeys - (*Environment)(nil), // 17: management.Environment - (*File)(nil), // 18: management.File - (*Flags)(nil), // 19: management.Flags - (*PeerSystemMeta)(nil), // 20: management.PeerSystemMeta - (*LoginResponse)(nil), // 21: management.LoginResponse - (*ServerKeyResponse)(nil), // 22: management.ServerKeyResponse - (*Empty)(nil), // 23: management.Empty - (*NetbirdConfig)(nil), // 24: management.NetbirdConfig - (*HostConfig)(nil), // 25: management.HostConfig - (*RelayConfig)(nil), // 26: management.RelayConfig - (*FlowConfig)(nil), // 27: management.FlowConfig - (*JWTConfig)(nil), // 28: management.JWTConfig - (*ProtectedHostConfig)(nil), // 29: management.ProtectedHostConfig - (*PeerConfig)(nil), // 30: management.PeerConfig - (*AutoUpdateSettings)(nil), // 31: management.AutoUpdateSettings - (*NetworkMap)(nil), // 32: management.NetworkMap - (*SSHAuth)(nil), // 33: management.SSHAuth - (*MachineUserIndexes)(nil), // 34: management.MachineUserIndexes - (*RemotePeerConfig)(nil), // 35: management.RemotePeerConfig - (*SSHConfig)(nil), // 36: management.SSHConfig - (*DeviceAuthorizationFlowRequest)(nil), // 37: management.DeviceAuthorizationFlowRequest - (*DeviceAuthorizationFlow)(nil), // 38: management.DeviceAuthorizationFlow - (*PKCEAuthorizationFlowRequest)(nil), // 39: management.PKCEAuthorizationFlowRequest - (*PKCEAuthorizationFlow)(nil), // 40: management.PKCEAuthorizationFlow - (*ProviderConfig)(nil), // 41: management.ProviderConfig - (*Route)(nil), // 42: management.Route - (*DNSConfig)(nil), // 43: management.DNSConfig - (*CustomZone)(nil), // 44: management.CustomZone - (*SimpleRecord)(nil), // 45: management.SimpleRecord - (*NameServerGroup)(nil), // 46: management.NameServerGroup - (*NameServer)(nil), // 47: management.NameServer - (*FirewallRule)(nil), // 48: management.FirewallRule - (*NetworkAddress)(nil), // 49: management.NetworkAddress - (*Checks)(nil), // 50: management.Checks - (*PortInfo)(nil), // 51: management.PortInfo - (*RouteFirewallRule)(nil), // 52: management.RouteFirewallRule - (*ForwardingRule)(nil), // 53: management.ForwardingRule - (*ExposeServiceRequest)(nil), // 54: management.ExposeServiceRequest - (*ExposeServiceResponse)(nil), // 55: management.ExposeServiceResponse - (*RenewExposeRequest)(nil), // 56: management.RenewExposeRequest - (*RenewExposeResponse)(nil), // 57: management.RenewExposeResponse - (*StopExposeRequest)(nil), // 58: management.StopExposeRequest - (*StopExposeResponse)(nil), // 59: management.StopExposeResponse - nil, // 60: management.SSHAuth.MachineUsersEntry - (*PortInfo_Range)(nil), // 61: management.PortInfo.Range - (*timestamppb.Timestamp)(nil), // 62: google.protobuf.Timestamp - (*durationpb.Duration)(nil), // 63: google.protobuf.Duration + (PeerCapability)(0), // 1: management.PeerCapability + (RuleProtocol)(0), // 2: management.RuleProtocol + (RuleDirection)(0), // 3: management.RuleDirection + (RuleAction)(0), // 4: management.RuleAction + (ExposeProtocol)(0), // 5: management.ExposeProtocol + (HostConfig_Protocol)(0), // 6: management.HostConfig.Protocol + (DeviceAuthorizationFlowProvider)(0), // 7: management.DeviceAuthorizationFlow.provider + (*EncryptedMessage)(nil), // 8: management.EncryptedMessage + (*JobRequest)(nil), // 9: management.JobRequest + (*JobResponse)(nil), // 10: management.JobResponse + (*BundleParameters)(nil), // 11: management.BundleParameters + (*BundleResult)(nil), // 12: management.BundleResult + (*SyncRequest)(nil), // 13: management.SyncRequest + (*SyncResponse)(nil), // 14: management.SyncResponse + (*SyncMetaRequest)(nil), // 15: management.SyncMetaRequest + (*LoginRequest)(nil), // 16: management.LoginRequest + (*PeerKeys)(nil), // 17: management.PeerKeys + (*Environment)(nil), // 18: management.Environment + (*File)(nil), // 19: management.File + (*Flags)(nil), // 20: management.Flags + (*PeerSystemMeta)(nil), // 21: management.PeerSystemMeta + (*LoginResponse)(nil), // 22: management.LoginResponse + (*ServerKeyResponse)(nil), // 23: management.ServerKeyResponse + (*Empty)(nil), // 24: management.Empty + (*NetbirdConfig)(nil), // 25: management.NetbirdConfig + (*HostConfig)(nil), // 26: management.HostConfig + (*RelayConfig)(nil), // 27: management.RelayConfig + (*FlowConfig)(nil), // 28: management.FlowConfig + (*JWTConfig)(nil), // 29: management.JWTConfig + (*ProtectedHostConfig)(nil), // 30: management.ProtectedHostConfig + (*PeerConfig)(nil), // 31: management.PeerConfig + (*AutoUpdateSettings)(nil), // 32: management.AutoUpdateSettings + (*NetworkMap)(nil), // 33: management.NetworkMap + (*SSHAuth)(nil), // 34: management.SSHAuth + (*MachineUserIndexes)(nil), // 35: management.MachineUserIndexes + (*RemotePeerConfig)(nil), // 36: management.RemotePeerConfig + (*SSHConfig)(nil), // 37: management.SSHConfig + (*DeviceAuthorizationFlowRequest)(nil), // 38: management.DeviceAuthorizationFlowRequest + (*DeviceAuthorizationFlow)(nil), // 39: management.DeviceAuthorizationFlow + (*PKCEAuthorizationFlowRequest)(nil), // 40: management.PKCEAuthorizationFlowRequest + (*PKCEAuthorizationFlow)(nil), // 41: management.PKCEAuthorizationFlow + (*ProviderConfig)(nil), // 42: management.ProviderConfig + (*Route)(nil), // 43: management.Route + (*DNSConfig)(nil), // 44: management.DNSConfig + (*CustomZone)(nil), // 45: management.CustomZone + (*SimpleRecord)(nil), // 46: management.SimpleRecord + (*NameServerGroup)(nil), // 47: management.NameServerGroup + (*NameServer)(nil), // 48: management.NameServer + (*FirewallRule)(nil), // 49: management.FirewallRule + (*NetworkAddress)(nil), // 50: management.NetworkAddress + (*Checks)(nil), // 51: management.Checks + (*PortInfo)(nil), // 52: management.PortInfo + (*RouteFirewallRule)(nil), // 53: management.RouteFirewallRule + (*ForwardingRule)(nil), // 54: management.ForwardingRule + (*ExposeServiceRequest)(nil), // 55: management.ExposeServiceRequest + (*ExposeServiceResponse)(nil), // 56: management.ExposeServiceResponse + (*RenewExposeRequest)(nil), // 57: management.RenewExposeRequest + (*RenewExposeResponse)(nil), // 58: management.RenewExposeResponse + (*StopExposeRequest)(nil), // 59: management.StopExposeRequest + (*StopExposeResponse)(nil), // 60: management.StopExposeResponse + nil, // 61: management.SSHAuth.MachineUsersEntry + (*PortInfo_Range)(nil), // 62: management.PortInfo.Range + (*timestamppb.Timestamp)(nil), // 63: google.protobuf.Timestamp + (*durationpb.Duration)(nil), // 64: google.protobuf.Duration } var file_management_proto_depIdxs = []int32{ - 10, // 0: management.JobRequest.bundle:type_name -> management.BundleParameters + 11, // 0: management.JobRequest.bundle:type_name -> management.BundleParameters 0, // 1: management.JobResponse.status:type_name -> management.JobStatus - 11, // 2: management.JobResponse.bundle:type_name -> management.BundleResult - 20, // 3: management.SyncRequest.meta:type_name -> management.PeerSystemMeta - 24, // 4: management.SyncResponse.netbirdConfig:type_name -> management.NetbirdConfig - 30, // 5: management.SyncResponse.peerConfig:type_name -> management.PeerConfig - 35, // 6: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig - 32, // 7: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap - 50, // 8: management.SyncResponse.Checks:type_name -> management.Checks - 20, // 9: management.SyncMetaRequest.meta:type_name -> management.PeerSystemMeta - 20, // 10: management.LoginRequest.meta:type_name -> management.PeerSystemMeta - 16, // 11: management.LoginRequest.peerKeys:type_name -> management.PeerKeys - 49, // 12: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress - 17, // 13: management.PeerSystemMeta.environment:type_name -> management.Environment - 18, // 14: management.PeerSystemMeta.files:type_name -> management.File - 19, // 15: management.PeerSystemMeta.flags:type_name -> management.Flags - 24, // 16: management.LoginResponse.netbirdConfig:type_name -> management.NetbirdConfig - 30, // 17: management.LoginResponse.peerConfig:type_name -> management.PeerConfig - 50, // 18: management.LoginResponse.Checks:type_name -> management.Checks - 62, // 19: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp - 25, // 20: management.NetbirdConfig.stuns:type_name -> management.HostConfig - 29, // 21: management.NetbirdConfig.turns:type_name -> management.ProtectedHostConfig - 25, // 22: management.NetbirdConfig.signal:type_name -> management.HostConfig - 26, // 23: management.NetbirdConfig.relay:type_name -> management.RelayConfig - 27, // 24: management.NetbirdConfig.flow:type_name -> management.FlowConfig - 5, // 25: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol - 63, // 26: management.FlowConfig.interval:type_name -> google.protobuf.Duration - 25, // 27: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig - 36, // 28: management.PeerConfig.sshConfig:type_name -> management.SSHConfig - 31, // 29: management.PeerConfig.autoUpdate:type_name -> management.AutoUpdateSettings - 30, // 30: management.NetworkMap.peerConfig:type_name -> management.PeerConfig - 35, // 31: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig - 42, // 32: management.NetworkMap.Routes:type_name -> management.Route - 43, // 33: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig - 35, // 34: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig - 48, // 35: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule - 52, // 36: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule - 53, // 37: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule - 33, // 38: management.NetworkMap.sshAuth:type_name -> management.SSHAuth - 60, // 39: management.SSHAuth.machine_users:type_name -> management.SSHAuth.MachineUsersEntry - 36, // 40: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig - 28, // 41: management.SSHConfig.jwtConfig:type_name -> management.JWTConfig - 6, // 42: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider - 41, // 43: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 41, // 44: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 46, // 45: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup - 44, // 46: management.DNSConfig.CustomZones:type_name -> management.CustomZone - 45, // 47: management.CustomZone.Records:type_name -> management.SimpleRecord - 47, // 48: management.NameServerGroup.NameServers:type_name -> management.NameServer - 2, // 49: management.FirewallRule.Direction:type_name -> management.RuleDirection - 3, // 50: management.FirewallRule.Action:type_name -> management.RuleAction - 1, // 51: management.FirewallRule.Protocol:type_name -> management.RuleProtocol - 51, // 52: management.FirewallRule.PortInfo:type_name -> management.PortInfo - 61, // 53: management.PortInfo.range:type_name -> management.PortInfo.Range - 3, // 54: management.RouteFirewallRule.action:type_name -> management.RuleAction - 1, // 55: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol - 51, // 56: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo - 1, // 57: management.ForwardingRule.protocol:type_name -> management.RuleProtocol - 51, // 58: management.ForwardingRule.destinationPort:type_name -> management.PortInfo - 51, // 59: management.ForwardingRule.translatedPort:type_name -> management.PortInfo - 4, // 60: management.ExposeServiceRequest.protocol:type_name -> management.ExposeProtocol - 34, // 61: management.SSHAuth.MachineUsersEntry.value:type_name -> management.MachineUserIndexes - 7, // 62: management.ManagementService.Login:input_type -> management.EncryptedMessage - 7, // 63: management.ManagementService.Sync:input_type -> management.EncryptedMessage - 23, // 64: management.ManagementService.GetServerKey:input_type -> management.Empty - 23, // 65: management.ManagementService.isHealthy:input_type -> management.Empty - 7, // 66: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage - 7, // 67: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage - 7, // 68: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage - 7, // 69: management.ManagementService.Logout:input_type -> management.EncryptedMessage - 7, // 70: management.ManagementService.Job:input_type -> management.EncryptedMessage - 7, // 71: management.ManagementService.CreateExpose:input_type -> management.EncryptedMessage - 7, // 72: management.ManagementService.RenewExpose:input_type -> management.EncryptedMessage - 7, // 73: management.ManagementService.StopExpose:input_type -> management.EncryptedMessage - 7, // 74: management.ManagementService.Login:output_type -> management.EncryptedMessage - 7, // 75: management.ManagementService.Sync:output_type -> management.EncryptedMessage - 22, // 76: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse - 23, // 77: management.ManagementService.isHealthy:output_type -> management.Empty - 7, // 78: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage - 7, // 79: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage - 23, // 80: management.ManagementService.SyncMeta:output_type -> management.Empty - 23, // 81: management.ManagementService.Logout:output_type -> management.Empty - 7, // 82: management.ManagementService.Job:output_type -> management.EncryptedMessage - 7, // 83: management.ManagementService.CreateExpose:output_type -> management.EncryptedMessage - 7, // 84: management.ManagementService.RenewExpose:output_type -> management.EncryptedMessage - 7, // 85: management.ManagementService.StopExpose:output_type -> management.EncryptedMessage - 74, // [74:86] is the sub-list for method output_type - 62, // [62:74] is the sub-list for method input_type - 62, // [62:62] is the sub-list for extension type_name - 62, // [62:62] is the sub-list for extension extendee - 0, // [0:62] is the sub-list for field type_name + 12, // 2: management.JobResponse.bundle:type_name -> management.BundleResult + 21, // 3: management.SyncRequest.meta:type_name -> management.PeerSystemMeta + 25, // 4: management.SyncResponse.netbirdConfig:type_name -> management.NetbirdConfig + 31, // 5: management.SyncResponse.peerConfig:type_name -> management.PeerConfig + 36, // 6: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig + 33, // 7: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap + 51, // 8: management.SyncResponse.Checks:type_name -> management.Checks + 21, // 9: management.SyncMetaRequest.meta:type_name -> management.PeerSystemMeta + 21, // 10: management.LoginRequest.meta:type_name -> management.PeerSystemMeta + 17, // 11: management.LoginRequest.peerKeys:type_name -> management.PeerKeys + 50, // 12: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress + 18, // 13: management.PeerSystemMeta.environment:type_name -> management.Environment + 19, // 14: management.PeerSystemMeta.files:type_name -> management.File + 20, // 15: management.PeerSystemMeta.flags:type_name -> management.Flags + 1, // 16: management.PeerSystemMeta.capabilities:type_name -> management.PeerCapability + 25, // 17: management.LoginResponse.netbirdConfig:type_name -> management.NetbirdConfig + 31, // 18: management.LoginResponse.peerConfig:type_name -> management.PeerConfig + 51, // 19: management.LoginResponse.Checks:type_name -> management.Checks + 63, // 20: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp + 26, // 21: management.NetbirdConfig.stuns:type_name -> management.HostConfig + 30, // 22: management.NetbirdConfig.turns:type_name -> management.ProtectedHostConfig + 26, // 23: management.NetbirdConfig.signal:type_name -> management.HostConfig + 27, // 24: management.NetbirdConfig.relay:type_name -> management.RelayConfig + 28, // 25: management.NetbirdConfig.flow:type_name -> management.FlowConfig + 6, // 26: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol + 64, // 27: management.FlowConfig.interval:type_name -> google.protobuf.Duration + 26, // 28: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig + 37, // 29: management.PeerConfig.sshConfig:type_name -> management.SSHConfig + 32, // 30: management.PeerConfig.autoUpdate:type_name -> management.AutoUpdateSettings + 31, // 31: management.NetworkMap.peerConfig:type_name -> management.PeerConfig + 36, // 32: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig + 43, // 33: management.NetworkMap.Routes:type_name -> management.Route + 44, // 34: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig + 36, // 35: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig + 49, // 36: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule + 53, // 37: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule + 54, // 38: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule + 34, // 39: management.NetworkMap.sshAuth:type_name -> management.SSHAuth + 61, // 40: management.SSHAuth.machine_users:type_name -> management.SSHAuth.MachineUsersEntry + 37, // 41: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig + 29, // 42: management.SSHConfig.jwtConfig:type_name -> management.JWTConfig + 7, // 43: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider + 42, // 44: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 42, // 45: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 47, // 46: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup + 45, // 47: management.DNSConfig.CustomZones:type_name -> management.CustomZone + 46, // 48: management.CustomZone.Records:type_name -> management.SimpleRecord + 48, // 49: management.NameServerGroup.NameServers:type_name -> management.NameServer + 3, // 50: management.FirewallRule.Direction:type_name -> management.RuleDirection + 4, // 51: management.FirewallRule.Action:type_name -> management.RuleAction + 2, // 52: management.FirewallRule.Protocol:type_name -> management.RuleProtocol + 52, // 53: management.FirewallRule.PortInfo:type_name -> management.PortInfo + 62, // 54: management.PortInfo.range:type_name -> management.PortInfo.Range + 4, // 55: management.RouteFirewallRule.action:type_name -> management.RuleAction + 2, // 56: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol + 52, // 57: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo + 2, // 58: management.ForwardingRule.protocol:type_name -> management.RuleProtocol + 52, // 59: management.ForwardingRule.destinationPort:type_name -> management.PortInfo + 52, // 60: management.ForwardingRule.translatedPort:type_name -> management.PortInfo + 5, // 61: management.ExposeServiceRequest.protocol:type_name -> management.ExposeProtocol + 35, // 62: management.SSHAuth.MachineUsersEntry.value:type_name -> management.MachineUserIndexes + 8, // 63: management.ManagementService.Login:input_type -> management.EncryptedMessage + 8, // 64: management.ManagementService.Sync:input_type -> management.EncryptedMessage + 24, // 65: management.ManagementService.GetServerKey:input_type -> management.Empty + 24, // 66: management.ManagementService.isHealthy:input_type -> management.Empty + 8, // 67: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage + 8, // 68: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage + 8, // 69: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage + 8, // 70: management.ManagementService.Logout:input_type -> management.EncryptedMessage + 8, // 71: management.ManagementService.Job:input_type -> management.EncryptedMessage + 8, // 72: management.ManagementService.CreateExpose:input_type -> management.EncryptedMessage + 8, // 73: management.ManagementService.RenewExpose:input_type -> management.EncryptedMessage + 8, // 74: management.ManagementService.StopExpose:input_type -> management.EncryptedMessage + 8, // 75: management.ManagementService.Login:output_type -> management.EncryptedMessage + 8, // 76: management.ManagementService.Sync:output_type -> management.EncryptedMessage + 23, // 77: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse + 24, // 78: management.ManagementService.isHealthy:output_type -> management.Empty + 8, // 79: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage + 8, // 80: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage + 24, // 81: management.ManagementService.SyncMeta:output_type -> management.Empty + 24, // 82: management.ManagementService.Logout:output_type -> management.Empty + 8, // 83: management.ManagementService.Job:output_type -> management.EncryptedMessage + 8, // 84: management.ManagementService.CreateExpose:output_type -> management.EncryptedMessage + 8, // 85: management.ManagementService.RenewExpose:output_type -> management.EncryptedMessage + 8, // 86: management.ManagementService.StopExpose:output_type -> management.EncryptedMessage + 75, // [75:87] is the sub-list for method output_type + 63, // [63:75] is the sub-list for method input_type + 63, // [63:63] is the sub-list for extension type_name + 63, // [63:63] is the sub-list for extension extendee + 0, // [0:63] is the sub-list for field type_name } func init() { file_management_proto_init() } @@ -5977,7 +6100,7 @@ func file_management_proto_init() { File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_management_proto_rawDesc, - NumEnums: 7, + NumEnums: 8, NumMessages: 55, NumExtensions: 0, NumServices: 1, diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index 70a530679..461a614fe 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -200,6 +200,18 @@ message Flags { bool enableSSHLocalPortForwarding = 13; bool enableSSHRemotePortForwarding = 14; bool disableSSHAuth = 15; + + bool disableIPv6 = 16; +} + +// PeerCapability represents a feature the client binary supports. +// Reported in PeerSystemMeta.capabilities on every login/sync. +enum PeerCapability { + PeerCapabilityUnknown = 0; + // Client reads SourcePrefixes instead of the deprecated PeerIP string. + PeerCapabilitySourcePrefixes = 1; + // Client handles IPv6 overlay addresses and firewall rules. + PeerCapabilityIPv6Overlay = 2; } // PeerSystemMeta is machine meta data like OS and version. @@ -221,6 +233,8 @@ message PeerSystemMeta { Environment environment = 15; repeated File files = 16; Flags flags = 17; + + repeated PeerCapability capabilities = 18; } message LoginResponse { @@ -335,6 +349,9 @@ message PeerConfig { // Auto-update config AutoUpdateSettings autoUpdate = 8; + + // IPv6 overlay address as compact bytes: 16 bytes IP + 1 byte prefix length. + bytes address_v6 = 9; } message AutoUpdateSettings { @@ -567,7 +584,8 @@ enum RuleAction { // FirewallRule represents a firewall rule message FirewallRule { - string PeerIP = 1; + // Use sourcePrefixes instead. + string PeerIP = 1 [deprecated = true]; RuleDirection Direction = 2; RuleAction Action = 3; RuleProtocol Protocol = 4; @@ -576,6 +594,13 @@ message FirewallRule { // PolicyID is the ID of the policy that this rule belongs to bytes PolicyID = 7; + + // CustomProtocol is a custom protocol ID when Protocol is CUSTOM. + uint32 customProtocol = 8; + + // Compact source IP prefixes for this rule, supersedes PeerIP. + // Each entry is 5 bytes (v4) or 17 bytes (v6): [IP bytes][1 byte prefix_len]. + repeated bytes sourcePrefixes = 9; } message NetworkAddress { diff --git a/shared/netiputil/compact.go b/shared/netiputil/compact.go new file mode 100644 index 000000000..0cd2b8a20 --- /dev/null +++ b/shared/netiputil/compact.go @@ -0,0 +1,78 @@ +// Package netiputil provides compact binary encoding for IP prefixes used in +// the management proto wire format. +// +// Format: [IP bytes][1 byte prefix_len] +// - IPv4: 5 bytes total (4 IP + 1 prefix_len, 0-32) +// - IPv6: 17 bytes total (16 IP + 1 prefix_len, 0-128) +// +// Address family is determined by length: 5 = v4, 17 = v6. +package netiputil + +import ( + "fmt" + "net/netip" +) + +// EncodePrefix encodes a netip.Prefix into compact bytes. +// The address is always unmapped before encoding. +func EncodePrefix(p netip.Prefix) ([]byte, error) { + addr := p.Addr().Unmap() + bits := p.Bits() + if addr.Is4() && bits > 32 { + return nil, fmt.Errorf("invalid prefix length %d for IPv4 address %s (max 32)", bits, addr) + } + return append(addr.AsSlice(), byte(bits)), nil +} + +// DecodePrefix decodes compact bytes into a netip.Prefix. +func DecodePrefix(b []byte) (netip.Prefix, error) { + switch len(b) { + case 5: + var ip4 [4]byte + copy(ip4[:], b) + bits := int(b[len(b)-1]) + if bits > 32 { + return netip.Prefix{}, fmt.Errorf("invalid IPv4 prefix length %d (max 32)", bits) + } + return netip.PrefixFrom(netip.AddrFrom4(ip4), bits), nil + case 17: + var ip6 [16]byte + copy(ip6[:], b) + addr := netip.AddrFrom16(ip6).Unmap() + bits := int(b[len(b)-1]) + if addr.Is4() { + if bits > 32 { + return netip.Prefix{}, fmt.Errorf("invalid prefix length %d for v4-mapped address (max 32)", bits) + } + } else if bits > 128 { + return netip.Prefix{}, fmt.Errorf("invalid IPv6 prefix length %d (max 128)", bits) + } + return netip.PrefixFrom(addr, bits), nil + default: + return netip.Prefix{}, fmt.Errorf("invalid compact prefix length %d (expected 5 or 17)", len(b)) + } +} + +// EncodeAddr encodes a netip.Addr into compact prefix bytes with a host prefix +// length (/32 for v4, /128 for v6). The address is always unmapped before encoding. +func EncodeAddr(a netip.Addr) []byte { + a = a.Unmap() + bits := 128 + if a.Is4() { + bits = 32 + } + // Host prefix lengths are always valid for the address family, so error is impossible. + b, _ := EncodePrefix(netip.PrefixFrom(a, bits)) + return b +} + +// DecodeAddr decodes compact prefix bytes and returns only the address, +// discarding the prefix length. Useful when the prefix length is implied +// (e.g. peer overlay IPs are always /32 or /128). +func DecodeAddr(b []byte) (netip.Addr, error) { + p, err := DecodePrefix(b) + if err != nil { + return netip.Addr{}, err + } + return p.Addr(), nil +} diff --git a/shared/netiputil/compact_test.go b/shared/netiputil/compact_test.go new file mode 100644 index 000000000..1e7c7ed82 --- /dev/null +++ b/shared/netiputil/compact_test.go @@ -0,0 +1,175 @@ +package netiputil + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEncodeDecodePrefix(t *testing.T) { + tests := []struct { + name string + prefix string + size int + }{ + { + name: "v4 host", + prefix: "100.64.0.1/32", + size: 5, + }, + { + name: "v4 network", + prefix: "10.0.0.0/8", + size: 5, + }, + { + name: "v4 default", + prefix: "0.0.0.0/0", + size: 5, + }, + { + name: "v6 host", + prefix: "fd00::1/128", + size: 17, + }, + { + name: "v6 network", + prefix: "fd00:1234:5678::/48", + size: 17, + }, + { + name: "v6 default", + prefix: "::/0", + size: 17, + }, + { + name: "v4 /16 overlay", + prefix: "100.64.0.1/16", + size: 5, + }, + { + name: "v6 /64 overlay", + prefix: "fd00::abcd:1/64", + size: 17, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := netip.MustParsePrefix(tt.prefix) + b, err := EncodePrefix(p) + require.NoError(t, err) + assert.Equal(t, tt.size, len(b), "encoded size") + + decoded, err := DecodePrefix(b) + require.NoError(t, err) + assert.Equal(t, p, decoded) + }) + } +} + +func TestEncodePrefixUnmaps(t *testing.T) { + // v4-mapped v6 address should encode as v4 + mapped := netip.MustParsePrefix("::ffff:10.1.2.3/32") + b, err := EncodePrefix(mapped) + require.NoError(t, err) + assert.Equal(t, 5, len(b), "v4-mapped should encode as 5 bytes") + + decoded, err := DecodePrefix(b) + require.NoError(t, err) + assert.Equal(t, netip.MustParsePrefix("10.1.2.3/32"), decoded) +} + +func TestEncodePrefixUnmapsRejectsInvalidBits(t *testing.T) { + // v4-mapped v6 with bits > 32 should return an error + mapped128 := netip.MustParsePrefix("::ffff:10.1.2.3/128") + _, err := EncodePrefix(mapped128) + require.Error(t, err) + + // v4-mapped v6 with bits=96 should also return an error + mapped96 := netip.MustParsePrefix("::ffff:10.0.0.0/96") + _, err = EncodePrefix(mapped96) + require.Error(t, err) + + // v4-mapped v6 with bits=32 should succeed + mapped32 := netip.MustParsePrefix("::ffff:10.1.2.3/32") + b, err := EncodePrefix(mapped32) + require.NoError(t, err) + assert.Equal(t, 5, len(b), "v4-mapped should encode as 5 bytes") + + decoded, err := DecodePrefix(b) + require.NoError(t, err) + assert.Equal(t, netip.MustParsePrefix("10.1.2.3/32"), decoded) +} + +func TestDecodeAddr(t *testing.T) { + v4 := netip.MustParseAddr("100.64.0.5") + b := EncodeAddr(v4) + assert.Equal(t, 5, len(b)) + + got, err := DecodeAddr(b) + require.NoError(t, err) + assert.Equal(t, v4, got) + + v6 := netip.MustParseAddr("fd00::1") + b = EncodeAddr(v6) + assert.Equal(t, 17, len(b)) + + got, err = DecodeAddr(b) + require.NoError(t, err) + assert.Equal(t, v6, got) +} + +func TestDecodePrefixInvalidLength(t *testing.T) { + _, err := DecodePrefix([]byte{1, 2, 3}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid compact prefix length 3") + + _, err = DecodePrefix(nil) + assert.Error(t, err) + + _, err = DecodePrefix([]byte{}) + assert.Error(t, err) +} + +func TestDecodePrefixInvalidBits(t *testing.T) { + // v4 with bits > 32 + b := []byte{10, 0, 0, 1, 33} + _, err := DecodePrefix(b) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid IPv4 prefix length 33") + + // v6 with bits > 128 + b = make([]byte, 17) + b[0] = 0xfd + b[16] = 129 + _, err = DecodePrefix(b) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid IPv6 prefix length 129") +} + +func TestDecodePrefixUnmapsV6Input(t *testing.T) { + addr := netip.MustParseAddr("::ffff:192.168.1.1") + + // v4-mapped v6 with bits > 32 should return an error + raw := addr.As16() + bInvalid := make([]byte, 17) + copy(bInvalid, raw[:]) + bInvalid[16] = 128 + + _, err := DecodePrefix(bInvalid) + require.Error(t, err, "v4-mapped address with /128 prefix should be rejected") + assert.Contains(t, err.Error(), "invalid prefix length") + + // v4-mapped v6 with valid /32 should decode and unmap correctly + bValid := make([]byte, 17) + copy(bValid, raw[:]) + bValid[16] = 32 + + decoded, err := DecodePrefix(bValid) + require.NoError(t, err) + assert.True(t, decoded.Addr().Is4(), "should be unmapped to v4") + assert.Equal(t, netip.MustParsePrefix("192.168.1.1/32"), decoded) +} diff --git a/shared/relay/client/client.go b/shared/relay/client/client.go index b10b05617..1800bddb2 100644 --- a/shared/relay/client/client.go +++ b/shared/relay/client/client.go @@ -2,8 +2,12 @@ package client import ( "context" + "errors" "fmt" "net" + "net/netip" + "net/url" + "strings" "sync" "time" @@ -146,6 +150,7 @@ func (cc *connContainer) close() { type Client struct { log *log.Entry connectionURL string + serverIP netip.Addr authTokenStore *auth.TokenStore hashedID messages.PeerID @@ -170,13 +175,22 @@ type Client struct { } // NewClient creates a new client for the relay server. The client is not connected to the server until the Connect +// is called. func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string, mtu uint16) *Client { + return NewClientWithServerIP(serverURL, netip.Addr{}, authTokenStore, peerID, mtu) +} + +// NewClientWithServerIP creates a new client for the relay server with a known server IP. serverIP, when valid, is +// dialed directly first; the FQDN is only attempted if the IP-based dial fails. TLS verification still uses the +// FQDN from serverURL via SNI. +func NewClientWithServerIP(serverURL string, serverIP netip.Addr, authTokenStore *auth.TokenStore, peerID string, mtu uint16) *Client { hashedID := messages.HashID(peerID) relayLog := log.WithFields(log.Fields{"relay": serverURL}) c := &Client{ log: relayLog, connectionURL: serverURL, + serverIP: serverIP, authTokenStore: authTokenStore, hashedID: hashedID, mtu: mtu, @@ -304,6 +318,23 @@ func (c *Client) ServerInstanceURL() (string, error) { return c.instanceURL.String(), nil } +// ConnectedIP returns the IP address of the live relay-server connection, +// extracted from the underlying socket's RemoteAddr. Zero value if not +// connected or if the address is not an IP literal. +func (c *Client) ConnectedIP() netip.Addr { + c.mu.Lock() + conn := c.relayConn + c.mu.Unlock() + if conn == nil { + return netip.Addr{} + } + addr := conn.RemoteAddr() + if addr == nil { + return netip.Addr{} + } + return extractIPLiteral(addr.String()) +} + // SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed. func (c *Client) SetOnDisconnectListener(fn func(string)) { c.listenerMutex.Lock() @@ -332,10 +363,23 @@ func (c *Client) Close() error { func (c *Client) connect(ctx context.Context) (*RelayAddr, error) { dialers := c.getDialers() - rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...) - conn, err := rd.Dial(ctx) - if err != nil { - return nil, err + var conn net.Conn + if c.serverIP.IsValid() { + var err error + conn, err = c.dialRaceDirect(ctx, dialers) + if err != nil { + c.log.Infof("dial via server IP %s failed, falling back to FQDN: %v", c.serverIP, err) + conn = nil + } + } + + if conn == nil { + rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...) + var err error + conn, err = rd.Dial(ctx) + if err != nil { + return nil, fmt.Errorf("dial via FQDN: %w", err) + } } c.relayConn = conn @@ -351,6 +395,52 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) { return instanceURL, nil } +// dialRaceDirect dials c.serverIP, preserving the original FQDN as the TLS ServerName for SNI. +func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (net.Conn, error) { + directURL, serverName, err := substituteHost(c.connectionURL, c.serverIP) + if err != nil { + return nil, fmt.Errorf("substitute host: %w", err) + } + + c.log.Debugf("dialing via server IP %s (SNI=%s)", c.serverIP, serverName) + + rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, directURL, dialers...). + WithServerName(serverName) + return rd.Dial(ctx) +} + +// substituteHost replaces the host portion of a rel/rels URL with ip, +// preserving the scheme and port. Returns the rewritten URL and the +// original host to use as the TLS ServerName, or empty if the original +// host is itself an IP literal (SNI requires a DNS name). +func substituteHost(serverURL string, ip netip.Addr) (string, string, error) { + u, err := url.Parse(serverURL) + if err != nil { + return "", "", fmt.Errorf("parse %q: %w", serverURL, err) + } + if u.Scheme == "" || u.Host == "" { + return "", "", fmt.Errorf("invalid relay URL %q", serverURL) + } + if !ip.IsValid() { + return "", "", errors.New("invalid server IP") + } + origHost := u.Hostname() + if _, err := netip.ParseAddr(origHost); err == nil { + origHost = "" + } + ip = ip.Unmap() + newHost := ip.String() + if ip.Is6() { + newHost = "[" + newHost + "]" + } + if port := u.Port(); port != "" { + u.Host = newHost + ":" + port + } else { + u.Host = newHost + } + return u.String(), origHost, nil +} + func (c *Client) handShake(ctx context.Context) (*RelayAddr, error) { msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary()) if err != nil { @@ -716,3 +806,21 @@ func (c *Client) handlePeersWentOfflineMsg(buf []byte) { } c.stateSubscription.OnPeersWentOffline(peersID) } + +// extractIPLiteral returns the IP from address forms produced by the relay +// dialers (URL or host:port). Zero value if the host is not an IP. +func extractIPLiteral(s string) netip.Addr { + if u, err := url.Parse(s); err == nil && u.Host != "" { + s = u.Host + } + host, _, err := net.SplitHostPort(s) + if err != nil { + host = s + } + host = strings.Trim(host, "[]") + ip, err := netip.ParseAddr(host) + if err != nil { + return netip.Addr{} + } + return ip.Unmap() +} diff --git a/shared/relay/client/client_serverip_test.go b/shared/relay/client/client_serverip_test.go new file mode 100644 index 000000000..7e699e37d --- /dev/null +++ b/shared/relay/client/client_serverip_test.go @@ -0,0 +1,280 @@ +package client + +import ( + "context" + "fmt" + "net" + "net/netip" + "testing" + "time" + + "go.opentelemetry.io/otel" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/relay/server" + "github.com/netbirdio/netbird/shared/relay/auth/allow" +) + +// TestClient_ServerIPRecoversFromUnresolvableFQDN verifies that when the +// primary FQDN-based dial fails (unresolvable .invalid host), Connect +// recovers via the server IP and SNI still uses the FQDN. +func TestClient_ServerIPRecoversFromUnresolvableFQDN(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + listenAddr, port := freeAddr(t) + srvCfg := server.Config{ + Meter: otel.Meter(""), + ExposedAddress: fmt.Sprintf("rel://test-unresolvable-host.invalid:%d", port), + TLSSupport: false, + AuthValidator: &allow.Auth{}, + } + srv, err := server.NewServer(srvCfg) + if err != nil { + t.Fatalf("create server: %s", err) + } + + errChan := make(chan error, 1) + go func() { + if err := srv.Listen(server.ListenerConfig{Address: listenAddr}); err != nil { + errChan <- err + } + }() + t.Cleanup(func() { + if err := srv.Shutdown(context.Background()); err != nil { + t.Errorf("shutdown server: %s", err) + } + }) + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("server failed to start: %s", err) + } + + t.Run("no server IP, primary fails", func(t *testing.T) { + c := NewClient(srvCfg.ExposedAddress, hmacTokenStore, "alice-noip", iface.DefaultMTU) + err := c.Connect(ctx) + if err == nil { + _ = c.Close() + t.Fatalf("expected connect to fail without server IP, got nil") + } + }) + + t.Run("server IP recovers", func(t *testing.T) { + c := NewClientWithServerIP(srvCfg.ExposedAddress, netip.MustParseAddr("127.0.0.1"), hmacTokenStore, "alice-with-ip", iface.DefaultMTU) + if err := c.Connect(ctx); err != nil { + t.Fatalf("connect with server IP: %s", err) + } + t.Cleanup(func() { _ = c.Close() }) + + if !c.Ready() { + t.Fatalf("client not ready after connect") + } + if got := c.ConnectedIP(); got.String() != "127.0.0.1" { + t.Fatalf("ConnectedIP = %q, want 127.0.0.1", got) + } + }) +} + +// TestClient_ConnectedIPAfterFQDNDial verifies ConnectedIP returns the +// resolved IP after a successful FQDN-based dial. The underlying socket's +// RemoteAddr must be exposed through the dialer wrappers; if it returns +// the dial-time URL instead, ConnectedIP returns empty and the dial +// IP we advertise to peers is empty too. +func TestClient_ConnectedIPAfterFQDNDial(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + listenAddr, port := freeAddr(t) + srvCfg := server.Config{ + Meter: otel.Meter(""), + ExposedAddress: fmt.Sprintf("rel://localhost:%d", port), + TLSSupport: false, + AuthValidator: &allow.Auth{}, + } + srv, err := server.NewServer(srvCfg) + if err != nil { + t.Fatalf("create server: %s", err) + } + errChan := make(chan error, 1) + go func() { + if err := srv.Listen(server.ListenerConfig{Address: listenAddr}); err != nil { + errChan <- err + } + }() + t.Cleanup(func() { _ = srv.Shutdown(context.Background()) }) + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("server failed to start: %s", err) + } + + c := NewClient(srvCfg.ExposedAddress, hmacTokenStore, "alice-fqdn", iface.DefaultMTU) + if err := c.Connect(ctx); err != nil { + t.Fatalf("connect: %s", err) + } + t.Cleanup(func() { _ = c.Close() }) + + got := c.ConnectedIP().String() + if got != "127.0.0.1" && got != "::1" { + t.Fatalf("ConnectedIP after FQDN dial = %q, want 127.0.0.1 or ::1", got) + } +} + +func TestSubstituteHost(t *testing.T) { + tests := []struct { + name string + serverURL string + ip string + wantURL string + wantServerName string + wantErr bool + }{ + { + name: "rels with port", + serverURL: "rels://relay.netbird.io:443", + ip: "10.0.0.5", + wantURL: "rels://10.0.0.5:443", + wantServerName: "relay.netbird.io", + }, + { + name: "rel with port", + serverURL: "rel://relay.example.com:80", + ip: "192.0.2.1", + wantURL: "rel://192.0.2.1:80", + wantServerName: "relay.example.com", + }, + { + name: "ipv6 server IP bracketed", + serverURL: "rels://relay.example.com:443", + ip: "2001:db8::1", + wantURL: "rels://[2001:db8::1]:443", + wantServerName: "relay.example.com", + }, + { + name: "no port", + serverURL: "rels://relay.example.com", + ip: "10.0.0.5", + wantURL: "rels://10.0.0.5", + wantServerName: "relay.example.com", + }, + { + name: "ipv6 server with port returns empty SNI", + serverURL: "rels://[2001:db8::5]:443", + ip: "10.0.0.5", + wantURL: "rels://10.0.0.5:443", + wantServerName: "", + }, + { + name: "ipv4 server with port returns empty SNI", + serverURL: "rels://10.0.0.5:443", + ip: "10.0.0.6", + wantURL: "rels://10.0.0.6:443", + wantServerName: "", + }, + { + name: "ipv6 server IP no port", + serverURL: "rels://relay.example.com", + ip: "2001:db8::1", + wantURL: "rels://[2001:db8::1]", + wantServerName: "relay.example.com", + }, + { + name: "missing scheme", + serverURL: "relay.example.com:443", + ip: "10.0.0.5", + wantErr: true, + }, + { + name: "empty", + serverURL: "", + ip: "10.0.0.5", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var ip netip.Addr + if tt.ip != "" { + ip = netip.MustParseAddr(tt.ip) + } + gotURL, gotName, err := substituteHost(tt.serverURL, ip) + if tt.wantErr { + if err == nil { + t.Fatalf("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if gotURL != tt.wantURL { + t.Errorf("URL = %q, want %q", gotURL, tt.wantURL) + } + if gotName != tt.wantServerName { + t.Errorf("ServerName = %q, want %q", gotName, tt.wantServerName) + } + }) + } +} + +func TestClient_ConnectedIPEmptyWhenNotConnected(t *testing.T) { + c := NewClient("rel://example.invalid:80", hmacTokenStore, "x", iface.DefaultMTU) + if got := c.ConnectedIP(); got.IsValid() { + t.Fatalf("ConnectedIP on disconnected client = %q, want zero", got) + } +} + +// staticAddr is a net.Addr that returns a fixed string. Used to verify +// ConnectedIP parses RemoteAddr correctly. +type staticAddr struct{ s string } + +func (a staticAddr) Network() string { return "tcp" } +func (a staticAddr) String() string { return a.s } + +type stubConn struct { + net.Conn + remote net.Addr +} + +func (s stubConn) RemoteAddr() net.Addr { return s.remote } + +func TestClient_ConnectedIPParsesRemoteAddr(t *testing.T) { + tests := []struct { + name string + s string + want string + }{ + {"hostport ipv4", "127.0.0.1:50301", "127.0.0.1"}, + {"hostport ipv6 bracketed", "[::1]:50301", "::1"}, + {"url with ipv4", "rel://127.0.0.1:50301", "127.0.0.1"}, + {"url with ipv6", "rels://[2001:db8::1]:443", "2001:db8::1"}, + {"fqdn url returns empty", "rel://relay.example.com:50301", ""}, + {"fqdn hostport returns empty", "relay.example.com:50301", ""}, + {"plain ipv4 no port", "10.0.0.1", "10.0.0.1"}, + {"empty", "", ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Client{relayConn: stubConn{remote: staticAddr{s: tt.s}}} + got := c.ConnectedIP() + var gotStr string + if got.IsValid() { + gotStr = got.String() + } + if gotStr != tt.want { + t.Errorf("ConnectedIP(%q) = %q, want %q", tt.s, gotStr, tt.want) + } + }) + } +} + +// freeAddr returns a 127.0.0.1 address with an OS-assigned port. The +// listener is closed before returning, so the port is briefly free for +// the caller to bind. Avoids hardcoded ports that can collide. +func freeAddr(t *testing.T) (string, int) { + t.Helper() + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("get free port: %s", err) + } + addr := l.Addr().(*net.TCPAddr) + _ = l.Close() + return addr.String(), addr.Port +} diff --git a/shared/relay/client/dialer/quic/quic.go b/shared/relay/client/dialer/quic/quic.go index 2d7b00a80..86f6f178d 100644 --- a/shared/relay/client/dialer/quic/quic.go +++ b/shared/relay/client/dialer/quic/quic.go @@ -23,7 +23,7 @@ func (d Dialer) Protocol() string { return Network } -func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { +func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) { quicURL, err := prepareURL(address) if err != nil { return nil, err @@ -32,11 +32,14 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { // Get the base TLS config tlsClientConfig := quictls.ClientQUICTLSConfig() - // Set ServerName to hostname if not an IP address - host, _, splitErr := net.SplitHostPort(quicURL) - if splitErr == nil && net.ParseIP(host) == nil { - // It's a hostname, not an IP - modify directly - tlsClientConfig.ServerName = host + switch { + case serverName != "" && net.ParseIP(serverName) == nil: + tlsClientConfig.ServerName = serverName + default: + host, _, splitErr := net.SplitHostPort(quicURL) + if splitErr == nil && net.ParseIP(host) == nil { + tlsClientConfig.ServerName = host + } } quicConfig := &quic.Config{ @@ -46,7 +49,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { InitialPacketSize: nbRelay.QUICInitialPacketSize, } - udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{Port: 0}) if err != nil { log.Errorf("failed to listen on UDP: %s", err) return nil, err diff --git a/shared/relay/client/dialer/race_dialer.go b/shared/relay/client/dialer/race_dialer.go index 34359d17e..15208b858 100644 --- a/shared/relay/client/dialer/race_dialer.go +++ b/shared/relay/client/dialer/race_dialer.go @@ -14,7 +14,9 @@ const ( ) type DialeFn interface { - Dial(ctx context.Context, address string) (net.Conn, error) + // Dial connects to address. serverName, when non-empty, overrides the TLS + // ServerName used for SNI/cert validation. Empty means derive from address. + Dial(ctx context.Context, address, serverName string) (net.Conn, error) Protocol() string } @@ -27,6 +29,7 @@ type dialResult struct { type RaceDial struct { log *log.Entry serverURL string + serverName string dialerFns []DialeFn connectionTimeout time.Duration } @@ -40,6 +43,16 @@ func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL stri } } +// WithServerName sets a TLS SNI/cert validation override. Used when serverURL +// contains an IP literal but the cert is issued for a different hostname. +// +// Mutates the receiver and is not safe for concurrent reconfiguration; a +// RaceDial is intended to be constructed per dial and discarded. +func (r *RaceDial) WithServerName(serverName string) *RaceDial { + r.serverName = serverName + return r +} + func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) { connChan := make(chan dialResult, len(r.dialerFns)) winnerConn := make(chan net.Conn, 1) @@ -64,7 +77,7 @@ func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dia defer cancel() r.log.Infof("dialing Relay server via %s", dfn.Protocol()) - conn, err := dfn.Dial(ctx, r.serverURL) + conn, err := dfn.Dial(ctx, r.serverURL, r.serverName) connChan <- dialResult{Conn: conn, Protocol: dfn.Protocol(), Err: err} } diff --git a/shared/relay/client/dialer/race_dialer_test.go b/shared/relay/client/dialer/race_dialer_test.go index aa18df578..a53edc00e 100644 --- a/shared/relay/client/dialer/race_dialer_test.go +++ b/shared/relay/client/dialer/race_dialer_test.go @@ -28,7 +28,7 @@ type MockDialer struct { protocolStr string } -func (m *MockDialer) Dial(ctx context.Context, address string) (net.Conn, error) { +func (m *MockDialer) Dial(ctx context.Context, address, _ string) (net.Conn, error) { return m.dialFunc(ctx, address) } diff --git a/shared/relay/client/dialer/ws/conn.go b/shared/relay/client/dialer/ws/conn.go index d5b719f51..9497fab89 100644 --- a/shared/relay/client/dialer/ws/conn.go +++ b/shared/relay/client/dialer/ws/conn.go @@ -12,14 +12,24 @@ import ( type Conn struct { ctx context.Context *websocket.Conn - remoteAddr WebsocketAddr + remoteAddr net.Addr } -func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn { +// NewConn builds a relay ws.Conn. underlying is the raw TCP/TLS conn captured +// from the http transport's DialContext; when set, RemoteAddr returns its +// peer address (an IP literal). When nil (e.g. wasm), RemoteAddr falls back +// to the dial-time URL. +func NewConn(wsConn *websocket.Conn, serverAddress string, underlying net.Conn) net.Conn { + var addr net.Addr = WebsocketAddr{serverAddress} + if underlying != nil { + if ra := underlying.RemoteAddr(); ra != nil { + addr = ra + } + } return &Conn{ ctx: context.Background(), Conn: wsConn, - remoteAddr: WebsocketAddr{serverAddress}, + remoteAddr: addr, } } diff --git a/shared/relay/client/dialer/ws/dialopts_generic.go b/shared/relay/client/dialer/ws/dialopts_generic.go index 9dfe698d0..8008d89d3 100644 --- a/shared/relay/client/dialer/ws/dialopts_generic.go +++ b/shared/relay/client/dialer/ws/dialopts_generic.go @@ -2,10 +2,14 @@ package ws -import "github.com/coder/websocket" +import ( + "net" -func createDialOptions() *websocket.DialOptions { + "github.com/coder/websocket" +) + +func createDialOptions(serverName string, underlyingOut *net.Conn) *websocket.DialOptions { return &websocket.DialOptions{ - HTTPClient: httpClientNbDialer(), + HTTPClient: httpClientNbDialer(serverName, underlyingOut), } } diff --git a/shared/relay/client/dialer/ws/dialopts_js.go b/shared/relay/client/dialer/ws/dialopts_js.go index 7eac27531..5b11fe765 100644 --- a/shared/relay/client/dialer/ws/dialopts_js.go +++ b/shared/relay/client/dialer/ws/dialopts_js.go @@ -2,9 +2,13 @@ package ws -import "github.com/coder/websocket" +import ( + "net" -func createDialOptions() *websocket.DialOptions { - // WASM version doesn't support HTTPClient + "github.com/coder/websocket" +) + +func createDialOptions(_ string, _ *net.Conn) *websocket.DialOptions { + // WASM version doesn't support HTTPClient or custom TLS config. return &websocket.DialOptions{} } diff --git a/shared/relay/client/dialer/ws/ws.go b/shared/relay/client/dialer/ws/ws.go index 37b189e05..301486514 100644 --- a/shared/relay/client/dialer/ws/ws.go +++ b/shared/relay/client/dialer/ws/ws.go @@ -26,13 +26,14 @@ func (d Dialer) Protocol() string { return "WS" } -func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { +func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) { wsURL, err := prepareURL(address) if err != nil { return nil, err } - opts := createDialOptions() + var underlying net.Conn + opts := createDialOptions(serverName, &underlying) parsedURL, err := url.Parse(wsURL) if err != nil { @@ -52,7 +53,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { _ = resp.Body.Close() } - conn := NewConn(wsConn, address) + conn := NewConn(wsConn, address, underlying) return conn, nil } @@ -64,7 +65,10 @@ func prepareURL(address string) (string, error) { return strings.Replace(address, "rel", "ws", 1), nil } -func httpClientNbDialer() *http.Client { +// httpClientNbDialer builds the http client used by the websocket library. +// underlyingOut, when non-nil, is populated with the raw conn from the +// transport's DialContext so the caller can read its RemoteAddr. +func httpClientNbDialer(serverName string, underlyingOut *net.Conn) *http.Client { customDialer := nbnet.NewDialer() certPool, err := x509.SystemCertPool() @@ -75,10 +79,15 @@ func httpClientNbDialer() *http.Client { customTransport := &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return customDialer.DialContext(ctx, network, addr) + c, err := customDialer.DialContext(ctx, network, addr) + if err == nil && underlyingOut != nil { + *underlyingOut = c + } + return c, err }, TLSClientConfig: &tls.Config{ - RootCAs: certPool, + RootCAs: certPool, + ServerName: serverName, }, } diff --git a/shared/relay/client/guard.go b/shared/relay/client/guard.go index f4d3a8cce..d7892d0ce 100644 --- a/shared/relay/client/guard.go +++ b/shared/relay/client/guard.go @@ -8,10 +8,7 @@ import ( log "github.com/sirupsen/logrus" ) -const ( - // TODO: make it configurable, the manager should validate all configurable parameters - reconnectingTimeout = 60 * time.Second -) +const defaultMaxBackoffInterval = 60 * time.Second // Guard manage the reconnection tries to the Relay server in case of disconnection event. type Guard struct { @@ -19,14 +16,23 @@ type Guard struct { OnNewRelayClient chan *Client OnReconnected chan struct{} serverPicker *ServerPicker + + // maxBackoffInterval caps the exponential backoff between reconnect + // attempts. + maxBackoffInterval time.Duration } -// NewGuard creates a new guard for the relay client. -func NewGuard(sp *ServerPicker) *Guard { +// NewGuard creates a new guard for the relay client. A non-positive +// maxBackoffInterval falls back to defaultMaxBackoffInterval. +func NewGuard(sp *ServerPicker, maxBackoffInterval time.Duration) *Guard { + if maxBackoffInterval <= 0 { + maxBackoffInterval = defaultMaxBackoffInterval + } g := &Guard{ - OnNewRelayClient: make(chan *Client, 1), - OnReconnected: make(chan struct{}, 1), - serverPicker: sp, + OnNewRelayClient: make(chan *Client, 1), + OnReconnected: make(chan struct{}, 1), + serverPicker: sp, + maxBackoffInterval: maxBackoffInterval, } return g } @@ -49,7 +55,7 @@ func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) { } // start a ticker to pick a new server - ticker := exponentTicker(ctx) + ticker := g.exponentTicker(ctx) defer ticker.Stop() for { @@ -125,11 +131,11 @@ func (g *Guard) notifyReconnected() { } } -func exponentTicker(ctx context.Context) *backoff.Ticker { +func (g *Guard) exponentTicker(ctx context.Context) *backoff.Ticker { bo := backoff.WithContext(&backoff.ExponentialBackOff{ InitialInterval: 2 * time.Second, Multiplier: 2, - MaxInterval: reconnectingTimeout, + MaxInterval: g.maxBackoffInterval, Clock: backoff.SystemClock, }, ctx) diff --git a/shared/relay/client/manager.go b/shared/relay/client/manager.go index 6220e7f6b..3858b3c83 100644 --- a/shared/relay/client/manager.go +++ b/shared/relay/client/manager.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "net" + "net/netip" "reflect" "sync" "time" @@ -39,6 +40,15 @@ func NewRelayTrack() *RelayTrack { type OnServerCloseListener func() +// ManagerOption configures a Manager at construction time. +type ManagerOption func(*Manager) + +// WithMaxBackoffInterval caps the exponential backoff between reconnect +// attempts to the home relay. A non-positive value keeps the default. +func WithMaxBackoffInterval(d time.Duration) ManagerOption { + return func(m *Manager) { m.maxBackoffInterval = d } +} + // Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL // and automatically reconnect to them in case disconnection. // The manager also manage temporary relay connection. If a client wants to communicate with a client on a @@ -64,12 +74,16 @@ type Manager struct { onReconnectedListenerFn func() listenerLock sync.Mutex - mtu uint16 + mtu uint16 + maxBackoffInterval time.Duration + + cleanupInterval time.Duration + keepUnusedServerTime time.Duration } // NewManager creates a new manager instance. // The serverURL address can be empty. In this case, the manager will not serve. -func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16) *Manager { +func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16, opts ...ManagerOption) *Manager { tokenStore := &relayAuth.TokenStore{} m := &Manager{ @@ -85,9 +99,14 @@ func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uin }, relayClients: make(map[string]*RelayTrack), onDisconnectedListeners: make(map[string]*list.List), + cleanupInterval: relayCleanupInterval, + keepUnusedServerTime: keepUnusedServerTime, + } + for _, opt := range opts { + opt(m) } m.serverPicker.ServerURLs.Store(serverURLs) - m.reconnectGuard = NewGuard(m.serverPicker) + m.reconnectGuard = NewGuard(m.serverPicker, m.maxBackoffInterval) return m } @@ -117,7 +136,10 @@ func (m *Manager) Serve() error { // OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be // established via the relay server. If the peer is on a different relay server, the manager will establish a new // connection to the relay server. It returns back with a net.Conn what represent the remote peer connection. -func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) { +// +// serverIP, when valid and serverAddress is foreign, is used as a dial target if the FQDN-based dial fails. +// Ignored for the local home-server path. TLS verification still uses the FQDN via SNI. +func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string, serverIP netip.Addr) (net.Conn, error) { m.relayClientMu.RLock() defer m.relayClientMu.RUnlock() @@ -138,7 +160,7 @@ func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) ( netConn, err = m.relayClient.OpenConn(ctx, peerKey) } else { log.Debugf("open peer connection via foreign server: %s", serverAddress) - netConn, err = m.openConnVia(ctx, serverAddress, peerKey) + netConn, err = m.openConnVia(ctx, serverAddress, peerKey, serverIP) } if err != nil { return nil, err @@ -190,16 +212,22 @@ func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServ return nil } -// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is -// lost. This address will be sent to the target peer to choose the common relay server for the communication. -func (m *Manager) RelayInstanceAddress() (string, error) { +// RelayInstanceAddress returns the address and resolved IP of the permanent relay server. It could change if the +// network connection is lost. The address is sent to the target peer to choose the common relay server for the +// communication; the IP is sent alongside so remote peers can dial directly without their own DNS lookup. Both +// values are read under the same lock so they cannot diverge across a reconnection. +func (m *Manager) RelayInstanceAddress() (string, netip.Addr, error) { m.relayClientMu.RLock() defer m.relayClientMu.RUnlock() if m.relayClient == nil { - return "", ErrRelayClientNotConnected + return "", netip.Addr{}, ErrRelayClientNotConnected } - return m.relayClient.ServerInstanceURL() + addr, err := m.relayClient.ServerInstanceURL() + if err != nil { + return "", netip.Addr{}, err + } + return addr, m.relayClient.ConnectedIP(), nil } // ServerURLs returns the addresses of the relay servers. @@ -223,7 +251,7 @@ func (m *Manager) UpdateToken(token *relayAuth.Token) error { return m.tokenStore.UpdateToken(token) } -func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) { +func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string, serverIP netip.Addr) (net.Conn, error) { // check if already has a connection to the desired relay server m.relayClientsMutex.RLock() rt, ok := m.relayClients[serverAddress] @@ -258,7 +286,7 @@ func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string m.relayClients[serverAddress] = rt m.relayClientsMutex.Unlock() - relayClient := NewClient(serverAddress, m.tokenStore, m.peerID, m.mtu) + relayClient := NewClientWithServerIP(serverAddress, serverIP, m.tokenStore, m.peerID, m.mtu) err := relayClient.Connect(m.ctx) if err != nil { rt.err = err @@ -290,19 +318,36 @@ func (m *Manager) onServerConnected() { go m.onReconnectedListenerFn() } -// onServerDisconnected start to reconnection for home server only +// onServerDisconnected handles relay disconnect events. For the home server it +// starts the reconnect guard. For foreign servers it evicts the now-dead client +// from the cache so the next OpenConn builds a fresh one instead of reusing a +// closed client. func (m *Manager) onServerDisconnected(serverAddress string) { m.relayClientMu.Lock() - if serverAddress == m.relayClient.connectionURL { + isHome := m.relayClient != nil && serverAddress == m.relayClient.connectionURL + if isHome { go func(client *Client) { m.reconnectGuard.StartReconnectTrys(m.ctx, client) }(m.relayClient) } m.relayClientMu.Unlock() + if !isHome { + m.evictForeignRelay(serverAddress) + } + m.notifyOnDisconnectListeners(serverAddress) } +func (m *Manager) evictForeignRelay(serverAddress string) { + m.relayClientsMutex.Lock() + defer m.relayClientsMutex.Unlock() + if _, ok := m.relayClients[serverAddress]; ok { + delete(m.relayClients, serverAddress) + log.Debugf("evicted disconnected foreign relay client: %s", serverAddress) + } +} + func (m *Manager) listenGuardEvent(ctx context.Context) { for { select { @@ -334,7 +379,7 @@ func (m *Manager) isForeignServer(address string) (bool, error) { } func (m *Manager) startCleanupLoop() { - ticker := time.NewTicker(relayCleanupInterval) + ticker := time.NewTicker(m.cleanupInterval) defer ticker.Stop() for { select { @@ -359,7 +404,7 @@ func (m *Manager) cleanUpUnusedRelays() { continue } - if time.Since(rt.created) <= keepUnusedServerTime { + if time.Since(rt.created) <= m.keepUnusedServerTime { rt.Unlock() continue } diff --git a/shared/relay/client/manager_serverip_test.go b/shared/relay/client/manager_serverip_test.go new file mode 100644 index 000000000..a354beade --- /dev/null +++ b/shared/relay/client/manager_serverip_test.go @@ -0,0 +1,144 @@ +package client + +import ( + "context" + "io" + "net/netip" + "testing" + "time" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/relay/server" +) + +// TestManager_ForeignRelayServerIP exercises the foreign-relay path +// end-to-end through Manager.OpenConn. Alice and Bob register on different +// relay servers; Alice dials Bob's foreign relay using an unresolvable +// FQDN. Without a server IP the dial fails; with Bob's advertised IP it +// recovers and a payload round-trips between the peers. +func TestManager_ForeignRelayServerIP(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + // Alice's home relay + homeCfg := server.ListenerConfig{Address: "127.0.0.1:52401"} + homeSrv, err := server.NewServer(newManagerTestServerConfig(homeCfg.Address)) + if err != nil { + t.Fatalf("create home server: %s", err) + } + homeErr := make(chan error, 1) + go func() { + if err := homeSrv.Listen(homeCfg); err != nil { + homeErr <- err + } + }() + t.Cleanup(func() { _ = homeSrv.Shutdown(context.Background()) }) + if err := waitForServerToStart(homeErr); err != nil { + t.Fatalf("home server: %s", err) + } + + // Bob's foreign relay + foreignCfg := server.ListenerConfig{Address: "127.0.0.1:52402"} + foreignSrv, err := server.NewServer(newManagerTestServerConfig(foreignCfg.Address)) + if err != nil { + t.Fatalf("create foreign server: %s", err) + } + foreignErr := make(chan error, 1) + go func() { + if err := foreignSrv.Listen(foreignCfg); err != nil { + foreignErr <- err + } + }() + t.Cleanup(func() { _ = foreignSrv.Shutdown(context.Background()) }) + if err := waitForServerToStart(foreignErr); err != nil { + t.Fatalf("foreign server: %s", err) + } + + mCtx, mCancel := context.WithCancel(ctx) + t.Cleanup(mCancel) + + mgrAlice := NewManager(mCtx, toURL(homeCfg), "alice", iface.DefaultMTU) + if err := mgrAlice.Serve(); err != nil { + t.Fatalf("alice manager serve: %s", err) + } + + mgrBob := NewManager(mCtx, toURL(foreignCfg), "bob", iface.DefaultMTU) + if err := mgrBob.Serve(); err != nil { + t.Fatalf("bob manager serve: %s", err) + } + + // Bob's real relay URL and the IP that would ride along in signal as relayServerIP. + bobRealAddr, bobAdvertisedIP, err := mgrBob.RelayInstanceAddress() + if err != nil { + t.Fatalf("bob relay address: %s", err) + } + if !bobAdvertisedIP.IsValid() { + t.Fatalf("expected valid RelayInstanceIP for bob, got zero") + } + + // .invalid is reserved (RFC 2606), so DNS resolution always fails. + const brokenFQDN = "rel://relay-bob-instance.invalid:52402" + if brokenFQDN == bobRealAddr { + t.Fatalf("broken FQDN must differ from bob's real address (%s)", bobRealAddr) + } + + t.Run("no server IP, dial fails", func(t *testing.T) { + dialCtx, dialCancel := context.WithTimeout(ctx, 5*time.Second) + defer dialCancel() + _, err := mgrAlice.OpenConn(dialCtx, brokenFQDN, "bob", netip.Addr{}) + if err == nil { + t.Fatalf("expected OpenConn to fail without server IP, got success") + } + }) + + t.Run("server IP recovers", func(t *testing.T) { + // Bob waits for Alice's incoming peer connection on his side. + bobSideCh := make(chan error, 1) + go func() { + conn, err := mgrBob.OpenConn(ctx, bobRealAddr, "alice", netip.Addr{}) + if err != nil { + bobSideCh <- err + return + } + buf := make([]byte, 1024) + n, err := conn.Read(buf) + if err != nil { + bobSideCh <- err + return + } + if _, err := conn.Write(buf[:n]); err != nil { + bobSideCh <- err + return + } + bobSideCh <- nil + }() + + aliceConn, err := mgrAlice.OpenConn(ctx, brokenFQDN, "bob", bobAdvertisedIP) + if err != nil { + t.Fatalf("alice OpenConn with server IP: %s", err) + } + t.Cleanup(func() { _ = aliceConn.Close() }) + + payload := []byte("alice-to-bob") + if _, err := aliceConn.Write(payload); err != nil { + t.Fatalf("alice write: %s", err) + } + + buf := make([]byte, len(payload)) + if _, err := io.ReadFull(aliceConn, buf); err != nil { + t.Fatalf("alice read echo: %s", err) + } + if string(buf) != string(payload) { + t.Fatalf("echo mismatch: got %q want %q", buf, payload) + } + + select { + case err := <-bobSideCh: + if err != nil { + t.Fatalf("bob side: %s", err) + } + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for bob side") + } + }) +} diff --git a/shared/relay/client/manager_test.go b/shared/relay/client/manager_test.go index fb91f7682..9e964f688 100644 --- a/shared/relay/client/manager_test.go +++ b/shared/relay/client/manager_test.go @@ -2,6 +2,8 @@ package client import ( "context" + "fmt" + "net/netip" "testing" "time" @@ -100,15 +102,15 @@ func TestForeignConn(t *testing.T) { if err := clientBob.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) } - bobsSrvAddr, err := clientBob.RelayInstanceAddress() + bobsSrvAddr, _, err := clientBob.RelayInstanceAddress() if err != nil { t.Fatalf("failed to get relay address: %s", err) } - connAliceToBob, err := clientAlice.OpenConn(ctx, bobsSrvAddr, "bob") + connAliceToBob, err := clientAlice.OpenConn(ctx, bobsSrvAddr, "bob", netip.Addr{}) if err != nil { t.Fatalf("failed to bind channel: %s", err) } - connBobToAlice, err := clientBob.OpenConn(ctx, bobsSrvAddr, "alice") + connBobToAlice, err := clientBob.OpenConn(ctx, bobsSrvAddr, "alice", netip.Addr{}) if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -208,7 +210,7 @@ func TestForeginConnClose(t *testing.T) { if err != nil { t.Fatalf("failed to serve manager: %s", err) } - conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "bob") + conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "bob", netip.Addr{}) if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -300,7 +302,7 @@ func TestForeignAutoClose(t *testing.T) { } t.Log("open connection to another peer") - if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer"); err == nil { + if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer", netip.Addr{}); err == nil { t.Fatalf("should have failed to open connection to another peer") } @@ -360,16 +362,17 @@ func TestAutoReconnect(t *testing.T) { t.Fatalf("failed to serve manager: %s", err) } - clientAlice := NewManager(mCtx, toURL(srvCfg), "alice", iface.DefaultMTU) + clientAlice := NewManager(mCtx, toURL(srvCfg), "alice", iface.DefaultMTU, + WithMaxBackoffInterval(2*time.Second)) err = clientAlice.Serve() if err != nil { t.Fatalf("failed to serve manager: %s", err) } - ra, err := clientAlice.RelayInstanceAddress() + ra, _, err := clientAlice.RelayInstanceAddress() if err != nil { t.Errorf("failed to get relay address: %s", err) } - conn, err := clientAlice.OpenConn(ctx, ra, "bob") + conn, err := clientAlice.OpenConn(ctx, ra, "bob", netip.Addr{}) if err != nil { t.Errorf("failed to bind channel: %s", err) } @@ -384,15 +387,32 @@ func TestAutoReconnect(t *testing.T) { } log.Infof("waiting for reconnection") - time.Sleep(reconnectingTimeout + 1*time.Second) + if err := waitForReady(ctx, clientAlice, 15*time.Second); err != nil { + t.Fatalf("manager did not reconnect: %s", err) + } log.Infof("reopent the connection") - _, err = clientAlice.OpenConn(ctx, ra, "bob") + _, err = clientAlice.OpenConn(ctx, ra, "bob", netip.Addr{}) if err != nil { t.Errorf("failed to open channel: %s", err) } } +func waitForReady(ctx context.Context, m *Manager, timeout time.Duration) error { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if m.Ready() { + return nil + } + select { + case <-time.After(100 * time.Millisecond): + case <-ctx.Done(): + return ctx.Err() + } + } + return fmt.Errorf("manager not ready within %s", timeout) +} + func TestNotifierDoubleAdd(t *testing.T) { ctx := context.Background() @@ -434,7 +454,7 @@ func TestNotifierDoubleAdd(t *testing.T) { t.Fatalf("failed to serve manager: %s", err) } - conn1, err := clientAlice.OpenConn(ctx, clientAlice.ServerURLs()[0], "bob") + conn1, err := clientAlice.OpenConn(ctx, clientAlice.ServerURLs()[0], "bob", netip.Addr{}) if err != nil { t.Fatalf("failed to bind channel: %s", err) } diff --git a/shared/signal/client/client.go b/shared/signal/client/client.go index 5347c80e9..9dc6ccd37 100644 --- a/shared/signal/client/client.go +++ b/shared/signal/client/client.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "net/netip" "strings" "github.com/netbirdio/netbird/shared/signal/proto" @@ -14,17 +15,17 @@ import ( // A set of tools to exchange connection details (Wireguard endpoints) with the remote peer. -// Status is the status of the client -type Status string - -const StreamConnected Status = "Connected" -const StreamDisconnected Status = "Disconnected" - const ( + StreamConnected Status = "Connected" + StreamDisconnected Status = "Disconnected" + // DirectCheck indicates support to direct mode checks DirectCheck uint32 = 1 ) +// Status is the status of the client +type Status string + type Client interface { io.Closer StreamConnected() bool @@ -38,6 +39,24 @@ type Client interface { SetOnReconnectedListener(func()) } +// Credential is an instance of a GrpcClient's Credential +type Credential struct { + UFrag string + Pwd string +} + +// CredentialPayload bundles the fields of a signal Body for MarshalCredential. +type CredentialPayload struct { + Type proto.Body_Type + WgListenPort int + Credential *Credential + RosenpassPubKey []byte + RosenpassAddr string + RelaySrvAddress string + RelaySrvIP netip.Addr + SessionID []byte +} + // UnMarshalCredential parses the credentials from the message and returns a Credential instance func UnMarshalCredential(msg *proto.Message) (*Credential, error) { @@ -52,27 +71,27 @@ func UnMarshalCredential(msg *proto.Message) (*Credential, error) { } // MarshalCredential marshal a Credential instance and returns a Message object -func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey string, credential *Credential, t proto.Body_Type, rosenpassPubKey []byte, rosenpassAddr string, relaySrvAddress string, sessionID []byte) (*proto.Message, error) { +func MarshalCredential(myKey wgtypes.Key, remoteKey string, p CredentialPayload) (*proto.Message, error) { + body := &proto.Body{ + Type: p.Type, + Payload: fmt.Sprintf("%s:%s", p.Credential.UFrag, p.Credential.Pwd), + WgListenPort: uint32(p.WgListenPort), + NetBirdVersion: version.NetbirdVersion(), + RosenpassConfig: &proto.RosenpassConfig{ + RosenpassPubKey: p.RosenpassPubKey, + RosenpassServerAddr: p.RosenpassAddr, + }, + SessionId: p.SessionID, + } + if p.RelaySrvAddress != "" { + body.RelayServerAddress = &p.RelaySrvAddress + } + if p.RelaySrvIP.IsValid() { + body.RelayServerIP = p.RelaySrvIP.Unmap().AsSlice() + } return &proto.Message{ Key: myKey.PublicKey().String(), RemoteKey: remoteKey, - Body: &proto.Body{ - Type: t, - Payload: fmt.Sprintf("%s:%s", credential.UFrag, credential.Pwd), - WgListenPort: uint32(myPort), - NetBirdVersion: version.NetbirdVersion(), - RosenpassConfig: &proto.RosenpassConfig{ - RosenpassPubKey: rosenpassPubKey, - RosenpassServerAddr: rosenpassAddr, - }, - RelayServerAddress: relaySrvAddress, - SessionId: sessionID, - }, + Body: body, }, nil } - -// Credential is an instance of a GrpcClient's Credential -type Credential struct { - UFrag string - Pwd string -} diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index 5368b57a2..b245b2296 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -23,6 +23,8 @@ import ( "github.com/netbirdio/netbird/util/wsproxy" ) +const healthCheckTimeout = 5 * time.Second + // ConnStateNotifier is a wrapper interface of the status recorder type ConnStateNotifier interface { MarkSignalDisconnected(error) @@ -165,7 +167,7 @@ func (c *GrpcClient) Receive(ctx context.Context, msgHandler func(msg *proto.Mes // start receiving messages from the Signal stream (from other peers through signal) err = c.receive(stream) if err != nil { - if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled { + if ctx.Err() != nil { log.Debugf("signal connection context has been canceled, this usually indicates shutdown") return nil } @@ -263,7 +265,7 @@ func (c *GrpcClient) IsHealthy() bool { case connectivity.Ready: } - ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second) + ctx, cancel := context.WithTimeout(c.ctx, healthCheckTimeout) defer cancel() _, err := c.realClient.Send(ctx, &proto.EncryptedMessage{ Key: c.key.PublicKey().String(), diff --git a/shared/signal/proto/signalexchange.pb.go b/shared/signal/proto/signalexchange.pb.go index d9c61a846..0c80fb489 100644 --- a/shared/signal/proto/signalexchange.pb.go +++ b/shared/signal/proto/signalexchange.pb.go @@ -229,8 +229,13 @@ type Body struct { // RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to RosenpassConfig *RosenpassConfig `protobuf:"bytes,7,opt,name=rosenpassConfig,proto3" json:"rosenpassConfig,omitempty"` // relayServerAddress is url of the relay server - RelayServerAddress string `protobuf:"bytes,8,opt,name=relayServerAddress,proto3" json:"relayServerAddress,omitempty"` - SessionId []byte `protobuf:"bytes,10,opt,name=sessionId,proto3,oneof" json:"sessionId,omitempty"` + RelayServerAddress *string `protobuf:"bytes,8,opt,name=relayServerAddress,proto3,oneof" json:"relayServerAddress,omitempty"` + SessionId []byte `protobuf:"bytes,10,opt,name=sessionId,proto3,oneof" json:"sessionId,omitempty"` + // relayServerIP is the IP the sender is connected to on its relay server, + // encoded as 4 bytes (IPv4) or 16 bytes (IPv6). Receivers may use it as a + // fallback dial target when DNS resolution of relayServerAddress fails. + // SNI/TLS verification still uses relayServerAddress. + RelayServerIP []byte `protobuf:"bytes,11,opt,name=relayServerIP,proto3,oneof" json:"relayServerIP,omitempty"` } func (x *Body) Reset() { @@ -315,8 +320,8 @@ func (x *Body) GetRosenpassConfig() *RosenpassConfig { } func (x *Body) GetRelayServerAddress() string { - if x != nil { - return x.RelayServerAddress + if x != nil && x.RelayServerAddress != nil { + return *x.RelayServerAddress } return "" } @@ -328,6 +333,13 @@ func (x *Body) GetSessionId() []byte { return nil } +func (x *Body) GetRelayServerIP() []byte { + if x != nil { + return x.RelayServerIP + } + return nil +} + // Mode indicates a connection mode type Mode struct { state protoimpl.MessageState @@ -451,7 +463,7 @@ var file_signalexchange_proto_rawDesc = []byte{ 0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52, - 0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xe4, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d, + 0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xc3, 0x04, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a, @@ -471,40 +483,46 @@ var file_signalexchange_proto_rawDesc = []byte{ 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x33, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x08, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, - 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x21, 0x0a, 0x09, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, - 0x6e, 0x49, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x00, 0x52, 0x09, 0x73, 0x65, 0x73, - 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x22, 0x43, 0x0a, 0x04, 0x54, 0x79, 0x70, + 0x28, 0x09, 0x48, 0x00, 0x52, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x88, 0x01, 0x01, 0x12, 0x21, 0x0a, 0x09, 0x73, + 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x01, + 0x52, 0x09, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x12, 0x29, + 0x0a, 0x0d, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x49, 0x50, 0x18, + 0x0b, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x02, 0x52, 0x0d, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x49, 0x50, 0x88, 0x01, 0x01, 0x22, 0x43, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09, 0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53, 0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41, 0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, - 0x04, 0x12, 0x0b, 0x0a, 0x07, 0x47, 0x4f, 0x5f, 0x49, 0x44, 0x4c, 0x45, 0x10, 0x05, 0x42, 0x0c, - 0x0a, 0x0a, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x22, 0x2e, 0x0a, 0x04, - 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x88, 0x01, - 0x01, 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d, 0x0a, 0x0f, - 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, - 0x28, 0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, - 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, - 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, - 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, - 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, 0x0a, 0x0e, - 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, - 0x0a, 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, - 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, - 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, - 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, - 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, - 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, - 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, - 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, - 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, - 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x04, 0x12, 0x0b, 0x0a, 0x07, 0x47, 0x4f, 0x5f, 0x49, 0x44, 0x4c, 0x45, 0x10, 0x05, 0x42, 0x15, + 0x0a, 0x13, 0x5f, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, + 0x64, 0x72, 0x65, 0x73, 0x73, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, + 0x6e, 0x49, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x49, 0x50, 0x4a, 0x04, 0x08, 0x09, 0x10, 0x0a, 0x22, 0x2e, 0x0a, 0x04, 0x4d, + 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x88, 0x01, 0x01, + 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d, 0x0a, 0x0f, 0x52, + 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x28, + 0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, + 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, + 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, + 0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, 0x0a, 0x0e, 0x53, + 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a, + 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, + 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, + 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, 0x43, + 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, 0x73, + 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, + 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, + 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/shared/signal/proto/signalexchange.proto b/shared/signal/proto/signalexchange.proto index 0a33ad78b..96a4001e3 100644 --- a/shared/signal/proto/signalexchange.proto +++ b/shared/signal/proto/signalexchange.proto @@ -63,9 +63,17 @@ message Body { RosenpassConfig rosenpassConfig = 7; // relayServerAddress is url of the relay server - string relayServerAddress = 8; + optional string relayServerAddress = 8; + + reserved 9; optional bytes sessionId = 10; + + // relayServerIP is the IP the sender is connected to on its relay server, + // encoded as 4 bytes (IPv4) or 16 bytes (IPv6). Receivers may use it as a + // fallback dial target when DNS resolution of relayServerAddress fails. + // SNI/TLS verification still uses relayServerAddress. + optional bytes relayServerIP = 11; } // Mode indicates a connection mode diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index 523beb32b..4855e1aed 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -10,15 +10,13 @@ import ( "context" "fmt" "net" - "sync" "time" "github.com/google/gopacket" "github.com/google/gopacket/layers" - "github.com/google/gopacket/routing" - "github.com/libp2p/go-netroute" "github.com/mdlayher/socket" log "github.com/sirupsen/logrus" + "github.com/vishvananda/netlink" "golang.org/x/sync/errgroup" "golang.org/x/sys/unix" @@ -37,8 +35,6 @@ type SharedSocket struct { conn6 *socket.Conn port int mtu uint16 - routerMux sync.RWMutex - router routing.Router packetDemux chan rcvdPacket cancel context.CancelFunc } @@ -82,11 +78,6 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error } }() - rawSock.router, err = netroute.New() - if err != nil { - return nil, fmt.Errorf("failed to create raw socket router: %w", err) - } - rawSock.conn4, err = socket.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp4", nil) if err != nil { return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err) @@ -127,31 +118,26 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error go rawSock.read(rawSock.conn6.Recvfrom) } - go rawSock.updateRouter() - return rawSock, nil } -// updateRouter updates the listener routing table client -// this is needed to avoid outdated information across different client networks -func (s *SharedSocket) updateRouter() { - ticker := time.NewTicker(15 * time.Second) - defer ticker.Stop() - for { - select { - case <-s.ctx.Done(): - return - case <-ticker.C: - router, err := netroute.New() - if err != nil { - log.Errorf("Failed to create and update packet router for stunListener: %s", err) - continue - } - s.routerMux.Lock() - s.router = router - s.routerMux.Unlock() +// resolveSrc returns the source IP the kernel will pick for a packet sent to +// dst by these raw sockets, mirroring the fwmark the kernel will see on send. +func (s *SharedSocket) resolveSrc(dst net.IP) (net.IP, error) { + opts := &netlink.RouteGetOptions{} + if nbnet.AdvancedRouting() { + opts.Mark = nbnet.ControlPlaneMark + } + routes, err := netlink.RouteGetWithOptions(dst, opts) + if err != nil { + return nil, fmt.Errorf("route get %s: %w", dst, err) + } + for _, r := range routes { + if r.Src != nil { + return r.Src, nil } } + return nil, fmt.Errorf("no source IP for %s", dst) } // LocalAddr returns the local address, preferring IPv4 for backward compatibility. @@ -310,15 +296,15 @@ func (s *SharedSocket) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) { DstPort: layers.UDPPort(rUDPAddr.Port), } - s.routerMux.RLock() - defer s.routerMux.RUnlock() - - _, _, src, err := s.router.Route(rUDPAddr.IP) + src, err := s.resolveSrc(rUDPAddr.IP) if err != nil { - return 0, fmt.Errorf("got an error while checking route, err: %w", err) + return 0, fmt.Errorf("resolve source for %s: %w", rUDPAddr.IP, err) } rSockAddr, conn, nwLayer := s.getWriterObjects(src, rUDPAddr.IP) + if conn == nil { + return 0, fmt.Errorf("no raw socket for %s", rUDPAddr.IP) + } if err := udp.SetNetworkLayerForChecksum(nwLayer); err != nil { return -1, fmt.Errorf("failed to set network layer for checksum: %w", err) diff --git a/upload-server/server/s3_test.go b/upload-server/server/s3_test.go index 7ab1bb379..a72356409 100644 --- a/upload-server/server/s3_test.go +++ b/upload-server/server/s3_test.go @@ -3,6 +3,7 @@ package server import ( "context" "encoding/json" + "net" "net/http" "net/http/httptest" "runtime" @@ -52,7 +53,7 @@ func Test_S3HandlerGetUploadURL(t *testing.T) { hostIP, err := c.Host(ctx) require.NoError(t, err) - awsEndpoint := "http://" + hostIP + ":" + mappedPort.Port() + awsEndpoint := "http://" + net.JoinHostPort(hostIP, mappedPort.Port()) t.Setenv("AWS_REGION", awsRegion) t.Setenv("AWS_ENDPOINT_URL", awsEndpoint) diff --git a/util/capture/afpacket_linux.go b/util/capture/afpacket_linux.go new file mode 100644 index 000000000..bf59e806a --- /dev/null +++ b/util/capture/afpacket_linux.go @@ -0,0 +1,199 @@ +package capture + +import ( + "encoding/binary" + "errors" + "fmt" + "net" + "sync" + "sync/atomic" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" +) + +// htons converts a uint16 from host to network (big-endian) byte order. +func htons(v uint16) uint16 { + var buf [2]byte + binary.BigEndian.PutUint16(buf[:], v) + return binary.NativeEndian.Uint16(buf[:]) +} + +// AFPacketCapture reads raw packets from a network interface using an +// AF_PACKET socket. This is the kernel-mode fallback when FilteredDevice is +// not available (kernel WireGuard). Linux only. +// +// It implements device.PacketCapture so it can be set on a Session, but it +// drives its own read loop rather than being called from FilteredDevice. +// Call Start to begin and Stop to end. +type AFPacketCapture struct { + ifaceName string + sess *Session + fd int + mu sync.Mutex + stopped chan struct{} + started atomic.Bool + closed atomic.Bool +} + +// NewAFPacketCapture creates a capture bound to the given interface. +// The session receives packets via Offer. +func NewAFPacketCapture(ifaceName string, sess *Session) *AFPacketCapture { + return &AFPacketCapture{ + ifaceName: ifaceName, + sess: sess, + fd: -1, + stopped: make(chan struct{}), + } +} + +// Start opens the AF_PACKET socket and begins reading packets. +// Packets are fed to the session via Offer. Returns immediately; +// the read loop runs in a goroutine. +func (c *AFPacketCapture) Start() error { + if c.sess == nil { + return errors.New("nil capture session") + } + if !c.started.CompareAndSwap(false, true) { + return errors.New("capture already started") + } + if c.closed.Load() { + c.started.Store(false) + return errors.New("cannot restart stopped capture") + } + + iface, err := net.InterfaceByName(c.ifaceName) + if err != nil { + c.started.Store(false) + return fmt.Errorf("interface %s: %w", c.ifaceName, err) + } + + fd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_DGRAM|unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC, int(htons(unix.ETH_P_ALL))) + if err != nil { + c.started.Store(false) + return fmt.Errorf("create AF_PACKET socket: %w", err) + } + + addr := &unix.SockaddrLinklayer{ + Protocol: htons(unix.ETH_P_ALL), + Ifindex: iface.Index, + } + if err := unix.Bind(fd, addr); err != nil { + unix.Close(fd) + c.started.Store(false) + return fmt.Errorf("bind to %s: %w", c.ifaceName, err) + } + + c.mu.Lock() + c.fd = fd + c.mu.Unlock() + + go c.readLoop(fd) + return nil +} + +// Stop closes the socket and waits for the read loop to exit. Idempotent. +func (c *AFPacketCapture) Stop() { + if !c.closed.CompareAndSwap(false, true) { + if c.started.Load() { + <-c.stopped + } + return + } + + c.mu.Lock() + fd := c.fd + c.fd = -1 + c.mu.Unlock() + + if fd >= 0 { + unix.Close(fd) + } + + if c.started.Load() { + <-c.stopped + } +} + +func (c *AFPacketCapture) readLoop(fd int) { + defer close(c.stopped) + + buf := make([]byte, 65536) + pollFds := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLIN}} + + for { + if c.closed.Load() { + return + } + + ok, err := c.pollOnce(pollFds) + if err != nil { + return + } + if !ok { + continue + } + + c.recvAndOffer(fd, buf) + } +} + +// pollOnce waits for data on the fd. Returns true if data is ready, false for timeout/retry. +// Returns an error to signal the loop should exit. +func (c *AFPacketCapture) pollOnce(pollFds []unix.PollFd) (bool, error) { + n, err := unix.Poll(pollFds, 200) + if err != nil { + if errors.Is(err, unix.EINTR) { + return false, nil + } + if c.closed.Load() { + return false, errors.New("closed") + } + log.Debugf("af_packet poll: %v", err) + return false, err + } + if n == 0 { + return false, nil + } + if pollFds[0].Revents&(unix.POLLERR|unix.POLLHUP|unix.POLLNVAL) != 0 { + return false, errors.New("fd error") + } + return true, nil +} + +func (c *AFPacketCapture) recvAndOffer(fd int, buf []byte) { + nr, from, err := unix.Recvfrom(fd, buf, 0) + if err != nil { + if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EINTR) { + return + } + if !c.closed.Load() { + log.Debugf("af_packet recvfrom: %v", err) + } + return + } + if nr < 1 { + return + } + + ver := buf[0] >> 4 + if ver != 4 && ver != 6 { + return + } + + // The kernel sets Pkttype on AF_PACKET sockets: + // PACKET_HOST(0) = addressed to us (inbound) + // PACKET_OUTGOING(4) = sent by us (outbound) + outbound := false + if sa, ok := from.(*unix.SockaddrLinklayer); ok { + outbound = sa.Pkttype == unix.PACKET_OUTGOING + } + c.sess.Offer(buf[:nr], outbound) +} + +// Offer satisfies device.PacketCapture but is unused: the AFPacketCapture +// drives its own read loop. This exists only so the type signature is +// compatible if someone tries to set it as a PacketCapture. +func (c *AFPacketCapture) Offer([]byte, bool) { + // unused: AFPacketCapture drives its own read loop +} diff --git a/util/capture/afpacket_stub.go b/util/capture/afpacket_stub.go new file mode 100644 index 000000000..bde368e88 --- /dev/null +++ b/util/capture/afpacket_stub.go @@ -0,0 +1,26 @@ +//go:build !linux + +package capture + +import "errors" + +// AFPacketCapture is not available on this platform. +type AFPacketCapture struct{} + +// NewAFPacketCapture returns nil on non-Linux platforms. +func NewAFPacketCapture(string, *Session) *AFPacketCapture { return nil } + +// Start returns an error on non-Linux platforms. +func (c *AFPacketCapture) Start() error { + return errors.New("AF_PACKET capture is only supported on Linux") +} + +// Stop is a no-op on non-Linux platforms. +func (c *AFPacketCapture) Stop() { + // no-op on non-Linux platforms +} + +// Offer is a no-op on non-Linux platforms. +func (c *AFPacketCapture) Offer([]byte, bool) { + // no-op on non-Linux platforms +} diff --git a/util/capture/capture.go b/util/capture/capture.go new file mode 100644 index 000000000..0d92a4311 --- /dev/null +++ b/util/capture/capture.go @@ -0,0 +1,59 @@ +// Package capture provides userspace packet capture in pcap format. +// +// It taps decrypted WireGuard packets flowing through the FilteredDevice and +// writes them as pcap (readable by tcpdump, tshark, Wireshark) or as +// human-readable one-line-per-packet text. +package capture + +import "io" + +// Direction indicates whether a packet is entering or leaving the host. +type Direction uint8 + +const ( + // Inbound is a packet arriving from the network (FilteredDevice.Write path). + Inbound Direction = iota + // Outbound is a packet leaving the host (FilteredDevice.Read path). + Outbound +) + +// String returns "IN" or "OUT". +func (d Direction) String() string { + if d == Outbound { + return "OUT" + } + return "IN" +} + +const ( + protoICMP = 1 + protoTCP = 6 + protoUDP = 17 + protoICMPv6 = 58 +) + +// Options configures a capture session. +type Options struct { + // Output receives pcap-formatted data. Nil disables pcap output. + Output io.Writer + // TextOutput receives human-readable packet summaries. Nil disables text output. + TextOutput io.Writer + // Matcher selects which packets to capture. Nil captures all. + // Use ParseFilter("host 10.0.0.1 and tcp") or &Filter{...}. + Matcher Matcher + // Verbose adds seq/ack, TTL, window, total length to text output. + Verbose bool + // ASCII dumps transport payload as printable ASCII after each packet line. + ASCII bool + // SnapLen is the maximum bytes captured per packet. 0 means 65535. + SnapLen uint32 + // BufSize is the internal channel buffer size. 0 means 256. + BufSize int +} + +// Stats reports capture session counters. +type Stats struct { + Packets int64 + Bytes int64 + Dropped int64 +} diff --git a/util/capture/filter.go b/util/capture/filter.go new file mode 100644 index 000000000..d463450b8 --- /dev/null +++ b/util/capture/filter.go @@ -0,0 +1,528 @@ +package capture + +import ( + "encoding/binary" + "fmt" + "net/netip" + "strconv" + "strings" +) + +// Matcher tests whether a raw packet should be captured. +type Matcher interface { + Match(data []byte) bool +} + +// Filter selects packets by flat AND'd criteria. Useful for structured APIs +// (query params, proto fields). Implements Matcher. +type Filter struct { + SrcIP netip.Addr + DstIP netip.Addr + Host netip.Addr + SrcPort uint16 + DstPort uint16 + Port uint16 + Proto uint8 +} + +// IsEmpty returns true if the filter has no criteria set. +func (f *Filter) IsEmpty() bool { + return !f.SrcIP.IsValid() && !f.DstIP.IsValid() && !f.Host.IsValid() && + f.SrcPort == 0 && f.DstPort == 0 && f.Port == 0 && f.Proto == 0 +} + +// Match implements Matcher. All non-zero fields must match (AND). +func (f *Filter) Match(data []byte) bool { + if f.IsEmpty() { + return true + } + info, ok := parsePacketInfo(data) + if !ok { + return false + } + if f.Host.IsValid() && info.srcIP != f.Host && info.dstIP != f.Host { + return false + } + if f.SrcIP.IsValid() && info.srcIP != f.SrcIP { + return false + } + if f.DstIP.IsValid() && info.dstIP != f.DstIP { + return false + } + if f.Proto != 0 && info.proto != f.Proto { + return false + } + if f.Port != 0 && info.srcPort != f.Port && info.dstPort != f.Port { + return false + } + if f.SrcPort != 0 && info.srcPort != f.SrcPort { + return false + } + if f.DstPort != 0 && info.dstPort != f.DstPort { + return false + } + return true +} + +// exprNode evaluates a filter condition against pre-parsed packet info. +type exprNode func(info *packetInfo) bool + +// exprMatcher wraps an expression tree. Parses the packet once, then walks the tree. +type exprMatcher struct { + root exprNode +} + +func (m *exprMatcher) Match(data []byte) bool { + info, ok := parsePacketInfo(data) + if !ok { + return false + } + return m.root(&info) +} + +func nodeAnd(a, b exprNode) exprNode { + return func(info *packetInfo) bool { return a(info) && b(info) } +} + +func nodeOr(a, b exprNode) exprNode { + return func(info *packetInfo) bool { return a(info) || b(info) } +} + +func nodeNot(n exprNode) exprNode { + return func(info *packetInfo) bool { return !n(info) } +} + +func nodeHost(addr netip.Addr) exprNode { + return func(info *packetInfo) bool { return info.srcIP == addr || info.dstIP == addr } +} + +func nodeSrcHost(addr netip.Addr) exprNode { + return func(info *packetInfo) bool { return info.srcIP == addr } +} + +func nodeDstHost(addr netip.Addr) exprNode { + return func(info *packetInfo) bool { return info.dstIP == addr } +} + +func nodePort(port uint16) exprNode { + return func(info *packetInfo) bool { return info.srcPort == port || info.dstPort == port } +} + +func nodeSrcPort(port uint16) exprNode { + return func(info *packetInfo) bool { return info.srcPort == port } +} + +func nodeDstPort(port uint16) exprNode { + return func(info *packetInfo) bool { return info.dstPort == port } +} + +func nodeProto(proto uint8) exprNode { + return func(info *packetInfo) bool { return info.proto == proto } +} + +func nodeFamily(family uint8) exprNode { + return func(info *packetInfo) bool { return info.family == family } +} + +func nodeNet(prefix netip.Prefix) exprNode { + return func(info *packetInfo) bool { return prefix.Contains(info.srcIP) || prefix.Contains(info.dstIP) } +} + +func nodeSrcNet(prefix netip.Prefix) exprNode { + return func(info *packetInfo) bool { return prefix.Contains(info.srcIP) } +} + +func nodeDstNet(prefix netip.Prefix) exprNode { + return func(info *packetInfo) bool { return prefix.Contains(info.dstIP) } +} + +// packetInfo holds parsed header fields for filtering and display. +type packetInfo struct { + family uint8 + srcIP netip.Addr + dstIP netip.Addr + proto uint8 + srcPort uint16 + dstPort uint16 + hdrLen int +} + +func parsePacketInfo(data []byte) (packetInfo, bool) { + if len(data) < 1 { + return packetInfo{}, false + } + switch data[0] >> 4 { + case 4: + return parseIPv4Info(data) + case 6: + return parseIPv6Info(data) + default: + return packetInfo{}, false + } +} + +func parseIPv4Info(data []byte) (packetInfo, bool) { + if len(data) < 20 { + return packetInfo{}, false + } + ihl := int(data[0]&0x0f) * 4 + if ihl < 20 || len(data) < ihl { + return packetInfo{}, false + } + info := packetInfo{ + family: 4, + srcIP: netip.AddrFrom4([4]byte{data[12], data[13], data[14], data[15]}), + dstIP: netip.AddrFrom4([4]byte{data[16], data[17], data[18], data[19]}), + proto: data[9], + hdrLen: ihl, + } + if (info.proto == protoTCP || info.proto == protoUDP) && len(data) >= ihl+4 { + info.srcPort = binary.BigEndian.Uint16(data[ihl:]) + info.dstPort = binary.BigEndian.Uint16(data[ihl+2:]) + } + return info, true +} + +// parseIPv6Info parses the fixed IPv6 header. It reads the Next Header field +// directly, so packets with extension headers (hop-by-hop, routing, fragment, +// etc.) will report the extension type as the protocol rather than the final +// transport protocol. This is acceptable for a debug capture tool. +func parseIPv6Info(data []byte) (packetInfo, bool) { + if len(data) < 40 { + return packetInfo{}, false + } + var src, dst [16]byte + copy(src[:], data[8:24]) + copy(dst[:], data[24:40]) + info := packetInfo{ + family: 6, + srcIP: netip.AddrFrom16(src), + dstIP: netip.AddrFrom16(dst), + proto: data[6], + hdrLen: 40, + } + if (info.proto == protoTCP || info.proto == protoUDP) && len(data) >= 44 { + info.srcPort = binary.BigEndian.Uint16(data[40:]) + info.dstPort = binary.BigEndian.Uint16(data[42:]) + } + return info, true +} + +// ParseFilter parses a BPF-like filter expression and returns a Matcher. +// Returns nil Matcher for an empty expression (match all). +// +// Grammar (mirrors common tcpdump BPF syntax): +// +// orExpr = andExpr ("or" andExpr)* +// andExpr = unary ("and" unary)* +// unary = "not" unary | "(" orExpr ")" | term +// +// term = "host" IP | "src" target | "dst" target +// | "port" NUM | "net" PREFIX +// | "tcp" | "udp" | "icmp" | "icmp6" +// | "ip" | "ip6" | "proto" NUM +// target = "host" IP | "port" NUM | "net" PREFIX | IP +// +// Examples: +// +// host 10.0.0.1 and tcp port 443 +// not port 22 +// (host 10.0.0.1 or host 10.0.0.2) and tcp +// ip6 and icmp6 +// net 10.0.0.0/24 +// src host 10.0.0.1 or dst port 80 +func ParseFilter(expr string) (Matcher, error) { + tokens := tokenize(expr) + if len(tokens) == 0 { + return nil, nil //nolint:nilnil // nil Matcher means "match all" + } + + p := &parser{tokens: tokens} + node, err := p.parseOr() + if err != nil { + return nil, err + } + if p.pos < len(p.tokens) { + return nil, fmt.Errorf("unexpected token %q at position %d", p.tokens[p.pos], p.pos) + } + return &exprMatcher{root: node}, nil +} + +func tokenize(expr string) []string { + expr = strings.TrimSpace(expr) + if expr == "" { + return nil + } + // Split on whitespace but keep parens as separate tokens. + var tokens []string + for _, field := range strings.Fields(expr) { + tokens = append(tokens, splitParens(field)...) + } + return tokens +} + +// splitParens splits "(foo)" into "(", "foo", ")". +func splitParens(s string) []string { + var out []string + for strings.HasPrefix(s, "(") { + out = append(out, "(") + s = s[1:] + } + var trail []string + for strings.HasSuffix(s, ")") { + trail = append(trail, ")") + s = s[:len(s)-1] + } + if s != "" { + out = append(out, s) + } + out = append(out, trail...) + return out +} + +type parser struct { + tokens []string + pos int +} + +func (p *parser) peek() string { + if p.pos >= len(p.tokens) { + return "" + } + return strings.ToLower(p.tokens[p.pos]) +} + +func (p *parser) next() string { + tok := p.peek() + if tok != "" { + p.pos++ + } + return tok +} + +func (p *parser) expect(tok string) error { + got := p.next() + if got != tok { + return fmt.Errorf("expected %q, got %q", tok, got) + } + return nil +} + +func (p *parser) parseOr() (exprNode, error) { + left, err := p.parseAnd() + if err != nil { + return nil, err + } + for p.peek() == "or" { + p.next() + right, err := p.parseAnd() + if err != nil { + return nil, err + } + left = nodeOr(left, right) + } + return left, nil +} + +func (p *parser) parseAnd() (exprNode, error) { + left, err := p.parseUnary() + if err != nil { + return nil, err + } + for { + tok := p.peek() + if tok == "and" { + p.next() + right, err := p.parseUnary() + if err != nil { + return nil, err + } + left = nodeAnd(left, right) + continue + } + // Implicit AND: two atoms without "and" between them. + // Only if the next token starts an atom (not "or", ")", or EOF). + if tok != "" && tok != "or" && tok != ")" { + right, err := p.parseUnary() + if err != nil { + return nil, err + } + left = nodeAnd(left, right) + continue + } + break + } + return left, nil +} + +func (p *parser) parseUnary() (exprNode, error) { + switch p.peek() { + case "not": + p.next() + inner, err := p.parseUnary() + if err != nil { + return nil, err + } + return nodeNot(inner), nil + case "(": + p.next() + inner, err := p.parseOr() + if err != nil { + return nil, err + } + if err := p.expect(")"); err != nil { + return nil, fmt.Errorf("unclosed parenthesis") + } + return inner, nil + default: + return p.parseAtom() + } +} + +func (p *parser) parseAtom() (exprNode, error) { + tok := p.next() + if tok == "" { + return nil, fmt.Errorf("unexpected end of expression") + } + + switch tok { + case "host": + addr, err := p.parseAddr() + if err != nil { + return nil, fmt.Errorf("host: %w", err) + } + return nodeHost(addr), nil + + case "port": + port, err := p.parsePort() + if err != nil { + return nil, fmt.Errorf("port: %w", err) + } + return nodePort(port), nil + + case "net": + prefix, err := p.parsePrefix() + if err != nil { + return nil, fmt.Errorf("net: %w", err) + } + return nodeNet(prefix), nil + + case "src": + return p.parseDirTarget(true) + + case "dst": + return p.parseDirTarget(false) + + case "tcp": + return nodeProto(protoTCP), nil + case "udp": + return nodeProto(protoUDP), nil + case "icmp": + return nodeProto(protoICMP), nil + case "icmp6": + return nodeProto(protoICMPv6), nil + case "ip": + return nodeFamily(4), nil + case "ip6": + return nodeFamily(6), nil + + case "proto": + raw := p.next() + if raw == "" { + return nil, fmt.Errorf("proto: missing number") + } + n, err := strconv.Atoi(raw) + if err != nil || n < 0 || n > 255 { + return nil, fmt.Errorf("proto: invalid number %q", raw) + } + return nodeProto(uint8(n)), nil + + default: + return nil, fmt.Errorf("unknown filter keyword %q", tok) + } +} + +func (p *parser) parseDirTarget(isSrc bool) (exprNode, error) { + tok := p.peek() + switch tok { + case "host": + p.next() + addr, err := p.parseAddr() + if err != nil { + return nil, err + } + if isSrc { + return nodeSrcHost(addr), nil + } + return nodeDstHost(addr), nil + + case "port": + p.next() + port, err := p.parsePort() + if err != nil { + return nil, err + } + if isSrc { + return nodeSrcPort(port), nil + } + return nodeDstPort(port), nil + + case "net": + p.next() + prefix, err := p.parsePrefix() + if err != nil { + return nil, err + } + if isSrc { + return nodeSrcNet(prefix), nil + } + return nodeDstNet(prefix), nil + + default: + // Try as bare IP: "src 10.0.0.1" + addr, err := p.parseAddr() + if err != nil { + return nil, fmt.Errorf("expected host, port, net, or IP after src/dst, got %q", tok) + } + if isSrc { + return nodeSrcHost(addr), nil + } + return nodeDstHost(addr), nil + } +} + +func (p *parser) parseAddr() (netip.Addr, error) { + raw := p.next() + if raw == "" { + return netip.Addr{}, fmt.Errorf("missing IP address") + } + addr, err := netip.ParseAddr(raw) + if err != nil { + return netip.Addr{}, fmt.Errorf("invalid IP %q", raw) + } + return addr.Unmap(), nil +} + +func (p *parser) parsePort() (uint16, error) { + raw := p.next() + if raw == "" { + return 0, fmt.Errorf("missing port number") + } + n, err := strconv.Atoi(raw) + if err != nil || n < 1 || n > 65535 { + return 0, fmt.Errorf("invalid port %q", raw) + } + return uint16(n), nil +} + +func (p *parser) parsePrefix() (netip.Prefix, error) { + raw := p.next() + if raw == "" { + return netip.Prefix{}, fmt.Errorf("missing network prefix") + } + prefix, err := netip.ParsePrefix(raw) + if err != nil { + return netip.Prefix{}, fmt.Errorf("invalid prefix %q", raw) + } + return prefix, nil +} diff --git a/util/capture/filter_test.go b/util/capture/filter_test.go new file mode 100644 index 000000000..d5fd17566 --- /dev/null +++ b/util/capture/filter_test.go @@ -0,0 +1,263 @@ +package capture + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// buildIPv4Packet creates a minimal IPv4+TCP/UDP packet for filter testing. +func buildIPv4Packet(t *testing.T, srcIP, dstIP netip.Addr, proto uint8, srcPort, dstPort uint16) []byte { + t.Helper() + + hdrLen := 20 + pkt := make([]byte, hdrLen+20) + pkt[0] = 0x45 + pkt[9] = proto + + src := srcIP.As4() + dst := dstIP.As4() + copy(pkt[12:16], src[:]) + copy(pkt[16:20], dst[:]) + + pkt[20] = byte(srcPort >> 8) + pkt[21] = byte(srcPort) + pkt[22] = byte(dstPort >> 8) + pkt[23] = byte(dstPort) + + return pkt +} + +// buildIPv6Packet creates a minimal IPv6+TCP/UDP packet for filter testing. +func buildIPv6Packet(t *testing.T, srcIP, dstIP netip.Addr, proto uint8, srcPort, dstPort uint16) []byte { + t.Helper() + + pkt := make([]byte, 44) // 40 header + 4 ports + pkt[0] = 0x60 // version 6 + pkt[6] = proto // next header + + src := srcIP.As16() + dst := dstIP.As16() + copy(pkt[8:24], src[:]) + copy(pkt[24:40], dst[:]) + + pkt[40] = byte(srcPort >> 8) + pkt[41] = byte(srcPort) + pkt[42] = byte(dstPort >> 8) + pkt[43] = byte(dstPort) + + return pkt +} + +// ---- Filter struct tests ---- + +func TestFilter_Empty(t *testing.T) { + f := Filter{} + assert.True(t, f.IsEmpty()) + assert.True(t, f.Match(buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 443))) +} + +func TestFilter_Host(t *testing.T) { + f := Filter{Host: netip.MustParseAddr("10.0.0.1")} + assert.True(t, f.Match(buildIPv4Packet(t, netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2"), protoTCP, 1234, 80))) + assert.True(t, f.Match(buildIPv4Packet(t, netip.MustParseAddr("10.0.0.2"), netip.MustParseAddr("10.0.0.1"), protoTCP, 1234, 80))) + assert.False(t, f.Match(buildIPv4Packet(t, netip.MustParseAddr("10.0.0.2"), netip.MustParseAddr("10.0.0.3"), protoTCP, 1234, 80))) +} + +func TestFilter_InvalidPacket(t *testing.T) { + f := Filter{Host: netip.MustParseAddr("10.0.0.1")} + assert.False(t, f.Match(nil)) + assert.False(t, f.Match([]byte{})) + assert.False(t, f.Match([]byte{0x00})) +} + +func TestParsePacketInfo_IPv4(t *testing.T) { + pkt := buildIPv4Packet(t, netip.MustParseAddr("192.168.1.1"), netip.MustParseAddr("10.0.0.1"), protoTCP, 54321, 80) + info, ok := parsePacketInfo(pkt) + require.True(t, ok) + assert.Equal(t, uint8(4), info.family) + assert.Equal(t, netip.MustParseAddr("192.168.1.1"), info.srcIP) + assert.Equal(t, netip.MustParseAddr("10.0.0.1"), info.dstIP) + assert.Equal(t, uint8(protoTCP), info.proto) + assert.Equal(t, uint16(54321), info.srcPort) + assert.Equal(t, uint16(80), info.dstPort) +} + +func TestParsePacketInfo_IPv6(t *testing.T) { + pkt := buildIPv6Packet(t, netip.MustParseAddr("fd00::1"), netip.MustParseAddr("fd00::2"), protoUDP, 1234, 53) + info, ok := parsePacketInfo(pkt) + require.True(t, ok) + assert.Equal(t, uint8(6), info.family) + assert.Equal(t, netip.MustParseAddr("fd00::1"), info.srcIP) + assert.Equal(t, netip.MustParseAddr("fd00::2"), info.dstIP) + assert.Equal(t, uint8(protoUDP), info.proto) + assert.Equal(t, uint16(1234), info.srcPort) + assert.Equal(t, uint16(53), info.dstPort) +} + +// ---- ParseFilter expression tests ---- + +func matchV4(t *testing.T, m Matcher, srcIP, dstIP string, proto uint8, srcPort, dstPort uint16) bool { + t.Helper() + return m.Match(buildIPv4Packet(t, netip.MustParseAddr(srcIP), netip.MustParseAddr(dstIP), proto, srcPort, dstPort)) +} + +func matchV6(t *testing.T, m Matcher, srcIP, dstIP string, proto uint8, srcPort, dstPort uint16) bool { + t.Helper() + return m.Match(buildIPv6Packet(t, netip.MustParseAddr(srcIP), netip.MustParseAddr(dstIP), proto, srcPort, dstPort)) +} + +func TestParseFilter_Empty(t *testing.T) { + m, err := ParseFilter("") + require.NoError(t, err) + assert.Nil(t, m, "empty expression should return nil matcher") +} + +func TestParseFilter_Atoms(t *testing.T) { + tests := []struct { + expr string + match bool + }{ + {"tcp", true}, + {"udp", false}, + {"host 10.0.0.1", true}, + {"host 10.0.0.99", false}, + {"port 443", true}, + {"port 80", false}, + {"src host 10.0.0.1", true}, + {"dst host 10.0.0.2", true}, + {"dst host 10.0.0.1", false}, + {"src port 12345", true}, + {"dst port 443", true}, + {"dst port 80", false}, + {"proto 6", true}, + {"proto 17", false}, + } + + pkt := buildIPv4Packet(t, netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2"), protoTCP, 12345, 443) + + for _, tt := range tests { + t.Run(tt.expr, func(t *testing.T) { + m, err := ParseFilter(tt.expr) + require.NoError(t, err) + assert.Equal(t, tt.match, m.Match(pkt)) + }) + } +} + +func TestParseFilter_And(t *testing.T) { + m, err := ParseFilter("host 10.0.0.1 and tcp port 443") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 55555, 443)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoUDP, 55555, 443), "wrong proto") + assert.False(t, matchV4(t, m, "10.0.0.3", "10.0.0.2", protoTCP, 55555, 443), "wrong host") + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 55555, 80), "wrong port") +} + +func TestParseFilter_ImplicitAnd(t *testing.T) { + // "tcp port 443" = implicit AND between tcp and port 443 + m, err := ParseFilter("tcp port 443") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 443)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoUDP, 1, 443)) +} + +func TestParseFilter_Or(t *testing.T) { + m, err := ParseFilter("port 80 or port 443") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 80)) + assert.True(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 443)) + assert.False(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 8080)) +} + +func TestParseFilter_Not(t *testing.T) { + m, err := ParseFilter("not port 22") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 443)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 22)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 22, 80)) +} + +func TestParseFilter_Parens(t *testing.T) { + m, err := ParseFilter("(port 80 or port 443) and tcp") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 443)) + assert.False(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoUDP, 1, 443), "wrong proto") + assert.False(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 8080), "wrong port") +} + +func TestParseFilter_Family(t *testing.T) { + mV4, err := ParseFilter("ip") + require.NoError(t, err) + assert.True(t, matchV4(t, mV4, "10.0.0.1", "10.0.0.2", protoTCP, 1, 80)) + assert.False(t, matchV6(t, mV4, "fd00::1", "fd00::2", protoTCP, 1, 80)) + + mV6, err := ParseFilter("ip6") + require.NoError(t, err) + assert.False(t, matchV4(t, mV6, "10.0.0.1", "10.0.0.2", protoTCP, 1, 80)) + assert.True(t, matchV6(t, mV6, "fd00::1", "fd00::2", protoTCP, 1, 80)) +} + +func TestParseFilter_Net(t *testing.T) { + m, err := ParseFilter("net 10.0.0.0/24") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "192.168.1.1", protoTCP, 1, 80), "src in net") + assert.True(t, matchV4(t, m, "192.168.1.1", "10.0.0.200", protoTCP, 1, 80), "dst in net") + assert.False(t, matchV4(t, m, "10.0.1.1", "192.168.1.1", protoTCP, 1, 80), "neither in net") +} + +func TestParseFilter_SrcDstNet(t *testing.T) { + m, err := ParseFilter("src net 10.0.0.0/8 and dst net 192.168.0.0/16") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.1.2.3", "192.168.1.1", protoTCP, 1, 80)) + assert.False(t, matchV4(t, m, "192.168.1.1", "10.1.2.3", protoTCP, 1, 80), "reversed") +} + +func TestParseFilter_Complex(t *testing.T) { + // Real-world: capture HTTP(S) traffic to/from specific host, excluding SSH + m, err := ParseFilter("host 10.0.0.1 and (port 80 or port 443) and not port 22") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 55555, 443)) + assert.True(t, matchV4(t, m, "10.0.0.2", "10.0.0.1", protoTCP, 55555, 80)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 22, 443), "port 22 excluded") + assert.False(t, matchV4(t, m, "10.0.0.3", "10.0.0.2", protoTCP, 55555, 443), "wrong host") +} + +func TestParseFilter_IPv6Combined(t *testing.T) { + m, err := ParseFilter("ip6 and icmp6") + require.NoError(t, err) + assert.True(t, matchV6(t, m, "fd00::1", "fd00::2", protoICMPv6, 0, 0)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoICMP, 0, 0), "wrong family") + assert.False(t, matchV6(t, m, "fd00::1", "fd00::2", protoTCP, 1, 80), "wrong proto") +} + +func TestParseFilter_CaseInsensitive(t *testing.T) { + m, err := ParseFilter("HOST 10.0.0.1 AND TCP PORT 443") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 443)) +} + +func TestParseFilter_Errors(t *testing.T) { + bad := []string{ + "badkeyword", + "host", + "port abc", + "port 99999", + "net invalid", + "(", + "(port 80", + "not", + "src", + } + for _, expr := range bad { + t.Run(expr, func(t *testing.T) { + _, err := ParseFilter(expr) + assert.Error(t, err, "should fail for %q", expr) + }) + } +} diff --git a/util/capture/pcap.go b/util/capture/pcap.go new file mode 100644 index 000000000..0a9057045 --- /dev/null +++ b/util/capture/pcap.go @@ -0,0 +1,85 @@ +package capture + +import ( + "encoding/binary" + "io" + "time" +) + +const ( + pcapMagic = 0xa1b2c3d4 + pcapVersionMaj = 2 + pcapVersionMin = 4 + // linkTypeRaw is LINKTYPE_RAW: raw IPv4/IPv6 packets without link-layer header. + linkTypeRaw = 101 + defaultSnapLen = 65535 +) + +// PcapWriter writes packets in pcap format to an underlying writer. +// The global header is written lazily on the first WritePacket call so that +// the writer can be used with unbuffered io.Pipes without deadlocking. +// It is not safe for concurrent use; callers must serialize access. +type PcapWriter struct { + w io.Writer + snapLen uint32 + headerWritten bool +} + +// NewPcapWriter creates a pcap writer. The global header is deferred until the +// first WritePacket call. +func NewPcapWriter(w io.Writer, snapLen uint32) *PcapWriter { + if snapLen == 0 { + snapLen = defaultSnapLen + } + return &PcapWriter{w: w, snapLen: snapLen} +} + +// writeGlobalHeader writes the 24-byte pcap file header. +func (pw *PcapWriter) writeGlobalHeader() error { + var hdr [24]byte + binary.LittleEndian.PutUint32(hdr[0:4], pcapMagic) + binary.LittleEndian.PutUint16(hdr[4:6], pcapVersionMaj) + binary.LittleEndian.PutUint16(hdr[6:8], pcapVersionMin) + binary.LittleEndian.PutUint32(hdr[16:20], pw.snapLen) + binary.LittleEndian.PutUint32(hdr[20:24], linkTypeRaw) + + _, err := pw.w.Write(hdr[:]) + return err +} + +// WriteHeader writes the pcap global header. Safe to call multiple times. +func (pw *PcapWriter) WriteHeader() error { + if pw.headerWritten { + return nil + } + if err := pw.writeGlobalHeader(); err != nil { + return err + } + pw.headerWritten = true + return nil +} + +// WritePacket writes a single packet record, preceded by the global header +// on the first call. +func (pw *PcapWriter) WritePacket(ts time.Time, data []byte) error { + if err := pw.WriteHeader(); err != nil { + return err + } + + origLen := uint32(len(data)) + if origLen > pw.snapLen { + data = data[:pw.snapLen] + } + + var hdr [16]byte + binary.LittleEndian.PutUint32(hdr[0:4], uint32(ts.Unix())) + binary.LittleEndian.PutUint32(hdr[4:8], uint32(ts.Nanosecond()/1000)) + binary.LittleEndian.PutUint32(hdr[8:12], uint32(len(data))) + binary.LittleEndian.PutUint32(hdr[12:16], origLen) + + if _, err := pw.w.Write(hdr[:]); err != nil { + return err + } + _, err := pw.w.Write(data) + return err +} diff --git a/util/capture/pcap_test.go b/util/capture/pcap_test.go new file mode 100644 index 000000000..c3d21ef4a --- /dev/null +++ b/util/capture/pcap_test.go @@ -0,0 +1,68 @@ +package capture + +import ( + "bytes" + "encoding/binary" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPcapWriter_GlobalHeader(t *testing.T) { + var buf bytes.Buffer + pw := NewPcapWriter(&buf, 0) + + // Header is lazy, so write a dummy packet to trigger it. + err := pw.WritePacket(time.Now(), []byte{0x45, 0, 0, 20, 0, 0, 0, 0, 64, 1, 0, 0, 10, 0, 0, 1, 10, 0, 0, 2}) + require.NoError(t, err) + + data := buf.Bytes() + require.GreaterOrEqual(t, len(data), 24, "should contain global header") + + assert.Equal(t, uint32(pcapMagic), binary.LittleEndian.Uint32(data[0:4]), "magic number") + assert.Equal(t, uint16(pcapVersionMaj), binary.LittleEndian.Uint16(data[4:6]), "version major") + assert.Equal(t, uint16(pcapVersionMin), binary.LittleEndian.Uint16(data[6:8]), "version minor") + assert.Equal(t, uint32(defaultSnapLen), binary.LittleEndian.Uint32(data[16:20]), "snap length") + assert.Equal(t, uint32(linkTypeRaw), binary.LittleEndian.Uint32(data[20:24]), "link type") +} + +func TestPcapWriter_WritePacket(t *testing.T) { + var buf bytes.Buffer + pw := NewPcapWriter(&buf, 100) + + ts := time.Date(2025, 6, 15, 12, 30, 45, 123456000, time.UTC) + payload := make([]byte, 50) + for i := range payload { + payload[i] = byte(i) + } + + err := pw.WritePacket(ts, payload) + require.NoError(t, err) + + data := buf.Bytes()[24:] // skip global header + require.Len(t, data, 16+50, "packet header + payload") + + assert.Equal(t, uint32(ts.Unix()), binary.LittleEndian.Uint32(data[0:4]), "timestamp seconds") + assert.Equal(t, uint32(123456), binary.LittleEndian.Uint32(data[4:8]), "timestamp microseconds") + assert.Equal(t, uint32(50), binary.LittleEndian.Uint32(data[8:12]), "included length") + assert.Equal(t, uint32(50), binary.LittleEndian.Uint32(data[12:16]), "original length") + assert.Equal(t, payload, data[16:], "packet data") +} + +func TestPcapWriter_SnapLen(t *testing.T) { + var buf bytes.Buffer + pw := NewPcapWriter(&buf, 10) + + ts := time.Now() + payload := make([]byte, 50) + + err := pw.WritePacket(ts, payload) + require.NoError(t, err) + + data := buf.Bytes()[24:] + assert.Equal(t, uint32(10), binary.LittleEndian.Uint32(data[8:12]), "included length should be truncated") + assert.Equal(t, uint32(50), binary.LittleEndian.Uint32(data[12:16]), "original length preserved") + assert.Len(t, data[16:], 10, "only snap_len bytes written") +} diff --git a/util/capture/session.go b/util/capture/session.go new file mode 100644 index 000000000..09806e10c --- /dev/null +++ b/util/capture/session.go @@ -0,0 +1,213 @@ +package capture + +import ( + "fmt" + "sync" + "sync/atomic" + "time" +) + +const defaultBufSize = 256 + +type packetEntry struct { + ts time.Time + data []byte + dir Direction +} + +// Session manages an active packet capture. Packets are offered via Offer, +// buffered in a channel, and written to configured sinks by a background +// goroutine. This keeps the hot path (FilteredDevice.Read/Write) non-blocking. +// +// The caller must call Stop when done to flush remaining packets and release +// resources. +type Session struct { + pcapW *PcapWriter + textW *TextWriter + matcher Matcher + snapLen uint32 + flushFn func() + + ch chan packetEntry + done chan struct{} + stopped chan struct{} + + closeOnce sync.Once + closed atomic.Bool + packets atomic.Int64 + bytes atomic.Int64 + dropped atomic.Int64 + started time.Time +} + +// NewSession creates and starts a capture session. At least one of +// Options.Output or Options.TextOutput must be non-nil. +func NewSession(opts Options) (*Session, error) { + if opts.Output == nil && opts.TextOutput == nil { + return nil, fmt.Errorf("at least one output sink required") + } + + snapLen := opts.SnapLen + if snapLen == 0 { + snapLen = defaultSnapLen + } + + bufSize := opts.BufSize + if bufSize <= 0 { + bufSize = defaultBufSize + } + + s := &Session{ + matcher: opts.Matcher, + snapLen: snapLen, + ch: make(chan packetEntry, bufSize), + done: make(chan struct{}), + stopped: make(chan struct{}), + started: time.Now(), + } + + if opts.Output != nil { + s.pcapW = NewPcapWriter(opts.Output, snapLen) + } + if opts.TextOutput != nil { + s.textW = NewTextWriter(opts.TextOutput, opts.Verbose, opts.ASCII) + } + + s.flushFn = buildFlushFn(opts.Output, opts.TextOutput) + + go s.run() + return s, nil +} + +// Offer submits a packet for capture. It returns immediately and never blocks +// the caller. If the internal buffer is full the packet is dropped silently. +// +// outbound should be true for packets leaving the host (FilteredDevice.Read +// path) and false for packets arriving (FilteredDevice.Write path). +// +// Offer satisfies the device.PacketCapture interface. +func (s *Session) Offer(data []byte, outbound bool) { + if s.closed.Load() { + return + } + + if s.matcher != nil && !s.matcher.Match(data) { + return + } + + captureLen := len(data) + if s.snapLen > 0 && uint32(captureLen) > s.snapLen { + captureLen = int(s.snapLen) + } + + copied := make([]byte, captureLen) + copy(copied, data) + + dir := Inbound + if outbound { + dir = Outbound + } + + select { + case s.ch <- packetEntry{ts: time.Now(), data: copied, dir: dir}: + s.packets.Add(1) + s.bytes.Add(int64(len(data))) + default: + s.dropped.Add(1) + } +} + +// Stop signals the session to stop accepting packets, drains any buffered +// packets to the sinks, and waits for the writer goroutine to exit. +// It is safe to call multiple times. +func (s *Session) Stop() { + s.closeOnce.Do(func() { + s.closed.Store(true) + close(s.done) + }) + <-s.stopped +} + +// Done returns a channel that is closed when the session's writer goroutine +// has fully exited and all buffered packets have been flushed. +func (s *Session) Done() <-chan struct{} { + return s.stopped +} + +// Stats returns current capture counters. +func (s *Session) Stats() Stats { + return Stats{ + Packets: s.packets.Load(), + Bytes: s.bytes.Load(), + Dropped: s.dropped.Load(), + } +} + +func (s *Session) run() { + defer close(s.stopped) + + for { + select { + case pkt := <-s.ch: + s.write(pkt) + case <-s.done: + s.drain() + return + } + } +} + +func (s *Session) drain() { + for { + select { + case pkt := <-s.ch: + s.write(pkt) + default: + return + } + } +} + +func (s *Session) write(pkt packetEntry) { + if s.pcapW != nil { + // Best-effort: if the writer fails (broken pipe etc.), discard silently. + _ = s.pcapW.WritePacket(pkt.ts, pkt.data) + } + if s.textW != nil { + _ = s.textW.WritePacket(pkt.ts, pkt.data, pkt.dir) + } + s.flushFn() +} + +// buildFlushFn returns a function that flushes all writers that support it. +// This covers http.Flusher and similar streaming writers. +func buildFlushFn(writers ...any) func() { + type flusher interface { + Flush() + } + + var fns []func() + for _, w := range writers { + if w == nil { + continue + } + if f, ok := w.(flusher); ok { + fns = append(fns, f.Flush) + } + } + + switch len(fns) { + case 0: + return func() { + // no writers to flush + } + case 1: + return fns[0] + default: + return func() { + for _, fn := range fns { + fn() + } + } + } +} diff --git a/util/capture/session_test.go b/util/capture/session_test.go new file mode 100644 index 000000000..ab27686c6 --- /dev/null +++ b/util/capture/session_test.go @@ -0,0 +1,144 @@ +package capture + +import ( + "bytes" + "encoding/binary" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSession_PcapOutput(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{ + Output: &buf, + BufSize: 16, + }) + require.NoError(t, err) + + pkt := buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 443) + + sess.Offer(pkt, true) + sess.Stop() + + data := buf.Bytes() + require.Greater(t, len(data), 24, "should have global header + at least one packet") + + // Verify global header + assert.Equal(t, uint32(pcapMagic), binary.LittleEndian.Uint32(data[0:4])) + assert.Equal(t, uint32(linkTypeRaw), binary.LittleEndian.Uint32(data[20:24])) + + // Verify packet record + pktData := data[24:] + inclLen := binary.LittleEndian.Uint32(pktData[8:12]) + assert.Equal(t, uint32(len(pkt)), inclLen) + + stats := sess.Stats() + assert.Equal(t, int64(1), stats.Packets) + assert.Equal(t, int64(len(pkt)), stats.Bytes) + assert.Equal(t, int64(0), stats.Dropped) +} + +func TestSession_TextOutput(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{ + TextOutput: &buf, + BufSize: 16, + }) + require.NoError(t, err) + + pkt := buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 443) + + sess.Offer(pkt, false) + sess.Stop() + + output := buf.String() + assert.Contains(t, output, "TCP") + assert.Contains(t, output, "10.0.0.1") + assert.Contains(t, output, "10.0.0.2") + assert.Contains(t, output, "443") + assert.Contains(t, output, "[IN TCP]") +} + +func TestSession_Filter(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{ + Output: &buf, + Matcher: &Filter{Port: 443}, + }) + require.NoError(t, err) + + pktMatch := buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 443) + pktNoMatch := buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 80) + + sess.Offer(pktMatch, true) + sess.Offer(pktNoMatch, true) + sess.Stop() + + stats := sess.Stats() + assert.Equal(t, int64(1), stats.Packets, "only matching packet should be captured") +} + +func TestSession_StopIdempotent(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{Output: &buf}) + require.NoError(t, err) + + sess.Stop() + sess.Stop() // should not panic or deadlock +} + +func TestSession_OfferAfterStop(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{Output: &buf}) + require.NoError(t, err) + sess.Stop() + + pkt := buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 443) + sess.Offer(pkt, true) // should not panic + + assert.Equal(t, int64(0), sess.Stats().Packets) +} + +func TestSession_Done(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{Output: &buf}) + require.NoError(t, err) + + select { + case <-sess.Done(): + t.Fatal("Done should not be closed before Stop") + default: + } + + sess.Stop() + + select { + case <-sess.Done(): + case <-time.After(time.Second): + t.Fatal("Done should be closed after Stop") + } +} + +func TestSession_RequiresOutput(t *testing.T) { + _, err := NewSession(Options{}) + assert.Error(t, err) +} diff --git a/util/capture/text.go b/util/capture/text.go new file mode 100644 index 000000000..fbb26654e --- /dev/null +++ b/util/capture/text.go @@ -0,0 +1,640 @@ +package capture + +import ( + "encoding/binary" + "fmt" + "io" + "net" + "net/netip" + "strconv" + "strings" + "time" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" +) + +// TextWriter writes human-readable one-line-per-packet summaries. +// It is not safe for concurrent use; callers must serialize access. +type TextWriter struct { + w io.Writer + verbose bool + ascii bool + flows map[dirKey]uint32 +} + +type dirKey struct { + src netip.AddrPort + dst netip.AddrPort +} + +// NewTextWriter creates a text formatter that writes to w. +func NewTextWriter(w io.Writer, verbose, ascii bool) *TextWriter { + return &TextWriter{ + w: w, + verbose: verbose, + ascii: ascii, + flows: make(map[dirKey]uint32), + } +} + +// tag formats the fixed-width "[DIR PROTO]" prefix with right-aligned protocol. +func tag(dir Direction, proto string) string { + return fmt.Sprintf("[%-3s %4s]", dir, proto) +} + +// WritePacket formats and writes a single packet line. +func (tw *TextWriter) WritePacket(ts time.Time, data []byte, dir Direction) error { + ts = ts.Local() + info, ok := parsePacketInfo(data) + if !ok { + _, err := fmt.Fprintf(tw.w, "%s [%-3s ?] ??? len=%d\n", + ts.Format("15:04:05.000000"), dir, len(data)) + return err + } + + timeStr := ts.Format("15:04:05.000000") + + var err error + switch info.proto { + case protoTCP: + err = tw.writeTCP(timeStr, dir, &info, data) + case protoUDP: + err = tw.writeUDP(timeStr, dir, &info, data) + case protoICMP: + err = tw.writeICMPv4(timeStr, dir, &info, data) + case protoICMPv6: + err = tw.writeICMPv6(timeStr, dir, &info, data) + default: + var verbose string + if tw.verbose { + verbose = tw.verboseIP(data, info.family) + } + _, err = fmt.Fprintf(tw.w, "%s %s %s > %s length %d%s\n", + timeStr, tag(dir, fmt.Sprintf("P%d", info.proto)), + info.srcIP, info.dstIP, len(data)-info.hdrLen, verbose) + } + return err +} + +func (tw *TextWriter) writeTCP(timeStr string, dir Direction, info *packetInfo, data []byte) error { + tcp := &layers.TCP{} + if err := tcp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil { + return tw.writeFallback(timeStr, dir, "TCP", info, data) + } + + flags := tcpFlagsStr(tcp) + plen := len(tcp.Payload) + + // Protocol annotation + var annotation string + if plen > 0 { + annotation = annotatePayload(tcp.Payload) + } + + if !tw.verbose { + _, err := fmt.Fprintf(tw.w, "%s %s %s > %s [%s] length %d%s\n", + timeStr, tag(dir, "TCP"), + net.JoinHostPort(info.srcIP.String(), strconv.Itoa(int(info.srcPort))), net.JoinHostPort(info.dstIP.String(), strconv.Itoa(int(info.dstPort))), + flags, plen, annotation) + if err != nil { + return err + } + if tw.ascii && plen > 0 { + return tw.writeASCII(tcp.Payload) + } + return nil + } + + relSeq, relAck := tw.relativeSeqAck(info, tcp.Seq, tcp.Ack) + + var seqStr string + if plen > 0 { + seqStr = fmt.Sprintf(", seq %d:%d", relSeq, relSeq+uint32(plen)) + } else { + seqStr = fmt.Sprintf(", seq %d", relSeq) + } + + var ackStr string + if tcp.ACK { + ackStr = fmt.Sprintf(", ack %d", relAck) + } + + var opts string + if s := formatTCPOptions(tcp.Options); s != "" { + opts = ", options [" + s + "]" + } + + verbose := tw.verboseIP(data, info.family) + + _, err := fmt.Fprintf(tw.w, "%s %s %s > %s [%s]%s%s, win %d%s, length %d%s%s\n", + timeStr, tag(dir, "TCP"), + net.JoinHostPort(info.srcIP.String(), strconv.Itoa(int(info.srcPort))), net.JoinHostPort(info.dstIP.String(), strconv.Itoa(int(info.dstPort))), + flags, seqStr, ackStr, tcp.Window, opts, plen, annotation, verbose) + if err != nil { + return err + } + if tw.ascii && plen > 0 { + return tw.writeASCII(tcp.Payload) + } + return nil +} + +func (tw *TextWriter) writeUDP(timeStr string, dir Direction, info *packetInfo, data []byte) error { + udp := &layers.UDP{} + if err := udp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil { + return tw.writeFallback(timeStr, dir, "UDP", info, data) + } + + plen := len(udp.Payload) + + // DNS replaces the entire line format + if plen > 0 && isDNSPort(info.srcPort, info.dstPort) { + if s := formatDNSPayload(udp.Payload); s != "" { + var verbose string + if tw.verbose { + verbose = tw.verboseIP(data, info.family) + } + _, err := fmt.Fprintf(tw.w, "%s %s %s > %s %s%s\n", + timeStr, tag(dir, "UDP"), + net.JoinHostPort(info.srcIP.String(), strconv.Itoa(int(info.srcPort))), net.JoinHostPort(info.dstIP.String(), strconv.Itoa(int(info.dstPort))), + s, verbose) + return err + } + } + + var verbose string + if tw.verbose { + verbose = tw.verboseIP(data, info.family) + } + _, err := fmt.Fprintf(tw.w, "%s %s %s > %s length %d%s\n", + timeStr, tag(dir, "UDP"), + net.JoinHostPort(info.srcIP.String(), strconv.Itoa(int(info.srcPort))), net.JoinHostPort(info.dstIP.String(), strconv.Itoa(int(info.dstPort))), + plen, verbose) + if err != nil { + return err + } + if tw.ascii && plen > 0 { + return tw.writeASCII(udp.Payload) + } + return nil +} + +func (tw *TextWriter) writeICMPv4(timeStr string, dir Direction, info *packetInfo, data []byte) error { + icmp := &layers.ICMPv4{} + if err := icmp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil { + return tw.writeFallback(timeStr, dir, "ICMP", info, data) + } + + var detail string + if icmp.TypeCode.Type() == layers.ICMPv4TypeEchoRequest || icmp.TypeCode.Type() == layers.ICMPv4TypeEchoReply { + detail = fmt.Sprintf("%s, id %d, seq %d", icmp.TypeCode.String(), icmp.Id, icmp.Seq) + } else { + detail = icmp.TypeCode.String() + } + + var verbose string + if tw.verbose { + verbose = tw.verboseIP(data, info.family) + } + _, err := fmt.Fprintf(tw.w, "%s %s %s > %s %s, length %d%s\n", + timeStr, tag(dir, "ICMP"), info.srcIP, info.dstIP, detail, len(data)-info.hdrLen, verbose) + return err +} + +func (tw *TextWriter) writeICMPv6(timeStr string, dir Direction, info *packetInfo, data []byte) error { + icmp := &layers.ICMPv6{} + if err := icmp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil { + return tw.writeFallback(timeStr, dir, "ICMP", info, data) + } + + var verbose string + if tw.verbose { + verbose = tw.verboseIP(data, info.family) + } + _, err := fmt.Fprintf(tw.w, "%s %s %s > %s %s, length %d%s\n", + timeStr, tag(dir, "ICMP"), info.srcIP, info.dstIP, icmp.TypeCode.String(), len(data)-info.hdrLen, verbose) + return err +} + +func (tw *TextWriter) writeFallback(timeStr string, dir Direction, proto string, info *packetInfo, data []byte) error { + _, err := fmt.Fprintf(tw.w, "%s %s %s > %s length %d\n", + timeStr, tag(dir, proto), + net.JoinHostPort(info.srcIP.String(), strconv.Itoa(int(info.srcPort))), net.JoinHostPort(info.dstIP.String(), strconv.Itoa(int(info.dstPort))), + len(data)-info.hdrLen) + return err +} + +func (tw *TextWriter) verboseIP(data []byte, family uint8) string { + return fmt.Sprintf(", ttl %d, id %d, iplen %d", + ipTTL(data, family), ipID(data, family), len(data)) +} + +// relativeSeqAck returns seq/ack relative to the first seen value per direction. +func (tw *TextWriter) relativeSeqAck(info *packetInfo, seq, ack uint32) (relSeq, relAck uint32) { + fwd := dirKey{ + src: netip.AddrPortFrom(info.srcIP, info.srcPort), + dst: netip.AddrPortFrom(info.dstIP, info.dstPort), + } + rev := dirKey{ + src: netip.AddrPortFrom(info.dstIP, info.dstPort), + dst: netip.AddrPortFrom(info.srcIP, info.srcPort), + } + + if isn, ok := tw.flows[fwd]; ok { + relSeq = seq - isn + } else { + tw.flows[fwd] = seq + } + + if isn, ok := tw.flows[rev]; ok { + relAck = ack - isn + } else { + relAck = ack + } + + return relSeq, relAck +} + +// writeASCII prints payload bytes as printable ASCII. +func (tw *TextWriter) writeASCII(payload []byte) error { + if len(payload) == 0 { + return nil + } + buf := make([]byte, len(payload)) + for i, b := range payload { + switch { + case b >= 0x20 && b < 0x7f: + buf[i] = b + case b == '\n' || b == '\r' || b == '\t': + buf[i] = b + default: + buf[i] = '.' + } + } + _, err := fmt.Fprintf(tw.w, "%s\n", buf) + return err +} + +// --- TCP helpers --- + +func ipTTL(data []byte, family uint8) uint8 { + if family == 4 && len(data) > 8 { + return data[8] + } + if family == 6 && len(data) > 7 { + return data[7] + } + return 0 +} + +func ipID(data []byte, family uint8) uint16 { + if family == 4 && len(data) >= 6 { + return binary.BigEndian.Uint16(data[4:6]) + } + return 0 +} + +func tcpFlagsStr(tcp *layers.TCP) string { + var buf [6]byte + n := 0 + if tcp.SYN { + buf[n] = 'S' + n++ + } + if tcp.FIN { + buf[n] = 'F' + n++ + } + if tcp.RST { + buf[n] = 'R' + n++ + } + if tcp.PSH { + buf[n] = 'P' + n++ + } + if tcp.ACK { + buf[n] = '.' + n++ + } + if tcp.URG { + buf[n] = 'U' + n++ + } + if n == 0 { + return "none" + } + return string(buf[:n]) +} + +func formatTCPOptions(opts []layers.TCPOption) string { + var parts []string + for _, opt := range opts { + switch opt.OptionType { + case layers.TCPOptionKindEndList: + return strings.Join(parts, ",") + case layers.TCPOptionKindNop: + parts = append(parts, "nop") + case layers.TCPOptionKindMSS: + if len(opt.OptionData) == 2 { + parts = append(parts, fmt.Sprintf("mss %d", binary.BigEndian.Uint16(opt.OptionData))) + } + case layers.TCPOptionKindWindowScale: + if len(opt.OptionData) == 1 { + parts = append(parts, fmt.Sprintf("wscale %d", opt.OptionData[0])) + } + case layers.TCPOptionKindSACKPermitted: + parts = append(parts, "sackOK") + case layers.TCPOptionKindSACK: + blocks := len(opt.OptionData) / 8 + parts = append(parts, fmt.Sprintf("sack %d", blocks)) + case layers.TCPOptionKindTimestamps: + if len(opt.OptionData) == 8 { + tsval := binary.BigEndian.Uint32(opt.OptionData[0:4]) + tsecr := binary.BigEndian.Uint32(opt.OptionData[4:8]) + parts = append(parts, fmt.Sprintf("TS val %d ecr %d", tsval, tsecr)) + } + } + } + return strings.Join(parts, ",") +} + +// --- Protocol annotation --- + +// annotatePayload returns a protocol annotation string for known application protocols. +func annotatePayload(payload []byte) string { + if len(payload) < 4 { + return "" + } + + s := string(payload) + + // SSH banner: "SSH-2.0-OpenSSH_9.6\r\n" + if strings.HasPrefix(s, "SSH-") { + if end := strings.IndexByte(s, '\r'); end > 0 && end < 256 { + return ": " + s[:end] + } + } + + // TLS records + if ann := annotateTLS(payload); ann != "" { + return ": " + ann + } + + // HTTP request or response + for _, method := range [...]string{"GET ", "POST ", "PUT ", "DELETE ", "HEAD ", "PATCH ", "OPTIONS ", "CONNECT "} { + if strings.HasPrefix(s, method) { + if end := strings.IndexByte(s, '\r'); end > 0 && end < 200 { + return ": " + s[:end] + } + } + } + if strings.HasPrefix(s, "HTTP/") { + if end := strings.IndexByte(s, '\r'); end > 0 && end < 200 { + return ": " + s[:end] + } + } + + return "" +} + +// annotateTLS returns a description for TLS handshake and alert records. +func annotateTLS(data []byte) string { + if len(data) < 6 { + return "" + } + + switch data[0] { + case 0x16: + return annotateTLSHandshake(data) + case 0x15: + return annotateTLSAlert(data) + } + return "" +} + +func annotateTLSHandshake(data []byte) string { + if len(data) < 10 { + return "" + } + switch data[5] { + case 0x01: + if sni := extractSNI(data); sni != "" { + return "TLS ClientHello SNI=" + sni + } + return "TLS ClientHello" + case 0x02: + return "TLS ServerHello" + } + return "" +} + +func annotateTLSAlert(data []byte) string { + if len(data) < 7 { + return "" + } + severity := "warning" + if data[5] == 2 { + severity = "fatal" + } + return fmt.Sprintf("TLS Alert %s %s", severity, tlsAlertDesc(data[6])) +} + +func tlsAlertDesc(code byte) string { + switch code { + case 0: + return "close_notify" + case 10: + return "unexpected_message" + case 40: + return "handshake_failure" + case 42: + return "bad_certificate" + case 43: + return "unsupported_certificate" + case 44: + return "certificate_revoked" + case 45: + return "certificate_expired" + case 48: + return "unknown_ca" + case 49: + return "access_denied" + case 50: + return "decode_error" + case 70: + return "protocol_version" + case 80: + return "internal_error" + case 86: + return "inappropriate_fallback" + case 90: + return "user_canceled" + case 112: + return "unrecognized_name" + default: + return fmt.Sprintf("alert(%d)", code) + } +} + +// extractSNI parses a TLS ClientHello and returns the SNI server name. +func extractSNI(data []byte) string { + if len(data) < 6 || data[0] != 0x16 { + return "" + } + recordLen := int(binary.BigEndian.Uint16(data[3:5])) + handshake := data[5:] + if len(handshake) > recordLen { + handshake = handshake[:recordLen] + } + + if len(handshake) < 4 || handshake[0] != 0x01 { + return "" + } + hsLen := int(handshake[1])<<16 | int(handshake[2])<<8 | int(handshake[3]) + body := handshake[4:] + if len(body) > hsLen { + body = body[:hsLen] + } + + extPos := clientHelloExtensionsOffset(body) + if extPos < 0 { + return "" + } + return findSNIExtension(body, extPos) +} + +// clientHelloExtensionsOffset returns the byte offset where extensions begin +// within the ClientHello body, or -1 if the body is too short. +func clientHelloExtensionsOffset(body []byte) int { + if len(body) < 38 { + return -1 + } + pos := 34 + + if pos >= len(body) { + return -1 + } + pos += 1 + int(body[pos]) // session ID + + if pos+2 > len(body) { + return -1 + } + pos += 2 + int(binary.BigEndian.Uint16(body[pos:pos+2])) // cipher suites + + if pos >= len(body) { + return -1 + } + pos += 1 + int(body[pos]) // compression methods + + if pos+2 > len(body) { + return -1 + } + return pos +} + +func findSNIExtension(body []byte, pos int) string { + extLen := int(binary.BigEndian.Uint16(body[pos : pos+2])) + pos += 2 + extEnd := pos + extLen + if extEnd > len(body) { + extEnd = len(body) + } + + for pos+4 <= extEnd { + extType := binary.BigEndian.Uint16(body[pos : pos+2]) + eLen := int(binary.BigEndian.Uint16(body[pos+2 : pos+4])) + pos += 4 + if pos+eLen > extEnd { + break + } + if extType == 0 && eLen >= 5 { + nameLen := int(binary.BigEndian.Uint16(body[pos+3 : pos+5])) + if pos+5+nameLen <= extEnd { + return string(body[pos+5 : pos+5+nameLen]) + } + } + pos += eLen + } + return "" +} + +func isDNSPort(src, dst uint16) bool { + return src == 53 || dst == 53 || src == 5353 || dst == 5353 +} + +// formatDNSPayload parses DNS and returns a tcpdump-style summary. +func formatDNSPayload(payload []byte) string { + d := &layers.DNS{} + if err := d.DecodeFromBytes(payload, gopacket.NilDecodeFeedback); err != nil { + return "" + } + + rd := "" + if d.RD { + rd = "+" + } + + if !d.QR { + return formatDNSQuery(d, rd, len(payload)) + } + return formatDNSResponse(d, rd, len(payload)) +} + +func formatDNSQuery(d *layers.DNS, rd string, plen int) string { + if len(d.Questions) == 0 { + return fmt.Sprintf("%04x%s (%d)", d.ID, rd, plen) + } + q := d.Questions[0] + return fmt.Sprintf("%04x%s %s? %s. (%d)", d.ID, rd, q.Type, q.Name, plen) +} + +func formatDNSResponse(d *layers.DNS, rd string, plen int) string { + anCount := d.ANCount + nsCount := d.NSCount + arCount := d.ARCount + + if d.ResponseCode != layers.DNSResponseCodeNoErr { + return fmt.Sprintf("%04x %d/%d/%d %s (%d)", d.ID, anCount, nsCount, arCount, d.ResponseCode, plen) + } + + if anCount > 0 && len(d.Answers) > 0 { + rr := d.Answers[0] + if rdata := shortRData(&rr); rdata != "" { + return fmt.Sprintf("%04x %d/%d/%d %s %s (%d)", d.ID, anCount, nsCount, arCount, rr.Type, rdata, plen) + } + } + + return fmt.Sprintf("%04x %d/%d/%d (%d)", d.ID, anCount, nsCount, arCount, plen) +} + +func shortRData(rr *layers.DNSResourceRecord) string { + switch rr.Type { + case layers.DNSTypeA, layers.DNSTypeAAAA: + if rr.IP != nil { + return rr.IP.String() + } + case layers.DNSTypeCNAME: + if len(rr.CNAME) > 0 { + return string(rr.CNAME) + "." + } + case layers.DNSTypePTR: + if len(rr.PTR) > 0 { + return string(rr.PTR) + "." + } + case layers.DNSTypeNS: + if len(rr.NS) > 0 { + return string(rr.NS) + "." + } + case layers.DNSTypeMX: + return fmt.Sprintf("%d %s.", rr.MX.Preference, rr.MX.Name) + case layers.DNSTypeTXT: + if len(rr.TXTs) > 0 { + return fmt.Sprintf("%q", string(rr.TXTs[0])) + } + case layers.DNSTypeSRV: + return fmt.Sprintf("%d %d %d %s.", rr.SRV.Priority, rr.SRV.Weight, rr.SRV.Port, rr.SRV.Name) + } + return "" +}