Compare commits

..

2 Commits

Author SHA1 Message Date
mlsmaycon
bfeb60fbb5 Create a system proxy change after receiving a network map
This is experimental and needs more test.

the purpose of this change is to validate that a TLS connection stuck using old routes because of keepalive settings on the remote webserver are reset once netbird receives a network map
2026-02-01 10:23:25 +01:00
mlsmaycon
ea41cf2d2c Create a system proxy change after receiving a network map
This is experimental and needs more test.

the purpose of this change is to validate that a TLS connection stuck using old routes because of keepalive settings on the remote webserver are reset once netbird receives a network map
2026-02-01 10:21:51 +01:00
986 changed files with 14760 additions and 134664 deletions

View File

@@ -1,6 +0,0 @@
.env
.env.*
*.pem
*.key
*.crt
*.p12

View File

@@ -1,130 +0,0 @@
body:
- type: markdown
attributes:
value: |
## Ideas & Feature Requests
Use this category for feature requests, enhancements, integrations, and product ideas.
NetBird uses community traction in discussions — upvotes, replies, affected users, and use-case detail — as an input when deciding what should become a maintainer-curated issue or roadmap item. A clear problem statement is more useful than a solution-only request.
Please search first and add your use case to an existing discussion when one already exists.
- type: checkboxes
id: preflight
attributes:
label: Before posting
options:
- label: I searched existing discussions and issues for similar requests.
required: true
- label: I checked the documentation to confirm this is not already supported.
required: true
- label: This is a product idea or enhancement request, not a support question.
required: true
- label: I removed or anonymized sensitive details from examples and screenshots.
required: true
- type: dropdown
id: area
attributes:
label: Product area
description: Select every area this request touches.
multiple: true
options:
- Client / Agent
- CLI
- Desktop UI
- Mobile app
- Dashboard / Admin UI
- Management service / API
- Signal service
- Relay
- DNS
- Routes / Exit nodes
- NetBird SSH
- Access control policies
- Posture checks
- Identity provider / SSO
- Self-hosting / Deployment
- Kubernetes / Operator
- Terraform / Automation
- Documentation
- Other / not sure
validations:
required: true
- type: textarea
id: problem
attributes:
label: Problem or use case
description: What are you trying to accomplish, and what is difficult or impossible today?
placeholder: |
As a ...
I want to ...
Because ...
validations:
required: true
- type: textarea
id: proposal
attributes:
label: Proposed solution
description: Describe the behavior, workflow, API, UI, or integration you would like to see.
validations:
required: true
- type: textarea
id: alternatives
attributes:
label: Alternatives or workarounds considered
description: What have you tried today? Why is the current workaround not enough?
- type: textarea
id: impact
attributes:
label: Community impact and priority
description: Help us understand who benefits and how urgent this is.
placeholder: |
- Number of users/teams/peers affected:
- Deployment type: Cloud / self-hosted / both
- Frequency: daily / weekly / occasional
- Blocking production adoption? yes/no
- Related comments, discussions, or customer requests:
validations:
required: true
- type: textarea
id: examples
attributes:
label: Examples from other tools or products
description: If another tool solves this well, link or describe the behavior.
- type: textarea
id: security
attributes:
label: Security, privacy, and compatibility considerations
description: Note any access-control, audit, data retention, network, platform, or backward-compatibility concerns.
- type: textarea
id: implementation
attributes:
label: Implementation ideas
description: Optional. If you are familiar with the codebase or API, share possible implementation notes.
- type: dropdown
id: contribution
attributes:
label: Are you willing to help?
options:
- Yes, I can submit a PR if the approach is accepted.
- Yes, I can test or validate a proposed implementation.
- Yes, I can provide more use-case details.
- Not at this time.
validations:
required: true
- type: textarea
id: additional-context
attributes:
label: Additional context
description: Add screenshots, diagrams, links, or anything else that helps explain the request.

View File

@@ -1,237 +0,0 @@
body:
- type: markdown
attributes:
value: |
## Issue Triage
Use this category for reproducible bugs and regressions in NetBird.
The more context you include, the faster we can validate and act on your report. If you're not sure whether something is a bug, **Q&A / Support** is a good starting point — we can always move the conversation here once we've confirmed it's a product issue.
Intermittent issues are useful too. Include the trigger, frequency, timing, and any logs or debug evidence you have, and we'll work from there.
Please don't include secrets, tokens, private keys, internal hostnames, or public IPs. Security vulnerabilities should be reported through the repository security policy rather than a public discussion.
- type: checkboxes
id: preflight
attributes:
label: Before posting
options:
- label: I searched existing discussions and issues, including closed ones, and checked the relevant docs.
required: true
- label: I believe this is a product bug rather than a configuration or setup question.
required: true
- label: I can reproduce this issue, or for intermittent issues I've included trigger, frequency, and timing details below.
required: true
- label: I removed or anonymized sensitive data from logs, screenshots, and configuration.
required: true
- type: dropdown
id: area
attributes:
label: Affected area
description: Select every area this report touches.
multiple: true
options:
- Client / Agent
- Reverse Proxy
- CLI
- Desktop UI
- Mobile app
- Peer connectivity
- DNS
- Routes / Exit nodes
- NetBird SSH
- Relay / Signal / NAT traversal
- Login / Authentication / IdP
- Dashboard / Admin UI
- Management service / API
- Access control policies / Posture checks
- Self-hosting / Deployment
- Kubernetes / Operator
- Documentation
- Other / not sure
validations:
required: true
- type: dropdown
id: deployment
attributes:
label: Deployment type
options:
- NetBird Cloud
- Self-hosted - quickstart script
- Self-hosted - advanced/custom deployment
- Local development build
- Not sure / environment I do not fully control
validations:
required: true
- type: dropdown
id: platform
attributes:
label: Operating system or environment
description: Select every environment involved in the reproduction.
multiple: true
options:
- Linux
- macOS
- Windows
- Android
- iOS
- FreeBSD
- OpenWRT
- Docker
- Kubernetes
- Synology
- Browser
- Other / not sure
validations:
required: true
- type: textarea
id: version
attributes:
label: NetBird version and upgrade status
description: Run `netbird version` where applicable. For self-hosted deployments, include management, signal, relay, and dashboard versions if available. If you cannot test on a current/supported version, explain why.
placeholder: |
Example:
- Client: 0.30.2
- Management: 0.30.2
- Signal: 0.30.2
- Relay: 0.30.2
- Dashboard: 0.30.2
- Upgrade status: reproduced on current version / cannot upgrade because ...
validations:
required: true
- type: dropdown
id: regression
attributes:
label: Did this work before?
options:
- Yes, this worked before
- No, this never worked
- Not sure
validations:
required: true
- type: textarea
id: regression-details
attributes:
label: Regression details
description: If this worked before, include the last known working version, first known broken version, and any recent upgrade, configuration, network, or IdP changes.
placeholder: |
- Last known working version:
- First known broken version:
- Recent changes:
- type: textarea
id: summary
attributes:
label: Summary
description: Briefly describe the reproducible bug.
placeholder: What is broken?
validations:
required: true
- type: textarea
id: current-behavior
attributes:
label: Current behavior
description: What happens now? Include exact errors, timeouts, UI messages, or failed commands when possible.
validations:
required: true
- type: textarea
id: expected-behavior
attributes:
label: Expected behavior
description: What did you expect to happen instead?
validations:
required: true
- type: textarea
id: reproduction
attributes:
label: Steps to reproduce
description: Provide the smallest set of steps that reliably reproduces the bug. If the issue is intermittent, include the trigger, frequency, timing, and relevant timestamps.
placeholder: |
1. Configure ...
2. Run ...
3. Observe ...
For intermittent issues:
- Trigger:
- Frequency:
- Timing/timestamps:
validations:
required: true
- type: textarea
id: environment
attributes:
label: Environment and topology
description: Include the relevant topology and software involved in the reproduction. For UI/docs-only reports, write `N/A` if this does not apply. Use `None`, `Unknown`, or `N/A` where appropriate.
placeholder: |
- Peer A:
- Peer B:
- Same LAN or different networks:
- NAT/CGNAT/corporate firewall/mobile network:
- Other VPN software:
- Firewall, DNS, or endpoint security software:
- Routes, DNS, policies, posture checks, or SSH rules involved:
- IdP, reverse proxy, or browser involved:
validations:
required: true
- type: textarea
id: self-hosted-details
attributes:
label: Self-hosted details, if available
description: Optional. If you use self-hosting and have access to these details, include them. If you do not administer the environment, provide what you know and say what you cannot access.
placeholder: |
- Deployment method: quickstart / Docker Compose / Helm / operator / custom
- Management/signal/relay/dashboard versions:
- Reverse proxy:
- IdP/provider:
- STUN/TURN/coturn/relay details:
- Relevant component logs:
- type: textarea
id: logs
attributes:
label: Logs, status output, or debug evidence
description: |
For client, connectivity, DNS, route, relay/signal, or self-hosted reports, logs are essential — please include anonymized output from `netbird status -dA`, or a debug bundle via `netbird debug for 1m -AS -U`. Debug bundles are automatically deleted after 30 days.
For UI, dashboard, or documentation reports, leave the pre-filled `N/A`.
value: "N/A"
render: shell
validations:
required: true
- type: textarea
id: related-reports
attributes:
label: Related issues or discussions
description: Optional. Link similar reports you found while searching, if any.
placeholder: |
- Related issue/discussion:
- Why this may be the same or different:
- type: textarea
id: impact
attributes:
label: Impact
description: Optional. Help us understand priority. How many users, peers, environments, or workflows are affected? Is there a workaround?
placeholder: |
- Affected users/peers:
- Business or production impact:
- Workaround available:
- type: textarea
id: additional-context
attributes:
label: Additional context
description: Add links to related discussions, issues, docs, screenshots, recordings, or anything else that may help validation.

View File

@@ -1,146 +0,0 @@
body:
- type: markdown
attributes:
value: |
## Q&A / Support
Use this category for questions about configuration, setup, self-hosted deployments, troubleshooting, and general NetBird usage.
This is community support and does not provide an SLA. For NetBird Cloud support, use the official support channel linked from the issue creation page. Please do not post secrets, tokens, private keys, internal hostnames, or public IPs unless you intentionally want them public.
If your question turns into a reproducible product defect, DevRel or a maintainer may ask you to open or move the conversation to Issue Triage.
- type: checkboxes
id: preflight
attributes:
label: Before posting
options:
- label: I searched existing discussions and issues for similar questions.
required: true
- label: I reviewed the relevant NetBird documentation or troubleshooting guide.
required: true
- label: I removed or anonymized sensitive data from logs, screenshots, and configuration.
required: true
- type: dropdown
id: topic
attributes:
label: Topic
multiple: true
options:
- Getting started
- Self-hosting
- Client / Agent
- CLI
- Desktop UI
- Mobile app
- Dashboard / Admin UI
- DNS
- Routes / Exit nodes
- NetBird SSH
- Relay
- Access control policies
- Posture checks
- Identity provider / SSO
- API
- Kubernetes / Operator
- Terraform / Automation
- Documentation
- Other / not sure
validations:
required: true
- type: dropdown
id: deployment
attributes:
label: Deployment type
options:
- NetBird Cloud
- Self-hosted - quickstart script
- Self-hosted - advanced/custom deployment
- Local development build
- Not sure
validations:
required: true
- type: dropdown
id: platform
attributes:
label: Operating system or environment
multiple: true
options:
- Linux
- macOS
- Windows
- Android
- iOS
- FreeBSD
- OpenWRT
- Docker
- Kubernetes
- Synology
- Browser
- Other / not sure
validations:
required: true
- type: input
id: version
attributes:
label: NetBird version
description: Run `netbird version` where applicable. For self-hosted deployments, include component versions if relevant.
placeholder: "Example: client 0.30.2, management 0.30.2"
- type: textarea
id: question
attributes:
label: Question
description: What are you trying to understand or accomplish?
placeholder: Describe your question clearly.
validations:
required: true
- type: textarea
id: goal
attributes:
label: Desired outcome
description: What would a successful answer help you do?
placeholder: |
I want to configure ...
I expected ...
I need help deciding ...
- type: textarea
id: attempted
attributes:
label: What have you tried?
description: Include commands, documentation links, configuration attempts, or troubleshooting steps already tried.
placeholder: |
- Read ...
- Ran ...
- Changed ...
- Observed ...
- type: textarea
id: environment
attributes:
label: Relevant environment details
description: Include redacted topology, IdP/provider, reverse proxy, firewall, DNS, route, policy, or self-hosted setup details that may affect the answer.
placeholder: |
- Deployment:
- Components involved:
- Network/topology:
- Related config:
- type: textarea
id: logs
attributes:
label: Logs or output
description: Optional. Include anonymized logs, command output, screenshots, or `netbird status -dA` if relevant.
render: shell
- type: textarea
id: additional-context
attributes:
label: Additional context
description: Add links, diagrams, screenshots, or other details that may help the community answer.

View File

@@ -0,0 +1,71 @@
---
name: Bug/Issue report
about: Create a report to help us improve
title: ''
labels: ['triage-needed']
assignees: ''
---
**Describe the problem**
A clear and concise description of what the problem is.
**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error
**Expected behavior**
A clear and concise description of what you expected to happen.
**Are you using NetBird Cloud?**
Please specify whether you use NetBird Cloud or self-host NetBird's control plane.
**NetBird version**
`netbird version`
**Is any other VPN software installed?**
If yes, which one?
**Debug output**
To help us resolve the problem, please attach the following anonymized status output
netbird status -dA
Create and upload a debug bundle, and share the returned file key:
netbird debug for 1m -AS -U
*Uploaded files are automatically deleted after 30 days.*
Alternatively, create the file only and attach it here manually:
netbird debug for 1m -AS
**Screenshots**
If applicable, add screenshots to help explain your problem.
**Additional context**
Add any other context about the problem here.
**Have you tried these troubleshooting steps?**
- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable)
- [ ] Checked for newer NetBird versions
- [ ] Searched for similar issues on GitHub (including closed ones)
- [ ] Restarted the NetBird client
- [ ] Disabled other VPN software
- [ ] Checked firewall settings

View File

@@ -1,26 +0,0 @@
blank_issues_enabled: false
contact_links:
- name: Start an Issue Triage discussion
url: https://github.com/netbirdio/netbird/discussions/new?category=issue-triage
about: Report a bug, regression, or unexpected behavior so DevRel can validate it before it becomes an issue.
- name: Propose an idea or feature request
url: https://github.com/netbirdio/netbird/discussions/new?category=ideas-feature-requests
about: Share feature requests, enhancements, and integration ideas for community feedback and prioritization.
- name: Ask a Q&A / Support question
url: https://github.com/netbirdio/netbird/discussions/new?category=q-a-support
about: Get help with setup, configuration, self-hosting, troubleshooting, and general usage.
- name: Security vulnerability disclosure
url: https://github.com/netbirdio/netbird/security/policy
about: Please do not report security vulnerabilities in public issues or discussions.
- name: Community Support Forum
url: https://forum.netbird.io/
about: Community support forum.
- name: Cloud Support
url: https://docs.netbird.io/help/report-bug-issues
about: Contact NetBird for Cloud support.
- name: Client / Connection Troubleshooting
url: https://docs.netbird.io/help/troubleshooting-client
about: See the client troubleshooting guide for common connectivity issues.
- name: Self-host Troubleshooting
url: https://docs.netbird.io/selfhosted/troubleshooting
about: See the self-host troubleshooting guide for common deployment issues.

View File

@@ -0,0 +1,20 @@
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ['feature-request']
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.

View File

@@ -1,128 +0,0 @@
name: Validated issue
description: Maintainer/DevRel only. Create an issue after a discussion has been validated or for internally validated work.
title: "[Validated]: "
body:
- type: markdown
attributes:
value: |
## Discussion-first issue policy
Issues are maintainer-curated work items. Community reports and feature requests should start in [Discussions](https://github.com/netbirdio/netbird/discussions) so DevRel can validate, reproduce, and route them before engineering time is committed.
Use this form when:
- A discussion has been validated and should become actionable work.
- A maintainer is opening internally validated work that can bypass the discussion-first flow.
Issues opened without a relevant validated discussion or maintainer context may be closed and redirected to Discussions.
- type: checkboxes
id: validation-checks
attributes:
label: Validation checklist
options:
- label: This issue is linked to a validated discussion, or it is being opened directly by a maintainer.
required: true
- label: The report has enough context for engineering to act on it without re-triaging from scratch.
required: true
- label: Sensitive data, secrets, private keys, internal hostnames, and public IPs have been removed or intentionally disclosed.
required: true
- type: dropdown
id: issue-type
attributes:
label: Issue type
options:
- Bug / Regression
- Feature / Enhancement
- Documentation
- Maintenance / Refactor
- Cross-repository coordination
- Other
validations:
required: true
- type: input
id: source-discussion
attributes:
label: Source discussion
description: Link the GitHub Discussion that was validated. Maintainers bypassing the flow can write "Maintainer-created" and explain why below.
placeholder: https://github.com/netbirdio/netbird/discussions/1234
validations:
required: true
- type: input
id: validation-owner
attributes:
label: Validation owner
description: GitHub handle of the DevRel team member or maintainer who validated this work.
placeholder: "@username"
validations:
required: true
- type: dropdown
id: target-repository
attributes:
label: Target repository
description: Where should the implementation work happen?
options:
- netbirdio/netbird
- netbirdio/dashboard
- netbirdio/kubernetes-operator
- netbirdio/docs
- Multiple repositories
- Unknown / needs routing
validations:
required: true
- type: textarea
id: summary
attributes:
label: Summary
description: Concise description of the validated work.
placeholder: What needs to be fixed, changed, documented, or built?
validations:
required: true
- type: textarea
id: evidence
attributes:
label: Validation evidence
description: For bugs, include reproduction status, affected versions, logs, and environment. For features, include community traction, affected users, and alignment notes.
placeholder: |
- Reproduced by:
- Affected versions / platforms:
- Community signal:
- Related logs or screenshots:
validations:
required: true
- type: textarea
id: scope
attributes:
label: Proposed scope
description: Describe what is in scope and, if helpful, what is explicitly out of scope.
placeholder: |
In scope:
- ...
Out of scope:
- ...
validations:
required: true
- type: textarea
id: acceptance-criteria
attributes:
label: Acceptance criteria
description: What must be true for this issue to be closed?
placeholder: |
- [ ] ...
- [ ] ...
validations:
required: true
- type: textarea
id: additional-context
attributes:
label: Additional context
description: Links to related PRs, docs, issues in other repositories, roadmap items, or implementation notes.

View File

@@ -23,7 +23,7 @@ jobs:
- name: Check for problematic license dependencies - name: Check for problematic license dependencies
run: | run: |
echo "Checking for dependencies on management/, signal/, relay/, and proxy/ packages..." echo "Checking for dependencies on management/, signal/, and relay/ packages..."
echo "" echo ""
# Find all directories except the problematic ones and system dirs # Find all directories except the problematic ones and system dirs
@@ -31,7 +31,7 @@ jobs:
while IFS= read -r dir; do while IFS= read -r dir; do
echo "=== Checking $dir ===" echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files # Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" | grep -v "tools/idp-migrate/" || true) RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
if [ -n "$RESULTS" ]; then if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:" echo "❌ Found problematic dependencies:"
echo "$RESULTS" echo "$RESULTS"
@@ -39,11 +39,11 @@ jobs:
else else
echo "✓ No problematic dependencies found" echo "✓ No problematic dependencies found"
fi fi
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name "combined" -not -name ".git*" | sort) done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
echo "" echo ""
if [ $FOUND_ISSUES -eq 1 ]; then if [ $FOUND_ISSUES -eq 1 ]; then
echo "❌ Found dependencies on management/, signal/, relay/, or proxy/ packages" echo "❌ Found dependencies on management/, signal/, or relay/ packages"
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code" echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
exit 1 exit 1
else else
@@ -88,7 +88,7 @@ jobs:
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath") IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay # Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1) BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER" echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"

View File

@@ -43,5 +43,5 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined) run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)

View File

@@ -46,5 +46,6 @@ jobs:
time go test -timeout 1m -failfast ./client/iface/... time go test -timeout 1m -failfast ./client/iface/...
time go test -timeout 1m -failfast ./route/... time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/... time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./signal/...
time go test -timeout 1m -failfast ./util/... time go test -timeout 1m -failfast ./util/...
time go test -timeout 1m -failfast ./version/... time go test -timeout 1m -failfast ./version/...

View File

@@ -97,16 +97,6 @@ jobs:
working-directory: relay working-directory: relay
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 . run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
- name: Build combined
if: steps.cache.outputs.cache-hit != 'true'
working-directory: combined
run: CGO_ENABLED=1 go build .
- name: Build combined 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: combined
run: CGO_ENABLED=1 GOARCH=386 go build -o combined-386 .
test: test:
name: "Client / Unit" name: "Client / Unit"
needs: [build-cache] needs: [build-cache]
@@ -154,7 +144,7 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined) run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
test_client_on_docker: test_client_on_docker:
name: "Client (Docker) / Unit" name: "Client (Docker) / Unit"
@@ -214,7 +204,7 @@ jobs:
sh -c ' \ sh -c ' \
apk update; apk add --no-cache \ apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \ ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server) go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
' '
test_relay: test_relay:
@@ -271,53 +261,6 @@ jobs:
-exec 'sudo' \ -exec 'sudo' \
-timeout 10m -p 1 ./relay/... ./shared/relay/... -timeout 10m -p 1 ./relay/... ./shared/relay/...
test_proxy:
name: "Proxy / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
- name: Install dependencies
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test -timeout 10m -p 1 ./proxy/...
test_signal: test_signal:
name: "Signal / Unit" name: "Signal / Unit"
needs: [build-cache] needs: [build-cache]
@@ -409,19 +352,12 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Login to Docker hub - name: Login to Docker hub
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v3 uses: docker/login-action@v1
with: with:
username: ${{ secrets.DOCKER_USER }} username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }} password: ${{ secrets.DOCKER_TOKEN }}
- name: docker login for root user
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
env:
DOCKER_USER: ${{ secrets.DOCKER_USER }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
- name: download mysql image - name: download mysql image
if: matrix.store == 'mysql' if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8 run: docker pull mlsmaycon/warmed-mysql:8
@@ -504,18 +440,15 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Login to Docker hub - name: Login to Docker hub
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v3 uses: docker/login-action@v1
with: with:
username: ${{ secrets.DOCKER_USER }} username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }} password: ${{ secrets.DOCKER_TOKEN }}
- name: docker login for root user - name: download mysql image
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref if: matrix.store == 'mysql'
env: run: docker pull mlsmaycon/warmed-mysql:8
DOCKER_USER: ${{ secrets.DOCKER_USER }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
- name: Test - name: Test
run: | run: |
@@ -596,18 +529,15 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Login to Docker hub - name: Login to Docker hub
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v3 uses: docker/login-action@v1
with: with:
username: ${{ secrets.DOCKER_USER }} username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }} password: ${{ secrets.DOCKER_TOKEN }}
- name: docker login for root user - name: download mysql image
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref if: matrix.store == 'mysql'
env: run: docker pull mlsmaycon/warmed-mysql:8
DOCKER_USER: ${{ secrets.DOCKER_USER }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
- name: Test - name: Test
run: | run: |

View File

@@ -63,15 +63,10 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- name: Generate test script - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV
run: |
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
$cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd
- name: test - name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "${{ github.workspace }}\run-tests.cmd" run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
- name: test output - name: test output
if: ${{ always() }} if: ${{ always() }}
run: Get-Content test-out.txt run: Get-Content test-out.txt

View File

@@ -19,8 +19,8 @@ jobs:
- name: codespell - name: codespell
uses: codespell-project/actions-codespell@v2 uses: codespell-project/actions-codespell@v2
with: with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
skip: go.mod,go.sum,**/proxy/web/** skip: go.mod,go.sum
golangci: golangci:
strategy: strategy:
fail-fast: false fail-fast: false

View File

@@ -1,51 +0,0 @@
name: PR Title Check
on:
pull_request:
types: [opened, edited, synchronize, reopened]
jobs:
check-title:
runs-on: ubuntu-latest
steps:
- name: Validate PR title prefix
uses: actions/github-script@v7
with:
script: |
const title = context.payload.pull_request.title;
const allowedTags = [
'management',
'client',
'signal',
'proxy',
'relay',
'misc',
'infrastructure',
'self-hosted',
'doc',
];
const pattern = /^\[([^\]]+)\]\s+.+/;
const match = title.match(pattern);
if (!match) {
core.setFailed(
`PR title must start with a tag in brackets.\n` +
`Example: [client] fix something\n` +
`Allowed tags: ${allowedTags.join(', ')}`
);
return;
}
const tags = match[1].split(',').map(t => t.trim().toLowerCase());
const invalid = tags.filter(t => !allowedTags.includes(t));
if (invalid.length > 0) {
core.setFailed(
`Invalid tag(s): ${invalid.join(', ')}\n` +
`Allowed tags: ${allowedTags.join(', ')}`
);
return;
}
console.log(`Valid PR title tags: [${tags.join(', ')}]`);

View File

@@ -1,62 +0,0 @@
name: Proto Version Check
on:
pull_request:
paths:
- "**/*.pb.go"
jobs:
check-proto-versions:
runs-on: ubuntu-latest
steps:
- name: Check for proto tool version changes
uses: actions/github-script@v7
with:
script: |
const files = await github.paginate(github.rest.pulls.listFiles, {
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: context.issue.number,
per_page: 100,
});
const pbFiles = files.filter(f => f.filename.endsWith('.pb.go'));
const missingPatch = pbFiles.filter(f => !f.patch).map(f => f.filename);
if (missingPatch.length > 0) {
core.setFailed(
`Cannot inspect patch data for:\n` +
missingPatch.map(f => `- ${f}`).join('\n') +
`\nThis can happen with very large PRs. Verify proto versions manually.`
);
return;
}
const versionPattern = /^[+-]\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
const violations = [];
for (const file of pbFiles) {
const changed = file.patch
.split('\n')
.filter(line => versionPattern.test(line));
if (changed.length > 0) {
violations.push({
file: file.filename,
lines: changed,
});
}
}
if (violations.length > 0) {
const details = violations.map(v =>
`${v.file}:\n${v.lines.map(l => ' ' + l).join('\n')}`
).join('\n\n');
core.setFailed(
`Proto version strings changed in generated files.\n` +
`This usually means the wrong protoc or protoc-gen-go version was used.\n` +
`Regenerate with the matching tool versions.\n\n` +
details
);
return;
}
console.log('No proto version string changes detected');

View File

@@ -9,8 +9,8 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.1.4" SIGN_PIPE_VER: "v0.1.0"
GORELEASER_VER: "v2.14.3" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH" COPYRIGHT: "NetBird GmbH"
@@ -114,13 +114,7 @@ jobs:
retention-days: 30 retention-days: 30
release: release:
runs-on: ubuntu-24.04-8-core runs-on: ubuntu-latest-m
outputs:
release_artifact_url: ${{ steps.upload_release.outputs.artifact-url }}
linux_packages_artifact_url: ${{ steps.upload_linux_packages.outputs.artifact-url }}
windows_packages_artifact_url: ${{ steps.upload_windows_packages.outputs.artifact-url }}
macos_packages_artifact_url: ${{ steps.upload_macos_packages.outputs.artifact-url }}
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
env: env:
flags: "" flags: ""
steps: steps:
@@ -166,7 +160,7 @@ jobs:
username: ${{ secrets.DOCKER_USER }} username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }} password: ${{ secrets.DOCKER_TOKEN }}
- name: Log in to the GitHub container registry - name: Log in to the GitHub container registry
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository if: github.event_name != 'pull_request'
uses: docker/login-action@v3 uses: docker/login-action@v3
with: with:
registry: ghcr.io registry: ghcr.io
@@ -175,14 +169,6 @@ jobs:
- name: Install OS build dependencies - name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
- name: Decode GPG signing key
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
env:
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
run: |
echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc
echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV
- name: Install goversioninfo - name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64 - name: Generate windows syso amd64
@@ -190,7 +176,6 @@ jobs:
- name: Generate windows syso arm64 - name: Generate windows syso arm64
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
- name: Run GoReleaser - name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
@@ -200,109 +185,25 @@ jobs:
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
NFPM_NETBIRD_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
- name: Verify RPM signatures
run: |
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
dnf install -y -q rpm-sign curl >/dev/null 2>&1
curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key
rpm --import /tmp/rpm-pub.key
echo "=== Verifying RPM signatures ==="
for rpm_file in /dist/*amd64*.rpm; do
[ -f "$rpm_file" ] || continue
echo "--- $(basename $rpm_file) ---"
rpm -K "$rpm_file"
done
'
- name: Clean up GPG key
if: always()
run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: Tag and push images (amd64 only)
id: tag_and_push_images
if: |
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) ||
(github.event_name == 'push' && github.ref == 'refs/heads/main')
run: |
set -euo pipefail
resolve_tags() {
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
echo "pr-${{ github.event.pull_request.number }}"
else
echo "main sha-$(git rev-parse --short HEAD)"
fi
}
ghcr_package_url() {
local image="$1" package encoded_package
package="${image#ghcr.io/}"
package="${package#*/}"
package="${package%%:*}"
encoded_package="${package//\//%2F}"
echo "https://github.com/orgs/netbirdio/packages/container/package/${encoded_package}"
}
image_refs=()
tag_and_push() {
local src="$1" img_name tag dst
img_name="${src%%:*}"
for tag in $(resolve_tags); do
dst="${img_name}:${tag}"
echo "Tagging ${src} -> ${dst}"
docker tag "$src" "$dst"
docker push "$dst"
image_refs+=("$dst")
done
}
cat > /tmp/goreleaser-artifacts.json <<'JSON'
${{ steps.goreleaser.outputs.artifacts }}
JSON
mapfile -t src_images < <(
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name | select(startswith("ghcr.io/"))' /tmp/goreleaser-artifacts.json
)
for src in "${src_images[@]}"; do
tag_and_push "$src"
done
{
echo "images_markdown<<EOF"
if [[ ${#image_refs[@]} -eq 0 ]]; then
echo "_No GHCR images were pushed._"
else
printf '%s\n' "${image_refs[@]}" | sort -u | while read -r image; do
printf -- '- [`%s`](%s)\n' "$image" "$(ghcr_package_url "$image")"
done
fi
echo "EOF"
} >> "$GITHUB_OUTPUT"
- name: upload non tags for debug purposes - name: upload non tags for debug purposes
id: upload_release
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: release name: release
path: dist/ path: dist/
retention-days: 7 retention-days: 7
- name: upload linux packages - name: upload linux packages
id: upload_linux_packages
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: linux-packages name: linux-packages
path: dist/netbird_linux** path: dist/netbird_linux**
retention-days: 7 retention-days: 7
- name: upload windows packages - name: upload windows packages
id: upload_windows_packages
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: windows-packages name: windows-packages
path: dist/netbird_windows** path: dist/netbird_windows**
retention-days: 7 retention-days: 7
- name: upload macos packages - name: upload macos packages
id: upload_macos_packages
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: macos-packages name: macos-packages
@@ -311,8 +212,6 @@ jobs:
release_ui: release_ui:
runs-on: ubuntu-latest runs-on: ubuntu-latest
outputs:
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
steps: steps:
- name: Parse semver string - name: Parse semver string
id: semver_parser id: semver_parser
@@ -352,14 +251,6 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64 run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
- name: Decode GPG signing key
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
env:
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
run: |
echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc
echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV
- name: Install LLVM-MinGW for ARM64 cross-compilation - name: Install LLVM-MinGW for ARM64 cross-compilation
run: | run: |
cd /tmp cd /tmp
@@ -384,26 +275,7 @@ jobs:
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
NFPM_NETBIRD_UI_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
- name: Verify RPM signatures
run: |
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
dnf install -y -q rpm-sign curl >/dev/null 2>&1
curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key
rpm --import /tmp/rpm-pub.key
echo "=== Verifying RPM signatures ==="
for rpm_file in /dist/*.rpm; do
[ -f "$rpm_file" ] || continue
echo "--- $(basename $rpm_file) ---"
rpm -K "$rpm_file"
done
'
- name: Clean up GPG key
if: always()
run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: upload non tags for debug purposes - name: upload non tags for debug purposes
id: upload_release_ui
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: release-ui name: release-ui
@@ -412,8 +284,6 @@ jobs:
release_ui_darwin: release_ui_darwin:
runs-on: macos-latest runs-on: macos-latest
outputs:
release_ui_darwin_artifact_url: ${{ steps.upload_release_ui_darwin.outputs.artifact-url }}
steps: steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
@@ -448,258 +318,15 @@ jobs:
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: upload non tags for debug purposes - name: upload non tags for debug purposes
id: upload_release_ui_darwin
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: release-ui-darwin name: release-ui-darwin
path: dist/ path: dist/
retention-days: 3 retention-days: 3
test_windows_installer:
name: "Windows Installer / Build Test"
runs-on: windows-2022
needs: [release, release_ui]
strategy:
fail-fast: false
matrix:
include:
- arch: amd64
wintun_arch: amd64
- arch: arm64
wintun_arch: arm64
defaults:
run:
shell: powershell
env:
PackageWorkdir: netbird_windows_${{ matrix.arch }}
downloadPath: '${{ github.workspace }}\temp'
steps:
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- name: Checkout
uses: actions/checkout@v4
- name: Add 7-Zip to PATH
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- name: Download release artifacts
uses: actions/download-artifact@v4
with:
name: release
path: release
- name: Download UI release artifacts
uses: actions/download-artifact@v4
with:
name: release-ui
path: release-ui
- name: Stage binaries into dist
run: |
$workdir = "dist\${{ env.PackageWorkdir }}"
New-Item -ItemType Directory -Force -Path $workdir | Out-Null
$client = Get-ChildItem -Recurse -Path release -Filter "netbird_*_windows_${{ matrix.arch }}.tar.gz" | Select-Object -First 1
$ui = Get-ChildItem -Recurse -Path release-ui -Filter "netbird-ui-windows_*_windows_${{ matrix.arch }}.tar.gz" | Select-Object -First 1
if (-not $client) { Write-Host "::error::client tarball not found for ${{ matrix.arch }}"; exit 1 }
if (-not $ui) { Write-Host "::error::ui tarball not found for ${{ matrix.arch }}"; exit 1 }
Write-Host "Client: $($client.FullName)"
Write-Host "UI: $($ui.FullName)"
tar -zvxf $client.FullName -C $workdir
tar -zvxf $ui.FullName -C $workdir
Get-ChildItem $workdir
- name: Download wintun
uses: carlosperate/download-file-action@v2
id: download-wintun
with:
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
file-name: wintun.zip
location: ${{ env.downloadPath }}
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
- name: Decompress wintun files
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
- name: Move wintun.dll into dist
run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
- name: Download Mesa3D (amd64 only)
uses: carlosperate/download-file-action@v2
id: download-mesa3d
if: matrix.arch == 'amd64'
with:
file-url: https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z
file-name: mesa3d.7z
location: ${{ env.downloadPath }}
sha256: '71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9'
- name: Extract Mesa3D driver (amd64 only)
if: matrix.arch == 'amd64'
run: 7z x -o"${{ env.downloadPath }}" "${{ env.downloadPath }}/mesa3d.7z"
- name: Move opengl32.dll into dist (amd64 only)
if: matrix.arch == 'amd64'
run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
- name: Download EnVar plugin for NSIS
uses: carlosperate/download-file-action@v2
with:
file-url: https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip
file-name: envar_plugin.zip
location: ${{ github.workspace }}
- name: Extract EnVar plugin
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/envar_plugin.zip"
- name: Download ShellExecAsUser plugin for NSIS (amd64 only)
uses: carlosperate/download-file-action@v2
if: matrix.arch == 'amd64'
with:
file-url: https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z
file-name: ShellExecAsUser_amd64-Unicode.7z
location: ${{ github.workspace }}
- name: Extract ShellExecAsUser plugin (amd64 only)
if: matrix.arch == 'amd64'
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z"
- name: Build NSIS installer
uses: joncloud/makensis-action@v3.3
with:
additional-plugin-paths: ${{ github.workspace }}/NSIS_Plugins/Plugins
script-file: client/installer.nsis
arguments: "/V4 /DARCH=${{ matrix.arch }}"
env:
APPVER: ${{ steps.semver_parser.outputs.major }}.${{ steps.semver_parser.outputs.minor }}.${{ steps.semver_parser.outputs.patch }}.${{ github.run_id }}
- name: Rename NSIS installer
run: mv netbird-installer.exe netbird_installer_test_windows_${{ matrix.arch }}.exe
- name: Install WiX
run: |
dotnet tool install --global wix --version 6.0.2
wix extension add WixToolset.Util.wixext/6.0.2
- name: Build MSI installer
env:
NETBIRD_VERSION: "${{ steps.semver_parser.outputs.fullversion }}"
run: wix build -arch ${{ matrix.arch == 'amd64' && 'x64' || 'arm64' }} -ext WixToolset.Util.wixext -o netbird_installer_test_windows_${{ matrix.arch }}.msi .\client\netbird.wxs -d ProcessorArchitecture=${{ matrix.arch == 'amd64' && 'x64' || 'arm64' }} -d ArchSuffix=${{ matrix.arch }}
- name: Upload installer artifacts
if: always()
uses: actions/upload-artifact@v4
with:
name: windows-installer-test-${{ matrix.arch }}
path: |
netbird_installer_test_windows_${{ matrix.arch }}.exe
netbird_installer_test_windows_${{ matrix.arch }}.msi
retention-days: 3
comment_release_artifacts:
name: Comment release artifacts
runs-on: ubuntu-latest
needs: [release, release_ui, release_ui_darwin]
if: ${{ always() && github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository }}
permissions:
contents: read
issues: write
pull-requests: write
steps:
- name: Create or update PR comment
uses: actions/github-script@v7
env:
RELEASE_RESULT: ${{ needs.release.result }}
RELEASE_UI_RESULT: ${{ needs.release_ui.result }}
RELEASE_UI_DARWIN_RESULT: ${{ needs.release_ui_darwin.result }}
RELEASE_ARTIFACT_URL: ${{ needs.release.outputs.release_artifact_url }}
LINUX_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.linux_packages_artifact_url }}
WINDOWS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.windows_packages_artifact_url }}
MACOS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.macos_packages_artifact_url }}
RELEASE_UI_ARTIFACT_URL: ${{ needs.release_ui.outputs.release_ui_artifact_url }}
RELEASE_UI_DARWIN_ARTIFACT_URL: ${{ needs.release_ui_darwin.outputs.release_ui_darwin_artifact_url }}
GHCR_IMAGES_MARKDOWN: ${{ needs.release.outputs.ghcr_images }}
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const marker = '<!-- netbird-release-artifacts -->';
const { owner, repo } = context.repo;
const issue_number = context.payload.pull_request.number;
const runUrl = `${context.serverUrl}/${owner}/${repo}/actions/runs/${context.runId}`;
const shortSha = context.payload.pull_request.head.sha.slice(0, 7);
const artifactCell = (url, result) => {
if (url) return `[Download](${url})`;
return result && result !== 'success' ? `_Not available (${result})_` : '_Not available_';
};
const artifacts = [
['All release artifacts', process.env.RELEASE_ARTIFACT_URL, process.env.RELEASE_RESULT],
['Linux packages', process.env.LINUX_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
['Windows packages', process.env.WINDOWS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
['macOS packages', process.env.MACOS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
['UI artifacts', process.env.RELEASE_UI_ARTIFACT_URL, process.env.RELEASE_UI_RESULT],
['UI macOS artifacts', process.env.RELEASE_UI_DARWIN_ARTIFACT_URL, process.env.RELEASE_UI_DARWIN_RESULT],
];
const artifactRows = artifacts
.map(([name, url, result]) => `| ${name} | ${artifactCell(url, result)} |`)
.join('\n');
const ghcrImages = (process.env.GHCR_IMAGES_MARKDOWN || '').trim() || '_No GHCR images were pushed._';
const body = [
marker,
'## Release artifacts',
'',
`Built for PR head \`${shortSha}\` in [workflow run #${process.env.GITHUB_RUN_NUMBER}](${runUrl}).`,
'',
'| Artifact | Link |',
'| --- | --- |',
artifactRows,
'',
'### GHCR images (amd64)',
ghcrImages,
'',
'_This comment is updated by the Release workflow. Artifact links expire according to the workflow retention policy._',
].join('\n');
const comments = await github.paginate(github.rest.issues.listComments, {
owner,
repo,
issue_number,
per_page: 100,
});
const previous = comments.find(comment =>
comment.user?.type === 'Bot' && comment.body?.includes(marker)
);
if (previous) {
await github.rest.issues.updateComment({
owner,
repo,
comment_id: previous.id,
body,
});
core.info(`Updated release artifacts comment ${previous.id}`);
} else {
const { data } = await github.rest.issues.createComment({
owner,
repo,
issue_number,
body,
});
core.info(`Created release artifacts comment ${data.id}`);
}
trigger_signer: trigger_signer:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [release, release_ui, release_ui_darwin, test_windows_installer] needs: [release, release_ui, release_ui_darwin]
if: startsWith(github.ref, 'refs/tags/') if: startsWith(github.ref, 'refs/tags/')
steps: steps:
- name: Trigger binaries sign pipelines - name: Trigger binaries sign pipelines

View File

@@ -9,8 +9,6 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true cancel-in-progress: true
# Receiving workflows (cloud sync-tag, mobile bump-netbird) expect the short
# tag form (e.g. v0.30.0), not refs/tags/v0.30.0 — github.ref_name, not github.ref.
jobs: jobs:
trigger_sync_tag: trigger_sync_tag:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -22,30 +20,4 @@ jobs:
ref: main ref: main
repo: ${{ secrets.UPSTREAM_REPO }} repo: ${{ secrets.UPSTREAM_REPO }}
token: ${{ secrets.NC_GITHUB_TOKEN }} token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }'
trigger_android_bump:
runs-on: ubuntu-latest
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
steps:
- name: Trigger android-client submodule bump
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
with:
workflow: bump-netbird.yml
ref: main
repo: netbirdio/android-client
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }'
trigger_ios_bump:
runs-on: ubuntu-latest
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
steps:
- name: Trigger ios-client submodule bump
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
with:
workflow: bump-netbird.yml
ref: main
repo: netbirdio/ios-client
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }' inputs: '{ "tag": "${{ github.ref_name }}" }'

View File

@@ -61,8 +61,8 @@ jobs:
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)" echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
if [ ${SIZE} -gt 58720256 ]; then if [ ${SIZE} -gt 57671680 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!" echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!"
exit 1 exit 1
fi fi

2
.gitignore vendored
View File

@@ -2,7 +2,6 @@
.run .run
*.iml *.iml
dist/ dist/
!proxy/web/dist/
bin/ bin/
.env .env
conf.json conf.json
@@ -33,4 +32,3 @@ infrastructure_files/setup-*.env
vendor/ vendor/
/netbird /netbird
client/netbird-electron/ client/netbird-electron/
management/server/types/testdata/

View File

@@ -58,11 +58,6 @@ linters:
govet: govet:
enable: enable:
- nilness - nilness
disable:
# The inline analyzer flags x/exp/maps Clone/Clear with //go:fix inline
# directives but cannot perform the rewrite due to generic type
# parameter inference limitations in the Go inliner.
- inline
enable-all: false enable-all: false
revive: revive:
rules: rules:

View File

@@ -106,26 +106,6 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-server
dir: combined
env:
- CGO_ENABLED=1
- >-
{{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }}
binary: netbird-server
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-upload - id: netbird-upload
dir: upload-server dir: upload-server
env: [CGO_ENABLED=0] env: [CGO_ENABLED=0]
@@ -140,40 +120,6 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-proxy
dir: proxy/cmd/proxy
env: [CGO_ENABLED=0]
binary: netbird-proxy
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-idp-migrate
dir: tools/idp-migrate
env:
- CGO_ENABLED=1
- >-
{{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }}
binary: netbird-idp-migrate
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
universal_binaries: universal_binaries:
- id: netbird - id: netbird
@@ -186,22 +132,18 @@ archives:
- netbird-wasm - netbird-wasm
name_template: "{{ .ProjectName }}_{{ .Version }}" name_template: "{{ .ProjectName }}_{{ .Version }}"
format: binary format: binary
- id: netbird-idp-migrate
builds:
- netbird-idp-migrate
name_template: "netbird-idp-migrate_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
nfpms: nfpms:
- maintainer: Netbird <dev@netbird.io> - maintainer: Netbird <dev@netbird.io>
description: Netbird client. description: Netbird client.
homepage: https://netbird.io/ homepage: https://netbird.io/
license: BSD-3-Clause id: netbird-deb
id: netbird_deb
bindir: /usr/bin bindir: /usr/bin
builds: builds:
- netbird - netbird
formats: formats:
- deb - deb
scripts: scripts:
postinstall: "release_files/post_install.sh" postinstall: "release_files/post_install.sh"
preremove: "release_files/pre_remove.sh" preremove: "release_files/pre_remove.sh"
@@ -209,19 +151,16 @@ nfpms:
- maintainer: Netbird <dev@netbird.io> - maintainer: Netbird <dev@netbird.io>
description: Netbird client. description: Netbird client.
homepage: https://netbird.io/ homepage: https://netbird.io/
license: BSD-3-Clause id: netbird-rpm
id: netbird_rpm
bindir: /usr/bin bindir: /usr/bin
builds: builds:
- netbird - netbird
formats: formats:
- rpm - rpm
scripts: scripts:
postinstall: "release_files/post_install.sh" postinstall: "release_files/post_install.sh"
preremove: "release_files/pre_remove.sh" preremove: "release_files/pre_remove.sh"
rpm:
signature:
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
dockers: dockers:
- image_templates: - image_templates:
- netbirdio/netbird:{{ .Version }}-amd64 - netbirdio/netbird:{{ .Version }}-amd64
@@ -581,104 +520,6 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-amd64
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
ids:
- netbird-server
goarch: amd64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
ids:
- netbird-server
goarch: arm64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
ids:
- netbird-server
goarch: arm
goarm: 6
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
ids:
- netbird-proxy
goarch: amd64
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
ids:
- netbird-proxy
goarch: arm64
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
ids:
- netbird-proxy
goarch: arm
goarm: 6
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
docker_manifests: docker_manifests:
- name_template: netbirdio/netbird:{{ .Version }} - name_template: netbirdio/netbird:{{ .Version }}
image_templates: image_templates:
@@ -757,18 +598,6 @@ docker_manifests:
- netbirdio/upload:{{ .Version }}-arm - netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64 - netbirdio/upload:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:{{ .Version }}
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:latest
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }} - name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
image_templates: image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8 - ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
@@ -846,43 +675,6 @@ docker_manifests:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8 - ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm - ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64 - ghcr.io/netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:latest
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: netbirdio/reverse-proxy:{{ .Version }}
image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- netbirdio/reverse-proxy:{{ .Version }}-arm
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: netbirdio/reverse-proxy:latest
image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- netbirdio/reverse-proxy:{{ .Version }}-arm
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/reverse-proxy:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/reverse-proxy:latest
image_templates:
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
brews: brews:
- ids: - ids:
- default - default
@@ -903,7 +695,7 @@ brews:
uploads: uploads:
- name: debian - name: debian
ids: ids:
- netbird_deb - netbird-deb
mode: archive mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package= target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com username: dev@wiretrustee.com
@@ -911,7 +703,7 @@ uploads:
- name: yum - name: yum
ids: ids:
- netbird_rpm - netbird-rpm
mode: archive mode: archive
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }} target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
username: dev@wiretrustee.com username: dev@wiretrustee.com

View File

@@ -61,7 +61,7 @@ nfpms:
- maintainer: Netbird <dev@netbird.io> - maintainer: Netbird <dev@netbird.io>
description: Netbird client UI. description: Netbird client UI.
homepage: https://netbird.io/ homepage: https://netbird.io/
id: netbird_ui_deb id: netbird-ui-deb
package_name: netbird-ui package_name: netbird-ui
builds: builds:
- netbird-ui - netbird-ui
@@ -80,7 +80,7 @@ nfpms:
- maintainer: Netbird <dev@netbird.io> - maintainer: Netbird <dev@netbird.io>
description: Netbird client UI. description: Netbird client UI.
homepage: https://netbird.io/ homepage: https://netbird.io/
id: netbird_ui_rpm id: netbird-ui-rpm
package_name: netbird-ui package_name: netbird-ui
builds: builds:
- netbird-ui - netbird-ui
@@ -95,14 +95,11 @@ nfpms:
dst: /usr/share/pixmaps/netbird.png dst: /usr/share/pixmaps/netbird.png
dependencies: dependencies:
- netbird - netbird
rpm:
signature:
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
uploads: uploads:
- name: debian - name: debian
ids: ids:
- netbird_ui_deb - netbird-ui-deb
mode: archive mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package= target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com username: dev@wiretrustee.com
@@ -110,7 +107,7 @@ uploads:
- name: yum - name: yum
ids: ids:
- netbird_ui_rpm - netbird-ui-rpm
mode: archive mode: archive
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }} target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
username: dev@wiretrustee.com username: dev@wiretrustee.com

View File

@@ -1,7 +1,7 @@
## Contributor License Agreement ## Contributor License Agreement
This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual
submitting this Agreement and NetBird GmbH, Brunnenstraße 196, 10119 Berlin, Germany, submitting this Agreement and NetBird GmbH, c/o Max-Beer-Straße 2-4 Münzstraße 12 10178 Berlin, Germany,
referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions
under which NetBird may utilize software contributions provided by the Contributor for inclusion in under which NetBird may utilize software contributions provided by the Contributor for inclusion in
its software development projects. By submitting this Agreement, the Contributor confirms their acceptance its software development projects. By submitting this Agreement, the Contributor confirms their acceptance

View File

@@ -1,4 +1,4 @@
This BSD3Clause license applies to all parts of the repository except for the directories management/, signal/, relay/ and combined/. This BSD3Clause license applies to all parts of the repository except for the directories management/, signal/ and relay/.
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory. Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
BSD 3-Clause License BSD 3-Clause License

View File

@@ -5,7 +5,7 @@ GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
$(GOLANGCI_LINT): $(GOLANGCI_LINT):
@echo "Installing golangci-lint..." @echo "Installing golangci-lint..."
@mkdir -p ./bin @mkdir -p ./bin
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest @GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
# Lint only changed files (fast, for pre-push) # Lint only changed files (fast, for pre-push)
lint: $(GOLANGCI_LINT) lint: $(GOLANGCI_LINT)

View File

@@ -60,8 +60,8 @@
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2 https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
### Self-Host NetBird (Video) ### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/bZAgpT6nzaQ/0.jpg)](https://youtu.be/bZAgpT6nzaQ) [![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)
### Key features ### Key features
@@ -126,7 +126,6 @@ See a complete [architecture overview](https://docs.netbird.io/about-netbird/how
### Community projects ### Community projects
- [NetBird installer script](https://github.com/physk/netbird-installer) - [NetBird installer script](https://github.com/physk/netbird-installer)
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/) - [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) — terminal UI for managing NetBird peers, routes, and settings
**Note**: The `main` branch may be in an *unstable or even broken state* during development. **Note**: The `main` branch may be in an *unstable or even broken state* during development.
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases). For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client . # sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest # sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.23.3 FROM alpine:3.23.2
# iproute2: busybox doesn't display ip rules properly # iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \ RUN apk add --no-cache \
bash \ bash \
@@ -17,8 +17,8 @@ ENV \
NETBIRD_BIN="/usr/local/bin/netbird" \ NETBIRD_BIN="/usr/local/bin/netbird" \
NB_LOG_FILE="console,/var/log/netbird/client.log" \ NB_LOG_FILE="console,/var/log/netbird/client.log" \
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \ NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
NB_ENABLE_CAPTURE="false" \ NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="30" NB_ENTRYPOINT_LOGIN_TIMEOUT="5"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -23,8 +23,8 @@ ENV \
NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \ NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \
NB_LOG_FILE="console,/var/lib/netbird/client.log" \ NB_LOG_FILE="console,/var/lib/netbird/client.log" \
NB_DISABLE_DNS="true" \ NB_DISABLE_DNS="true" \
NB_ENABLE_CAPTURE="false" \ NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="30" NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -8,7 +8,6 @@ import (
"os" "os"
"slices" "slices"
"sync" "sync"
"time"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
@@ -16,7 +15,6 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
@@ -28,7 +26,6 @@ import (
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
types "github.com/netbirdio/netbird/upload-server/types"
) )
// ConnectionListener export internal Listener for mobile // ConnectionListener export internal Listener for mobile
@@ -71,30 +68,7 @@ type Client struct {
uiVersion string uiVersion string
networkChangeListener listener.NetworkChangeListener networkChangeListener listener.NetworkChangeListener
stateMu sync.RWMutex
connectClient *internal.ConnectClient connectClient *internal.ConnectClient
config *profilemanager.Config
cacheDir string
}
func (c *Client) setState(cfg *profilemanager.Config, cacheDir string, cc *internal.ConnectClient) {
c.stateMu.Lock()
defer c.stateMu.Unlock()
c.config = cfg
c.cacheDir = cacheDir
c.connectClient = cc
}
func (c *Client) stateSnapshot() (*profilemanager.Config, string, *internal.ConnectClient) {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.config, c.cacheDir, c.connectClient
}
func (c *Client) getConnectClient() *internal.ConnectClient {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.connectClient
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client
@@ -119,7 +93,6 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
cfgFile := platformFiles.ConfigurationFilePath() cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath() stateFile := platformFiles.StateFilePath()
cacheDir := platformFiles.CacheDir()
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile) log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
@@ -151,9 +124,8 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
c.setState(cfg, cacheDir, connectClient) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
} }
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -163,7 +135,6 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
cfgFile := platformFiles.ConfigurationFilePath() cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath() stateFile := platformFiles.StateFilePath()
cacheDir := platformFiles.CacheDir()
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile) log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
@@ -186,9 +157,8 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
c.setState(cfg, cacheDir, connectClient) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources
@@ -203,12 +173,11 @@ func (c *Client) Stop() {
} }
func (c *Client) RenewTun(fd int) error { func (c *Client) RenewTun(fd int) error {
cc := c.getConnectClient() if c.connectClient == nil {
if cc == nil {
return fmt.Errorf("engine not running") return fmt.Errorf("engine not running")
} }
e := cc.Engine() e := c.connectClient.Engine()
if e == nil { if e == nil {
return fmt.Errorf("engine not initialized") return fmt.Errorf("engine not initialized")
} }
@@ -216,73 +185,6 @@ func (c *Client) RenewTun(fd int) error {
return e.RenewTun(fd) return e.RenewTun(fd)
} }
// DebugBundle generates a debug bundle, uploads it, and returns the upload key.
// It works both with and without a running engine.
func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (string, error) {
cfg, cacheDir, cc := c.stateSnapshot()
// If the engine hasn't been started, load config from disk
if cfg == nil {
var err error
cfg, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: platformFiles.ConfigurationFilePath(),
})
if err != nil {
return "", fmt.Errorf("load config: %w", err)
}
cacheDir = platformFiles.CacheDir()
}
deps := debug.GeneratorDependencies{
InternalConfig: cfg,
StatusRecorder: c.recorder,
TempDir: cacheDir,
}
if cc != nil {
resp, err := cc.GetLatestSyncResponse()
if err != nil {
log.Warnf("get latest sync response: %v", err)
}
deps.SyncResponse = resp
if e := cc.Engine(); e != nil {
if cm := e.GetClientMetrics(); cm != nil {
deps.ClientMetrics = cm
}
}
}
bundleGenerator := debug.NewBundleGenerator(
deps,
debug.BundleConfig{
Anonymize: anonymize,
IncludeSystemInfo: true,
},
)
path, err := bundleGenerator.Generate()
if err != nil {
return "", fmt.Errorf("generate debug bundle: %w", err)
}
defer func() {
if err := os.Remove(path); err != nil {
log.Errorf("failed to remove debug bundle file: %v", err)
}
}()
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
if err != nil {
return "", fmt.Errorf("upload debug bundle: %w", err)
}
log.Infof("debug bundle uploaded with key %s", key)
return key, nil
}
// SetTraceLogLevel configure the logger to trace level // SetTraceLogLevel configure the logger to trace level
func (c *Client) SetTraceLogLevel() { func (c *Client) SetTraceLogLevel() {
log.SetLevel(log.TraceLevel) log.SetLevel(log.TraceLevel)
@@ -301,11 +203,10 @@ func (c *Client) PeersList() *PeerInfoArray {
peerInfos := make([]PeerInfo, len(fullStatus.Peers)) peerInfos := make([]PeerInfo, len(fullStatus.Peers))
for n, p := range fullStatus.Peers { for n, p := range fullStatus.Peers {
pi := PeerInfo{ pi := PeerInfo{
IP: p.IP, p.IP,
IPv6: p.IPv6, p.FQDN,
FQDN: p.FQDN, p.ConnStatus.String(),
ConnStatus: int(p.ConnStatus), PeerRoutes{routes: maps.Keys(p.GetRoutes())},
Routes: PeerRoutes{routes: maps.Keys(p.GetRoutes())},
} }
peerInfos[n] = pi peerInfos[n] = pi
} }
@@ -313,13 +214,12 @@ func (c *Client) PeersList() *PeerInfoArray {
} }
func (c *Client) Networks() *NetworkArray { func (c *Client) Networks() *NetworkArray {
cc := c.getConnectClient() if c.connectClient == nil {
if cc == nil {
log.Error("not connected") log.Error("not connected")
return nil return nil
} }
engine := cc.Engine() engine := c.connectClient.Engine()
if engine == nil { if engine == nil {
log.Error("could not get engine") log.Error("could not get engine")
return nil return nil
@@ -337,84 +237,43 @@ func (c *Client) Networks() *NetworkArray {
return nil return nil
} }
routesMap := routeManager.GetClientRoutesWithNetID()
v6Merged := route.V6ExitMergeSet(routesMap)
resolvedDomains := c.recorder.GetResolvedDomainsStates()
networkArray := &NetworkArray{ networkArray := &NetworkArray{
items: make([]Network, 0), items: make([]Network, 0),
} }
for id, routes := range routesMap { resolvedDomains := c.recorder.GetResolvedDomainsStates()
for id, routes := range routeManager.GetClientRoutesWithNetID() {
if len(routes) == 0 { if len(routes) == 0 {
continue continue
} }
if _, skip := v6Merged[id]; skip {
continue r := routes[0]
domains := c.getNetworkDomainsFromRoute(r, resolvedDomains)
netStr := r.Network.String()
if r.IsDynamic() {
netStr = r.Domains.SafeString()
} }
network := c.buildNetwork(id, routes, routeSelector.IsSelected(id), resolvedDomains, v6Merged) routePeer, err := c.recorder.GetPeer(routes[0].Peer)
if network == nil { if err != nil {
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
continue continue
} }
networkArray.Add(*network) network := Network{
Name: string(id),
Network: netStr,
Peer: routePeer.FQDN,
Status: routePeer.ConnStatus.String(),
IsSelected: routeSelector.IsSelected(id),
Domains: domains,
}
networkArray.Add(network)
} }
return networkArray return networkArray
} }
func (c *Client) buildNetwork(id route.NetID, routes []*route.Route, selected bool, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo, v6Merged map[route.NetID]struct{}) *Network {
r := routes[0]
netStr := r.Network.String()
if r.IsDynamic() {
netStr = r.Domains.SafeString()
}
routePeer, err := c.findBestRoutePeer(routes)
if err != nil {
log.Errorf("could not get peer info for route %s: %v", id, err)
return nil
}
network := &Network{
Name: string(id),
Network: netStr,
Peer: routePeer.FQDN,
Status: routePeer.ConnStatus.String(),
IsSelected: selected,
Domains: c.getNetworkDomainsFromRoute(r, resolvedDomains),
}
if route.IsV4DefaultRoute(r.Network) && route.HasV6ExitPair(id, v6Merged) {
network.Network = "0.0.0.0/0, ::/0"
}
return network
}
// findBestRoutePeer returns the peer actively routing traffic for the given
// HA route group. Falls back to the first connected peer, then the first peer.
func (c *Client) findBestRoutePeer(routes []*route.Route) (peer.State, error) {
netStr := routes[0].Network.String()
fullStatus := c.recorder.GetFullStatus()
for _, p := range fullStatus.Peers {
if _, ok := p.GetRoutes()[netStr]; ok {
return p, nil
}
}
for _, r := range routes {
p, err := c.recorder.GetPeer(r.Peer)
if err != nil {
continue
}
if p.ConnStatus == peer.StatusConnected {
return p, nil
}
}
return c.recorder.GetPeer(routes[0].Peer)
}
// OnUpdatedHostDNS update the DNS servers addresses for root zones // OnUpdatedHostDNS update the DNS servers addresses for root zones
func (c *Client) OnUpdatedHostDNS(list *DNSList) error { func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
dnsServer, err := dns.GetServerDns() dnsServer, err := dns.GetServerDns()
@@ -441,7 +300,7 @@ func (c *Client) toggleRoute(command routeCommand) error {
} }
func (c *Client) getRouteManager() (routemanager.Manager, error) { func (c *Client) getRouteManager() (routemanager.Manager, error) {
client := c.getConnectClient() client := c.connectClient
if client == nil { if client == nil {
return nil, fmt.Errorf("not connected") return nil, fmt.Errorf("not connected")
} }

View File

@@ -1,19 +1,10 @@
package android package android
import ( import "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/peer"
)
var ( var (
// EnvKeyNBForceRelay Exported for Android java client to force relay connections // EnvKeyNBForceRelay Exported for Android java client
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
// EnvKeyNBLazyConn Exported for Android java client to configure lazy connection
EnvKeyNBLazyConn = lazyconn.EnvEnableLazyConn
// EnvKeyNBInactivityThreshold Exported for Android java client to configure connection inactivity threshold
EnvKeyNBInactivityThreshold = lazyconn.EnvInactivityThreshold
) )
// EnvList wraps a Go map for export to Java // EnvList wraps a Go map for export to Java

View File

@@ -2,21 +2,11 @@
package android package android
import "github.com/netbirdio/netbird/client/internal/peer"
// Connection status constants exported via gomobile.
const (
ConnStatusIdle = int(peer.StatusIdle)
ConnStatusConnecting = int(peer.StatusConnecting)
ConnStatusConnected = int(peer.StatusConnected)
)
// PeerInfo describe information about the peers. It designed for the UI usage // PeerInfo describe information about the peers. It designed for the UI usage
type PeerInfo struct { type PeerInfo struct {
IP string IP string
IPv6 string
FQDN string FQDN string
ConnStatus int ConnStatus string // Todo replace to enum
Routes PeerRoutes Routes PeerRoutes
} }

View File

@@ -7,5 +7,4 @@ package android
type PlatformFiles interface { type PlatformFiles interface {
ConfigurationFilePath() string ConfigurationFilePath() string
StateFilePath() string StateFilePath() string
CacheDir() string
} }

View File

@@ -307,24 +307,6 @@ func (p *Preferences) SetBlockInbound(block bool) {
p.configInput.BlockInbound = &block p.configInput.BlockInbound = &block
} }
// GetDisableIPv6 reads disable IPv6 setting from config file
func (p *Preferences) GetDisableIPv6() (bool, error) {
if p.configInput.DisableIPv6 != nil {
return *p.configInput.DisableIPv6, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableIPv6, err
}
// SetDisableIPv6 stores the given value and waits for commit
func (p *Preferences) SetDisableIPv6(disable bool) {
p.configInput.DisableIPv6 = &disable
}
// Commit writes out the changes to the config file // Commit writes out the changes to the config file
func (p *Preferences) Commit() error { func (p *Preferences) Commit() error {
_, err := profilemanager.UpdateOrCreateConfig(p.configInput) _, err := profilemanager.UpdateOrCreateConfig(p.configInput)

View File

@@ -18,12 +18,9 @@ func executeRouteToggle(id string, manager routemanager.Manager,
netID := route.NetID(id) netID := route.NetID(id)
routes := []route.NetID{netID} routes := []route.NetID{netID}
routesMap := manager.GetClientRoutesWithNetID() log.Debugf("%s with id: %s", operationName, id)
routes = route.ExpandV6ExitPairs(routes, routesMap)
log.Debugf("%s with ids: %v", operationName, routes) if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil {
if err := routeOperation(routes, maps.Keys(routesMap)); err != nil {
log.Debugf("error when %s: %s", operationName, err) log.Debugf("error when %s: %s", operationName, err)
return fmt.Errorf("error %s: %w", operationName, err) return fmt.Errorf("error %s: %w", operationName, err)
} }

View File

@@ -9,7 +9,6 @@ import (
"net/url" "net/url"
"regexp" "regexp"
"slices" "slices"
"strconv"
"strings" "strings"
) )
@@ -27,9 +26,8 @@ type Anonymizer struct {
} }
func DefaultAddresses() (netip.Addr, netip.Addr) { func DefaultAddresses() (netip.Addr, netip.Addr) {
// 198.51.100.0 (RFC 5737 TEST-NET-2), 2001:db8:ffff:: (RFC 3849 documentation, last /48) // 198.51.100.0, 100::
// The old start 100:: (discard, RFC 6666) is now used for fake IPs on Android. return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.MustParseAddr("2001:db8:ffff::")
} }
func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer { func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
@@ -50,7 +48,7 @@ func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
ip.IsLinkLocalUnicast() || ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() || ip.IsLinkLocalMulticast() ||
ip.IsInterfaceLocalMulticast() || ip.IsInterfaceLocalMulticast() ||
(ip.Is4() && ip.IsPrivate()) || ip.IsPrivate() ||
ip.IsUnspecified() || ip.IsUnspecified() ||
ip.IsMulticast() || ip.IsMulticast() ||
isWellKnown(ip) || isWellKnown(ip) ||
@@ -98,11 +96,6 @@ func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
} }
func (a *Anonymizer) AnonymizeIPString(ip string) string { func (a *Anonymizer) AnonymizeIPString(ip string) string {
// Handle CIDR notation (e.g. "2001:db8::/32")
if prefix, err := netip.ParsePrefix(ip); err == nil {
return a.AnonymizeIP(prefix.Addr()).String() + "/" + strconv.Itoa(prefix.Bits())
}
addr, err := netip.ParseAddr(ip) addr, err := netip.ParseAddr(ip)
if err != nil { if err != nil {
return ip return ip
@@ -157,7 +150,7 @@ func (a *Anonymizer) AnonymizeURI(uri string) string {
if u.Opaque != "" { if u.Opaque != "" {
host, port, err := net.SplitHostPort(u.Opaque) host, port, err := net.SplitHostPort(u.Opaque)
if err == nil { if err == nil {
anonymizedHost = net.JoinHostPort(a.AnonymizeDomain(host), port) anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
} else { } else {
anonymizedHost = a.AnonymizeDomain(u.Opaque) anonymizedHost = a.AnonymizeDomain(u.Opaque)
} }
@@ -165,7 +158,7 @@ func (a *Anonymizer) AnonymizeURI(uri string) string {
} else if u.Host != "" { } else if u.Host != "" {
host, port, err := net.SplitHostPort(u.Host) host, port, err := net.SplitHostPort(u.Host)
if err == nil { if err == nil {
anonymizedHost = net.JoinHostPort(a.AnonymizeDomain(host), port) anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
} else { } else {
anonymizedHost = a.AnonymizeDomain(u.Host) anonymizedHost = a.AnonymizeDomain(u.Host)
} }

View File

@@ -13,7 +13,7 @@ import (
func TestAnonymizeIP(t *testing.T) { func TestAnonymizeIP(t *testing.T) {
startIPv4 := netip.MustParseAddr("198.51.100.0") startIPv4 := netip.MustParseAddr("198.51.100.0")
startIPv6 := netip.MustParseAddr("2001:db8:ffff::") startIPv6 := netip.MustParseAddr("100::")
anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6) anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6)
tests := []struct { tests := []struct {
@@ -26,9 +26,9 @@ func TestAnonymizeIP(t *testing.T) {
{"Second Public IPv4", "4.3.2.1", "198.51.100.1"}, {"Second Public IPv4", "4.3.2.1", "198.51.100.1"},
{"Repeated IPv4", "1.2.3.4", "198.51.100.0"}, {"Repeated IPv4", "1.2.3.4", "198.51.100.0"},
{"Private IPv4", "192.168.1.1", "192.168.1.1"}, {"Private IPv4", "192.168.1.1", "192.168.1.1"},
{"First Public IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"}, {"First Public IPv6", "2607:f8b0:4005:805::200e", "100::"},
{"Second Public IPv6", "a::b", "2001:db8:ffff::1"}, {"Second Public IPv6", "a::b", "100::1"},
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"}, {"Repeated IPv6", "2607:f8b0:4005:805::200e", "100::"},
{"Private IPv6", "fe80::1", "fe80::1"}, {"Private IPv6", "fe80::1", "fe80::1"},
{"In Range IPv4", "198.51.100.2", "198.51.100.2"}, {"In Range IPv4", "198.51.100.2", "198.51.100.2"},
} }
@@ -274,27 +274,17 @@ func TestAnonymizeString_IPAddresses(t *testing.T) {
{ {
name: "IPv6 Address", name: "IPv6 Address",
input: "Access attempted from 2001:db8::ff00:42", input: "Access attempted from 2001:db8::ff00:42",
expect: "Access attempted from 2001:db8:ffff::", expect: "Access attempted from 100::",
}, },
{ {
name: "IPv6 Address with Port", name: "IPv6 Address with Port",
input: "Access attempted from [2001:db8::ff00:42]:8080", input: "Access attempted from [2001:db8::ff00:42]:8080",
expect: "Access attempted from [2001:db8:ffff::]:8080", expect: "Access attempted from [100::]:8080",
}, },
{ {
name: "Both IPv4 and IPv6", name: "Both IPv4 and IPv6",
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43", input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
expect: "IPv4: 198.51.100.1 and IPv6: 2001:db8:ffff::1", expect: "IPv4: 198.51.100.1 and IPv6: 100::1",
},
{
name: "STUN URI with IPv6",
input: "Connecting to stun:[2001:db8::ff00:42]:3478",
expect: "Connecting to stun:[2001:db8:ffff::]:3478",
},
{
name: "HTTPS URI with IPv6",
input: "Visit https://[2001:db8::ff00:42]:443/path",
expect: "Visit https://[2001:db8:ffff::]:443/path",
}, },
} }

View File

@@ -1,196 +0,0 @@
package cmd
import (
"context"
"fmt"
"io"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"github.com/hashicorp/go-multierror"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/util/capture"
)
var captureCmd = &cobra.Command{
Use: "capture",
Short: "Capture packets on the WireGuard interface",
Long: `Captures decrypted packets flowing through the WireGuard interface.
Default output is human-readable text. Use --pcap or --output for pcap binary.
Requires --enable-capture to be set at service install or reconfigure time.
Examples:
netbird debug capture
netbird debug capture host 100.64.0.1 and port 443
netbird debug capture tcp
netbird debug capture icmp
netbird debug capture src host 10.0.0.1 and dst port 80
netbird debug capture -o capture.pcap
netbird debug capture --pcap | tshark -r -
netbird debug capture --pcap | tcpdump -r - -n`,
Args: cobra.ArbitraryArgs,
RunE: runCapture,
}
func init() {
debugCmd.AddCommand(captureCmd)
captureCmd.Flags().Bool("pcap", false, "Force pcap binary output (default when --output is set)")
captureCmd.Flags().BoolP("verbose", "v", false, "Show seq/ack, TTL, window, total length")
captureCmd.Flags().Bool("ascii", false, "Print payload as ASCII after each packet (useful for HTTP)")
captureCmd.Flags().Uint32("snap-len", 0, "Max bytes per packet (0 = full)")
captureCmd.Flags().DurationP("duration", "d", 0, "Capture duration (0 = until interrupted)")
captureCmd.Flags().StringP("output", "o", "", "Write pcap to file instead of stdout")
}
func runCapture(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
cmd.PrintErrf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
req, err := buildCaptureRequest(cmd, args)
if err != nil {
return err
}
ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
stream, err := client.StartCapture(ctx, req)
if err != nil {
return handleCaptureError(err)
}
// First Recv is the empty acceptance message from the server. If the
// device is unavailable (kernel WG, not connected, capture disabled),
// the server returns an error instead.
if _, err := stream.Recv(); err != nil {
return handleCaptureError(err)
}
out, cleanup, err := captureOutput(cmd)
if err != nil {
return err
}
if req.TextOutput {
cmd.PrintErrf("Capturing packets... Press Ctrl+C to stop.\n")
} else {
cmd.PrintErrf("Capturing packets (pcap)... Press Ctrl+C to stop.\n")
}
streamErr := streamCapture(ctx, cmd, stream, out)
cleanupErr := cleanup()
if streamErr != nil {
return streamErr
}
return cleanupErr
}
func buildCaptureRequest(cmd *cobra.Command, args []string) (*proto.StartCaptureRequest, error) {
req := &proto.StartCaptureRequest{}
if len(args) > 0 {
expr := strings.Join(args, " ")
if _, err := capture.ParseFilter(expr); err != nil {
return nil, fmt.Errorf("invalid filter: %w", err)
}
req.FilterExpr = expr
}
if snap, _ := cmd.Flags().GetUint32("snap-len"); snap > 0 {
req.SnapLen = snap
}
if d, _ := cmd.Flags().GetDuration("duration"); d != 0 {
if d < 0 {
return nil, fmt.Errorf("duration must not be negative")
}
req.Duration = durationpb.New(d)
}
req.Verbose, _ = cmd.Flags().GetBool("verbose")
req.Ascii, _ = cmd.Flags().GetBool("ascii")
outPath, _ := cmd.Flags().GetString("output")
forcePcap, _ := cmd.Flags().GetBool("pcap")
req.TextOutput = !forcePcap && outPath == ""
return req, nil
}
func streamCapture(ctx context.Context, cmd *cobra.Command, stream proto.DaemonService_StartCaptureClient, out io.Writer) error {
for {
pkt, err := stream.Recv()
if err != nil {
if ctx.Err() != nil {
cmd.PrintErrf("\nCapture stopped.\n")
return nil //nolint:nilerr // user interrupted
}
if err == io.EOF {
cmd.PrintErrf("\nCapture finished.\n")
return nil
}
return handleCaptureError(err)
}
if _, err := out.Write(pkt.GetData()); err != nil {
return fmt.Errorf("write output: %w", err)
}
}
}
// captureOutput returns the writer for capture data and a cleanup function
// that finalizes the file. Errors from the cleanup must be propagated.
func captureOutput(cmd *cobra.Command) (io.Writer, func() error, error) {
outPath, _ := cmd.Flags().GetString("output")
if outPath == "" {
return os.Stdout, func() error { return nil }, nil
}
f, err := os.CreateTemp(filepath.Dir(outPath), filepath.Base(outPath)+".*.tmp")
if err != nil {
return nil, nil, fmt.Errorf("create output file: %w", err)
}
tmpPath := f.Name()
return f, func() error {
var merr *multierror.Error
if err := f.Close(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("close output file: %w", err))
}
fi, statErr := os.Stat(tmpPath)
if statErr != nil || fi.Size() == 0 {
if rmErr := os.Remove(tmpPath); rmErr != nil && !os.IsNotExist(rmErr) {
merr = multierror.Append(merr, fmt.Errorf("remove empty output file: %w", rmErr))
}
return nberrors.FormatErrorOrNil(merr)
}
if err := os.Rename(tmpPath, outPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("rename output file: %w", err))
return nberrors.FormatErrorOrNil(merr)
}
cmd.PrintErrf("Wrote %s\n", outPath)
return nberrors.FormatErrorOrNil(merr)
}, nil
}
func handleCaptureError(err error) error {
if s, ok := status.FromError(err); ok {
return fmt.Errorf("%s", s.Message())
}
return err
}

View File

@@ -9,7 +9,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/debug" "github.com/netbirdio/netbird/client/internal/debug"
@@ -182,11 +181,10 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if stateWasDown { if stateWasDown {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message()) return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
} else {
cmd.Println("netbird up")
time.Sleep(time.Second * 10)
} }
cmd.Println("netbird up")
time.Sleep(time.Second * 10)
} }
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
@@ -200,13 +198,10 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Log level set to trace.") cmd.Println("Log level set to trace.")
} }
needsRestoreUp := false
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message()) return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
} else {
needsRestoreUp = !stateWasDown
cmd.Println("netbird down")
} }
cmd.Println("netbird down")
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
@@ -214,15 +209,13 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if _, err := client.SetSyncResponsePersistence(cmd.Context(), &proto.SetSyncResponsePersistenceRequest{ if _, err := client.SetSyncResponsePersistence(cmd.Context(), &proto.SetSyncResponsePersistenceRequest{
Enabled: true, Enabled: true,
}); err != nil { }); err != nil {
cmd.PrintErrf("Failed to enable sync response persistence: %v\n", status.Convert(err).Message()) return fmt.Errorf("failed to enable sync response persistence: %v", status.Convert(err).Message())
} }
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message()) return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
} else {
needsRestoreUp = false
cmd.Println("netbird up")
} }
cmd.Println("netbird up")
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
@@ -240,50 +233,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
}() }()
} }
captureStarted := false
if wantCapture, _ := cmd.Flags().GetBool("capture"); wantCapture {
captureTimeout := duration + 30*time.Second
const maxBundleCapture = 10 * time.Minute
if captureTimeout > maxBundleCapture {
captureTimeout = maxBundleCapture
}
_, err := client.StartBundleCapture(cmd.Context(), &proto.StartBundleCaptureRequest{
Timeout: durationpb.New(captureTimeout),
})
if err != nil {
cmd.PrintErrf("Failed to start packet capture: %v\n", status.Convert(err).Message())
} else {
captureStarted = true
cmd.Println("Packet capture started.")
// Safety: always stop on exit, even if the normal stop below runs too.
defer func() {
if captureStarted {
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
cmd.PrintErrf("Failed to stop packet capture: %v\n", err)
}
}
}()
}
}
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil { if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr return waitErr
} }
cmd.Println("\nDuration completed") cmd.Println("\nDuration completed")
if captureStarted {
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
cmd.PrintErrf("Failed to stop packet capture: %v\n", err)
} else {
captureStarted = false
cmd.Println("Packet capture stopped.")
}
}
if cpuProfilingStarted { if cpuProfilingStarted {
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil { if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err) cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
@@ -307,28 +261,18 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
} }
if needsRestoreUp {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
cmd.PrintErrf("Failed to restore service up state: %v\n", status.Convert(err).Message())
} else {
cmd.Println("netbird up (restored)")
}
}
if stateWasDown { if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message()) return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
} else {
cmd.Println("netbird down")
} }
cmd.Println("netbird down")
} }
if !initialLevelTrace { if !initialLevelTrace {
if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil { if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil {
cmd.PrintErrf("Failed to restore log level: %v\n", status.Convert(err).Message()) return fmt.Errorf("failed to restore log level: %v", status.Convert(err).Message())
} else {
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
} }
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
} }
cmd.Printf("Local file:\n%s\n", resp.GetPath()) cmd.Printf("Local file:\n%s\n", resp.GetPath())
@@ -456,5 +400,4 @@ func init() {
forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle") forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server") forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle") forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
forCmd.Flags().Bool("capture", false, "Capture packets during the debug duration and include in bundle")
} }

View File

@@ -1,287 +0,0 @@
package cmd
import (
"context"
"errors"
"fmt"
"io"
"os"
"os/signal"
"regexp"
"strconv"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/util"
)
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
var (
exposePin string
exposePassword string
exposeUserGroups []string
exposeDomain string
exposeNamePrefix string
exposeProtocol string
exposeExternalPort uint16
)
var exposeCmd = &cobra.Command{
Use: "expose <port>",
Short: "Expose a local port via the NetBird reverse proxy",
Args: cobra.ExactArgs(1),
Example: ` netbird expose --with-password safe-pass 8080
netbird expose --protocol tcp 5432
netbird expose --protocol tcp --with-external-port 5433 5432
netbird expose --protocol tls --with-custom-domain tls.example.com 4443`,
RunE: exposeFn,
}
func init() {
exposeCmd.Flags().StringVar(&exposePin, "with-pin", "", "Protect the exposed service with a 6-digit PIN (e.g. --with-pin 123456)")
exposeCmd.Flags().StringVar(&exposePassword, "with-password", "", "Protect the exposed service with a password (e.g. --with-password my-secret)")
exposeCmd.Flags().StringSliceVar(&exposeUserGroups, "with-user-groups", nil, "Restrict access to specific user groups with SSO (e.g. --with-user-groups devops,Backend)")
exposeCmd.Flags().StringVar(&exposeDomain, "with-custom-domain", "", "Custom domain for the exposed service, must be configured to your account (e.g. --with-custom-domain myapp.example.com)")
exposeCmd.Flags().StringVar(&exposeNamePrefix, "with-name-prefix", "", "Prefix for the generated service name (e.g. --with-name-prefix my-app)")
exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use: http, https, tcp, udp, or tls (e.g. --protocol tcp)")
exposeCmd.Flags().Uint16Var(&exposeExternalPort, "with-external-port", 0, "Public-facing external port on the proxy cluster (defaults to the target port for L4)")
}
// isClusterProtocol returns true for L4/TLS protocols that reject HTTP-style auth flags.
func isClusterProtocol(protocol string) bool {
switch strings.ToLower(protocol) {
case "tcp", "udp", "tls":
return true
default:
return false
}
}
// isPortBasedProtocol returns true for pure port-based protocols (TCP/UDP)
// where domain display doesn't apply. TLS uses SNI so it has a domain.
func isPortBasedProtocol(protocol string) bool {
switch strings.ToLower(protocol) {
case "tcp", "udp":
return true
default:
return false
}
}
// extractPort returns the port portion of a URL like "tcp://host:12345", or
// falls back to the given default formatted as a string.
func extractPort(serviceURL string, fallback uint16) string {
u := serviceURL
if idx := strings.Index(u, "://"); idx != -1 {
u = u[idx+3:]
}
if i := strings.LastIndex(u, ":"); i != -1 {
if p := u[i+1:]; p != "" {
return p
}
}
return strconv.FormatUint(uint64(fallback), 10)
}
// resolveExternalPort returns the effective external port, defaulting to the target port.
func resolveExternalPort(targetPort uint64) uint16 {
if exposeExternalPort != 0 {
return exposeExternalPort
}
return uint16(targetPort)
}
func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) {
port, err := strconv.ParseUint(portStr, 10, 32)
if err != nil {
return 0, fmt.Errorf("invalid port number: %s", portStr)
}
if port == 0 || port > 65535 {
return 0, fmt.Errorf("invalid port number: must be between 1 and 65535")
}
if !isProtocolValid(exposeProtocol) {
return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol)
}
if isClusterProtocol(exposeProtocol) {
if exposePin != "" || exposePassword != "" || len(exposeUserGroups) > 0 {
return 0, fmt.Errorf("auth flags (--with-pin, --with-password, --with-user-groups) are not supported for %s protocol", exposeProtocol)
}
} else if cmd.Flags().Changed("with-external-port") {
return 0, fmt.Errorf("--with-external-port is not supported for %s protocol", exposeProtocol)
}
if exposePin != "" && !pinRegexp.MatchString(exposePin) {
return 0, fmt.Errorf("invalid pin: must be exactly 6 digits")
}
if cmd.Flags().Changed("with-password") && exposePassword == "" {
return 0, fmt.Errorf("password cannot be empty")
}
if cmd.Flags().Changed("with-user-groups") && len(exposeUserGroups) == 0 {
return 0, fmt.Errorf("user groups cannot be empty")
}
return port, nil
}
func isProtocolValid(exposeProtocol string) bool {
switch strings.ToLower(exposeProtocol) {
case "http", "https", "tcp", "udp", "tls":
return true
default:
return false
}
}
func exposeFn(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
log.Errorf("failed initializing log %v", err)
return err
}
cmd.Root().SilenceUsage = false
port, err := validateExposeFlags(cmd, args[0])
if err != nil {
return err
}
cmd.Root().SilenceUsage = true
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigCh
cancel()
}()
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("connect to daemon: %w", err)
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("failed to close daemon connection: %v", err)
}
}()
client := proto.NewDaemonServiceClient(conn)
protocol, err := toExposeProtocol(exposeProtocol)
if err != nil {
return err
}
req := &proto.ExposeServiceRequest{
Port: uint32(port),
Protocol: protocol,
Pin: exposePin,
Password: exposePassword,
UserGroups: exposeUserGroups,
Domain: exposeDomain,
NamePrefix: exposeNamePrefix,
}
if isClusterProtocol(exposeProtocol) {
req.ListenPort = uint32(resolveExternalPort(port))
}
stream, err := client.ExposeService(ctx, req)
if err != nil {
return fmt.Errorf("expose service: %v", status.Convert(err).Message())
}
if err := handleExposeReady(cmd, stream, port); err != nil {
return err
}
return waitForExposeEvents(cmd, ctx, stream)
}
func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
p, err := expose.ParseProtocolType(exposeProtocol)
if err != nil {
return 0, fmt.Errorf("invalid protocol: %w", err)
}
switch p {
case expose.ProtocolHTTP:
return proto.ExposeProtocol_EXPOSE_HTTP, nil
case expose.ProtocolHTTPS:
return proto.ExposeProtocol_EXPOSE_HTTPS, nil
case expose.ProtocolTCP:
return proto.ExposeProtocol_EXPOSE_TCP, nil
case expose.ProtocolUDP:
return proto.ExposeProtocol_EXPOSE_UDP, nil
case expose.ProtocolTLS:
return proto.ExposeProtocol_EXPOSE_TLS, nil
default:
return 0, fmt.Errorf("unhandled protocol type: %d", p)
}
}
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
event, err := stream.Recv()
if err != nil {
return fmt.Errorf("receive expose event: %v", status.Convert(err).Message())
}
ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready)
if !ok {
return fmt.Errorf("unexpected expose event: %T", event.Event)
}
printExposeReady(cmd, ready.Ready, port)
return nil
}
func printExposeReady(cmd *cobra.Command, r *proto.ExposeServiceReady, port uint64) {
cmd.Println("Service exposed successfully!")
cmd.Printf(" Name: %s\n", r.ServiceName)
if r.ServiceUrl != "" {
cmd.Printf(" URL: %s\n", r.ServiceUrl)
}
if r.Domain != "" && !isPortBasedProtocol(exposeProtocol) {
cmd.Printf(" Domain: %s\n", r.Domain)
}
cmd.Printf(" Protocol: %s\n", exposeProtocol)
cmd.Printf(" Internal: %d\n", port)
if isClusterProtocol(exposeProtocol) {
cmd.Printf(" External: %s\n", extractPort(r.ServiceUrl, resolveExternalPort(port)))
}
if r.PortAutoAssigned && exposeExternalPort != 0 {
cmd.Printf("\n Note: requested port %d was reassigned\n", exposeExternalPort)
}
cmd.Println()
cmd.Println("Press Ctrl+C to stop exposing.")
}
func waitForExposeEvents(cmd *cobra.Command, ctx context.Context, stream proto.DaemonService_ExposeServiceClient) error {
for {
_, err := stream.Recv()
if err != nil {
if ctx.Err() != nil {
cmd.Println("\nService stopped.")
//nolint:nilerr
return nil
}
if errors.Is(err, io.EOF) {
return fmt.Errorf("connection to daemon closed unexpectedly")
}
return fmt.Errorf("stream error: %w", err)
}
}
}

View File

@@ -10,7 +10,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/term"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
@@ -24,7 +23,6 @@ import (
func init() { func init() {
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc) loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
loginCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc) loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location") loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
} }
@@ -258,7 +256,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
} }
func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error { func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser, showQR) openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName}) resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil { if err != nil {
@@ -284,9 +282,13 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
} }
defer authClient.Close() defer authClient.Close()
needsLogin, err := authClient.IsLoginRequired(ctx) needsLogin := false
if err != nil {
return fmt.Errorf("check login required: %v", err) err, isAuthError := authClient.Login(ctx, "", "")
if isAuthError {
needsLogin = true
} else if err != nil {
return fmt.Errorf("login check failed: %v", err)
} }
jwtToken := "" jwtToken := ""
@@ -326,7 +328,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err) return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
} }
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser, showQR) openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo) tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
if err != nil { if err != nil {
@@ -336,7 +338,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
return &tokenInfo, nil return &tokenInfo, nil
} }
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser, showQR bool) { func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) {
var codeMsg string var codeMsg string
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) { if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode) codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
@@ -350,12 +352,6 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
verificationURIComplete + " " + codeMsg) verificationURIComplete + " " + codeMsg)
} }
if showQR {
if f, ok := cmd.OutOrStdout().(*os.File); ok && term.IsTerminal(int(f.Fd())) {
printQRCode(f, verificationURIComplete)
}
}
cmd.Println("") cmd.Println("")
if !noBrowser { if !noBrowser {

View File

@@ -1,25 +0,0 @@
package cmd
import (
"io"
"github.com/mdp/qrterminal/v3"
)
// printQRCode prints a QR code for the given URL to the writer.
// Called only when the user explicitly requests QR output via --qr.
func printQRCode(w io.Writer, url string) {
if url == "" {
return
}
qrterminal.GenerateWithConfig(url, qrterminal.Config{
Level: qrterminal.M,
Writer: w,
HalfBlocks: true,
BlackChar: qrterminal.BLACK_BLACK,
WhiteChar: qrterminal.WHITE_WHITE,
BlackWhiteChar: qrterminal.BLACK_WHITE,
WhiteBlackChar: qrterminal.WHITE_BLACK,
QuietZone: qrterminal.QUIET_ZONE,
})
}

View File

@@ -1,26 +0,0 @@
package cmd
import (
"bytes"
"testing"
)
func TestPrintQRCode_EmptyURL(t *testing.T) {
var buf bytes.Buffer
printQRCode(&buf, "")
if buf.Len() != 0 {
t.Error("expected no output for empty URL")
}
}
func TestPrintQRCode_WritesOutput(t *testing.T) {
var buf bytes.Buffer
printQRCode(&buf, "https://example.com/auth")
if buf.Len() == 0 {
t.Error("expected QR code output for non-empty URL")
}
}

View File

@@ -22,7 +22,6 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
) )
@@ -75,23 +74,12 @@ var (
mtu uint16 mtu uint16
profilesDisabled bool profilesDisabled bool
updateSettingsDisabled bool updateSettingsDisabled bool
captureEnabled bool
networksDisabled bool
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird", Use: "netbird",
Short: "", Short: "",
Long: "", Long: "",
SilenceUsage: true, SilenceUsage: true,
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(cmd.Root())
// Don't resolve for service commands — they create the socket, not connect to it.
if !isServiceCmd(cmd) {
daemonAddr = daddr.ResolveUnixDaemonAddr(daemonAddr)
}
return nil
},
} }
) )
@@ -156,7 +144,6 @@ func init() {
rootCmd.AddCommand(forwardingRulesCmd) rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd) rootCmd.AddCommand(debugCmd)
rootCmd.AddCommand(profileCmd) rootCmd.AddCommand(profileCmd)
rootCmd.AddCommand(exposeCmd)
networksCMD.AddCommand(routesListCmd) networksCMD.AddCommand(routesListCmd)
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd) networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
@@ -398,6 +385,7 @@ func migrateToNetbird(oldPath, newPath string) bool {
} }
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) { func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout()) cmd.SetOut(cmd.OutOrStdout())
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr) conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
@@ -410,13 +398,3 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
return conn, nil return conn, nil
} }
// isServiceCmd returns true if cmd is the "service" command or a child of it.
func isServiceCmd(cmd *cobra.Command) bool {
for c := cmd; c != nil; c = c.Parent() {
if c.Name() == "service" {
return true
}
}
return false
}

View File

@@ -41,17 +41,13 @@ func init() {
defaultServiceName = "Netbird" defaultServiceName = "Netbird"
} }
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd) serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd)
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles") serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings") serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
serviceCmd.PersistentFlags().BoolVar(&captureEnabled, "enable-capture", false, "Enables packet capture via 'netbird debug capture'. To persist, use: netbird service install --enable-capture")
serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks")
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
serviceEnvDesc := `Sets extra environment variables for the service. ` + serviceEnvDesc := `Sets extra environment variables for the service. ` +
`You can specify a comma-separated list of KEY=VALUE pairs. ` + `You can specify a comma-separated list of KEY=VALUE pairs. ` +
`New keys are merged with previously saved env vars; existing keys are overwritten. ` +
`Use --service-env "" to clear all saved env vars. ` +
`E.g. --service-env NB_LOG_LEVEL=debug,CUSTOM_VAR=value` `E.g. --service-env NB_LOG_LEVEL=debug,CUSTOM_VAR=value`
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc) installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)

View File

@@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error {
} }
} }
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, captureEnabled, networksDisabled) serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled)
if err := serverInstance.Start(); err != nil { if err := serverInstance.Start(); err != nil {
log.Fatalf("failed to start daemon: %v", err) log.Fatalf("failed to start daemon: %v", err)
} }
@@ -103,7 +103,7 @@ func (p *program) Stop(srv service.Service) error {
// Common setup for service control commands // Common setup for service control commands
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) { func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) {
// rootCmd env vars are already applied by PersistentPreRunE. SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(serviceCmd) SetFlagsFromEnvVars(serviceCmd)
cmd.SetOut(cmd.OutOrStdout()) cmd.SetOut(cmd.OutOrStdout())

View File

@@ -59,14 +59,6 @@ func buildServiceArguments() []string {
args = append(args, "--disable-update-settings") args = append(args, "--disable-update-settings")
} }
if captureEnabled {
args = append(args, "--enable-capture")
}
if networksDisabled {
args = append(args, "--disable-networks")
}
return args return args
} }
@@ -127,10 +119,6 @@ var installCmd = &cobra.Command{
return err return err
} }
if err := loadAndApplyServiceParams(cmd); err != nil {
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
}
svcConfig, err := createServiceConfigForInstall() svcConfig, err := createServiceConfigForInstall()
if err != nil { if err != nil {
return err return err
@@ -148,10 +136,6 @@ var installCmd = &cobra.Command{
return fmt.Errorf("install service: %w", err) return fmt.Errorf("install service: %w", err)
} }
if err := saveServiceParams(currentServiceParams()); err != nil {
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
}
cmd.Println("NetBird service has been installed") cmd.Println("NetBird service has been installed")
return nil return nil
}, },
@@ -203,10 +187,6 @@ This command will temporarily stop the service, update its configuration, and re
return err return err
} }
if err := loadAndApplyServiceParams(cmd); err != nil {
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
}
wasRunning, err := isServiceRunning() wasRunning, err := isServiceRunning()
if err != nil && !errors.Is(err, ErrGetServiceStatus) { if err != nil && !errors.Is(err, ErrGetServiceStatus) {
return fmt.Errorf("check service status: %w", err) return fmt.Errorf("check service status: %w", err)
@@ -242,10 +222,6 @@ This command will temporarily stop the service, update its configuration, and re
return fmt.Errorf("install service with new config: %w", err) return fmt.Errorf("install service with new config: %w", err)
} }
if err := saveServiceParams(currentServiceParams()); err != nil {
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
}
if wasRunning { if wasRunning {
cmd.Println("Starting NetBird service...") cmd.Println("Starting NetBird service...")
if err := s.Start(); err != nil { if err := s.Start(); err != nil {

View File

@@ -1,224 +0,0 @@
//go:build !ios && !android
package cmd
import (
"context"
"encoding/json"
"fmt"
"maps"
"os"
"path/filepath"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/configs"
"github.com/netbirdio/netbird/util"
)
const serviceParamsFile = "service.json"
// serviceParams holds install-time service parameters that persist across
// uninstall/reinstall cycles. Saved to <stateDir>/service.json.
type serviceParams struct {
LogLevel string `json:"log_level"`
DaemonAddr string `json:"daemon_addr"`
ManagementURL string `json:"management_url,omitempty"`
ConfigPath string `json:"config_path,omitempty"`
LogFiles []string `json:"log_files,omitempty"`
DisableProfiles bool `json:"disable_profiles,omitempty"`
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
EnableCapture bool `json:"enable_capture,omitempty"`
DisableNetworks bool `json:"disable_networks,omitempty"`
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
}
// serviceParamsPath returns the path to the service params file.
func serviceParamsPath() string {
return filepath.Join(configs.StateDir, serviceParamsFile)
}
// loadServiceParams reads saved service parameters from disk.
// Returns nil with no error if the file does not exist.
func loadServiceParams() (*serviceParams, error) {
path := serviceParamsPath()
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil, nil //nolint:nilnil
}
return nil, fmt.Errorf("read service params %s: %w", path, err)
}
var params serviceParams
if err := json.Unmarshal(data, &params); err != nil {
return nil, fmt.Errorf("parse service params %s: %w", path, err)
}
return &params, nil
}
// saveServiceParams writes current service parameters to disk atomically
// with restricted permissions.
func saveServiceParams(params *serviceParams) error {
path := serviceParamsPath()
if err := util.WriteJsonWithRestrictedPermission(context.Background(), path, params); err != nil {
return fmt.Errorf("save service params: %w", err)
}
return nil
}
// currentServiceParams captures the current state of all package-level
// variables into a serviceParams struct.
func currentServiceParams() *serviceParams {
params := &serviceParams{
LogLevel: logLevel,
DaemonAddr: daemonAddr,
ManagementURL: managementURL,
ConfigPath: configPath,
LogFiles: logFiles,
DisableProfiles: profilesDisabled,
DisableUpdateSettings: updateSettingsDisabled,
EnableCapture: captureEnabled,
DisableNetworks: networksDisabled,
}
if len(serviceEnvVars) > 0 {
parsed, err := parseServiceEnvVars(serviceEnvVars)
if err == nil {
params.ServiceEnvVars = parsed
}
}
return params
}
// loadAndApplyServiceParams loads saved params from disk and applies them
// to any flags that were not explicitly set.
func loadAndApplyServiceParams(cmd *cobra.Command) error {
params, err := loadServiceParams()
if err != nil {
return err
}
applyServiceParams(cmd, params)
return nil
}
// applyServiceParams merges saved parameters into package-level variables
// for any flag that was not explicitly set by the user (via CLI or env var).
// Flags that were Changed() are left untouched.
func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
if params == nil {
return
}
// For fields with non-empty defaults (log-level, daemon-addr), keep the
// != "" guard so that an older service.json missing the field doesn't
// clobber the default with an empty string.
if !rootCmd.PersistentFlags().Changed("log-level") && params.LogLevel != "" {
logLevel = params.LogLevel
}
if !rootCmd.PersistentFlags().Changed("daemon-addr") && params.DaemonAddr != "" {
daemonAddr = params.DaemonAddr
}
// For optional fields where empty means "use default", always apply so
// that an explicit clear (--management-url "") persists across reinstalls.
if !rootCmd.PersistentFlags().Changed("management-url") {
managementURL = params.ManagementURL
}
if !rootCmd.PersistentFlags().Changed("config") {
configPath = params.ConfigPath
}
if !rootCmd.PersistentFlags().Changed("log-file") {
logFiles = params.LogFiles
}
if !serviceCmd.PersistentFlags().Changed("disable-profiles") {
profilesDisabled = params.DisableProfiles
}
if !serviceCmd.PersistentFlags().Changed("disable-update-settings") {
updateSettingsDisabled = params.DisableUpdateSettings
}
if !serviceCmd.PersistentFlags().Changed("enable-capture") {
captureEnabled = params.EnableCapture
}
if !serviceCmd.PersistentFlags().Changed("disable-networks") {
networksDisabled = params.DisableNetworks
}
applyServiceEnvParams(cmd, params)
}
// applyServiceEnvParams merges saved service environment variables.
// If --service-env was explicitly set with values, explicit values win on key
// conflict but saved keys not in the explicit set are carried over.
// If --service-env was explicitly set to empty, all saved env vars are cleared.
// If --service-env was not set, saved env vars are used entirely.
func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) {
if !cmd.Flags().Changed("service-env") {
if len(params.ServiceEnvVars) > 0 {
// No explicit env vars: rebuild serviceEnvVars from saved params.
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
}
return
}
// Flag was explicitly set: parse what the user provided.
explicit, err := parseServiceEnvVars(serviceEnvVars)
if err != nil {
cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err)
return
}
// If the user passed an empty value (e.g. --service-env ""), clear all
// saved env vars rather than merging.
if len(explicit) == 0 {
serviceEnvVars = nil
return
}
if len(params.ServiceEnvVars) == 0 {
return
}
// Merge saved values underneath explicit ones.
merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit))
maps.Copy(merged, params.ServiceEnvVars)
maps.Copy(merged, explicit) // explicit wins on conflict
serviceEnvVars = envMapToSlice(merged)
}
var resetParamsCmd = &cobra.Command{
Use: "reset-params",
Short: "Remove saved service install parameters",
Long: "Removes the saved service.json file so the next install uses default parameters.",
RunE: func(cmd *cobra.Command, args []string) error {
path := serviceParamsPath()
if err := os.Remove(path); err != nil {
if os.IsNotExist(err) {
cmd.Println("No saved service parameters found")
return nil
}
return fmt.Errorf("remove service params: %w", err)
}
cmd.Printf("Removed saved service parameters (%s)\n", path)
return nil
},
}
// envMapToSlice converts a map of env vars to a KEY=VALUE slice.
func envMapToSlice(m map[string]string) []string {
s := make([]string, 0, len(m))
for k, v := range m {
s = append(s, k+"="+v)
}
return s
}

View File

@@ -1,560 +0,0 @@
//go:build !ios && !android
package cmd
import (
"encoding/json"
"go/ast"
"go/parser"
"go/token"
"os"
"path/filepath"
"strings"
"testing"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/configs"
)
func TestServiceParamsPath(t *testing.T) {
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = "/var/lib/netbird"
assert.Equal(t, filepath.Join("/var/lib/netbird", "service.json"), serviceParamsPath())
configs.StateDir = "/custom/state"
assert.Equal(t, filepath.Join("/custom/state", "service.json"), serviceParamsPath())
}
func TestSaveAndLoadServiceParams(t *testing.T) {
tmpDir := t.TempDir()
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = tmpDir
params := &serviceParams{
LogLevel: "debug",
DaemonAddr: "unix:///var/run/netbird.sock",
ManagementURL: "https://my.server.com",
ConfigPath: "/etc/netbird/config.json",
LogFiles: []string{"/var/log/netbird/client.log", "console"},
DisableProfiles: true,
DisableUpdateSettings: false,
ServiceEnvVars: map[string]string{"NB_LOG_FORMAT": "json", "CUSTOM": "val"},
}
err := saveServiceParams(params)
require.NoError(t, err)
// Verify the file exists and is valid JSON.
data, err := os.ReadFile(filepath.Join(tmpDir, "service.json"))
require.NoError(t, err)
assert.True(t, json.Valid(data))
loaded, err := loadServiceParams()
require.NoError(t, err)
require.NotNil(t, loaded)
assert.Equal(t, params.LogLevel, loaded.LogLevel)
assert.Equal(t, params.DaemonAddr, loaded.DaemonAddr)
assert.Equal(t, params.ManagementURL, loaded.ManagementURL)
assert.Equal(t, params.ConfigPath, loaded.ConfigPath)
assert.Equal(t, params.LogFiles, loaded.LogFiles)
assert.Equal(t, params.DisableProfiles, loaded.DisableProfiles)
assert.Equal(t, params.DisableUpdateSettings, loaded.DisableUpdateSettings)
assert.Equal(t, params.ServiceEnvVars, loaded.ServiceEnvVars)
}
func TestLoadServiceParams_FileNotExists(t *testing.T) {
tmpDir := t.TempDir()
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = tmpDir
params, err := loadServiceParams()
assert.NoError(t, err)
assert.Nil(t, params)
}
func TestLoadServiceParams_InvalidJSON(t *testing.T) {
tmpDir := t.TempDir()
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = tmpDir
err := os.WriteFile(filepath.Join(tmpDir, "service.json"), []byte("not json"), 0600)
require.NoError(t, err)
params, err := loadServiceParams()
assert.Error(t, err)
assert.Nil(t, params)
}
func TestCurrentServiceParams(t *testing.T) {
origLogLevel := logLevel
origDaemonAddr := daemonAddr
origManagementURL := managementURL
origConfigPath := configPath
origLogFiles := logFiles
origProfilesDisabled := profilesDisabled
origUpdateSettingsDisabled := updateSettingsDisabled
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() {
logLevel = origLogLevel
daemonAddr = origDaemonAddr
managementURL = origManagementURL
configPath = origConfigPath
logFiles = origLogFiles
profilesDisabled = origProfilesDisabled
updateSettingsDisabled = origUpdateSettingsDisabled
serviceEnvVars = origServiceEnvVars
})
logLevel = "trace"
daemonAddr = "tcp://127.0.0.1:9999"
managementURL = "https://mgmt.example.com"
configPath = "/tmp/test-config.json"
logFiles = []string{"/tmp/test.log"}
profilesDisabled = true
updateSettingsDisabled = true
serviceEnvVars = []string{"FOO=bar", "BAZ=qux"}
params := currentServiceParams()
assert.Equal(t, "trace", params.LogLevel)
assert.Equal(t, "tcp://127.0.0.1:9999", params.DaemonAddr)
assert.Equal(t, "https://mgmt.example.com", params.ManagementURL)
assert.Equal(t, "/tmp/test-config.json", params.ConfigPath)
assert.Equal(t, []string{"/tmp/test.log"}, params.LogFiles)
assert.True(t, params.DisableProfiles)
assert.True(t, params.DisableUpdateSettings)
assert.Equal(t, map[string]string{"FOO": "bar", "BAZ": "qux"}, params.ServiceEnvVars)
}
func TestApplyServiceParams_OnlyUnchangedFlags(t *testing.T) {
origLogLevel := logLevel
origDaemonAddr := daemonAddr
origManagementURL := managementURL
origConfigPath := configPath
origLogFiles := logFiles
origProfilesDisabled := profilesDisabled
origUpdateSettingsDisabled := updateSettingsDisabled
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() {
logLevel = origLogLevel
daemonAddr = origDaemonAddr
managementURL = origManagementURL
configPath = origConfigPath
logFiles = origLogFiles
profilesDisabled = origProfilesDisabled
updateSettingsDisabled = origUpdateSettingsDisabled
serviceEnvVars = origServiceEnvVars
})
// Reset all flags to defaults.
logLevel = "info"
daemonAddr = "unix:///var/run/netbird.sock"
managementURL = ""
configPath = "/etc/netbird/config.json"
logFiles = []string{"/var/log/netbird/client.log"}
profilesDisabled = false
updateSettingsDisabled = false
serviceEnvVars = nil
// Reset Changed state on all relevant flags.
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
// Simulate user explicitly setting --log-level via CLI.
logLevel = "warn"
require.NoError(t, rootCmd.PersistentFlags().Set("log-level", "warn"))
saved := &serviceParams{
LogLevel: "debug",
DaemonAddr: "tcp://127.0.0.1:5555",
ManagementURL: "https://saved.example.com",
ConfigPath: "/saved/config.json",
LogFiles: []string{"/saved/client.log"},
DisableProfiles: true,
DisableUpdateSettings: true,
ServiceEnvVars: map[string]string{"SAVED_KEY": "saved_val"},
}
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
applyServiceParams(cmd, saved)
// log-level was Changed, so it should keep "warn", not use saved "debug".
assert.Equal(t, "warn", logLevel)
// All other fields were not Changed, so they should use saved values.
assert.Equal(t, "tcp://127.0.0.1:5555", daemonAddr)
assert.Equal(t, "https://saved.example.com", managementURL)
assert.Equal(t, "/saved/config.json", configPath)
assert.Equal(t, []string{"/saved/client.log"}, logFiles)
assert.True(t, profilesDisabled)
assert.True(t, updateSettingsDisabled)
assert.Equal(t, []string{"SAVED_KEY=saved_val"}, serviceEnvVars)
}
func TestApplyServiceParams_BooleanRevertToFalse(t *testing.T) {
origProfilesDisabled := profilesDisabled
origUpdateSettingsDisabled := updateSettingsDisabled
t.Cleanup(func() {
profilesDisabled = origProfilesDisabled
updateSettingsDisabled = origUpdateSettingsDisabled
})
// Simulate current state where booleans are true (e.g. set by previous install).
profilesDisabled = true
updateSettingsDisabled = true
// Reset Changed state so flags appear unset.
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
// Saved params have both as false.
saved := &serviceParams{
DisableProfiles: false,
DisableUpdateSettings: false,
}
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
applyServiceParams(cmd, saved)
assert.False(t, profilesDisabled, "saved false should override current true")
assert.False(t, updateSettingsDisabled, "saved false should override current true")
}
func TestApplyServiceParams_ClearManagementURL(t *testing.T) {
origManagementURL := managementURL
t.Cleanup(func() { managementURL = origManagementURL })
managementURL = "https://leftover.example.com"
// Simulate saved params where management URL was explicitly cleared.
saved := &serviceParams{
LogLevel: "info",
DaemonAddr: "unix:///var/run/netbird.sock",
// ManagementURL intentionally empty: was cleared with --management-url "".
}
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
applyServiceParams(cmd, saved)
assert.Equal(t, "", managementURL, "saved empty management URL should clear the current value")
}
func TestApplyServiceParams_NilParams(t *testing.T) {
origLogLevel := logLevel
t.Cleanup(func() { logLevel = origLogLevel })
logLevel = "info"
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
// Should be a no-op.
applyServiceParams(cmd, nil)
assert.Equal(t, "info", logLevel)
}
func TestApplyServiceEnvParams_MergeExplicitAndSaved(t *testing.T) {
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
// Set up a command with --service-env marked as Changed.
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
require.NoError(t, cmd.Flags().Set("service-env", "EXPLICIT=yes,OVERLAP=explicit"))
serviceEnvVars = []string{"EXPLICIT=yes", "OVERLAP=explicit"}
saved := &serviceParams{
ServiceEnvVars: map[string]string{
"SAVED": "val",
"OVERLAP": "saved",
},
}
applyServiceEnvParams(cmd, saved)
// Parse result for easier assertion.
result, err := parseServiceEnvVars(serviceEnvVars)
require.NoError(t, err)
assert.Equal(t, "yes", result["EXPLICIT"])
assert.Equal(t, "val", result["SAVED"])
// Explicit wins on conflict.
assert.Equal(t, "explicit", result["OVERLAP"])
}
func TestApplyServiceEnvParams_NotChanged(t *testing.T) {
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
serviceEnvVars = nil
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
saved := &serviceParams{
ServiceEnvVars: map[string]string{"FROM_SAVED": "val"},
}
applyServiceEnvParams(cmd, saved)
result, err := parseServiceEnvVars(serviceEnvVars)
require.NoError(t, err)
assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result)
}
func TestApplyServiceEnvParams_ExplicitEmptyClears(t *testing.T) {
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
// Simulate --service-env "" which produces [""] in the slice.
serviceEnvVars = []string{""}
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
require.NoError(t, cmd.Flags().Set("service-env", ""))
saved := &serviceParams{
ServiceEnvVars: map[string]string{"OLD_VAR": "should_be_cleared"},
}
applyServiceEnvParams(cmd, saved)
assert.Nil(t, serviceEnvVars, "explicit empty --service-env should clear all saved env vars")
}
func TestCurrentServiceParams_EmptyEnvVarsAfterParse(t *testing.T) {
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
// Simulate --service-env "" which produces [""] in the slice.
serviceEnvVars = []string{""}
params := currentServiceParams()
// After parsing, the empty string is skipped, resulting in an empty map.
// The map should still be set (not nil) so it overwrites saved values.
assert.NotNil(t, params.ServiceEnvVars, "empty env vars should produce empty map, not nil")
assert.Empty(t, params.ServiceEnvVars, "no valid env vars should be parsed from empty string")
}
// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are
// referenced in both currentServiceParams() and applyServiceParams(). If a new field is
// added to serviceParams but not wired into these functions, this test fails.
func TestServiceParams_FieldsCoveredInFunctions(t *testing.T) {
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
require.NoError(t, err)
// Collect all JSON field names from the serviceParams struct.
structFields := extractStructJSONFields(t, file, "serviceParams")
require.NotEmpty(t, structFields, "failed to find serviceParams struct fields")
// Collect field names referenced in currentServiceParams and applyServiceParams.
currentFields := extractFuncFieldRefs(t, file, "currentServiceParams", structFields)
applyFields := extractFuncFieldRefs(t, file, "applyServiceParams", structFields)
// applyServiceEnvParams handles ServiceEnvVars indirectly.
applyEnvFields := extractFuncFieldRefs(t, file, "applyServiceEnvParams", structFields)
for k, v := range applyEnvFields {
applyFields[k] = v
}
for _, field := range structFields {
assert.Contains(t, currentFields, field,
"serviceParams field %q is not captured in currentServiceParams()", field)
assert.Contains(t, applyFields, field,
"serviceParams field %q is not restored in applyServiceParams()/applyServiceEnvParams()", field)
}
}
// TestServiceParams_BuildArgsCoversAllFlags ensures that buildServiceArguments references
// all serviceParams fields that should become CLI args. ServiceEnvVars is excluded because
// it flows through newSVCConfig() EnvVars, not CLI args.
func TestServiceParams_BuildArgsCoversAllFlags(t *testing.T) {
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
require.NoError(t, err)
structFields := extractStructJSONFields(t, file, "serviceParams")
require.NotEmpty(t, structFields)
installerFile, err := parser.ParseFile(fset, "service_installer.go", nil, 0)
require.NoError(t, err)
// Fields that are handled outside of buildServiceArguments (env vars go through newSVCConfig).
fieldsNotInArgs := map[string]bool{
"ServiceEnvVars": true,
}
buildFields := extractFuncGlobalRefs(t, installerFile, "buildServiceArguments")
// Forward: every struct field must appear in buildServiceArguments.
for _, field := range structFields {
if fieldsNotInArgs[field] {
continue
}
globalVar := fieldToGlobalVar(field)
assert.Contains(t, buildFields, globalVar,
"serviceParams field %q (global %q) is not referenced in buildServiceArguments()", field, globalVar)
}
// Reverse: every service-related global used in buildServiceArguments must
// have a corresponding serviceParams field. This catches a developer adding
// a new flag to buildServiceArguments without adding it to the struct.
globalToField := make(map[string]string, len(structFields))
for _, field := range structFields {
globalToField[fieldToGlobalVar(field)] = field
}
// Identifiers in buildServiceArguments that are not service params
// (builtins, boilerplate, loop variables).
nonParamGlobals := map[string]bool{
"args": true, "append": true, "string": true, "_": true,
"logFile": true, // range variable over logFiles
}
for ref := range buildFields {
if nonParamGlobals[ref] {
continue
}
_, inStruct := globalToField[ref]
assert.True(t, inStruct,
"buildServiceArguments() references global %q which has no corresponding serviceParams field", ref)
}
}
// extractStructJSONFields returns field names from a named struct type.
func extractStructJSONFields(t *testing.T, file *ast.File, structName string) []string {
t.Helper()
var fields []string
ast.Inspect(file, func(n ast.Node) bool {
ts, ok := n.(*ast.TypeSpec)
if !ok || ts.Name.Name != structName {
return true
}
st, ok := ts.Type.(*ast.StructType)
if !ok {
return false
}
for _, f := range st.Fields.List {
if len(f.Names) > 0 {
fields = append(fields, f.Names[0].Name)
}
}
return false
})
return fields
}
// extractFuncFieldRefs returns which of the given field names appear inside the
// named function, either as selector expressions (params.FieldName) or as
// composite literal keys (&serviceParams{FieldName: ...}).
func extractFuncFieldRefs(t *testing.T, file *ast.File, funcName string, fields []string) map[string]bool {
t.Helper()
fieldSet := make(map[string]bool, len(fields))
for _, f := range fields {
fieldSet[f] = true
}
found := make(map[string]bool)
fn := findFuncDecl(file, funcName)
require.NotNil(t, fn, "function %s not found", funcName)
ast.Inspect(fn.Body, func(n ast.Node) bool {
switch v := n.(type) {
case *ast.SelectorExpr:
if fieldSet[v.Sel.Name] {
found[v.Sel.Name] = true
}
case *ast.KeyValueExpr:
if ident, ok := v.Key.(*ast.Ident); ok && fieldSet[ident.Name] {
found[ident.Name] = true
}
}
return true
})
return found
}
// extractFuncGlobalRefs returns all identifier names referenced in the named function body.
func extractFuncGlobalRefs(t *testing.T, file *ast.File, funcName string) map[string]bool {
t.Helper()
fn := findFuncDecl(file, funcName)
require.NotNil(t, fn, "function %s not found", funcName)
refs := make(map[string]bool)
ast.Inspect(fn.Body, func(n ast.Node) bool {
if ident, ok := n.(*ast.Ident); ok {
refs[ident.Name] = true
}
return true
})
return refs
}
func findFuncDecl(file *ast.File, name string) *ast.FuncDecl {
for _, decl := range file.Decls {
fn, ok := decl.(*ast.FuncDecl)
if ok && fn.Name.Name == name {
return fn
}
}
return nil
}
// fieldToGlobalVar maps serviceParams field names to the package-level variable
// names used in buildServiceArguments and applyServiceParams.
func fieldToGlobalVar(field string) string {
m := map[string]string{
"LogLevel": "logLevel",
"DaemonAddr": "daemonAddr",
"ManagementURL": "managementURL",
"ConfigPath": "configPath",
"LogFiles": "logFiles",
"DisableProfiles": "profilesDisabled",
"DisableUpdateSettings": "updateSettingsDisabled",
"EnableCapture": "captureEnabled",
"DisableNetworks": "networksDisabled",
"ServiceEnvVars": "serviceEnvVars",
}
if v, ok := m[field]; ok {
return v
}
// Default: lowercase first letter.
return strings.ToLower(field[:1]) + field[1:]
}
func TestEnvMapToSlice(t *testing.T) {
m := map[string]string{"A": "1", "B": "2"}
s := envMapToSlice(m)
assert.Len(t, s, 2)
assert.Contains(t, s, "A=1")
assert.Contains(t, s, "B=2")
}
func TestEnvMapToSlice_Empty(t *testing.T) {
s := envMapToSlice(map[string]string{})
assert.Empty(t, s)
}

View File

@@ -4,9 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"os/signal"
"runtime" "runtime"
"syscall"
"testing" "testing"
"time" "time"
@@ -15,22 +13,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// TestMain intercepts when this test binary is run as a daemon subprocess.
// On FreeBSD, the rc.d service script runs the binary via daemon(8) -r with
// "service run ..." arguments. Since the test binary can't handle cobra CLI
// args, it exits immediately, causing daemon -r to respawn rapidly until
// hitting the rate limit and exiting. This makes service restart unreliable.
// Blocking here keeps the subprocess alive until the init system sends SIGTERM.
func TestMain(m *testing.M) {
if len(os.Args) > 2 && os.Args[1] == "service" && os.Args[2] == "run" {
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, os.Interrupt)
<-sig
return
}
os.Exit(m.Run())
}
const ( const (
serviceStartTimeout = 10 * time.Second serviceStartTimeout = 10 * time.Second
serviceStopTimeout = 5 * time.Second serviceStopTimeout = 5 * time.Second
@@ -97,34 +79,6 @@ func TestServiceLifecycle(t *testing.T) {
logLevel = "info" logLevel = "info"
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir) daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
t.Cleanup(func() {
cfg, err := newSVCConfig()
if err != nil {
t.Errorf("cleanup: create service config: %v", err)
return
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
t.Errorf("cleanup: create service: %v", err)
return
}
// If the subtests already cleaned up, there's nothing to do.
if _, err := s.Status(); err != nil {
return
}
if err := s.Stop(); err != nil {
t.Errorf("cleanup: stop service: %v", err)
}
if err := s.Uninstall(); err != nil {
t.Errorf("cleanup: uninstall service: %v", err)
}
})
ctx := context.Background() ctx := context.Background()
t.Run("Install", func(t *testing.T) { t.Run("Install", func(t *testing.T) {

View File

@@ -7,7 +7,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updater/reposign" "github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
) )
var ( var (

View File

@@ -6,7 +6,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updater/reposign" "github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
) )
const ( const (

View File

@@ -7,7 +7,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updater/reposign" "github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
) )
const ( const (

View File

@@ -7,7 +7,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updater/reposign" "github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
) )
var ( var (

View File

@@ -523,7 +523,7 @@ func parseHostnameAndCommand(args []string) error {
} }
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error { func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
target := net.JoinHostPort(strings.Trim(addr, "[]"), strconv.Itoa(port)) target := fmt.Sprintf("%s:%d", addr, port)
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{ c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
KnownHostsFile: knownHostsFile, KnownHostsFile: knownHostsFile,
IdentityFile: identityFile, IdentityFile: identityFile,
@@ -787,10 +787,10 @@ func isUnixSocket(path string) bool {
return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./") return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./")
} }
// normalizeLocalHost converts "*" to "" for binding to all interfaces (dual-stack). // normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces.
func normalizeLocalHost(host string) string { func normalizeLocalHost(host string) string {
if host == "*" { if host == "*" {
return "" return "0.0.0.0"
} }
return host return host
} }

View File

@@ -527,10 +527,10 @@ func TestParsePortForward(t *testing.T) {
{ {
name: "wildcard bind all interfaces", name: "wildcard bind all interfaces",
spec: "*:8080:localhost:80", spec: "*:8080:localhost:80",
expectedLocal: ":8080", expectedLocal: "0.0.0.0:8080",
expectedRemote: "localhost:80", expectedRemote: "localhost:80",
expectError: false, expectError: false,
description: "Wildcard * should bind to all interfaces (dual-stack)", description: "Wildcard * should bind to all interfaces (0.0.0.0)",
}, },
{ {
name: "wildcard for port only", name: "wildcard for port only",

View File

@@ -20,7 +20,6 @@ import (
var ( var (
detailFlag bool detailFlag bool
ipv4Flag bool ipv4Flag bool
ipv6Flag bool
jsonFlag bool jsonFlag bool
yamlFlag bool yamlFlag bool
ipsFilter []string ipsFilter []string
@@ -29,7 +28,6 @@ var (
ipsFilterMap map[string]struct{} ipsFilterMap map[string]struct{}
prefixNamesFilterMap map[string]struct{} prefixNamesFilterMap map[string]struct{}
connectionTypeFilter string connectionTypeFilter string
checkFlag string
) )
var statusCmd = &cobra.Command{ var statusCmd = &cobra.Command{
@@ -46,13 +44,11 @@ func init() {
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format") statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format") statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33") statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
statusCmd.PersistentFlags().BoolVar(&ipv6Flag, "ipv6", false, "display only NetBird IPv6 of this peer") statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4", "ipv6") statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected") statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P") statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
} }
func statusFunc(cmd *cobra.Command, args []string) error { func statusFunc(cmd *cobra.Command, args []string) error {
@@ -60,10 +56,6 @@ func statusFunc(cmd *cobra.Command, args []string) error {
cmd.SetOut(cmd.OutOrStdout()) cmd.SetOut(cmd.OutOrStdout())
if checkFlag != "" {
return runHealthCheck(cmd)
}
err := parseFilters() err := parseFilters()
if err != nil { if err != nil {
return err return err
@@ -76,17 +68,15 @@ func statusFunc(cmd *cobra.Command, args []string) error {
ctx := internal.CtxInitState(cmd.Context()) ctx := internal.CtxInitState(cmd.Context())
resp, err := getStatus(ctx, true, false) resp, err := getStatus(ctx, false)
if err != nil { if err != nil {
return err return err
} }
status := resp.GetStatus() status := resp.GetStatus()
needsAuth := status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) || if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
status == string(internal.StatusSessionExpired) status == string(internal.StatusSessionExpired) {
if needsAuth && !jsonFlag && !yamlFlag {
cmd.Printf("Daemon status: %s\n\n"+ cmd.Printf("Daemon status: %s\n\n"+
"Run UP command to log in with SSO (interactive login):\n\n"+ "Run UP command to log in with SSO (interactive login):\n\n"+
" netbird up \n\n"+ " netbird up \n\n"+
@@ -103,31 +93,13 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil return nil
} }
if ipv6Flag {
ipv6 := resp.GetFullStatus().GetLocalPeerState().GetIpv6()
if ipv6 != "" {
cmd.Print(parseInterfaceIP(ipv6))
}
return nil
}
pm := profilemanager.NewProfileManager() pm := profilemanager.NewProfileManager()
var profName string var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil { if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name profName = activeProf.Name
} }
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{ var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
Anonymize: anonymizeFlag,
DaemonVersion: resp.GetDaemonVersion(),
DaemonStatus: nbstatus.ParseDaemonStatus(status),
StatusFilter: statusFilter,
PrefixNamesFilter: prefixNamesFilter,
PrefixNamesFilterMap: prefixNamesFilterMap,
IPsFilter: ipsFilterMap,
ConnectionTypeFilter: connectionTypeFilter,
ProfileName: profName,
})
var statusOutputString string var statusOutputString string
switch { switch {
case detailFlag: case detailFlag:
@@ -149,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (*proto.StatusResponse, error) { func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil { if err != nil {
//nolint //nolint
@@ -159,7 +131,7 @@ func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (
} }
defer conn.Close() defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: fullPeerStatus, ShouldRunProbes: shouldRunProbes}) resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes})
if err != nil { if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message()) return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
} }
@@ -213,83 +185,6 @@ func enableDetailFlagWhenFilterFlag() {
} }
} }
func runHealthCheck(cmd *cobra.Command) error {
check := strings.ToLower(checkFlag)
switch check {
case "live", "ready", "startup":
default:
return fmt.Errorf("unknown check %q, must be one of: live, ready, startup", checkFlag)
}
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := internal.CtxInitState(cmd.Context())
isStartup := check == "startup"
resp, err := getStatus(ctx, isStartup, false)
if err != nil {
return err
}
switch check {
case "live":
return nil
case "ready":
return checkReadiness(resp)
case "startup":
return checkStartup(resp)
default:
return nil
}
}
func checkReadiness(resp *proto.StatusResponse) error {
daemonStatus := internal.StatusType(resp.GetStatus())
switch daemonStatus {
case internal.StatusIdle, internal.StatusConnecting, internal.StatusConnected:
return nil
case internal.StatusNeedsLogin, internal.StatusLoginFailed, internal.StatusSessionExpired:
return fmt.Errorf("readiness check: daemon status is %s", daemonStatus)
default:
return fmt.Errorf("readiness check: unexpected daemon status %q", daemonStatus)
}
}
func checkStartup(resp *proto.StatusResponse) error {
fullStatus := resp.GetFullStatus()
if fullStatus == nil {
return fmt.Errorf("startup check: no full status available")
}
if !fullStatus.GetManagementState().GetConnected() {
return fmt.Errorf("startup check: management not connected")
}
if !fullStatus.GetSignalState().GetConnected() {
return fmt.Errorf("startup check: signal not connected")
}
var relayCount, relaysConnected int
for _, r := range fullStatus.GetRelays() {
uri := r.GetURI()
if !strings.HasPrefix(uri, "rel://") && !strings.HasPrefix(uri, "rels://") {
continue
}
relayCount++
if r.GetAvailable() {
relaysConnected++
}
}
if relayCount > 0 && relaysConnected == 0 {
return fmt.Errorf("startup check: no relay servers available (0/%d connected)", relayCount)
}
return nil
}
func parseInterfaceIP(interfaceIP string) string { func parseInterfaceIP(interfaceIP string) string {
ip, _, err := net.ParseCIDR(interfaceIP) ip, _, err := net.ParseCIDR(interfaceIP)
if err != nil { if err != nil {

View File

@@ -8,7 +8,6 @@ const (
disableFirewallFlag = "disable-firewall" disableFirewallFlag = "disable-firewall"
blockLANAccessFlag = "block-lan-access" blockLANAccessFlag = "block-lan-access"
blockInboundFlag = "block-inbound" blockInboundFlag = "block-inbound"
disableIPv6Flag = "disable-ipv6"
) )
var ( var (
@@ -18,7 +17,6 @@ var (
disableFirewall bool disableFirewall bool
blockLANAccess bool blockLANAccess bool
blockInbound bool blockInbound bool
disableIPv6 bool
) )
func init() { func init() {
@@ -41,7 +39,4 @@ func init() {
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false, upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+ "Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
"This overrides any policies received from the management service.") "This overrides any policies received from the management service.")
upCmd.PersistentFlags().BoolVar(&disableIPv6, disableIPv6Flag, false,
"Disable IPv6 overlay. If enabled, the client won't request or use an IPv6 overlay address.")
} }

View File

@@ -13,8 +13,6 @@ import (
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/internals/modules/peers"
@@ -102,16 +100,9 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
jobManager := job.NewJobManager(nil, store, peersmanager) jobManager := job.NewJobManager(nil, store, peersmanager)
ctx := context.Background() iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil {
t.Fatal(err)
}
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
require.NoError(t, err) require.NoError(t, err)
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
@@ -122,11 +113,12 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
Return(&types.Settings{}, nil). Return(&types.Settings{}, nil).
AnyTimes() AnyTimes()
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics) updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
accountManager, err := mgmt.BuildManager(ctx, config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore) accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -135,7 +127,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil) mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -160,7 +152,7 @@ func startClientDaemon(
s := grpc.NewServer() s := grpc.NewServer()
server := client.New(ctx, server := client.New(ctx,
"", "", false, false, false, false) "", "", false, false)
if err := server.Start(); err != nil { if err := server.Start(); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -39,9 +39,6 @@ const (
noBrowserFlag = "no-browser" noBrowserFlag = "no-browser"
noBrowserDesc = "do not open the browser for SSO login" noBrowserDesc = "do not open the browser for SSO login"
showQRFlag = "qr"
showQRDesc = "show QR code for the SSO login URL (useful for headless machines without browser access)"
profileNameFlag = "profile" profileNameFlag = "profile"
profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used." profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used."
) )
@@ -51,7 +48,6 @@ var (
dnsLabels []string dnsLabels []string
dnsLabelsValidated domain.List dnsLabelsValidated domain.List
noBrowser bool noBrowser bool
showQR bool
profileName string profileName string
configPath string configPath string
@@ -84,7 +80,6 @@ func init() {
) )
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc) upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
upCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc) upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ") upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ")
@@ -202,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
r := peer.NewRecorder(config.ManagementURL.String()) r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus() r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r) connectClient := internal.NewConnectClient(ctx, config, r, false)
SetupDebugHandler(ctx, config, r, connectClient, "") SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil, util.FindFirstLogPath(logFiles)) return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
@@ -435,10 +430,6 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
req.BlockInbound = &blockInbound req.BlockInbound = &blockInbound
} }
if cmd.Flag(disableIPv6Flag).Changed {
req.DisableIpv6 = &disableIPv6
}
if cmd.Flag(enableLazyConnectionFlag).Changed { if cmd.Flag(enableLazyConnectionFlag).Changed {
req.LazyConnectionEnabled = &lazyConnEnabled req.LazyConnectionEnabled = &lazyConnEnabled
} }
@@ -556,10 +547,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.BlockInbound = &blockInbound ic.BlockInbound = &blockInbound
} }
if cmd.Flag(disableIPv6Flag).Changed {
ic.DisableIPv6 = &disableIPv6
}
if cmd.Flag(enableLazyConnectionFlag).Changed { if cmd.Flag(enableLazyConnectionFlag).Changed {
ic.LazyConnectionEnabled = &lazyConnEnabled ic.LazyConnectionEnabled = &lazyConnEnabled
} }
@@ -674,10 +661,6 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.BlockInbound = &blockInbound loginRequest.BlockInbound = &blockInbound
} }
if cmd.Flag(disableIPv6Flag).Changed {
loginRequest.DisableIpv6 = &disableIPv6
}
if cmd.Flag(enableLazyConnectionFlag).Changed { if cmd.Flag(enableLazyConnectionFlag).Changed {
loginRequest.LazyConnectionEnabled = &lazyConnEnabled loginRequest.LazyConnectionEnabled = &lazyConnEnabled
} }

View File

@@ -11,7 +11,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updater/installer" "github.com/netbirdio/netbird/client/internal/updatemanager/installer"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )

View File

@@ -1,65 +0,0 @@
package embed
import (
"io"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/util/capture"
)
// CaptureOptions configures a packet capture session.
type CaptureOptions struct {
// Output receives pcap-formatted data. Nil disables pcap output.
Output io.Writer
// TextOutput receives human-readable packet summaries. Nil disables text output.
TextOutput io.Writer
// Filter is a BPF-like filter expression (e.g. "host 10.0.0.1 and tcp port 443").
// Empty captures all packets.
Filter string
// Verbose adds seq/ack, TTL, window, and total length to text output.
Verbose bool
// ASCII dumps transport payload as printable ASCII after each packet line.
ASCII bool
}
// CaptureStats reports capture session counters.
type CaptureStats struct {
Packets int64
Bytes int64
Dropped int64
}
// CaptureSession represents an active packet capture. Call Stop to end the
// capture and flush buffered packets.
type CaptureSession struct {
sess *capture.Session
engine *internal.Engine
}
// Stop ends the capture, flushes remaining packets, and detaches from the device.
// Safe to call multiple times.
func (cs *CaptureSession) Stop() {
if cs.engine != nil {
_ = cs.engine.SetCapture(nil)
cs.engine = nil
}
if cs.sess != nil {
cs.sess.Stop()
}
}
// Stats returns current capture counters.
func (cs *CaptureSession) Stats() CaptureStats {
s := cs.sess.Stats()
return CaptureStats{
Packets: s.Packets,
Bytes: s.Bytes,
Dropped: s.Dropped,
}
}
// Done returns a channel that is closed when the capture's writer goroutine
// has fully exited and all buffered packets have been flushed.
func (cs *CaptureSession) Done() <-chan struct{} {
return cs.sess.Done()
}

View File

@@ -14,7 +14,6 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack" wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/internal/auth"
@@ -22,9 +21,7 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh" sshcommon "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util/capture"
) )
var ( var (
@@ -34,14 +31,6 @@ var (
ErrConfigNotInitialized = errors.New("config not initialized") ErrConfigNotInitialized = errors.New("config not initialized")
) )
const (
// PeerStatusConnected indicates the peer is in connected state.
PeerStatusConnected = peer.StatusConnected
)
// PeerConnStatus is a peer's connection status.
type PeerConnStatus = peer.ConnStatus
// Client manages a netbird embedded client instance. // Client manages a netbird embedded client instance.
type Client struct { type Client struct {
deviceName string deviceName string
@@ -66,7 +55,7 @@ type Options struct {
PrivateKey string PrivateKey string
// ManagementURL overrides the default management server URL // ManagementURL overrides the default management server URL
ManagementURL string ManagementURL string
// PreSharedKey is the pre-shared key for the tunnel interface // PreSharedKey is the pre-shared key for the WireGuard interface
PreSharedKey string PreSharedKey string
// LogOutput is the output destination for logs (defaults to os.Stderr if nil) // LogOutput is the output destination for logs (defaults to os.Stderr if nil)
LogOutput io.Writer LogOutput io.Writer
@@ -80,20 +69,8 @@ type Options struct {
StatePath string StatePath string
// DisableClientRoutes disables the client routes // DisableClientRoutes disables the client routes
DisableClientRoutes bool DisableClientRoutes bool
// DisableIPv6 disables IPv6 overlay addressing
DisableIPv6 bool
// BlockInbound blocks all inbound connections from peers // BlockInbound blocks all inbound connections from peers
BlockInbound bool BlockInbound bool
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
WireguardPort *int
// MTU is the MTU for the tunnel interface.
// Valid values are in the range 576..8192 bytes.
// If non-nil, this value overrides any value stored in the config file.
// If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280.
// Set to a higher value (e.g. 1400) if carrying QUIC or other protocols that require larger datagrams.
MTU *uint16
// DNSLabels defines additional DNS labels configured in the peer.
DNSLabels []string
} }
// validateCredentials checks that exactly one credential type is provided // validateCredentials checks that exactly one credential type is provided
@@ -125,12 +102,6 @@ func New(opts Options) (*Client, error) {
return nil, err return nil, err
} }
if opts.MTU != nil {
if err := iface.ValidateMTU(*opts.MTU); err != nil {
return nil, fmt.Errorf("invalid MTU: %w", err)
}
}
if opts.LogOutput != nil { if opts.LogOutput != nil {
logrus.SetOutput(opts.LogOutput) logrus.SetOutput(opts.LogOutput)
} }
@@ -159,25 +130,16 @@ func New(opts Options) (*Client, error) {
} }
} }
var err error
var parsedLabels domain.List
if parsedLabels, err = domain.FromStringList(opts.DNSLabels); err != nil {
return nil, fmt.Errorf("invalid dns labels: %w", err)
}
t := true t := true
var config *profilemanager.Config var config *profilemanager.Config
var err error
input := profilemanager.ConfigInput{ input := profilemanager.ConfigInput{
ConfigPath: opts.ConfigPath, ConfigPath: opts.ConfigPath,
ManagementURL: opts.ManagementURL, ManagementURL: opts.ManagementURL,
PreSharedKey: &opts.PreSharedKey, PreSharedKey: &opts.PreSharedKey,
DisableServerRoutes: &t, DisableServerRoutes: &t,
DisableClientRoutes: &opts.DisableClientRoutes, DisableClientRoutes: &opts.DisableClientRoutes,
DisableIPv6: &opts.DisableIPv6,
BlockInbound: &opts.BlockInbound, BlockInbound: &opts.BlockInbound,
WireguardPort: opts.WireguardPort,
MTU: opts.MTU,
DNSLabels: parsedLabels,
} }
if opts.ConfigPath != "" { if opts.ConfigPath != "" {
config, err = profilemanager.UpdateOrCreateConfig(input) config, err = profilemanager.UpdateOrCreateConfig(input)
@@ -197,7 +159,6 @@ func New(opts Options) (*Client, error) {
setupKey: opts.SetupKey, setupKey: opts.SetupKey,
jwtToken: opts.JWTToken, jwtToken: opts.JWTToken,
config: config, config: config,
recorder: peer.NewRecorder(config.ManagementURL.String()),
}, nil }, nil
} }
@@ -219,7 +180,6 @@ func (c *Client) Start(startCtx context.Context) error {
// nolint:staticcheck // nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName) ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config) authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
if err != nil { if err != nil {
return fmt.Errorf("create auth client: %w", err) return fmt.Errorf("create auth client: %w", err)
@@ -229,7 +189,10 @@ func (c *Client) Start(startCtx context.Context) error {
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil { if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err) return fmt.Errorf("login: %w", err)
} }
client := internal.NewConnectClient(ctx, c.config, c.recorder)
recorder := peer.NewRecorder(c.config.ManagementURL.String())
c.recorder = recorder
client := internal.NewConnectClient(ctx, c.config, recorder, false)
client.SetSyncResponsePersistence(true) client.SetSyncResponsePersistence(true)
// either startup error (permanent backoff err) or nil err (successful engine up) // either startup error (permanent backoff err) or nil err (successful engine up)
@@ -379,38 +342,17 @@ func (c *Client) NewHTTPClient() *http.Client {
} }
} }
// Expose exposes a local service via the NetBird reverse proxy, making it accessible through a public URL.
// It returns an ExposeSession. Call Wait on the session to keep it alive.
func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession, error) {
engine, err := c.getEngine()
if err != nil {
return nil, err
}
mgr := engine.GetExposeManager()
if mgr == nil {
return nil, fmt.Errorf("expose manager not available")
}
resp, err := mgr.Expose(ctx, req)
if err != nil {
return nil, fmt.Errorf("expose: %w", err)
}
return &ExposeSession{
Domain: resp.Domain,
ServiceName: resp.ServiceName,
ServiceURL: resp.ServiceURL,
mgr: mgr,
}, nil
}
// Status returns the current status of the client. // Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) { func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock() c.mu.Lock()
recorder := c.recorder
connect := c.connect connect := c.connect
c.mu.Unlock() c.mu.Unlock()
if recorder == nil {
return peer.FullStatus{}, errors.New("client not started")
}
if connect != nil { if connect != nil {
engine := connect.Engine() engine := connect.Engine()
if engine != nil { if engine != nil {
@@ -418,7 +360,7 @@ func (c *Client) Status() (peer.FullStatus, error) {
} }
} }
return c.recorder.GetFullStatus(), nil return recorder.GetFullStatus(), nil
} }
// GetLatestSyncResponse returns the latest sync response from the management server. // GetLatestSyncResponse returns the latest sync response from the management server.
@@ -473,52 +415,6 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
return sshcommon.VerifyHostKey(storedKey, key, peerAddress) return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
} }
// StartCapture begins capturing packets on this client's tunnel device.
// Only one capture can be active at a time; starting a new one stops the previous.
// Call StopCapture (or CaptureSession.Stop) to end it.
func (c *Client) StartCapture(opts CaptureOptions) (*CaptureSession, error) {
engine, err := c.getEngine()
if err != nil {
return nil, err
}
var matcher capture.Matcher
if opts.Filter != "" {
m, err := capture.ParseFilter(opts.Filter)
if err != nil {
return nil, fmt.Errorf("parse filter: %w", err)
}
matcher = m
}
sess, err := capture.NewSession(capture.Options{
Output: opts.Output,
TextOutput: opts.TextOutput,
Matcher: matcher,
Verbose: opts.Verbose,
ASCII: opts.ASCII,
})
if err != nil {
return nil, fmt.Errorf("create capture session: %w", err)
}
if err := engine.SetCapture(sess); err != nil {
sess.Stop()
return nil, fmt.Errorf("set capture: %w", err)
}
return &CaptureSession{sess: sess, engine: engine}, nil
}
// StopCapture stops the active capture session if one is running.
func (c *Client) StopCapture() error {
engine, err := c.getEngine()
if err != nil {
return err
}
return engine.SetCapture(nil)
}
// getEngine safely retrieves the engine from the client with proper locking. // getEngine safely retrieves the engine from the client with proper locking.
// Returns ErrClientNotStarted if the client is not started. // Returns ErrClientNotStarted if the client is not started.
// Returns ErrEngineNotStarted if the engine is not available. // Returns ErrEngineNotStarted if the engine is not available.

View File

@@ -1,45 +0,0 @@
package embed
import (
"context"
"errors"
"github.com/netbirdio/netbird/client/internal/expose"
)
const (
// ExposeProtocolHTTP exposes the service as HTTP.
ExposeProtocolHTTP = expose.ProtocolHTTP
// ExposeProtocolHTTPS exposes the service as HTTPS.
ExposeProtocolHTTPS = expose.ProtocolHTTPS
// ExposeProtocolTCP exposes the service as TCP.
ExposeProtocolTCP = expose.ProtocolTCP
// ExposeProtocolUDP exposes the service as UDP.
ExposeProtocolUDP = expose.ProtocolUDP
// ExposeProtocolTLS exposes the service as TLS.
ExposeProtocolTLS = expose.ProtocolTLS
)
// ExposeRequest is a request to expose a local service via the NetBird reverse proxy.
type ExposeRequest = expose.Request
// ExposeProtocolType represents the protocol used for exposing a service.
type ExposeProtocolType = expose.ProtocolType
// ExposeSession represents an active expose session. Use Wait to block until the session ends.
type ExposeSession struct {
Domain string
ServiceName string
ServiceURL string
mgr *expose.Manager
}
// Wait blocks while keeping the expose session alive.
// It returns when ctx is cancelled or a keep-alive error occurs, then terminates the session.
func (s *ExposeSession) Wait(ctx context.Context) error {
if s == nil || s.mgr == nil {
return errors.New("expose session is not initialized")
}
return s.mgr.KeepAlive(ctx, s.Domain)
}

View File

@@ -6,7 +6,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"os" "os"
"strconv"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/google/nftables" "github.com/google/nftables"
@@ -36,34 +35,20 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
type FWType int type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
// We run in userspace mode and force userspace firewall was requested. We don't attempt native firewall. // on the linux system we try to user nftables or iptables
if iface.IsUserspaceBind() && forceUserspaceFirewall() { // in any case, because we need to allow netbird interface traffic
log.Info("forcing userspace firewall") // so we use AllowNetbird traffic from these firewall managers
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu) // for the userspace packet filtering firewall
}
// Use native firewall for either kernel or userspace, the interface appears identical to netfilter
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu) fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
// Kernel cannot fall back to anything else, need to return error
if !iface.IsUserspaceBind() { if !iface.IsUserspaceBind() {
return fm, err return fm, err
} }
// Fall back to the userspace packet filter if native is unavailable
if err != nil { if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
} }
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
// Native firewall handles packet filtering, but the userspace WireGuard bind
// needs a device filter for DNS interception hooks. Install a minimal
// hooks-only filter that passes all traffic through to the kernel firewall.
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
}
return fm, nil
} }
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) { func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
@@ -175,17 +160,3 @@ func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter") _, err := client.ListChains("filter")
return err == nil return err == nil
} }
func forceUserspaceFirewall() bool {
val := os.Getenv(EnvForceUserspaceFirewall)
if val == "" {
return false
}
force, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvForceUserspaceFirewall, err)
return false
}
return force
}

View File

@@ -1,11 +0,0 @@
// Package firewalld integrates with the firewalld daemon so NetBird can place
// its wg interface into firewalld's "trusted" zone. This is required because
// firewalld's nftables chains are created with NFT_CHAIN_OWNER on recent
// versions, which returns EPERM to any other process that tries to insert
// rules into them. The workaround mirrors what Tailscale does: let firewalld
// itself add the accept rules to its own chains by trusting the interface.
package firewalld
// TrustedZone is the firewalld zone name used for interfaces whose traffic
// should bypass firewalld filtering.
const TrustedZone = "trusted"

View File

@@ -1,260 +0,0 @@
//go:build linux
package firewalld
import (
"context"
"errors"
"fmt"
"os/exec"
"strings"
"sync"
"time"
"github.com/godbus/dbus/v5"
log "github.com/sirupsen/logrus"
)
const (
dbusDest = "org.fedoraproject.FirewallD1"
dbusPath = "/org/fedoraproject/FirewallD1"
dbusRootIface = "org.fedoraproject.FirewallD1"
dbusZoneIface = "org.fedoraproject.FirewallD1.zone"
errZoneAlreadySet = "ZONE_ALREADY_SET"
errAlreadyEnabled = "ALREADY_ENABLED"
errUnknownIface = "UNKNOWN_INTERFACE"
errNotEnabled = "NOT_ENABLED"
// callTimeout bounds each individual DBus or firewall-cmd invocation.
// A fresh context is created for each call so a slow DBus probe can't
// exhaust the deadline before the firewall-cmd fallback gets to run.
callTimeout = 3 * time.Second
)
var (
errDBusUnavailable = errors.New("firewalld dbus unavailable")
// trustLogOnce ensures the "added to trusted zone" message is logged at
// Info level only for the first successful add per process; repeat adds
// from other init paths are quieter.
trustLogOnce sync.Once
parentCtxMu sync.RWMutex
parentCtx context.Context = context.Background()
)
// SetParentContext installs a parent context whose cancellation aborts any
// in-flight TrustInterface call. It does not affect UntrustInterface, which
// always uses a fresh Background-rooted timeout so cleanup can still run
// during engine shutdown when the engine context is already cancelled.
func SetParentContext(ctx context.Context) {
parentCtxMu.Lock()
parentCtx = ctx
parentCtxMu.Unlock()
}
func getParentContext() context.Context {
parentCtxMu.RLock()
defer parentCtxMu.RUnlock()
return parentCtx
}
// TrustInterface places iface into firewalld's trusted zone if firewalld is
// running. It is idempotent and best-effort: errors are returned so callers
// can log, but a non-running firewalld is not an error. Only the first
// successful call per process logs at Info. Respects the parent context set
// via SetParentContext so startup-time cancellation unblocks it.
func TrustInterface(iface string) error {
parent := getParentContext()
if !isRunning(parent) {
return nil
}
if err := addTrusted(parent, iface); err != nil {
return fmt.Errorf("add %s to firewalld trusted zone: %w", iface, err)
}
trustLogOnce.Do(func() {
log.Infof("added %s to firewalld trusted zone", iface)
})
log.Debugf("firewalld: ensured %s is in trusted zone", iface)
return nil
}
// UntrustInterface removes iface from firewalld's trusted zone if firewalld
// is running. Idempotent. Uses a Background-rooted timeout so it still runs
// during shutdown after the engine context has been cancelled.
func UntrustInterface(iface string) error {
if !isRunning(context.Background()) {
return nil
}
if err := removeTrusted(context.Background(), iface); err != nil {
return fmt.Errorf("remove %s from firewalld trusted zone: %w", iface, err)
}
return nil
}
func newCallContext(parent context.Context) (context.Context, context.CancelFunc) {
return context.WithTimeout(parent, callTimeout)
}
func isRunning(parent context.Context) bool {
ctx, cancel := newCallContext(parent)
ok, err := isRunningDBus(ctx)
cancel()
if err == nil {
return ok
}
if errors.Is(err, errDBusUnavailable) || errors.Is(err, context.DeadlineExceeded) {
ctx, cancel = newCallContext(parent)
defer cancel()
return isRunningCLI(ctx)
}
return false
}
func addTrusted(parent context.Context, iface string) error {
ctx, cancel := newCallContext(parent)
err := addDBus(ctx, iface)
cancel()
if err == nil {
return nil
}
if !errors.Is(err, errDBusUnavailable) {
log.Debugf("firewalld: dbus add failed, falling back to firewall-cmd: %v", err)
}
ctx, cancel = newCallContext(parent)
defer cancel()
return addCLI(ctx, iface)
}
func removeTrusted(parent context.Context, iface string) error {
ctx, cancel := newCallContext(parent)
err := removeDBus(ctx, iface)
cancel()
if err == nil {
return nil
}
if !errors.Is(err, errDBusUnavailable) {
log.Debugf("firewalld: dbus remove failed, falling back to firewall-cmd: %v", err)
}
ctx, cancel = newCallContext(parent)
defer cancel()
return removeCLI(ctx, iface)
}
func isRunningDBus(ctx context.Context) (bool, error) {
conn, err := dbus.SystemBus()
if err != nil {
return false, fmt.Errorf("%w: %v", errDBusUnavailable, err)
}
obj := conn.Object(dbusDest, dbusPath)
var zone string
if err := obj.CallWithContext(ctx, dbusRootIface+".getDefaultZone", 0).Store(&zone); err != nil {
return false, fmt.Errorf("firewalld getDefaultZone: %w", err)
}
return true, nil
}
func isRunningCLI(ctx context.Context) bool {
if _, err := exec.LookPath("firewall-cmd"); err != nil {
return false
}
return exec.CommandContext(ctx, "firewall-cmd", "--state").Run() == nil
}
func addDBus(ctx context.Context, iface string) error {
conn, err := dbus.SystemBus()
if err != nil {
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
}
obj := conn.Object(dbusDest, dbusPath)
call := obj.CallWithContext(ctx, dbusZoneIface+".addInterface", 0, TrustedZone, iface)
if call.Err == nil {
return nil
}
if dbusErrContains(call.Err, errAlreadyEnabled) {
return nil
}
if dbusErrContains(call.Err, errZoneAlreadySet) {
move := obj.CallWithContext(ctx, dbusZoneIface+".changeZoneOfInterface", 0, TrustedZone, iface)
if move.Err != nil {
return fmt.Errorf("firewalld changeZoneOfInterface: %w", move.Err)
}
return nil
}
return fmt.Errorf("firewalld addInterface: %w", call.Err)
}
func removeDBus(ctx context.Context, iface string) error {
conn, err := dbus.SystemBus()
if err != nil {
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
}
obj := conn.Object(dbusDest, dbusPath)
call := obj.CallWithContext(ctx, dbusZoneIface+".removeInterface", 0, TrustedZone, iface)
if call.Err == nil {
return nil
}
if dbusErrContains(call.Err, errUnknownIface) || dbusErrContains(call.Err, errNotEnabled) {
return nil
}
return fmt.Errorf("firewalld removeInterface: %w", call.Err)
}
func addCLI(ctx context.Context, iface string) error {
if _, err := exec.LookPath("firewall-cmd"); err != nil {
return fmt.Errorf("firewall-cmd not available: %w", err)
}
// --change-interface (no --permanent) binds the interface for the
// current runtime only; we do not want membership to persist across
// reboots because netbird re-asserts it on every startup.
out, err := exec.CommandContext(ctx,
"firewall-cmd", "--zone="+TrustedZone, "--change-interface="+iface,
).CombinedOutput()
if err != nil {
return fmt.Errorf("firewall-cmd change-interface: %w: %s", err, strings.TrimSpace(string(out)))
}
return nil
}
func removeCLI(ctx context.Context, iface string) error {
if _, err := exec.LookPath("firewall-cmd"); err != nil {
return fmt.Errorf("firewall-cmd not available: %w", err)
}
out, err := exec.CommandContext(ctx,
"firewall-cmd", "--zone="+TrustedZone, "--remove-interface="+iface,
).CombinedOutput()
if err != nil {
msg := strings.TrimSpace(string(out))
if strings.Contains(msg, errUnknownIface) || strings.Contains(msg, errNotEnabled) {
return nil
}
return fmt.Errorf("firewall-cmd remove-interface: %w: %s", err, msg)
}
return nil
}
func dbusErrContains(err error, code string) bool {
if err == nil {
return false
}
var de dbus.Error
if errors.As(err, &de) {
for _, b := range de.Body {
if s, ok := b.(string); ok && strings.Contains(s, code) {
return true
}
}
}
return strings.Contains(err.Error(), code)
}

View File

@@ -1,49 +0,0 @@
//go:build linux
package firewalld
import (
"errors"
"testing"
"github.com/godbus/dbus/v5"
)
func TestDBusErrContains(t *testing.T) {
tests := []struct {
name string
err error
code string
want bool
}{
{"nil error", nil, errZoneAlreadySet, false},
{"plain error match", errors.New("ZONE_ALREADY_SET: wt0"), errZoneAlreadySet, true},
{"plain error miss", errors.New("something else"), errZoneAlreadySet, false},
{
"dbus.Error body match",
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"ZONE_ALREADY_SET: wt0"}},
errZoneAlreadySet,
true,
},
{
"dbus.Error body miss",
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"INVALID_INTERFACE"}},
errAlreadyEnabled,
false,
},
{
"dbus.Error non-string body falls back to Error()",
dbus.Error{Name: "x", Body: []any{123}},
"x",
true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := dbusErrContains(tc.err, tc.code)
if got != tc.want {
t.Fatalf("dbusErrContains(%v, %q) = %v; want %v", tc.err, tc.code, got, tc.want)
}
})
}
}

View File

@@ -1,25 +0,0 @@
//go:build !linux
package firewalld
import "context"
// SetParentContext is a no-op on non-Linux platforms because firewalld only
// runs on Linux.
func SetParentContext(context.Context) {
// intentionally empty: firewalld is a Linux-only daemon
}
// TrustInterface is a no-op on non-Linux platforms because firewalld only
// runs on Linux.
func TrustInterface(string) error {
// intentionally empty: firewalld is a Linux-only daemon
return nil
}
// UntrustInterface is a no-op on non-Linux platforms because firewalld only
// runs on Linux.
func UntrustInterface(string) error {
// intentionally empty: firewalld is a Linux-only daemon
return nil
}

View File

@@ -7,12 +7,6 @@ import (
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// EnvForceUserspaceFirewall forces the use of the userspace packet filter even when
// native iptables/nftables is available. This only applies when the WireGuard interface
// runs in userspace mode. When set, peer ACLs are handled by USPFilter instead of
// kernel netfilter rules.
const EnvForceUserspaceFirewall = "NB_FORCE_USERSPACE_FIREWALL"
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
Name() string Name() string

View File

@@ -21,10 +21,6 @@ const (
// rules chains contains the effective ACL rules // rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT" chainNameInputRules = "NETBIRD-ACL-INPUT"
// mangleFwdKey is the entries map key for mangle FORWARD guard rules that prevent
// external DNAT from bypassing ACL rules.
mangleFwdKey = "MANGLE-FORWARD"
) )
type aclEntries map[string][][]string type aclEntries map[string][][]string
@@ -40,7 +36,6 @@ type aclManager struct {
entries aclEntries entries aclEntries
optionalEntries map[string][]entry optionalEntries map[string][]entry
ipsetStore *ipsetStore ipsetStore *ipsetStore
v6 bool
stateManager *statemanager.Manager stateManager *statemanager.Manager
} }
@@ -52,7 +47,6 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*acl
entries: make(map[string][][]string), entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry), optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(), ipsetStore: newIpsetStore(),
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
}, nil }, nil
} }
@@ -87,11 +81,7 @@ func (m *aclManager) AddPeerFiltering(
chain := chainNameInputRules chain := chainNameInputRules
ipsetName = transformIPsetName(ipsetName, sPort, dPort, action) ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
if m.v6 && ipsetName != "" { specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName)
ipsetName += "-v6"
}
proto := protoForFamily(protocol, m.v6)
specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName)
mangleSpecs := slices.Clone(specs) mangleSpecs := slices.Clone(specs)
mangleSpecs = append(mangleSpecs, mangleSpecs = append(mangleSpecs,
@@ -115,7 +105,6 @@ func (m *aclManager) AddPeerFiltering(
ip: ip.String(), ip: ip.String(),
chain: chain, chain: chain,
specs: specs, specs: specs,
v6: m.v6,
}}, nil }}, nil
} }
@@ -168,7 +157,6 @@ func (m *aclManager) AddPeerFiltering(
ipsetName: ipsetName, ipsetName: ipsetName,
ip: ip.String(), ip: ip.String(),
chain: chain, chain: chain,
v6: m.v6,
} }
m.updateState() m.updateState()
@@ -286,12 +274,6 @@ func (m *aclManager) cleanChains() error {
} }
} }
for _, rule := range m.entries[mangleFwdKey] {
if err := m.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil {
log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err)
}
}
for _, ipsetName := range m.ipsetStore.ipsetNames() { for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := m.flushIPSet(ipsetName); err != nil { if err := m.flushIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrSetNotExist) { if errors.Is(err, ipset.ErrSetNotExist) {
@@ -321,10 +303,6 @@ func (m *aclManager) createDefaultChains() error {
} }
for chainName, rules := range m.entries { for chainName, rules := range m.entries {
// mangle FORWARD guard rules are handled separately below
if chainName == mangleFwdKey {
continue
}
for _, rule := range rules { for _, rule := range rules {
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil { if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err) log.Debugf("failed to create input chain jump rule: %s", err)
@@ -344,13 +322,6 @@ func (m *aclManager) createDefaultChains() error {
} }
clear(m.optionalEntries) clear(m.optionalEntries)
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
for _, rule := range m.entries[mangleFwdKey] {
if err := m.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil {
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
}
}
return nil return nil
} }
@@ -372,22 +343,6 @@ func (m *aclManager) seedInitialEntries() {
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT}) m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN}) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
// Mangle FORWARD guard: when external DNAT redirects traffic from the wg interface, it
// traverses FORWARD instead of INPUT, bypassing ACL rules. ACCEPT rules in filter FORWARD
// can be inserted above ours. Mangle runs before filter, so these guard rules enforce the
// ACL mark check where it cannot be overridden.
m.appendToEntries(mangleFwdKey, []string{
"-i", m.wgIface.Name(),
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
"-j", "ACCEPT",
})
m.appendToEntries(mangleFwdKey, []string{
"-i", m.wgIface.Name(),
"-m", "conntrack", "--ctstate", "DNAT",
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
"-j", "DROP",
})
} }
func (m *aclManager) seedInitialOptionalEntries() { func (m *aclManager) seedInitialOptionalEntries() {
@@ -421,13 +376,8 @@ func (m *aclManager) updateState() {
currentState.Lock() currentState.Lock()
defer currentState.Unlock() defer currentState.Unlock()
if m.v6 { currentState.ACLEntries = m.entries
currentState.ACLEntries6 = m.entries currentState.ACLIPsetStore = m.ipsetStore
currentState.ACLIPsetStore6 = m.ipsetStore
} else {
currentState.ACLEntries = m.entries
currentState.ACLIPsetStore = m.ipsetStore
}
if err := m.stateManager.UpdateState(currentState); err != nil { if err := m.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err) log.Errorf("failed to update state: %v", err)
@@ -435,22 +385,13 @@ func (m *aclManager) updateState() {
} }
// filterRuleSpecs returns the specs of a filtering rule // filterRuleSpecs returns the specs of a filtering rule
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
if v6 && protocol == firewall.ProtocolICMP {
return "ipv6-icmp"
}
return string(protocol)
}
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) { func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
// don't use IP matching if IP is 0.0.0.0 // don't use IP matching if IP is 0.0.0.0
matchByIP := !ip.IsUnspecified() matchByIP := !ip.IsUnspecified()
if matchByIP { if matchByIP {
if ipsetName != "" { if ipsetName != "" {
specs = append(specs, "-m", "set", "--match-set", ipsetName, "src") specs = append(specs, "-m", "set", "--set", ipsetName, "src")
} else { } else {
specs = append(specs, "-s", ip.String()) specs = append(specs, "-s", ip.String())
} }
@@ -496,9 +437,6 @@ func (m *aclManager) createIPSet(name string) error {
opts := ipset.CreateOptions{ opts := ipset.CreateOptions{
Replace: true, Replace: true,
} }
if m.v6 {
opts.Family = ipset.FamilyIPV6
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil { if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err) return fmt.Errorf("create ipset %s: %w", name, err)

View File

@@ -12,37 +12,27 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
type resetter interface {
Reset() error
}
// Manager of iptables firewall // Manager of iptables firewall
type Manager struct { type Manager struct {
mutex sync.Mutex mutex sync.Mutex
wgIface iFaceMapper wgIface iFaceMapper
ipv4Client *iptables.IPTables ipv4Client *iptables.IPTables
aclMgr *aclManager aclMgr *aclManager
router *router router *router
rawSupported bool
// IPv6 counterparts, nil when no v6 overlay
ipv6Client *iptables.IPTables
aclMgr6 *aclManager
router6 *router
} }
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface { type iFaceMapper interface {
Name() string Name() string
Address() wgaddr.Address Address() wgaddr.Address
IsUserspaceBind() bool
} }
// Create iptables firewall manager // Create iptables firewall manager
@@ -67,49 +57,16 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
return nil, fmt.Errorf("create acl manager: %w", err) return nil, fmt.Errorf("create acl manager: %w", err)
} }
if wgIface.Address().HasIPv6() {
if err := m.createIPv6Components(wgIface, mtu); err != nil {
return nil, fmt.Errorf("create IPv6 firewall: %w", err)
}
}
return m, nil return m, nil
} }
func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error {
ip6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
return fmt.Errorf("init ip6tables: %w", err)
}
m.ipv6Client = ip6Client
m.router6, err = newRouter(ip6Client, wgIface, mtu)
if err != nil {
return fmt.Errorf("create v6 router: %w", err)
}
// Share the same IP forwarding state with the v4 router, since
// EnableIPForwarding controls both v4 and v6 sysctls.
m.router6.ipFwdState = m.router.ipFwdState
m.aclMgr6, err = newAclManager(ip6Client, wgIface)
if err != nil {
return fmt.Errorf("create v6 acl manager: %w", err)
}
return nil
}
func (m *Manager) hasIPv6() bool {
return m.ipv6Client != nil
}
func (m *Manager) Init(stateManager *statemanager.Manager) error { func (m *Manager) Init(stateManager *statemanager.Manager) error {
state := &ShutdownState{ state := &ShutdownState{
InterfaceState: &InterfaceState{ InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(), NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(), WGAddress: m.wgIface.Address(),
MTU: m.router.mtu, UserspaceBind: m.wgIface.IsUserspaceBind(),
MTU: m.router.mtu,
}, },
} }
stateManager.RegisterState(state) stateManager.RegisterState(state)
@@ -117,18 +74,17 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
log.Errorf("failed to update state: %v", err) log.Errorf("failed to update state: %v", err)
} }
if err := m.initChains(stateManager); err != nil { if err := m.router.init(stateManager); err != nil {
return err return fmt.Errorf("router init: %w", err)
}
if err := m.aclMgr.init(stateManager); err != nil {
// TODO: cleanup router
return fmt.Errorf("acl manager init: %w", err)
} }
if err := m.initNoTrackChain(); err != nil { if err := m.initNoTrackChain(); err != nil {
log.Warnf("raw table not available, notrack rules will be disabled: %v", err) return fmt.Errorf("init notrack chain: %w", err)
}
// Trust after all fatal init steps so a later failure doesn't leave the
// interface in firewalld's trusted zone without a corresponding Close.
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
} }
// persist early to ensure cleanup of chains // persist early to ensure cleanup of chains
@@ -141,41 +97,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
return nil return nil
} }
// initChains initializes router and ACL chains for both address families,
// rolling back on failure.
func (m *Manager) initChains(stateManager *statemanager.Manager) error {
type initStep struct {
name string
init func(*statemanager.Manager) error
mgr resetter
}
steps := []initStep{
{"router", m.router.init, m.router},
{"acl manager", m.aclMgr.init, m.aclMgr},
}
if m.hasIPv6() {
steps = append(steps,
initStep{"v6 router", m.router6.init, m.router6},
initStep{"v6 acl manager", m.aclMgr6.init, m.aclMgr6},
)
}
var initialized []initStep
for _, s := range steps {
if err := s.init(stateManager); err != nil {
for i := len(initialized) - 1; i >= 0; i-- {
if rerr := initialized[i].mgr.Reset(); rerr != nil {
log.Warnf("rollback %s: %v", initialized[i].name, rerr)
}
}
return fmt.Errorf("%s init: %w", s.name, err)
}
initialized = append(initialized, s)
}
return nil
}
// AddPeerFiltering adds a rule to the firewall // AddPeerFiltering adds a rule to the firewall
// //
// Comment will be ignored because some system this feature is not supported // Comment will be ignored because some system this feature is not supported
@@ -191,13 +112,7 @@ func (m *Manager) AddPeerFiltering(
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if ip.To4() != nil { return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
if !m.hasIPv6() {
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
}
return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
} }
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
@@ -211,48 +126,25 @@ func (m *Manager) AddRouteFiltering(
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if isIPv6RouteRule(sources, destination) { if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
if !m.hasIPv6() { return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
} }
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
} }
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
return len(sources) > 0 && sources[0].Addr().Is6()
}
// DeletePeerRule from the firewall by rule definition // DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error { func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.hasIPv6() && isIPv6IptRule(rule) {
return m.aclMgr6.DeletePeerRule(rule)
}
return m.aclMgr.DeletePeerRule(rule) return m.aclMgr.DeletePeerRule(rule)
} }
func isIPv6IptRule(rule firewall.Rule) bool {
r, ok := rule.(*Rule)
return ok && r.v6
}
// DeleteRouteRule deletes a routing rule.
// Route rules are keyed by content hash. Check v4 first, try v6 if not found.
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
return m.router6.DeleteRouteRule(rule)
}
return m.router.DeleteRouteRule(rule) return m.router.DeleteRouteRule(rule)
} }
@@ -268,65 +160,18 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { return m.router.AddNatRule(pair)
if !m.hasIPv6() {
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddNatRule(pair)
}
if err := m.router.AddNatRule(pair); err != nil {
return err
}
// Dynamic routes need NAT in both tables since resolved IPs can be
// either v4 or v6. This covers both DomainSet (modern) and the legacy
// wildcard 0.0.0.0/0 destination where the client resolves DNS.
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.router6.AddNatRule(v6Pair); err != nil {
return fmt.Errorf("add v6 NAT rule: %w", err)
}
}
return nil
} }
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { return m.router.RemoveNatRule(pair)
if !m.hasIPv6() {
return nil
}
return m.router6.RemoveNatRule(pair)
}
var merr *multierror.Error
if err := m.router.RemoveNatRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
}
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
} }
func (m *Manager) SetLegacyManagement(isLegacy bool) error { func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil { return firewall.SetLegacyManagement(m.router, isLegacy)
return err
}
if m.hasIPv6() {
return firewall.SetLegacyManagement(m.router6, isLegacy)
}
return nil
} }
// Reset firewall to the default state // Reset firewall to the default state
@@ -340,15 +185,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err)) merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
} }
if m.hasIPv6() {
if err := m.aclMgr6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err))
}
if err := m.router6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err))
}
}
if err := m.aclMgr.Reset(); err != nil { if err := m.aclMgr.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err)) merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
} }
@@ -356,12 +192,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err)) merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
} }
// Appending to merr intentionally blocks DeleteState below so ShutdownState
// stays persisted and the crash-recovery path retries firewalld cleanup.
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
merr = multierror.Append(merr, err)
}
// attempt to delete state only if all other operations succeeded // attempt to delete state only if all other operations succeeded
if merr == nil { if merr == nil {
if err := stateManager.DeleteState(&ShutdownState{}); err != nil { if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
@@ -372,25 +202,25 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
// AllowNetbird allows netbird interface traffic. // AllowNetbird allows netbird interface traffic
// This is called when USPFilter wraps the native firewall, adding blanket accept
// rules so that packet filtering is handled in userspace instead of by netfilter.
func (m *Manager) AllowNetbird() error { func (m *Manager) AllowNetbird() error {
var merr *multierror.Error if !m.wgIface.IsUserspaceBind() {
if _, err := m.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil { return nil
merr = multierror.Append(merr, fmt.Errorf("allow netbird v4 interface traffic: %w", err))
}
if m.hasIPv6() {
if _, err := m.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
merr = multierror.Append(merr, fmt.Errorf("allow netbird v6 interface traffic: %w", err))
}
} }
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil { _, err := m.AddPeerFiltering(
log.Warnf("failed to trust interface in firewalld: %v", err) nil,
net.IP{0, 0, 0, 0},
firewall.ProtocolALL,
nil,
nil,
firewall.ActionAccept,
"",
)
if err != nil {
return fmt.Errorf("allow netbird interface traffic: %w", err)
} }
return nil
return nberrors.FormatErrorOrNil(merr)
} }
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
@@ -420,12 +250,6 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if rule.TranslatedAddress.Is6() {
if !m.hasIPv6() {
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddDNATRule(rule)
}
return m.router.AddDNATRule(rule) return m.router.AddDNATRule(rule)
} }
@@ -434,9 +258,6 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) {
return m.router6.DeleteDNATRule(rule)
}
return m.router.DeleteDNATRule(rule) return m.router.DeleteDNATRule(rule)
} }
@@ -445,82 +266,23 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
var v4Prefixes, v6Prefixes []netip.Prefix return m.router.UpdateSet(set, prefixes)
for _, p := range prefixes {
if p.Addr().Is6() {
v6Prefixes = append(v6Prefixes, p)
} else {
v4Prefixes = append(v4Prefixes, p)
}
}
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
return err
}
if m.hasIPv6() && len(v6Prefixes) > 0 {
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
return fmt.Errorf("update v6 set: %w", err)
}
}
return nil
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if localAddr.Is6() { return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
if !m.hasIPv6() {
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
} }
// RemoveInboundDNAT removes an inbound DNAT rule. // RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if localAddr.Is6() { return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
if !m.hasIPv6() {
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if localAddr.Is6() {
if !m.hasIPv6() {
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if localAddr.Is6() {
if !m.hasIPv6() {
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
} }
const ( const (
@@ -556,10 +318,6 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if !m.rawSupported {
return fmt.Errorf("raw table not available")
}
wgPortStr := fmt.Sprintf("%d", wgPort) wgPortStr := fmt.Sprintf("%d", wgPort)
proxyPortStr := fmt.Sprintf("%d", proxyPort) proxyPortStr := fmt.Sprintf("%d", proxyPort)
@@ -617,16 +375,12 @@ func (m *Manager) initNoTrackChain() error {
return fmt.Errorf("add prerouting jump rule: %w", err) return fmt.Errorf("add prerouting jump rule: %w", err)
} }
m.rawSupported = true
return nil return nil
} }
func (m *Manager) cleanupNoTrackChain() error { func (m *Manager) cleanupNoTrackChain() error {
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw) exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
if err != nil { if err != nil {
if !m.rawSupported {
return nil
}
return fmt.Errorf("check chain exists: %w", err) return fmt.Errorf("check chain exists: %w", err)
} }
if !exists { if !exists {
@@ -647,7 +401,6 @@ func (m *Manager) cleanupNoTrackChain() error {
return fmt.Errorf("clear and delete chain: %w", err) return fmt.Errorf("clear and delete chain: %w", err)
} }
m.rawSupported = false
return nil return nil
} }

View File

@@ -47,6 +47,8 @@ func (i *iFaceMock) Address() wgaddr.Address {
panic("AddressFunc is not set") panic("AddressFunc is not set")
} }
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestIptablesManager(t *testing.T) { func TestIptablesManager(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -36,7 +36,6 @@ const (
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT" chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
chainRTPRE = "NETBIRD-RT-PRE" chainRTPRE = "NETBIRD-RT-PRE"
chainRTRDR = "NETBIRD-RT-RDR" chainRTRDR = "NETBIRD-RT-RDR"
chainNATOutput = "NETBIRD-NAT-OUTPUT"
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP" chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
routingFinalForwardJump = "ACCEPT" routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE" routingFinalNatJump = "MASQUERADE"
@@ -44,7 +43,6 @@ const (
jumpManglePre = "jump-mangle-pre" jumpManglePre = "jump-mangle-pre"
jumpNatPre = "jump-nat-pre" jumpNatPre = "jump-nat-pre"
jumpNatPost = "jump-nat-post" jumpNatPost = "jump-nat-post"
jumpNatOutput = "jump-nat-output"
jumpMSSClamp = "jump-mss-clamp" jumpMSSClamp = "jump-mss-clamp"
markManglePre = "mark-mangle-pre" markManglePre = "mark-mangle-pre"
markManglePost = "mark-mangle-post" markManglePost = "mark-mangle-post"
@@ -54,10 +52,8 @@ const (
snatSuffix = "_snat" snatSuffix = "_snat"
fwdSuffix = "_fwd" fwdSuffix = "_fwd"
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation. // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipv4TCPHeaderSize = 40 ipTCPHeaderMinSize = 40
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
ipv6TCPHeaderSize = 60
) )
type ruleInfo struct { type ruleInfo struct {
@@ -88,7 +84,6 @@ type router struct {
wgIface iFaceMapper wgIface iFaceMapper
legacyManagement bool legacyManagement bool
mtu uint16 mtu uint16
v6 bool
stateManager *statemanager.Manager stateManager *statemanager.Manager
ipFwdState *ipfwdstate.IPForwardingState ipFwdState *ipfwdstate.IPForwardingState
@@ -100,7 +95,6 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1
rules: make(map[string][]string), rules: make(map[string][]string),
wgIface: wgIface, wgIface: wgIface,
mtu: mtu, mtu: mtu,
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
ipFwdState: ipfwdstate.NewIPForwardingState(), ipFwdState: ipfwdstate.NewIPForwardingState(),
} }
@@ -190,11 +184,6 @@ func (r *router) AddRouteFiltering(
return ruleKey, nil return ruleKey, nil
} }
func (r *router) hasRule(id string) bool {
_, ok := r.rules[id]
return ok
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error { func (r *router) DeleteRouteRule(rule firewall.Rule) error {
ruleKey := rule.ID() ruleKey := rule.ID()
@@ -398,18 +387,6 @@ func (r *router) cleanUpDefaultForwardRules() error {
} }
log.Debug("flushing routing related tables") log.Debug("flushing routing related tables")
// Remove jump rules from built-in chains before deleting custom chains,
// otherwise the chain deletion fails with "device or resource busy".
if ok, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput); err != nil {
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
} else if ok {
jumpRule := []string{"-j", chainNATOutput}
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil {
log.Debugf("clean OUTPUT jump rule: %v", err)
}
}
for _, chainInfo := range []struct { for _, chainInfo := range []struct {
chain string chain string
table string table string
@@ -419,7 +396,6 @@ func (r *router) cleanUpDefaultForwardRules() error {
{chainRTPRE, tableMangle}, {chainRTPRE, tableMangle},
{chainRTNAT, tableNat}, {chainRTNAT, tableNat},
{chainRTRDR, tableNat}, {chainRTRDR, tableNat},
{chainNATOutput, tableNat},
{chainRTMSSCLAMP, tableMangle}, {chainRTMSSCLAMP, tableMangle},
} { } {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain) ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
@@ -447,12 +423,6 @@ func (r *router) createContainers() error {
{chainRTRDR, tableNat}, {chainRTRDR, tableNat},
{chainRTMSSCLAMP, tableMangle}, {chainRTMSSCLAMP, tableMangle},
} { } {
// Fallback: clear chains that survived an unclean shutdown.
if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok {
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err)
}
}
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil { if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
} }
@@ -559,12 +529,9 @@ func (r *router) addPostroutingRules() error {
} }
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. // addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
// TODO: Add IPv6 support
func (r *router) addMSSClampingRules() error { func (r *router) addMSSClampingRules() error {
overhead := uint16(ipv4TCPHeaderSize) mss := r.mtu - ipTCPHeaderMinSize
if r.v6 {
overhead = ipv6TCPHeaderSize
}
mss := r.mtu - overhead
// Add jump rule from FORWARD chain in mangle table to our custom chain // Add jump rule from FORWARD chain in mangle table to our custom chain
jumpRule := []string{ jumpRule := []string{
@@ -749,13 +716,8 @@ func (r *router) updateState() {
currentState.Lock() currentState.Lock()
defer currentState.Unlock() defer currentState.Unlock()
if r.v6 { currentState.RouteRules = r.rules
currentState.RouteRules6 = r.rules currentState.RouteIPsetCounter = r.ipsetCounter
currentState.RouteIPsetCounter6 = r.ipsetCounter
} else {
currentState.RouteRules = r.rules
currentState.RouteIPsetCounter = r.ipsetCounter
}
if err := r.stateManager.UpdateState(currentState); err != nil { if err := r.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err) log.Errorf("failed to update state: %v", err)
@@ -883,7 +845,7 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
} }
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists { if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDOUT, fwdRule...); err != nil { if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err)) merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
} }
delete(r.rules, ruleKey+fwdSuffix) delete(r.rules, ruleKey+fwdSuffix)
@@ -910,7 +872,7 @@ func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []net
rule = append(rule, destExp...) rule = append(rule, destExp...)
if params.Proto != firewall.ProtocolALL { if params.Proto != firewall.ProtocolALL {
rule = append(rule, "-p", strings.ToLower(protoForFamily(params.Proto, r.v6))) rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
rule = append(rule, applyPort("--sport", params.SPort)...) rule = append(rule, applyPort("--sport", params.SPort)...)
rule = append(rule, applyPort("--dport", params.DPort)...) rule = append(rule, applyPort("--dport", params.DPort)...)
} }
@@ -927,12 +889,11 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []
} }
if network.IsSet() { if network.IsSet() {
name := r.ipsetName(network.Set.HashedName()) if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil {
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err) return nil, fmt.Errorf("create or get ipset: %w", err)
} }
return []string{"-m", "set", matchSet, name, direction}, nil return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil
} }
if network.IsPrefix() { if network.IsPrefix() {
return []string{flag, network.Prefix.String()}, nil return []string{flag, network.Prefix.String()}, nil
@@ -943,23 +904,27 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []
} }
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
name := r.ipsetName(set.HashedName())
var merr *multierror.Error var merr *multierror.Error
for _, prefix := range prefixes { for _, prefix := range prefixes {
if err := r.addPrefixToIPSet(name, prefix); err != nil { // TODO: Implement IPv6 support
if prefix.Addr().Is6() {
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err)) merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
} }
} }
if merr == nil { if merr == nil {
log.Debugf("updated set %s with prefixes %v", name, prefixes) log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
} }
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort) ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists { if _, exists := r.rules[ruleID]; exists {
return nil return nil
@@ -967,12 +932,12 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
dnatRule := []string{ dnatRule := []string{
"-i", r.wgIface.Name(), "-i", r.wgIface.Name(),
"-p", strings.ToLower(protoForFamily(protocol, r.v6)), "-p", strings.ToLower(string(protocol)),
"--dport", strconv.Itoa(int(originalPort)), "--dport", strconv.Itoa(int(sourcePort)),
"-d", localAddr.String(), "-d", localAddr.String(),
"-m", "addrtype", "--dst-type", "LOCAL", "-m", "addrtype", "--dst-type", "LOCAL",
"-j", "DNAT", "-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(translatedPort)), "--to-destination", ":" + strconv.Itoa(int(targetPort)),
} }
ruleInfo := ruleInfo{ ruleInfo := ruleInfo{
@@ -991,8 +956,8 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
} }
// RemoveInboundDNAT removes an inbound DNAT rule. // RemoveInboundDNAT removes an inbound DNAT rule.
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort) ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if dnatRule, exists := r.rules[ruleID]; exists { if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil { if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
@@ -1005,81 +970,6 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
return nil return nil
} }
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
func (r *router) ensureNATOutputChain() error {
if _, exists := r.rules[jumpNatOutput]; exists {
return nil
}
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
if err != nil {
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
}
if !chainExists {
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
}
}
jumpRule := []string{"-j", chainNATOutput}
if err := r.iptablesClient.Insert(tableNat, "OUTPUT", 1, jumpRule...); err != nil {
if !chainExists {
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
}
}
return fmt.Errorf("add OUTPUT jump rule: %w", err)
}
r.rules[jumpNatOutput] = jumpRule
r.updateState()
return nil
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
if err := r.ensureNATOutputChain(); err != nil {
return err
}
dnatRule := []string{
"-p", strings.ToLower(protoForFamily(protocol, localAddr.Is6())),
"--dport", strconv.Itoa(int(originalPort)),
"-d", localAddr.String(),
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
}
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
return fmt.Errorf("add output DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
r.updateState()
return nil
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
return fmt.Errorf("delete output DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
r.updateState()
return nil
}
func applyPort(flag string, port *firewall.Port) []string { func applyPort(flag string, port *firewall.Port) []string {
if port == nil { if port == nil {
return nil return nil
@@ -1100,22 +990,10 @@ func applyPort(flag string, port *firewall.Port) []string {
return []string{flag, strconv.Itoa(int(port.Values[0]))} return []string{flag, strconv.Itoa(int(port.Values[0]))}
} }
// ipsetName returns the ipset name, suffixed with "-v6" for the v6 router
// to avoid collisions since ipsets are global in the kernel.
func (r *router) ipsetName(name string) string {
if r.v6 {
return name + "-v6"
}
return name
}
func (r *router) createIPSet(name string) error { func (r *router) createIPSet(name string) error {
opts := ipset.CreateOptions{ opts := ipset.CreateOptions{
Replace: true, Replace: true,
} }
if r.v6 {
opts.Family = ipset.FamilyIPV6
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil { if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err) return fmt.Errorf("create ipset %s: %w", name, err)

View File

@@ -9,7 +9,6 @@ type Rule struct {
mangleSpecs []string mangleSpecs []string
ip string ip string
chain string chain string
v6 bool
} }
// GetRuleID returns the rule id // GetRuleID returns the rule id

View File

@@ -4,16 +4,15 @@ import (
"fmt" "fmt"
"sync" "sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type InterfaceState struct { type InterfaceState struct {
NameStr string `json:"name"` NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"` WGAddress wgaddr.Address `json:"wg_address"`
MTU uint16 `json:"mtu"` UserspaceBind bool `json:"userspace_bind"`
MTU uint16 `json:"mtu"`
} }
func (i *InterfaceState) Name() string { func (i *InterfaceState) Name() string {
@@ -24,6 +23,10 @@ func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress return i.WGAddress
} }
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct { type ShutdownState struct {
sync.Mutex sync.Mutex
@@ -34,12 +37,6 @@ type ShutdownState struct {
ACLEntries aclEntries `json:"acl_entries,omitempty"` ACLEntries aclEntries `json:"acl_entries,omitempty"`
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"` ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
// IPv6 counterparts
RouteRules6 routeRules `json:"route_rules_v6,omitempty"`
RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"`
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"`
} }
func (s *ShutdownState) Name() string { func (s *ShutdownState) Name() string {
@@ -70,28 +67,6 @@ func (s *ShutdownState) Cleanup() error {
ipt.aclMgr.ipsetStore = s.ACLIPsetStore ipt.aclMgr.ipsetStore = s.ACLIPsetStore
} }
// Clean up v6 state even if the current run has no IPv6.
// The previous run may have left ip6tables rules behind.
if !ipt.hasIPv6() {
if err := ipt.createIPv6Components(s.InterfaceState, mtu); err != nil {
log.Warnf("failed to create v6 components for cleanup: %v", err)
}
}
if ipt.hasIPv6() {
if s.RouteRules6 != nil {
ipt.router6.rules = s.RouteRules6
}
if s.RouteIPsetCounter6 != nil {
ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
}
if s.ACLEntries6 != nil {
ipt.aclMgr6.entries = s.ACLEntries6
}
if s.ACLIPsetStore6 != nil {
ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6
}
}
if err := ipt.Close(nil); err != nil { if err := ipt.Close(nil); err != nil {
return fmt.Errorf("reset iptables manager: %w", err) return fmt.Errorf("reset iptables manager: %w", err)
} }

View File

@@ -1,7 +1,6 @@
package manager package manager
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -12,10 +11,6 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
// ErrIPv6NotInitialized is returned when an IPv6 address is passed to a firewall
// method but the IPv6 firewall components were not initialized.
var ErrIPv6NotInitialized = errors.New("IPv6 firewall not initialized")
const ( const (
ForwardingFormatPrefix = "netbird-fwd-" ForwardingFormatPrefix = "netbird-fwd-"
ForwardingFormat = "netbird-fwd-%s-%t" ForwardingFormat = "netbird-fwd-%s-%t"
@@ -169,16 +164,10 @@ type Manager interface {
UpdateSet(hash Set, prefixes []netip.Prefix) error UpdateSet(hash Set, prefixes []netip.Prefix) error
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
AddInboundDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// RemoveInboundDNAT removes inbound DNAT rule // RemoveInboundDNAT removes inbound DNAT rule
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
AddOutputDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic. // SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from interfering with WireGuard proxy communication. // This prevents conntrack from interfering with WireGuard proxy communication.

View File

@@ -1,8 +1,6 @@
package manager package manager
import ( import (
"net/netip"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -12,10 +10,6 @@ type RouterPair struct {
Destination Network Destination Network
Masquerade bool Masquerade bool
Inverse bool Inverse bool
// Dynamic indicates the route is domain-based. NAT rules for dynamic
// routes are duplicated to the v6 table so that resolved AAAA records
// are masqueraded correctly.
Dynamic bool
} }
func GetInversePair(pair RouterPair) RouterPair { func GetInversePair(pair RouterPair) RouterPair {
@@ -26,17 +20,5 @@ func GetInversePair(pair RouterPair) RouterPair {
Destination: pair.Source, Destination: pair.Source,
Masquerade: pair.Masquerade, Masquerade: pair.Masquerade,
Inverse: true, Inverse: true,
Dynamic: pair.Dynamic,
} }
} }
// ToV6NatPair creates a v6 counterpart of a v4 NAT pair with `::/0` source
// and, for prefix destinations, `::/0` destination.
func ToV6NatPair(pair RouterPair) RouterPair {
v6 := pair
v6.Source = Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
if v6.Destination.IsPrefix() {
v6.Destination = Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
}
return v6
}

View File

@@ -33,12 +33,15 @@ const (
const flushError = "flush: %w" const flushError = "flush: %w"
var (
anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
)
type AclManager struct { type AclManager struct {
rConn *nftables.Conn rConn *nftables.Conn
sConn *nftables.Conn sConn *nftables.Conn
wgIface iFaceMapper wgIface iFaceMapper
routingFwChainName string routingFwChainName string
af addrFamily
workTable *nftables.Table workTable *nftables.Table
chainInputRules *nftables.Chain chainInputRules *nftables.Chain
@@ -64,7 +67,6 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
wgIface: wgIface, wgIface: wgIface,
workTable: table, workTable: table,
routingFwChainName: routingFwChainName, routingFwChainName: routingFwChainName,
af: familyForAddr(table.Family == nftables.TableFamilyIPv4),
ipsetStore: newIpsetStore(), ipsetStore: newIpsetStore(),
rules: make(map[string]*Rule), rules: make(map[string]*Rule),
@@ -143,7 +145,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
} }
if _, ok := ips[r.ip.String()]; ok { if _, ok := ips[r.ip.String()]; ok {
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}}) err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}})
if err != nil { if err != nil {
log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err) log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err)
} }
@@ -252,11 +254,11 @@ func (m *AclManager) addIOFiltering(
expressions = append(expressions, &expr.Payload{ expressions = append(expressions, &expr.Payload{
DestRegister: 1, DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
Offset: m.af.protoOffset, Offset: uint32(9),
Len: uint32(1), Len: uint32(1),
}) })
protoData, err := m.af.protoNum(proto) protoData, err := protoToInt(proto)
if err != nil { if err != nil {
return nil, fmt.Errorf("convert protocol to number: %v", err) return nil, fmt.Errorf("convert protocol to number: %v", err)
} }
@@ -268,16 +270,19 @@ func (m *AclManager) addIOFiltering(
}) })
} }
rawIP := ipToBytes(ip, m.af) rawIP := ip.To4()
// check if rawIP contains zeroed IPv4 0.0.0.0 value // check if rawIP contains zeroed IPv4 0.0.0.0 value
// in that case not add IP match expression into the rule definition // in that case not add IP match expression into the rule definition
if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) { if !bytes.HasPrefix(anyIP, rawIP) {
// source address position
addrOffset := uint32(12)
expressions = append(expressions, expressions = append(expressions,
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
Offset: m.af.srcAddrOffset, Offset: addrOffset,
Len: m.af.addrLen, Len: 4,
}, },
) )
// add individual IP for match if no ipset defined // add individual IP for match if no ipset defined
@@ -582,7 +587,7 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr
func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) { func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) {
ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName) ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName)
rawIP := ipToBytes(ip, m.af) rawIP := ip.To4()
if err != nil { if err != nil {
if ipset, err = m.createSet(m.workTable, ipsetName); err != nil { if ipset, err = m.createSet(m.workTable, ipsetName); err != nil {
return nil, fmt.Errorf("get set name: %v", err) return nil, fmt.Errorf("get set name: %v", err)
@@ -614,7 +619,7 @@ func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Se
Name: name, Name: name,
Table: table, Table: table,
Dynamic: true, Dynamic: true,
KeyType: m.af.setKeyType, KeyType: nftables.TypeIPAddr,
} }
if err := m.rConn.AddSet(ipset, nil); err != nil { if err := m.rConn.AddSet(ipset, nil); err != nil {
@@ -702,12 +707,15 @@ func ifname(n string) []byte {
return b return b
} }
func protoToInt(protocol firewall.Protocol) (uint8, error) {
// ipToBytes converts net.IP to the correct byte length for the address family. switch protocol {
func ipToBytes(ip net.IP, af addrFamily) []byte { case firewall.ProtocolTCP:
if af.addrLen == 4 { return unix.IPPROTO_TCP, nil
return ip.To4() case firewall.ProtocolUDP:
return unix.IPPROTO_UDP, nil
case firewall.ProtocolICMP:
return unix.IPPROTO_ICMP, nil
} }
return ip.To16()
}
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
}

View File

@@ -1,81 +0,0 @@
package nftables
import (
"fmt"
"net"
"github.com/google/nftables"
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
var (
// afIPv4 defines IPv4 header layout and nftables types.
afIPv4 = addrFamily{
protoOffset: 9,
srcAddrOffset: 12,
dstAddrOffset: 16,
addrLen: net.IPv4len,
totalBits: 8 * net.IPv4len,
setKeyType: nftables.TypeIPAddr,
tableFamily: nftables.TableFamilyIPv4,
icmpProto: unix.IPPROTO_ICMP,
}
// afIPv6 defines IPv6 header layout and nftables types.
afIPv6 = addrFamily{
protoOffset: 6,
srcAddrOffset: 8,
dstAddrOffset: 24,
addrLen: net.IPv6len,
totalBits: 8 * net.IPv6len,
setKeyType: nftables.TypeIP6Addr,
tableFamily: nftables.TableFamilyIPv6,
icmpProto: unix.IPPROTO_ICMPV6,
}
)
// addrFamily holds protocol-specific constants for nftables expression building.
type addrFamily struct {
// protoOffset is the IP header offset for the protocol/next-header field (9 for v4, 6 for v6)
protoOffset uint32
// srcAddrOffset is the IP header offset for the source address (12 for v4, 8 for v6)
srcAddrOffset uint32
// dstAddrOffset is the IP header offset for the destination address (16 for v4, 24 for v6)
dstAddrOffset uint32
// addrLen is the byte length of addresses (4 for v4, 16 for v6)
addrLen uint32
// totalBits is the address size in bits (32 for v4, 128 for v6)
totalBits int
// setKeyType is the nftables set data type for addresses
setKeyType nftables.SetDatatype
// tableFamily is the nftables table family
tableFamily nftables.TableFamily
// icmpProto is the ICMP protocol number for this family (1 for v4, 58 for v6)
icmpProto uint8
}
// familyForAddr returns the address family for the given IP.
func familyForAddr(is4 bool) addrFamily {
if is4 {
return afIPv4
}
return afIPv6
}
// protoNum converts a firewall protocol to the IP protocol number,
// using the correct ICMP variant for the address family.
func (af addrFamily) protoNum(protocol firewall.Protocol) (uint8, error) {
switch protocol {
case firewall.ProtocolTCP:
return unix.IPPROTO_TCP, nil
case firewall.ProtocolUDP:
return unix.IPPROTO_UDP, nil
case firewall.ProtocolICMP:
return af.icmpProto, nil
case firewall.ProtocolALL:
return 0, nil
default:
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
}
}

View File

@@ -1,76 +0,0 @@
//go:build linux
package nftables
import (
"os"
"sync/atomic"
"testing"
"time"
"github.com/google/nftables"
"github.com/stretchr/testify/require"
)
// TestExternalChainMonitorRootIntegration verifies that adding a new chain
// in an external (non-netbird) filter table triggers the reconciler.
// Requires CAP_NET_ADMIN; skip otherwise.
func TestExternalChainMonitorRootIntegration(t *testing.T) {
if os.Geteuid() != 0 {
t.Skip("root required")
}
calls := make(chan struct{}, 8)
var count atomic.Int32
rec := &countingReconciler{calls: calls, count: &count}
m := newExternalChainMonitor(rec)
m.start()
t.Cleanup(m.stop)
// Give the netlink subscription a moment to register.
time.Sleep(200 * time.Millisecond)
conn := &nftables.Conn{}
table := conn.AddTable(&nftables.Table{
Name: "nbmon_integration_test",
Family: nftables.TableFamilyINet,
})
t.Cleanup(func() {
cleanup := &nftables.Conn{}
cleanup.DelTable(table)
_ = cleanup.Flush()
})
chain := conn.AddChain(&nftables.Chain{
Name: "filter_INPUT",
Table: table,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
Type: nftables.ChainTypeFilter,
})
_ = chain
require.NoError(t, conn.Flush(), "create external test chain")
select {
case <-calls:
// success
case <-time.After(3 * time.Second):
t.Fatalf("reconcile was not invoked after creating an external chain")
}
require.GreaterOrEqual(t, count.Load(), int32(1))
}
type countingReconciler struct {
calls chan struct{}
count *atomic.Int32
}
func (c *countingReconciler) reconcileExternalChains() error {
c.count.Add(1)
select {
case c.calls <- struct{}{}:
default:
}
return nil
}

View File

@@ -1,199 +0,0 @@
package nftables
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/nftables"
log "github.com/sirupsen/logrus"
)
const (
externalMonitorReconcileDelay = 500 * time.Millisecond
externalMonitorInitInterval = 5 * time.Second
externalMonitorMaxInterval = 5 * time.Minute
externalMonitorRandomization = 0.5
)
// externalChainReconciler re-applies passthrough accept rules to external
// nftables chains. Implementations must be safe to call from the monitor
// goroutine; the Manager locks its mutex internally.
type externalChainReconciler interface {
reconcileExternalChains() error
}
// externalChainMonitor watches nftables netlink events and triggers a
// reconcile when a new table or chain appears (e.g. after
// `firewall-cmd --reload`). Netlink errors trigger exponential-backoff
// reconnect.
type externalChainMonitor struct {
reconciler externalChainReconciler
mu sync.Mutex
cancel context.CancelFunc
done chan struct{}
}
func newExternalChainMonitor(r externalChainReconciler) *externalChainMonitor {
return &externalChainMonitor{reconciler: r}
}
func (m *externalChainMonitor) start() {
m.mu.Lock()
defer m.mu.Unlock()
if m.cancel != nil {
return
}
ctx, cancel := context.WithCancel(context.Background())
m.cancel = cancel
m.done = make(chan struct{})
go m.run(ctx)
}
func (m *externalChainMonitor) stop() {
m.mu.Lock()
cancel := m.cancel
done := m.done
m.cancel = nil
m.done = nil
m.mu.Unlock()
if cancel == nil {
return
}
cancel()
<-done
}
func (m *externalChainMonitor) run(ctx context.Context) {
defer close(m.done)
bo := &backoff.ExponentialBackOff{
InitialInterval: externalMonitorInitInterval,
RandomizationFactor: externalMonitorRandomization,
Multiplier: backoff.DefaultMultiplier,
MaxInterval: externalMonitorMaxInterval,
MaxElapsedTime: 0,
Clock: backoff.SystemClock,
}
bo.Reset()
for ctx.Err() == nil {
err := m.watch(ctx)
if ctx.Err() != nil {
return
}
delay := bo.NextBackOff()
log.Warnf("external chain monitor: %v, reconnecting in %s", err, delay)
select {
case <-ctx.Done():
return
case <-time.After(delay):
}
}
}
func (m *externalChainMonitor) watch(ctx context.Context) error {
events, closeMon, err := m.subscribe()
if err != nil {
return err
}
defer closeMon()
debounce := time.NewTimer(time.Hour)
if !debounce.Stop() {
<-debounce.C
}
defer debounce.Stop()
pending := false
for {
select {
case <-ctx.Done():
return nil
case <-debounce.C:
pending = false
m.reconcile()
case ev, ok := <-events:
if !ok {
return errors.New("monitor channel closed")
}
if ev.Error != nil {
return fmt.Errorf("monitor event: %w", ev.Error)
}
if !isRelevantMonitorEvent(ev) {
continue
}
resetDebounce(debounce, pending)
pending = true
}
}
}
func (m *externalChainMonitor) subscribe() (chan *nftables.MonitorEvent, func(), error) {
conn := &nftables.Conn{}
mon := nftables.NewMonitor(
nftables.WithMonitorAction(nftables.MonitorActionNew),
nftables.WithMonitorObject(nftables.MonitorObjectChains|nftables.MonitorObjectTables),
)
events, err := conn.AddMonitor(mon)
if err != nil {
return nil, nil, fmt.Errorf("add netlink monitor: %w", err)
}
return events, func() { _ = mon.Close() }, nil
}
// resetDebounce reschedules a pending debounce timer without leaking a stale
// fire on its channel. pending must reflect whether the timer is armed.
func resetDebounce(t *time.Timer, pending bool) {
if pending && !t.Stop() {
select {
case <-t.C:
default:
}
}
t.Reset(externalMonitorReconcileDelay)
}
func (m *externalChainMonitor) reconcile() {
if err := m.reconciler.reconcileExternalChains(); err != nil {
log.Warnf("reconcile external chain rules: %v", err)
}
}
// isRelevantMonitorEvent returns true for table/chain creation events on
// families we care about. The reconciler filters to actual external filter
// chains.
func isRelevantMonitorEvent(ev *nftables.MonitorEvent) bool {
switch ev.Type {
case nftables.MonitorEventTypeNewChain:
chain, ok := ev.Data.(*nftables.Chain)
if !ok || chain == nil || chain.Table == nil {
return false
}
return isMonitoredFamily(chain.Table.Family)
case nftables.MonitorEventTypeNewTable:
table, ok := ev.Data.(*nftables.Table)
if !ok || table == nil {
return false
}
return isMonitoredFamily(table.Family)
}
return false
}
func isMonitoredFamily(family nftables.TableFamily) bool {
switch family {
case nftables.TableFamilyIPv4, nftables.TableFamilyIPv6, nftables.TableFamilyINet:
return true
}
return false
}

View File

@@ -1,137 +0,0 @@
package nftables
import (
"testing"
"github.com/google/nftables"
"github.com/stretchr/testify/assert"
)
func TestIsMonitoredFamily(t *testing.T) {
tests := []struct {
family nftables.TableFamily
want bool
}{
{nftables.TableFamilyIPv4, true},
{nftables.TableFamilyIPv6, true},
{nftables.TableFamilyINet, true},
{nftables.TableFamilyARP, false},
{nftables.TableFamilyBridge, false},
{nftables.TableFamilyNetdev, false},
{nftables.TableFamilyUnspecified, false},
}
for _, tc := range tests {
assert.Equal(t, tc.want, isMonitoredFamily(tc.family), "family=%d", tc.family)
}
}
func TestIsRelevantMonitorEvent(t *testing.T) {
inetTable := &nftables.Table{Name: "firewalld", Family: nftables.TableFamilyINet}
ipTable := &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4}
arpTable := &nftables.Table{Name: "arp", Family: nftables.TableFamilyARP}
tests := []struct {
name string
ev *nftables.MonitorEvent
want bool
}{
{
name: "new chain in inet firewalld",
ev: &nftables.MonitorEvent{
Type: nftables.MonitorEventTypeNewChain,
Data: &nftables.Chain{Name: "filter_INPUT", Table: inetTable},
},
want: true,
},
{
name: "new chain in ip filter",
ev: &nftables.MonitorEvent{
Type: nftables.MonitorEventTypeNewChain,
Data: &nftables.Chain{Name: "INPUT", Table: ipTable},
},
want: true,
},
{
name: "new chain in unwatched arp family",
ev: &nftables.MonitorEvent{
Type: nftables.MonitorEventTypeNewChain,
Data: &nftables.Chain{Name: "x", Table: arpTable},
},
want: false,
},
{
name: "new table inet",
ev: &nftables.MonitorEvent{
Type: nftables.MonitorEventTypeNewTable,
Data: inetTable,
},
want: true,
},
{
name: "del chain (we only act on new)",
ev: &nftables.MonitorEvent{
Type: nftables.MonitorEventTypeDelChain,
Data: &nftables.Chain{Name: "filter_INPUT", Table: inetTable},
},
want: false,
},
{
name: "chain with nil table",
ev: &nftables.MonitorEvent{
Type: nftables.MonitorEventTypeNewChain,
Data: &nftables.Chain{Name: "x"},
},
want: false,
},
{
name: "nil data",
ev: &nftables.MonitorEvent{
Type: nftables.MonitorEventTypeNewChain,
Data: (*nftables.Chain)(nil),
},
want: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, isRelevantMonitorEvent(tc.ev))
})
}
}
// fakeReconciler records reconcile invocations for debounce tests.
type fakeReconciler struct {
calls chan struct{}
}
func (f *fakeReconciler) reconcileExternalChains() error {
f.calls <- struct{}{}
return nil
}
func TestExternalChainMonitorStopWithoutStart(t *testing.T) {
m := newExternalChainMonitor(&fakeReconciler{calls: make(chan struct{}, 1)})
// Must not panic or block.
m.stop()
}
func TestExternalChainMonitorDoubleStart(t *testing.T) {
// start() twice should be a no-op; stop() cleans up once.
// We avoid exercising the netlink watch loop here because it needs root.
m := newExternalChainMonitor(&fakeReconciler{calls: make(chan struct{}, 1)})
// Replace run with a stub that just waits for cancel, so start() stays
// deterministic without opening a netlink socket.
origDone := make(chan struct{})
m.done = origDone
m.cancel = func() { close(origDone) }
// Second start should be a no-op (cancel already set).
m.start()
assert.NotNil(t, m.cancel)
m.stop()
assert.Nil(t, m.cancel)
assert.Nil(t, m.done)
}

View File

@@ -11,12 +11,9 @@ import (
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil" "github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
@@ -43,6 +40,7 @@ func getTableName() string {
type iFaceMapper interface { type iFaceMapper interface {
Name() string Name() string
Address() wgaddr.Address Address() wgaddr.Address
IsUserspaceBind() bool
} }
// Manager of iptables firewall // Manager of iptables firewall
@@ -51,17 +49,10 @@ type Manager struct {
rConn *nftables.Conn rConn *nftables.Conn
wgIface iFaceMapper wgIface iFaceMapper
router *router router *router
aclManager *AclManager aclManager *AclManager
// IPv6 counterparts, nil when no v6 overlay
router6 *router
aclManager6 *AclManager
notrackOutputChain *nftables.Chain notrackOutputChain *nftables.Chain
notrackPreroutingChain *nftables.Chain notrackPreroutingChain *nftables.Chain
extMonitor *externalChainMonitor
} }
// Create nftables firewall manager // Create nftables firewall manager
@@ -71,8 +62,7 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
wgIface: wgIface, wgIface: wgIface,
} }
tableName := getTableName() workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}
workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}
var err error var err error
m.router, err = newRouter(workTable, wgIface, mtu) m.router, err = newRouter(workTable, wgIface, mtu)
@@ -85,170 +75,54 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
return nil, fmt.Errorf("create acl manager: %w", err) return nil, fmt.Errorf("create acl manager: %w", err)
} }
if wgIface.Address().HasIPv6() {
if err := m.createIPv6Components(tableName, wgIface, mtu); err != nil {
return nil, fmt.Errorf("create IPv6 firewall: %w", err)
}
}
m.extMonitor = newExternalChainMonitor(m)
return m, nil return m, nil
} }
func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mtu uint16) error {
workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}
var err error
m.router6, err = newRouter(workTable6, wgIface, mtu)
if err != nil {
return fmt.Errorf("create v6 router: %w", err)
}
// Share the same IP forwarding state with the v4 router, since
// EnableIPForwarding controls both v4 and v6 sysctls.
m.router6.ipFwdState = m.router.ipFwdState
m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw)
if err != nil {
return fmt.Errorf("create v6 acl manager: %w", err)
}
return nil
}
// hasIPv6 reports whether the manager has IPv6 components initialized.
func (m *Manager) hasIPv6() bool {
return m.router6 != nil
}
func (m *Manager) initIPv6() error {
workTable6, err := m.createWorkTableFamily(nftables.TableFamilyIPv6)
if err != nil {
return fmt.Errorf("create v6 work table: %w", err)
}
if err := m.router6.init(workTable6); err != nil {
return fmt.Errorf("v6 router init: %w", err)
}
if err := m.aclManager6.init(workTable6); err != nil {
return fmt.Errorf("v6 acl manager init: %w", err)
}
return nil
}
// Init nftables firewall manager // Init nftables firewall manager
func (m *Manager) Init(stateManager *statemanager.Manager) error { func (m *Manager) Init(stateManager *statemanager.Manager) error {
if err := m.initFirewall(); err != nil {
return err
}
m.persistState(stateManager)
// Start after initFirewall has installed the baseline external-chain
// accept rules. start() is idempotent across Init/Close/Init cycles.
m.extMonitor.start()
return nil
}
// reconcileExternalChains re-applies passthrough accept rules to external
// filter chains for both IPv4 and IPv6 routers. Called by the monitor when
// tables or chains appear (e.g. after firewalld reloads).
func (m *Manager) reconcileExternalChains() error {
m.mutex.Lock()
defer m.mutex.Unlock()
var merr *multierror.Error
if m.router != nil {
if err := m.router.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("v4: %w", err))
}
}
if m.hasIPv6() {
if err := m.router6.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("v6: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (m *Manager) initFirewall() (err error) {
workTable, err := m.createWorkTable() workTable, err := m.createWorkTable()
if err != nil { if err != nil {
return fmt.Errorf("create work table: %w", err) return fmt.Errorf("create work table: %w", err)
} }
defer func() {
if err != nil {
m.rollbackInit()
}
}()
if err := m.router.init(workTable); err != nil { if err := m.router.init(workTable); err != nil {
return fmt.Errorf("router init: %w", err) return fmt.Errorf("router init: %w", err)
} }
if err := m.aclManager.init(workTable); err != nil { if err := m.aclManager.init(workTable); err != nil {
// TODO: cleanup router
return fmt.Errorf("acl manager init: %w", err) return fmt.Errorf("acl manager init: %w", err)
} }
if m.hasIPv6() {
if err := m.initIPv6(); err != nil {
// Peer has a v6 address: v6 firewall MUST work or we risk fail-open.
return fmt.Errorf("init IPv6 firewall (required because peer has IPv6 address): %w", err)
}
}
if err := m.initNoTrackChains(workTable); err != nil { if err := m.initNoTrackChains(workTable); err != nil {
log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err) return fmt.Errorf("init notrack chains: %w", err)
} }
return nil
}
// persistState saves the current interface state for potential recreation on restart.
// Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Close() without needing to store specific rules.
func (m *Manager) persistState(stateManager *statemanager.Manager) {
stateManager.RegisterState(&ShutdownState{}) stateManager.RegisterState(&ShutdownState{})
// We only need to record minimal interface state for potential recreation.
// Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Close() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{ if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{ InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(), NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(), WGAddress: m.wgIface.Address(),
MTU: m.router.mtu, UserspaceBind: m.wgIface.IsUserspaceBind(),
MTU: m.router.mtu,
}, },
}); err != nil { }); err != nil {
log.Errorf("failed to update state: %v", err) log.Errorf("failed to update state: %v", err)
} }
// persist early
go func() { go func() {
if err := stateManager.PersistState(context.Background()); err != nil { if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err) log.Errorf("failed to persist state: %v", err)
} }
}() }()
}
// rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through. return nil
func (m *Manager) rollbackInit() {
if err := m.router.Reset(); err != nil {
log.Warnf("rollback router: %v", err)
}
if m.hasIPv6() {
if err := m.router6.Reset(); err != nil {
log.Warnf("rollback v6 router: %v", err)
}
}
if err := m.cleanupNetbirdTables(); err != nil {
log.Warnf("cleanup tables: %v", err)
}
if err := m.rConn.Flush(); err != nil {
log.Warnf("flush: %v", err)
}
} }
// AddPeerFiltering rule to the firewall // AddPeerFiltering rule to the firewall
@@ -267,14 +141,12 @@ func (m *Manager) AddPeerFiltering(
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if ip.To4() != nil { rawIP := ip.To4()
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) if rawIP == nil {
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
} }
if !m.hasIPv6() { return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
}
return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
} }
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
@@ -288,11 +160,8 @@ func (m *Manager) AddRouteFiltering(
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if isIPv6RouteRule(sources, destination) { if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
if !m.hasIPv6() { return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
} }
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
@@ -303,66 +172,15 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.hasIPv6() && isIPv6Rule(rule) {
return m.aclManager6.DeletePeerRule(rule)
}
return m.aclManager.DeletePeerRule(rule) return m.aclManager.DeletePeerRule(rule)
} }
func isIPv6Rule(rule firewall.Rule) bool { // DeleteRouteRule deletes a routing rule
r, ok := rule.(*Rule)
return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6
}
// isIPv6RouteRule determines whether a route rule belongs to the v6 table.
// For static routes, the destination prefix determines the family. For dynamic
// routes (DomainSet), the sources determine the family since management
// duplicates dynamic rules per family.
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
return len(sources) > 0 && sources[0].Addr().Is6()
}
// DeleteRouteRule deletes a routing rule. Route rules live in exactly one
// router; the cached maps are normally authoritative, so the kernel is only
// consulted when neither map knows about the rule.
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
id := rule.ID() return m.router.DeleteRouteRule(rule)
r, err := m.routerForRuleID(id, (*router).hasRule)
if err != nil {
return err
}
return r.DeleteRouteRule(rule)
}
// routerForRuleID picks the router holding the rule with the given id, using
// the supplied lookup. If the cached maps disagree (or both miss), it refreshes
// from the kernel once and re-checks before falling back to the v4 router.
func (m *Manager) routerForRuleID(id string, has func(*router, string) bool) (*router, error) {
if has(m.router, id) {
return m.router, nil
}
if m.hasIPv6() && has(m.router6, id) {
return m.router6, nil
}
if !m.hasIPv6() {
return m.router, nil
}
if err := m.router.refreshRulesMap(); err != nil {
return nil, fmt.Errorf("refresh v4 rules: %w", err)
}
if err := m.router6.refreshRulesMap(); err != nil {
return nil, fmt.Errorf("refresh v6 rules: %w", err)
}
if has(m.router6, id) && !has(m.router, id) {
return m.router6, nil
}
return m.router, nil
} }
func (m *Manager) IsServerRouteSupported() bool { func (m *Manager) IsServerRouteSupported() bool {
@@ -377,136 +195,62 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { return m.router.AddNatRule(pair)
if !m.hasIPv6() {
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddNatRule(pair)
}
if err := m.router.AddNatRule(pair); err != nil {
return err
}
// Dynamic routes need NAT in both tables since resolved IPs can be
// either v4 or v6. This covers both DomainSet (modern) and the legacy
// wildcard 0.0.0.0/0 destination where the client resolves DNS.
// On v6 failure we keep the v4 NAT rule rather than rolling back: half
// connectivity is better than none, and RemoveNatRule is content-keyed
// so the eventual cleanup still works.
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.router6.AddNatRule(v6Pair); err != nil {
return fmt.Errorf("add v6 NAT rule: %w", err)
}
}
return nil
} }
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { return m.router.RemoveNatRule(pair)
if !m.hasIPv6() {
return nil
}
return m.router6.RemoveNatRule(pair)
}
var merr *multierror.Error
if err := m.router.RemoveNatRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
}
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
} }
// AllowNetbird allows netbird interface traffic. // AllowNetbird allows netbird interface traffic
// This is called when USPFilter wraps the native firewall, adding blanket accept
// rules so that packet filtering is handled in userspace instead of by netfilter.
//
// TODO: In USP mode this only adds ACCEPT to the netbird table's own chains,
// which doesn't override DROP rules in external tables (e.g. firewalld).
// Should add passthrough rules to external chains (like the native mode router's
// addExternalChainsRules does) for both the netbird table family and inet tables.
// The netbird table itself is fine (routing chains already exist there), but
// non-netbird tables with INPUT/FORWARD hooks can still DROP our WG traffic.
func (m *Manager) AllowNetbird() error { func (m *Manager) AllowNetbird() error {
if !m.wgIface.IsUserspaceBind() {
return nil
}
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if err := m.aclManager.createDefaultAllowRules(); err != nil { if err := m.aclManager.createDefaultAllowRules(); err != nil {
return fmt.Errorf("create default allow rules: %w", err) return fmt.Errorf("create default allow rules: %w", err)
} }
if m.hasIPv6() {
if err := m.aclManager6.createDefaultAllowRules(); err != nil {
return fmt.Errorf("create v6 default allow rules: %w", err)
}
}
if err := m.rConn.Flush(); err != nil { if err := m.rConn.Flush(); err != nil {
return fmt.Errorf("flush allow input netbird rules: %w", err) return fmt.Errorf("flush allow input netbird rules: %w", err)
} }
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nil return nil
} }
// SetLegacyManagement sets the route manager to use legacy management // SetLegacyManagement sets the route manager to use legacy management
func (m *Manager) SetLegacyManagement(isLegacy bool) error { func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil { return firewall.SetLegacyManagement(m.router, isLegacy)
return err
}
if m.hasIPv6() {
return firewall.SetLegacyManagement(m.router6, isLegacy)
}
return nil
} }
// Close closes the firewall manager // Close closes the firewall manager
func (m *Manager) Close(stateManager *statemanager.Manager) error { func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.extMonitor.stop()
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
var merr *multierror.Error
if err := m.router.Reset(); err != nil { if err := m.router.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset router: %v", err)) return fmt.Errorf("reset router: %v", err)
}
if m.hasIPv6() {
if err := m.router6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %v", err))
}
} }
if err := m.cleanupNetbirdTables(); err != nil { if err := m.cleanupNetbirdTables(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("cleanup netbird tables: %v", err)) return fmt.Errorf("cleanup netbird tables: %v", err)
} }
if err := m.rConn.Flush(); err != nil { if err := m.rConn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err)) return fmt.Errorf(flushError, err)
} }
if err := stateManager.DeleteState(&ShutdownState{}); err != nil { if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete state: %v", err)) return fmt.Errorf("delete state: %v", err)
} }
return nberrors.FormatErrorOrNil(merr) return nil
} }
func (m *Manager) cleanupNetbirdTables() error { func (m *Manager) cleanupNetbirdTables() error {
@@ -555,12 +299,6 @@ func (m *Manager) Flush() error {
return err return err
} }
if m.hasIPv6() {
if err := m.aclManager6.Flush(); err != nil {
return fmt.Errorf("flush v6 acl: %w", err)
}
}
if err := m.refreshNoTrackChains(); err != nil { if err := m.refreshNoTrackChains(); err != nil {
log.Errorf("failed to refresh notrack chains: %v", err) log.Errorf("failed to refresh notrack chains: %v", err)
} }
@@ -573,12 +311,6 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if rule.TranslatedAddress.Is6() {
if !m.hasIPv6() {
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddDNATRule(rule)
}
return m.router.AddDNATRule(rule) return m.router.AddDNATRule(rule)
} }
@@ -587,11 +319,7 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
r, err := m.routerForRuleID(rule.ID(), (*router).hasDNATRule) return m.router.DeleteDNATRule(rule)
if err != nil {
return err
}
return r.DeleteDNATRule(rule)
} }
// UpdateSet updates the set with the given prefixes // UpdateSet updates the set with the given prefixes
@@ -599,82 +327,23 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
var v4Prefixes, v6Prefixes []netip.Prefix return m.router.UpdateSet(set, prefixes)
for _, p := range prefixes {
if p.Addr().Is6() {
v6Prefixes = append(v6Prefixes, p)
} else {
v4Prefixes = append(v4Prefixes, p)
}
}
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
return err
}
if m.hasIPv6() && len(v6Prefixes) > 0 {
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
return fmt.Errorf("update v6 set: %w", err)
}
}
return nil
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if localAddr.Is6() { return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
if !m.hasIPv6() {
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
} }
// RemoveInboundDNAT removes an inbound DNAT rule. // RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if localAddr.Is6() { return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
if !m.hasIPv6() {
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if localAddr.Is6() {
if !m.hasIPv6() {
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if localAddr.Is6() {
if !m.hasIPv6() {
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
} }
const ( const (
@@ -848,11 +517,7 @@ func (m *Manager) refreshNoTrackChains() error {
} }
func (m *Manager) createWorkTable() (*nftables.Table, error) { func (m *Manager) createWorkTable() (*nftables.Table, error) {
return m.createWorkTableFamily(nftables.TableFamilyIPv4) tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
}
func (m *Manager) createWorkTableFamily(family nftables.TableFamily) (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(family)
if err != nil { if err != nil {
return nil, fmt.Errorf("list of tables: %w", err) return nil, fmt.Errorf("list of tables: %w", err)
} }
@@ -864,7 +529,7 @@ func (m *Manager) createWorkTableFamily(family nftables.TableFamily) (*nftables.
} }
} }
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: family}) table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4})
err = m.rConn.Flush() err = m.rConn.Flush()
return table, err return table, err
} }

View File

@@ -52,6 +52,8 @@ func (i *iFaceMock) Address() wgaddr.Address {
panic("AddressFunc is not set") panic("AddressFunc is not set")
} }
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) { func TestNftablesManager(t *testing.T) {
// just check on the local interface // just check on the local interface
@@ -383,138 +385,10 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
err = manager.AddNatRule(pair) err = manager.AddNatRule(pair)
require.NoError(t, err, "failed to add NAT rule") require.NoError(t, err, "failed to add NAT rule")
dnatRule, err := manager.AddDNATRule(fw.ForwardRule{
Protocol: fw.ProtocolTCP,
DestinationPort: fw.Port{Values: []uint16{8080}},
TranslatedAddress: netip.MustParseAddr("100.96.0.2"),
TranslatedPort: fw.Port{Values: []uint16{80}},
})
require.NoError(t, err, "failed to add DNAT rule")
t.Cleanup(func() {
require.NoError(t, manager.DeleteDNATRule(dnatRule), "failed to delete DNAT rule")
})
stdout, stderr = runIptablesSave(t) stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr) verifyIptablesOutput(t, stdout, stderr)
} }
func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
for _, bin := range []string{"ip6tables", "ip6tables-save", "iptables-save"} {
if _, err := exec.LookPath(bin); err != nil {
t.Skipf("%s not available on this system: %v", bin, err)
}
}
// Seed ip6 tables in the nft backend. Docker may not create them.
seedIp6tables(t)
ifaceMockV6 := &iFaceMock{
NameFunc: func() string { return "wt-test" },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.96.0.1"),
Network: netip.MustParsePrefix("100.96.0.0/16"),
IPv6: netip.MustParseAddr("fd00::1"),
IPv6Net: netip.MustParsePrefix("fd00::/64"),
}
},
}
manager, err := Create(ifaceMockV6, iface.DefaultMTU)
require.NoError(t, err, "create manager")
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
require.NoError(t, manager.Close(nil), "close manager")
stdout, stderr := runIp6tablesSave(t)
verifyIp6tablesOutput(t, stdout, stderr)
})
ip := netip.MustParseAddr("fd00::2")
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "add v6 peer filtering rule")
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00:1::/64")},
fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err, "add v6 route filtering rule")
err = manager.AddNatRule(fw.RouterPair{
Source: fw.Network{Prefix: netip.MustParsePrefix("fd00::/64")},
Destination: fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
Masquerade: true,
})
require.NoError(t, err, "add v6 NAT rule")
dnatRule, err := manager.AddDNATRule(fw.ForwardRule{
Protocol: fw.ProtocolTCP,
DestinationPort: fw.Port{Values: []uint16{8080}},
TranslatedAddress: netip.MustParseAddr("fd00::2"),
TranslatedPort: fw.Port{Values: []uint16{80}},
})
require.NoError(t, err, "add v6 DNAT rule")
t.Cleanup(func() {
require.NoError(t, manager.DeleteDNATRule(dnatRule), "delete v6 DNAT rule")
})
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
stdout, stderr = runIp6tablesSave(t)
verifyIp6tablesOutput(t, stdout, stderr)
}
func seedIp6tables(t *testing.T) {
t.Helper()
for _, tc := range []struct{ table, chain string }{
{"filter", "FORWARD"},
{"nat", "POSTROUTING"},
{"mangle", "FORWARD"},
} {
add := exec.Command("ip6tables", "-t", tc.table, "-A", tc.chain, "-j", "ACCEPT")
require.NoError(t, add.Run(), "seed ip6tables -t %s", tc.table)
del := exec.Command("ip6tables", "-t", tc.table, "-D", tc.chain, "-j", "ACCEPT")
require.NoError(t, del.Run(), "unseed ip6tables -t %s", tc.table)
}
}
func runIp6tablesSave(t *testing.T) (string, string) {
t.Helper()
var stdout, stderr bytes.Buffer
cmd := exec.Command("ip6tables-save")
cmd.Stdout = &stdout
cmd.Stderr = &stderr
require.NoError(t, cmd.Run(), "ip6tables-save failed")
return stdout.String(), stderr.String()
}
func verifyIp6tablesOutput(t *testing.T, stdout, stderr string) {
t.Helper()
for _, msg := range []string{
"Table `nat' is incompatible",
"Table `mangle' is incompatible",
"Table `filter' is incompatible",
} {
require.NotContains(t, stdout, msg,
"ip6tables-save stdout reports incompatibility: %s", stdout)
require.NotContains(t, stderr, msg,
"ip6tables-save stderr reports incompatibility: %s", stderr)
}
}
func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) { func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
if check() != NFTABLES { if check() != NFTABLES {
t.Skip("nftables not supported on this system") t.Skip("nftables not supported on this system")

File diff suppressed because it is too large Load Diff

View File

@@ -18,7 +18,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test" "github.com/netbirdio/netbird/client/firewall/test"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/acl/id"
) )
const ( const (
@@ -90,9 +89,8 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
} }
// Build CIDR matching expressions // Build CIDR matching expressions
testRouter := &router{af: afIPv4} sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true)
sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true) destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false)
destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false)
// Combine all expressions in the correct order // Combine all expressions in the correct order
// nolint:gocritic // nolint:gocritic
@@ -509,136 +507,6 @@ func TestNftablesCreateIpSet(t *testing.T) {
} }
} }
func TestNftablesCreateIpSet_IPv6(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTableIPv6()
require.NoError(t, err, "Failed to create v6 work table")
defer deleteWorkTableIPv6()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func() {
require.NoError(t, r.Reset(), "Failed to reset router")
}()
tests := []struct {
name string
sources []netip.Prefix
expected []netip.Prefix
}{
{
name: "Single IPv6",
sources: []netip.Prefix{netip.MustParsePrefix("2001:db8::1/128")},
},
{
name: "Multiple IPv6 Subnets",
sources: []netip.Prefix{
netip.MustParsePrefix("fd00::/64"),
netip.MustParsePrefix("2001:db8::/48"),
netip.MustParsePrefix("fe80::/10"),
},
},
{
name: "Overlapping IPv6",
sources: []netip.Prefix{
netip.MustParsePrefix("fd00::/48"),
netip.MustParsePrefix("fd00::/64"),
netip.MustParsePrefix("fd00::1/128"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("fd00::/48"),
},
},
{
name: "Mixed prefix lengths",
sources: []netip.Prefix{
netip.MustParsePrefix("2001:db8:1::/48"),
netip.MustParsePrefix("2001:db8:2::1/128"),
netip.MustParsePrefix("fd00:abcd::/32"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setName := firewall.NewPrefixSet(tt.sources).HashedName()
set, err := r.createIpSet(setName, setInput{prefixes: tt.sources})
require.NoError(t, err, "Failed to create IPv6 set")
require.NotNil(t, set)
assert.Equal(t, setName, set.Name)
assert.True(t, set.Interval)
assert.Equal(t, nftables.TypeIP6Addr, set.KeyType)
fetchedSet, err := r.conn.GetSetByName(r.workTable, setName)
require.NoError(t, err, "Failed to fetch created set")
elements, err := r.conn.GetSetElements(fetchedSet)
require.NoError(t, err, "Failed to get set elements")
uniquePrefixes := make(map[string]bool)
for _, elem := range elements {
if !elem.IntervalEnd && len(elem.Key) == 16 {
ip := netip.AddrFrom16([16]byte(elem.Key))
uniquePrefixes[ip.String()] = true
}
}
expectedCount := len(tt.expected)
if expectedCount == 0 {
expectedCount = len(tt.sources)
}
assert.Equal(t, expectedCount, len(uniquePrefixes), "unique prefix count mismatch")
r.conn.DelSet(set)
require.NoError(t, r.conn.Flush())
})
}
}
func createWorkTableIPv6() (*nftables.Table, error) {
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return nil, err
}
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
if err != nil {
return nil, err
}
for _, t := range tables {
if t.Name == tableNameNetbird {
sConn.DelTable(t)
}
}
table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv6})
err = sConn.Flush()
return table, err
}
func deleteWorkTableIPv6() {
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return
}
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
if err != nil {
return
}
for _, t := range tables {
if t.Name == tableNameNetbird {
sConn.DelTable(t)
_ = sConn.Flush()
}
}
}
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) { func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
t.Helper() t.Helper()
@@ -758,7 +626,7 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool { func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
var metaFound, cmpFound bool var metaFound, cmpFound bool
expectedProto, _ := afIPv4.protoNum(proto) expectedProto, _ := protoToInt(proto)
for _, e := range exprs { for _, e := range exprs {
switch ex := e.(type) { switch ex := e.(type) {
case *expr.Meta: case *expr.Meta:
@@ -851,189 +719,3 @@ func deleteWorkTable() {
} }
} }
} }
func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err)
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Add a real rule to the kernel
ruleKey, err := r.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
firewall.ProtocolTCP,
nil,
&firewall.Port{Values: []uint16{80}},
firewall.ActionAccept,
)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, r.DeleteRouteRule(ruleKey))
})
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
staleKey := "stale-rule-that-does-not-exist"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
}
require.Contains(t, r.rules, staleKey, "stale entry should be in map before refresh")
err = r.refreshRulesMap()
require.NoError(t, err)
assert.NotContains(t, r.rules, staleKey, "stale entry should be removed after refresh")
realRule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "real rule should still exist after refresh")
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
}
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err)
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Inject a stale entry with Handle=0
staleKey := "stale-route-rule"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
}
// DeleteRouteRule should not return an error for stale handles
err = r.DeleteRouteRule(id.RuleID(staleKey))
assert.NoError(t, err, "deleting a stale rule should not error")
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
}
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
pair := firewall.RouterPair{
ID: "staletest",
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Masquerade: true,
}
rtr := manager.router
// First add succeeds
err = rtr.AddNatRule(pair)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, rtr.RemoveNatRule(pair))
})
// Corrupt the handle to simulate stale state
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := rtr.rules[natRuleKey]; exists {
rule.Handle = 0
}
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
if rule, exists := rtr.rules[inverseKey]; exists {
rule.Handle = 0
}
// Adding the same rule again should succeed despite stale handles
err = rtr.AddNatRule(pair)
assert.NoError(t, err, "AddNatRule should succeed even with stale entries")
// Verify rules exist in kernel
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
require.NoError(t, err)
found := 0
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found++
}
}
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
}
func TestCalculateLastIP(t *testing.T) {
tests := []struct {
prefix string
want string
}{
{"10.0.0.0/24", "10.0.0.255"},
{"10.0.0.0/32", "10.0.0.0"},
{"0.0.0.0/0", "255.255.255.255"},
{"192.168.1.0/28", "192.168.1.15"},
{"fd00::/64", "fd00::ffff:ffff:ffff:ffff"},
{"fd00::/128", "fd00::"},
{"2001:db8::/48", "2001:db8:0:ffff:ffff:ffff:ffff:ffff"},
{"::/0", "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"},
}
for _, tt := range tests {
t.Run(tt.prefix, func(t *testing.T) {
prefix := netip.MustParsePrefix(tt.prefix)
got := calculateLastIP(prefix)
assert.Equal(t, tt.want, got.String())
})
}
}
func TestConvertPrefixesToSet_IPv6(t *testing.T) {
r := &router{af: afIPv6}
prefixes := []netip.Prefix{
netip.MustParsePrefix("fd00::/64"),
netip.MustParsePrefix("2001:db8::1/128"),
}
elements := r.convertPrefixesToSet(prefixes)
// Each prefix produces 2 elements (start + end)
require.Len(t, elements, 4)
// fd00::/64 start
assert.Equal(t, netip.MustParseAddr("fd00::").As16(), [16]byte(elements[0].Key))
assert.False(t, elements[0].IntervalEnd)
// fd00::/64 end (fd00:0:0:1::, one past the last)
assert.Equal(t, netip.MustParseAddr("fd00:0:0:1::").As16(), [16]byte(elements[1].Key))
assert.True(t, elements[1].IntervalEnd)
// 2001:db8::1/128 start
assert.Equal(t, netip.MustParseAddr("2001:db8::1").As16(), [16]byte(elements[2].Key))
assert.False(t, elements[2].IntervalEnd)
// 2001:db8::1/128 end (2001:db8::2)
assert.Equal(t, netip.MustParseAddr("2001:db8::2").As16(), [16]byte(elements[3].Key))
assert.True(t, elements[3].IntervalEnd)
}

View File

@@ -8,9 +8,10 @@ import (
) )
type InterfaceState struct { type InterfaceState struct {
NameStr string `json:"name"` NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"` WGAddress wgaddr.Address `json:"wg_address"`
MTU uint16 `json:"mtu"` UserspaceBind bool `json:"userspace_bind"`
MTU uint16 `json:"mtu"`
} }
func (i *InterfaceState) Name() string { func (i *InterfaceState) Name() string {
@@ -21,6 +22,10 @@ func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress return i.WGAddress
} }
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct { type ShutdownState struct {
InterfaceState *InterfaceState `json:"interface_state,omitempty"` InterfaceState *InterfaceState `json:"interface_state,omitempty"`
} }

View File

@@ -3,9 +3,12 @@
package uspfilter package uspfilter
import ( import (
"context"
"net/netip"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/firewalld"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -14,14 +17,37 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.resetState() m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
if m.nativeFirewall != nil { if m.nativeFirewall != nil {
return m.nativeFirewall.Close(stateManager) return m.nativeFirewall.Close(stateManager)
} }
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to untrust interface in firewalld: %v", err)
}
return nil return nil
} }
@@ -30,8 +56,5 @@ func (m *Manager) AllowNetbird() error {
if m.nativeFirewall != nil { if m.nativeFirewall != nil {
return m.nativeFirewall.AllowNetbird() return m.nativeFirewall.AllowNetbird()
} }
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nil return nil
} }

View File

@@ -1,14 +1,15 @@
package uspfilter package uspfilter
import ( import (
"context"
"fmt" "fmt"
"net/netip"
"os/exec" "os/exec"
"syscall" "syscall"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -25,26 +26,47 @@ func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.resetState() m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
if !isWindowsFirewallReachable() { if !isWindowsFirewallReachable() {
return nil return nil
} }
var merr *multierror.Error if !isFirewallRuleActive(firewallRuleName) {
if isFirewallRuleActive(firewallRuleName) { return nil
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove windows firewall rule: %w", err))
}
} }
if isFirewallRuleActive(firewallRuleName + "-v6") { if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
if err := manageFirewallRule(firewallRuleName+"-v6", deleteRule); err != nil { return fmt.Errorf("couldn't remove windows firewall: %w", err)
merr = multierror.Append(merr, fmt.Errorf("remove windows v6 firewall rule: %w", err))
}
} }
return nberrors.FormatErrorOrNil(merr) return nil
} }
// AllowNetbird allows netbird interface traffic // AllowNetbird allows netbird interface traffic
@@ -53,33 +75,17 @@ func (m *Manager) AllowNetbird() error {
return nil return nil
} }
if !isFirewallRuleActive(firewallRuleName) { if isFirewallRuleActive(firewallRuleName) {
if err := manageFirewallRule(firewallRuleName, return nil
addRule,
"dir=in",
"enable=yes",
"action=allow",
"profile=any",
"localip="+m.wgIface.Address().IP.String(),
); err != nil {
return err
}
} }
return manageFirewallRule(firewallRuleName,
if v6 := m.wgIface.Address().IPv6; v6.IsValid() && !isFirewallRuleActive(firewallRuleName+"-v6") { addRule,
if err := manageFirewallRule(firewallRuleName+"-v6", "dir=in",
addRule, "enable=yes",
"dir=in", "action=allow",
"enable=yes", "profile=any",
"action=allow", "localip="+m.wgIface.Address().IP.String(),
"profile=any", )
"localip="+v6.String(),
); err != nil {
return err
}
}
return nil
} }
func manageFirewallRule(ruleName string, action action, extraArgs ...string) error { func manageFirewallRule(ruleName string, action action, extraArgs ...string) error {

View File

@@ -1,37 +0,0 @@
package common
import (
"net/netip"
"sync/atomic"
)
// PacketHook stores a registered hook for a specific IP:port.
type PacketHook struct {
IP netip.Addr
Port uint16
Fn func([]byte) bool
}
// HookMatches checks if a packet's destination matches the hook and invokes it.
func HookMatches(h *PacketHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
if h == nil {
return false
}
if h.IP == dstIP && h.Port == dport {
return h.Fn(packetData)
}
return false
}
// SetHook atomically stores a hook, handling nil removal.
func SetHook(ptr *atomic.Pointer[PacketHook], ip netip.Addr, dPort uint16, hook func([]byte) bool) {
if hook == nil {
ptr.Store(nil)
return
}
ptr.Store(&PacketHook{
IP: ip,
Port: dPort,
Fn: hook,
})
}

View File

@@ -9,7 +9,6 @@ import (
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
Name() string
SetFilter(device.PacketFilter) error SetFilter(device.PacketFilter) error
Address() wgaddr.Address Address() wgaddr.Address
GetWGDevice() *wgdevice.Device GetWGDevice() *wgdevice.Device

View File

@@ -1,9 +1,8 @@
package conntrack package conntrack
import ( import (
"net" "fmt"
"net/netip" "net/netip"
"strconv"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -65,7 +64,5 @@ type ConnKey struct {
} }
func (c ConnKey) String() string { func (c ConnKey) String() string {
return net.JoinHostPort(c.SrcIP.Unmap().String(), strconv.Itoa(int(c.SrcPort))) + return fmt.Sprintf("%s:%d → %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
" → " +
net.JoinHostPort(c.DstIP.Unmap().String(), strconv.Itoa(int(c.DstPort)))
} }

View File

@@ -13,54 +13,6 @@ import (
var logger = log.NewFromLogrus(logrus.StandardLogger()) var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger() var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
func TestConnKey_String(t *testing.T) {
tests := []struct {
name string
key ConnKey
expect string
}{
{
name: "IPv4",
key: ConnKey{
SrcIP: netip.MustParseAddr("192.168.1.1"),
DstIP: netip.MustParseAddr("10.0.0.1"),
SrcPort: 12345,
DstPort: 80,
},
expect: "192.168.1.1:12345 → 10.0.0.1:80",
},
{
name: "IPv6",
key: ConnKey{
SrcIP: netip.MustParseAddr("2001:db8::1"),
DstIP: netip.MustParseAddr("2001:db8::2"),
SrcPort: 54321,
DstPort: 443,
},
expect: "[2001:db8::1]:54321 → [2001:db8::2]:443",
},
{
name: "IPv4-mapped IPv6 unmaps",
key: ConnKey{
SrcIP: netip.MustParseAddr("::ffff:10.0.0.1"),
DstIP: netip.MustParseAddr("::ffff:10.0.0.2"),
SrcPort: 1000,
DstPort: 2000,
},
expect: "10.0.0.1:1000 → 10.0.0.2:2000",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := tc.key.String()
if got != tc.expect {
t.Errorf("got %q, want %q", got, tc.expect)
}
})
}
}
// Memory pressure tests // Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) { func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) { b.Run("TCPHighLoad", func(b *testing.B) {

View File

@@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"strconv"
"sync" "sync"
"time" "time"
@@ -22,14 +21,9 @@ const (
// ICMPCleanupInterval is how often we check for stale ICMP connections // ICMPCleanupInterval is how often we check for stale ICMP connections
ICMPCleanupInterval = 15 * time.Second ICMPCleanupInterval = 15 * time.Second
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info. // MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info,
// IPv4: 20-byte header + 8-byte transport = 28 bytes. // which includes the IP header (20 bytes) and transport header (8 bytes)
// IPv6: 40-byte header + 8-byte transport = 48 bytes. MaxICMPPayloadLength = 28
MaxICMPPayloadLength = 48
// minICMPPayloadIPv4 is the minimum embedded packet length for IPv4 ICMP errors.
minICMPPayloadIPv4 = 28
// minICMPPayloadIPv6 is the minimum embedded packet length for IPv6 ICMP errors.
minICMPPayloadIPv6 = 48
) )
// ICMPConnKey uniquely identifies an ICMP connection // ICMPConnKey uniquely identifies an ICMP connection
@@ -71,7 +65,7 @@ type ICMPInfo struct {
// String implements fmt.Stringer for lazy evaluation in log messages // String implements fmt.Stringer for lazy evaluation in log messages
func (info ICMPInfo) String() string { func (info ICMPInfo) String() string {
if info.isErrorMessage() && info.PayloadLen >= minICMPPayloadIPv4 { if info.isErrorMessage() && info.PayloadLen >= MaxICMPPayloadLength {
if origInfo := info.parseOriginalPacket(); origInfo != "" { if origInfo := info.parseOriginalPacket(); origInfo != "" {
return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo) return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo)
} }
@@ -80,72 +74,42 @@ func (info ICMPInfo) String() string {
return info.TypeCode.String() return info.TypeCode.String()
} }
// isErrorMessage returns true if this ICMP type carries original packet info. // isErrorMessage returns true if this ICMP type carries original packet info
// Covers both ICMPv4 and ICMPv6 error types. Without a family field we match
// both sets; type 3 overlaps (v4 DestUnreachable / v6 TimeExceeded) so it's
// kept as a literal.
func (info ICMPInfo) isErrorMessage() bool { func (info ICMPInfo) isErrorMessage() bool {
typ := info.TypeCode.Type() typ := info.TypeCode.Type()
// ICMPv4 error types return typ == 3 || // Destination Unreachable
if typ == layers.ICMPv4TypeDestinationUnreachable || typ == 5 || // Redirect
typ == layers.ICMPv4TypeRedirect || typ == 11 || // Time Exceeded
typ == layers.ICMPv4TypeTimeExceeded || typ == 12 // Parameter Problem
typ == layers.ICMPv4TypeParameterProblem {
return true
}
// ICMPv6 error types (type 3 already matched above as v4 DestUnreachable)
if typ == layers.ICMPv6TypeDestinationUnreachable ||
typ == layers.ICMPv6TypePacketTooBig ||
typ == layers.ICMPv6TypeParameterProblem {
return true
}
return false
} }
// parseOriginalPacket extracts info about the original packet from ICMP payload // parseOriginalPacket extracts info about the original packet from ICMP payload
func (info ICMPInfo) parseOriginalPacket() string { func (info ICMPInfo) parseOriginalPacket() string {
if info.PayloadLen == 0 { if info.PayloadLen < MaxICMPPayloadLength {
return "" return ""
} }
version := (info.PayloadData[0] >> 4) & 0xF // TODO: handle IPv6
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
var protocol uint8
var srcIP, dstIP net.IP
var transportData []byte
switch version {
case 4:
if info.PayloadLen < minICMPPayloadIPv4 {
return ""
}
protocol = info.PayloadData[9]
srcIP = net.IP(info.PayloadData[12:16])
dstIP = net.IP(info.PayloadData[16:20])
transportData = info.PayloadData[20:]
case 6:
if info.PayloadLen < minICMPPayloadIPv6 {
return ""
}
// Next Header field in IPv6 header
protocol = info.PayloadData[6]
srcIP = net.IP(info.PayloadData[8:24])
dstIP = net.IP(info.PayloadData[24:40])
transportData = info.PayloadData[40:]
default:
return "" return ""
} }
protocol := info.PayloadData[9]
srcIP := net.IP(info.PayloadData[12:16])
dstIP := net.IP(info.PayloadData[16:20])
transportData := info.PayloadData[20:]
switch nftypes.Protocol(protocol) { switch nftypes.Protocol(protocol) {
case nftypes.TCP: case nftypes.TCP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1]) srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3]) dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return "TCP " + net.JoinHostPort(srcIP.String(), strconv.Itoa(int(srcPort))) + " → " + net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort))) return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
case nftypes.UDP: case nftypes.UDP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1]) srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3]) dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return "UDP " + net.JoinHostPort(srcIP.String(), strconv.Itoa(int(srcPort))) + " → " + net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort))) return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
case nftypes.ICMP: case nftypes.ICMP:
icmpType := transportData[0] icmpType := transportData[0]
@@ -283,10 +247,9 @@ func (t *ICMPTracker) track(
t.sendEvent(nftypes.TypeStart, conn, ruleId) t.sendEvent(nftypes.TypeStart, conn, ruleId)
} }
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request. // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
// Accepts both ICMPv4 (type 0) and ICMPv6 (type 129) echo replies.
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool { func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) && icmpType != uint8(layers.ICMPv6TypeEchoReply) { if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
return false return false
} }
@@ -338,13 +301,6 @@ func (t *ICMPTracker) cleanup() {
} }
} }
func icmpProtocolForAddr(ip netip.Addr) nftypes.Protocol {
if ip.Is6() {
return nftypes.ICMPv6
}
return nftypes.ICMP
}
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *ICMPTracker) Close() { func (t *ICMPTracker) Close() {
t.tickerCancel() t.tickerCancel()
@@ -360,7 +316,7 @@ func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []
Type: typ, Type: typ,
RuleID: ruleID, RuleID: ruleID,
Direction: conn.Direction, Direction: conn.Direction,
Protocol: icmpProtocolForAddr(conn.SourceIP), Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
SourceIP: conn.SourceIP, SourceIP: conn.SourceIP,
DestIP: conn.DestIP, DestIP: conn.DestIP,
ICMPType: conn.ICMPType, ICMPType: conn.ICMPType,
@@ -378,7 +334,7 @@ func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Ad
Type: nftypes.TypeStart, Type: nftypes.TypeStart,
RuleID: ruleID, RuleID: ruleID,
Direction: direction, Direction: direction,
Protocol: icmpProtocolForAddr(srcIP), Protocol: nftypes.ICMP,
SourceIP: srcIP, SourceIP: srcIP,
DestIP: dstIP, DestIP: dstIP,
ICMPType: typ, ICMPType: typ,

View File

@@ -5,42 +5,6 @@ import (
"testing" "testing"
) )
func TestICMPConnKey_String(t *testing.T) {
tests := []struct {
name string
key ICMPConnKey
expect string
}{
{
name: "IPv4",
key: ICMPConnKey{
SrcIP: netip.MustParseAddr("192.168.1.1"),
DstIP: netip.MustParseAddr("10.0.0.1"),
ID: 1234,
},
expect: "192.168.1.1 → 10.0.0.1 (id 1234)",
},
{
name: "IPv6",
key: ICMPConnKey{
SrcIP: netip.MustParseAddr("2001:db8::1"),
DstIP: netip.MustParseAddr("2001:db8::2"),
ID: 5678,
},
expect: "2001:db8::1 → 2001:db8::2 (id 5678)",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := tc.key.String()
if got != tc.expect {
t.Errorf("got %q, want %q", got, tc.expect)
}
})
}
}
func BenchmarkICMPTracker(b *testing.B) { func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger) tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)

View File

@@ -115,17 +115,6 @@ func (t *TCPConnTrack) IsTombstone() bool {
return t.tombstone.Load() return t.tombstone.Load()
} }
// IsSupersededBy returns true if this connection should be replaced by a new one
// carrying the given flags. Tombstoned connections are always superseded; TIME-WAIT
// connections are superseded by a pure SYN (a new connection attempt for the same
// four-tuple, as contemplated by RFC 1122 §4.2.2.13 and RFC 6191).
func (t *TCPConnTrack) IsSupersededBy(flags uint8) bool {
if t.tombstone.Load() {
return true
}
return flags&TCPSyn != 0 && flags&TCPAck == 0 && TCPState(t.state.Load()) == TCPStateTimeWait
}
// SetTombstone safely marks the connection for deletion // SetTombstone safely marks the connection for deletion
func (t *TCPConnTrack) SetTombstone() { func (t *TCPConnTrack) SetTombstone() {
t.tombstone.Store(true) t.tombstone.Store(true)
@@ -180,7 +169,7 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
conn, exists := t.connections[key] conn, exists := t.connections[key]
t.mutex.RUnlock() t.mutex.RUnlock()
if exists && !conn.IsSupersededBy(flags) { if exists {
t.updateState(key, conn, flags, direction, size) t.updateState(key, conn, flags, direction, size)
return key, uint16(conn.DNATOrigPort.Load()), true return key, uint16(conn.DNATOrigPort.Load()), true
} }
@@ -252,7 +241,7 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
conn, exists := t.connections[key] conn, exists := t.connections[key]
t.mutex.RUnlock() t.mutex.RUnlock()
if !exists || conn.IsSupersededBy(flags) { if !exists || conn.IsTombstone() {
return false return false
} }

View File

@@ -485,261 +485,6 @@ func TestTCPAbnormalSequences(t *testing.T) {
}) })
} }
// TestTCPPortReuseTombstone verifies that a new connection on a port with a
// tombstoned (closed) conntrack entry is properly tracked. Without the fix,
// updateIfExists treats tombstoned entries as live, causing track() to skip
// creating a new connection. The subsequent SYN-ACK then fails IsValidInbound
// because the entry is tombstoned, and the response packet gets dropped by ACL.
func TestTCPPortReuseTombstone(t *testing.T) {
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Outbound port reuse after graceful close", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and gracefully close a connection (server-initiated close)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Server sends FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
// Client sends FIN-ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
// Server sends final ACK
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
// Connection should be tombstoned
conn := tracker.connections[key]
require.NotNil(t, conn, "old connection should still be in map")
require.True(t, conn.IsTombstone(), "old connection should be tombstoned")
// Now reuse the same port for a new connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
// The old tombstoned entry should be replaced with a new one
newConn := tracker.connections[key]
require.NotNil(t, newConn, "new connection should exist")
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
require.Equal(t, TCPStateSynSent, newConn.GetState())
// SYN-ACK for the new connection should be valid
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK for new connection on reused port should be accepted")
require.Equal(t, TCPStateEstablished, newConn.GetState())
// Data transfer should work
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 500)
require.True(t, valid, "data should be allowed on new connection")
})
t.Run("Outbound port reuse after RST", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and RST a connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
require.True(t, valid)
conn := tracker.connections[key]
require.True(t, conn.IsTombstone(), "RST connection should be tombstoned")
// Reuse the same port
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
newConn := tracker.connections[key]
require.NotNil(t, newConn)
require.False(t, newConn.IsTombstone())
require.Equal(t, TCPStateSynSent, newConn.GetState())
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK should be accepted after RST tombstone")
})
t.Run("Inbound port reuse after close", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
clientIP := srcIP
serverIP := dstIP
clientPort := srcPort
serverPort := dstPort
key := ConnKey{SrcIP: clientIP, DstIP: serverIP, SrcPort: clientPort, DstPort: serverPort}
// Inbound connection: client SYN → server SYN-ACK → client ACK
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateEstablished, conn.GetState())
// Server-initiated close to reach Closed/tombstoned:
// Server FIN (opposite dir) → CloseWait
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPFin|TCPAck, 100)
require.Equal(t, TCPStateCloseWait, conn.GetState())
// Client FIN-ACK (same dir as conn) → LastAck
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPFin|TCPAck, nil, 100, 0)
require.Equal(t, TCPStateLastAck, conn.GetState())
// Server final ACK (opposite dir) → Closed → tombstoned
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
require.True(t, conn.IsTombstone())
// New inbound connection on same ports
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
newConn := tracker.connections[key]
require.NotNil(t, newConn)
require.False(t, newConn.IsTombstone())
require.Equal(t, TCPStateSynReceived, newConn.GetState())
// Complete handshake: server SYN-ACK, then client ACK
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
require.Equal(t, TCPStateEstablished, newConn.GetState())
})
t.Run("Late ACK on tombstoned connection is harmless", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and close via passive close (server-initiated FIN → Closed → tombstoned)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) // CloseWait
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // LastAck
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) // Closed
conn := tracker.connections[key]
require.True(t, conn.IsTombstone())
// Late ACK should be rejected (tombstoned)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.False(t, valid, "late ACK on tombstoned connection should be rejected")
// Late outbound ACK should not create a new connection (not a SYN)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.True(t, tracker.connections[key].IsTombstone(), "late outbound ACK should not replace tombstoned entry")
})
}
func TestTCPPortReuseTimeWait(t *testing.T) {
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Outbound port reuse during TIME-WAIT (active close)", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Active close: client (outbound initiator) sends FIN first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Server ACKs the FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateFinWait2, conn.GetState())
// Server sends its own FIN
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Client sends final ACK (TIME-WAIT stays, not tombstoned)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.False(t, conn.IsTombstone(), "TIME-WAIT should not be tombstoned")
// New outbound SYN on the same port (port reuse during TIME-WAIT)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
// Per RFC 1122/6191, new SYN during TIME-WAIT should start a new connection
newConn := tracker.connections[key]
require.NotNil(t, newConn, "new connection should exist")
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
require.Equal(t, TCPStateSynSent, newConn.GetState(), "new connection should be in SYN-SENT")
// SYN-ACK for new connection should be valid
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK for new connection should be accepted")
require.Equal(t, TCPStateEstablished, newConn.GetState())
})
t.Run("Inbound SYN during TIME-WAIT falls through to normal tracking", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish outbound connection and close via active close → TIME-WAIT
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Inbound SYN on same ports during TIME-WAIT: IsValidInbound returns false
// so the filter falls through to ACL check + TrackInbound (which creates
// a new connection via track() → updateIfExists skips TIME-WAIT for SYN)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, 0)
require.False(t, valid, "inbound SYN during TIME-WAIT should fail conntrack validation")
// Simulate what the filter does next: TrackInbound via the normal path
tracker.TrackInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, nil, 100, 0)
// The new inbound connection uses the inverted key (dst→src becomes src→dst in track)
invertedKey := ConnKey{SrcIP: dstIP, DstIP: srcIP, SrcPort: dstPort, DstPort: srcPort}
newConn := tracker.connections[invertedKey]
require.NotNil(t, newConn, "new inbound connection should be tracked")
require.Equal(t, TCPStateSynReceived, newConn.GetState())
require.False(t, newConn.IsTombstone())
})
t.Run("Late retransmit during TIME-WAIT still allowed", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and active close → TIME-WAIT
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Late ACK retransmits during TIME-WAIT should still be accepted
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "retransmitted ACK during TIME-WAIT should be accepted")
})
}
func TestTCPTimeoutHandling(t *testing.T) { func TestTCPTimeoutHandling(t *testing.T) {
// Create tracker with a very short timeout for testing // Create tracker with a very short timeout for testing
shortTimeout := 100 * time.Millisecond shortTimeout := 100 * time.Millisecond

View File

@@ -1,7 +1,6 @@
package uspfilter package uspfilter
import ( import (
"context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@@ -13,22 +12,18 @@ import (
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common" "github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -36,10 +31,8 @@ import (
const ( const (
layerTypeAll = 255 layerTypeAll = 255
// ipv4TCPHeaderMinSize represents minimum IPv4 (20) + TCP (20) header size for MSS calculation // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipv4TCPHeaderMinSize = 40 ipTCPHeaderMinSize = 40
// ipv6TCPHeaderMinSize represents minimum IPv6 (40) + TCP (20) header size for MSS calculation
ipv6TCPHeaderMinSize = 60
) )
// serviceKey represents a protocol/port combination for netstack service registry // serviceKey represents a protocol/port combination for netstack service registry
@@ -96,7 +89,6 @@ type Manager struct {
incomingDenyRules map[netip.Addr]RuleSet incomingDenyRules map[netip.Addr]RuleSet
incomingRules map[netip.Addr]RuleSet incomingRules map[netip.Addr]RuleSet
routeRules RouteRules routeRules RouteRules
routeRulesMap map[nbid.RuleID]*RouteRule
decoders sync.Pool decoders sync.Pool
wgIface common.IFaceMapper wgIface common.IFaceMapper
nativeFirewall firewall.Manager nativeFirewall firewall.Manager
@@ -118,15 +110,14 @@ type Manager struct {
localipmanager *localIPManager localipmanager *localIPManager
udpTracker *conntrack.UDPTracker udpTracker *conntrack.UDPTracker
icmpTracker *conntrack.ICMPTracker icmpTracker *conntrack.ICMPTracker
tcpTracker *conntrack.TCPTracker tcpTracker *conntrack.TCPTracker
forwarder atomic.Pointer[forwarder.Forwarder] forwarder atomic.Pointer[forwarder.Forwarder]
pendingCapture atomic.Pointer[forwarder.PacketCapture] logger *nblog.Logger
logger *nblog.Logger flowLogger nftypes.FlowLogger
flowLogger nftypes.FlowLogger
blockRules []firewall.Rule blockRule firewall.Rule
// Internal 1:1 DNAT // Internal 1:1 DNAT
dnatEnabled atomic.Bool dnatEnabled atomic.Bool
@@ -141,14 +132,9 @@ type Manager struct {
netstackServices map[serviceKey]struct{} netstackServices map[serviceKey]struct{}
netstackServiceMutex sync.RWMutex netstackServiceMutex sync.RWMutex
mtu uint16 mtu uint16
mssClampValueIPv4 uint16 mssClampValue uint16
mssClampValueIPv6 uint16 mssClampEnabled bool
mssClampEnabled bool
// Only one hook per protocol is supported. Outbound direction only.
udpHookOut atomic.Pointer[common.PacketHook]
tcpHookOut atomic.Pointer[common.PacketHook]
} }
// decoder for packages // decoder for packages
@@ -161,28 +147,11 @@ type decoder struct {
icmp4 layers.ICMPv4 icmp4 layers.ICMPv4
icmp6 layers.ICMPv6 icmp6 layers.ICMPv6
decoded []gopacket.LayerType decoded []gopacket.LayerType
parser4 *gopacket.DecodingLayerParser parser *gopacket.DecodingLayerParser
parser6 *gopacket.DecodingLayerParser
dnatOrigPort uint16 dnatOrigPort uint16
} }
// decodePacket decodes packet data using the appropriate parser based on IP version.
func (d *decoder) decodePacket(data []byte) error {
if len(data) == 0 {
return errors.New("empty packet")
}
version := data[0] >> 4
switch version {
case 4:
return d.parser4.DecodeLayers(data, &d.decoded)
case 6:
return d.parser6.DecodeLayers(data, &d.decoded)
default:
return fmt.Errorf("unknown IP version %d", version)
}
}
// Create userspace firewall manager constructor // Create userspace firewall manager constructor
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) { func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
return create(iface, nil, disableServerRoutes, flowLogger, mtu) return create(iface, nil, disableServerRoutes, flowLogger, mtu)
@@ -240,17 +209,11 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
d := &decoder{ d := &decoder{
decoded: []gopacket.LayerType{}, decoded: []gopacket.LayerType{},
} }
d.parser4 = gopacket.NewDecodingLayerParser( d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4, layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
) )
d.parser4.IgnoreUnsupported = true d.parser.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
return d return d
}, },
}, },
@@ -266,7 +229,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
flowLogger: flowLogger, flowLogger: flowLogger,
netstack: netstack.IsEnabled(), netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding, localForwarding: enableLocalForwarding,
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
dnatMappings: make(map[netip.Addr]netip.Addr), dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{}, portDNATRules: []portDNATRule{},
netstackServices: make(map[serviceKey]struct{}), netstackServices: make(map[serviceKey]struct{}),
@@ -276,12 +238,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
if !disableMSSClamping { if !disableMSSClamping {
m.mssClampEnabled = true m.mssClampEnabled = true
if mtu > ipv4TCPHeaderMinSize { m.mssClampValue = mtu - ipTCPHeaderMinSize
m.mssClampValueIPv4 = mtu - ipv4TCPHeaderMinSize
}
if mtu > ipv6TCPHeaderMinSize {
m.mssClampValueIPv6 = mtu - ipv6TCPHeaderMinSize
}
} }
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
return nil, fmt.Errorf("update local IPs: %w", err) return nil, fmt.Errorf("update local IPs: %w", err)
@@ -304,25 +261,13 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
return m, nil return m, nil
} }
// blockInvalidRouted installs drop rules for traffic to the wg overlay that func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
// arrives via the routing path. v4 and v6 are independent: a v6 install
// failure leaves v4 protection in place (and vice versa) so the returned
// slice always contains whatever was successfully installed, even on error.
// Callers must persist the slice so DisableRouting can clean partial state.
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule, error) {
wgPrefix := iface.Address().Network wgPrefix := iface.Address().Network
log.Debugf("blocking invalid routed traffic for %s", wgPrefix) log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
sources := []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)} rule, err := m.addRouteFiltering(
v6Net := iface.Address().IPv6Net
if v6Net.IsValid() {
sources = append(sources, netip.PrefixFrom(netip.IPv6Unspecified(), 0))
}
var rules []firewall.Rule
v4Rule, err := m.addRouteFiltering(
nil, nil,
sources, []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
firewall.Network{Prefix: wgPrefix}, firewall.Network{Prefix: wgPrefix},
firewall.ProtocolALL, firewall.ProtocolALL,
nil, nil,
@@ -330,30 +275,12 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule,
firewall.ActionDrop, firewall.ActionDrop,
) )
if err != nil { if err != nil {
return rules, fmt.Errorf("block wg v4 net: %w", err) return nil, fmt.Errorf("block wg nte : %w", err)
}
rules = append(rules, v4Rule)
if v6Net.IsValid() {
log.Debugf("blocking invalid routed traffic for %s", v6Net)
v6Rule, err := m.addRouteFiltering(
nil,
sources,
firewall.Network{Prefix: v6Net},
firewall.ProtocolALL,
nil,
nil,
firewall.ActionDrop,
)
if err != nil {
return rules, fmt.Errorf("block wg v6 net: %w", err)
}
rules = append(rules, v6Rule)
} }
// TODO: Block networks that we're a client of // TODO: Block networks that we're a client of
return rules, nil return rule, nil
} }
func (m *Manager) determineRouting() error { func (m *Manager) determineRouting() error {
@@ -414,19 +341,6 @@ func (m *Manager) determineRouting() error {
return nil return nil
} }
// SetPacketCapture sets or clears packet capture on the forwarder endpoint.
// This captures outbound response packets that bypass the FilteredDevice in netstack mode.
func (m *Manager) SetPacketCapture(pc forwarder.PacketCapture) {
if pc == nil {
m.pendingCapture.Store(nil)
} else {
m.pendingCapture.Store(&pc)
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.SetCapture(pc)
}
}
// initForwarder initializes the forwarder, it disables routing on errors // initForwarder initializes the forwarder, it disables routing on errors
func (m *Manager) initForwarder() error { func (m *Manager) initForwarder() error {
if m.forwarder.Load() != nil { if m.forwarder.Load() != nil {
@@ -448,11 +362,6 @@ func (m *Manager) initForwarder() error {
m.forwarder.Store(forwarder) m.forwarder.Store(forwarder)
// Re-load after store: a concurrent SetPacketCapture may have seen forwarder as nil and only updated pendingCapture.
if pc := m.pendingCapture.Load(); pc != nil {
forwarder.SetCapture(*pc)
}
log.Debug("forwarder initialized") log.Debug("forwarder initialized")
return nil return nil
@@ -571,19 +480,15 @@ func (m *Manager) addRouteFiltering(
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
} }
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) ruleID := uuid.New().String()
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
return existingRule, nil
}
rule := RouteRule{ rule := RouteRule{
// TODO: consolidate these IDs // TODO: consolidate these IDs
id: string(ruleKey), id: ruleID,
mgmtId: id, mgmtId: id,
sources: sources, sources: sources,
dstSet: destination.Set, dstSet: destination.Set,
protoLayer: protoToLayer(proto, ipLayerFromPrefix(destination.Prefix)), protoLayer: protoToLayer(proto, layers.LayerTypeIPv4),
srcPort: sPort, srcPort: sPort,
dstPort: dPort, dstPort: dPort,
action: action, action: action,
@@ -594,7 +499,6 @@ func (m *Manager) addRouteFiltering(
m.routeRules = append(m.routeRules, &rule) m.routeRules = append(m.routeRules, &rule)
m.routeRules.Sort() m.routeRules.Sort()
m.routeRulesMap[ruleKey] = &rule
return &rule, nil return &rule, nil
} }
@@ -611,20 +515,15 @@ func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
return m.nativeFirewall.DeleteRouteRule(rule) return m.nativeFirewall.DeleteRouteRule(rule)
} }
ruleKey := nbid.RuleID(rule.ID()) ruleID := rule.ID()
if _, ok := m.routeRulesMap[ruleKey]; !ok {
return fmt.Errorf("route rule not found: %s", ruleKey)
}
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool { idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
return r.id == string(ruleKey) return r.id == ruleID
}) })
if idx < 0 { if idx < 0 {
return fmt.Errorf("route rule not found in slice: %s", ruleKey) return fmt.Errorf("route rule not found: %s", ruleID)
} }
m.routeRules = slices.Delete(m.routeRules, idx, idx+1) m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
delete(m.routeRulesMap, ruleKey)
return nil return nil
} }
@@ -671,43 +570,6 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
// resetState clears all firewall rules and closes connection trackers.
// Must be called with m.mutex held.
func (m *Manager) resetState() {
clear(m.outgoingRules)
clear(m.incomingDenyRules)
clear(m.incomingRules)
clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0]
m.udpHookOut.Store(nil)
m.tcpHookOut.Store(nil)
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.SetCapture(nil)
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
}
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic. // SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error { func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
@@ -738,7 +600,11 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
} }
destinations := matches[0].destinations destinations := matches[0].destinations
destinations = append(destinations, prefixes...) for _, prefix := range prefixes {
if prefix.Addr().Is4() {
destinations = append(destinations, prefix)
}
}
slices.SortFunc(destinations, func(a, b netip.Prefix) int { slices.SortFunc(destinations, func(a, b netip.Prefix) int {
cmp := a.Addr().Compare(b.Addr()) cmp := a.Addr().Compare(b.Addr())
@@ -777,7 +643,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder) d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d) defer m.decoders.Put(d)
if err := d.decodePacket(packetData); err != nil { if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
return false return false
} }
@@ -797,9 +663,6 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
return true return true
} }
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
if m.tcpHooksDrop(uint16(d.tcp.DstPort), dstIP, packetData) {
return true
}
// Clamp MSS on all TCP SYN packets, including those from local IPs. // Clamp MSS on all TCP SYN packets, including those from local IPs.
// SNATed routed traffic may appear as local IP but still requires clamping. // SNATed routed traffic may appear as local IP but still requires clamping.
if m.mssClampEnabled { if m.mssClampEnabled {
@@ -861,32 +724,12 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
return false return false
} }
var mssClampValue uint16
var ipHeaderSize int
switch d.decoded[0] {
case layers.LayerTypeIPv4:
mssClampValue = m.mssClampValueIPv4
ipHeaderSize = int(d.ip4.IHL) * 4
if ipHeaderSize < 20 {
return false
}
case layers.LayerTypeIPv6:
mssClampValue = m.mssClampValueIPv6
ipHeaderSize = 40
default:
return false
}
if mssClampValue == 0 {
return false
}
mssOptionIndex := -1 mssOptionIndex := -1
var currentMSS uint16 var currentMSS uint16
for i, opt := range d.tcp.Options { for i, opt := range d.tcp.Options {
if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 { if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 {
currentMSS = binary.BigEndian.Uint16(opt.OptionData) currentMSS = binary.BigEndian.Uint16(opt.OptionData)
if currentMSS > mssClampValue { if currentMSS > m.mssClampValue {
mssOptionIndex = i mssOptionIndex = i
break break
} }
@@ -897,15 +740,20 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
return false return false
} }
if !m.updateMSSOption(packetData, d, mssOptionIndex, mssClampValue, ipHeaderSize) { ipHeaderSize := int(d.ip4.IHL) * 4
if ipHeaderSize < 20 {
return false return false
} }
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue) if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) {
return false
}
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
return true return true
} }
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex int, mssClampValue uint16, ipHeaderSize int) bool { func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool {
tcpHeaderStart := ipHeaderSize tcpHeaderStart := ipHeaderSize
tcpOptionsStart := tcpHeaderStart + 20 tcpOptionsStart := tcpHeaderStart + 20
@@ -920,7 +768,7 @@ func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex
} }
mssValueOffset := optOffset + 2 mssValueOffset := optOffset + 2
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], mssClampValue) binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue)
m.recalculateTCPChecksum(packetData, d, tcpHeaderStart) m.recalculateTCPChecksum(packetData, d, tcpHeaderStart)
return true return true
@@ -930,32 +778,18 @@ func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeade
tcpLayer := packetData[tcpHeaderStart:] tcpLayer := packetData[tcpHeaderStart:]
tcpLength := len(packetData) - tcpHeaderStart tcpLength := len(packetData) - tcpHeaderStart
// Zero out existing checksum
tcpLayer[16] = 0 tcpLayer[16] = 0
tcpLayer[17] = 0 tcpLayer[17] = 0
// Build pseudo-header checksum based on IP version
var pseudoSum uint32 var pseudoSum uint32
switch d.decoded[0] { pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
case layers.LayerTypeIPv4: pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1]) pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3]) pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1]) pseudoSum += uint32(d.ip4.Protocol)
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3]) pseudoSum += uint32(tcpLength)
pseudoSum += uint32(d.ip4.Protocol)
pseudoSum += uint32(tcpLength)
case layers.LayerTypeIPv6:
for i := 0; i < 16; i += 2 {
pseudoSum += uint32(d.ip6.SrcIP[i])<<8 | uint32(d.ip6.SrcIP[i+1])
}
for i := 0; i < 16; i += 2 {
pseudoSum += uint32(d.ip6.DstIP[i])<<8 | uint32(d.ip6.DstIP[i+1])
}
pseudoSum += uint32(tcpLength)
pseudoSum += uint32(layers.IPProtocolTCP)
}
sum := pseudoSum var sum = pseudoSum
for i := 0; i < tcpLength-1; i += 2 { for i := 0; i < tcpLength-1; i += 2 {
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1]) sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
} }
@@ -993,9 +827,6 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData
} }
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size) m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
case layers.LayerTypeICMPv6:
id, tc := icmpv6EchoFields(d)
m.icmpTracker.TrackOutbound(srcIP, dstIP, id, tc, d.icmp6.Payload, size)
} }
} }
@@ -1009,20 +840,44 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort) m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size) m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
case layers.LayerTypeICMPv6:
id, tc := icmpv6EchoFields(d)
m.icmpTracker.TrackInbound(srcIP, dstIP, id, tc, ruleID, d.icmp6.Payload, size)
} }
d.dnatOrigPort = 0 d.dnatOrigPort = 0
} }
// udpHooksDrop checks if any UDP hooks should drop the packet
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool { func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
return common.HookMatches(m.udpHookOut.Load(), dstIP, dport, packetData) m.mutex.RLock()
} defer m.mutex.RUnlock()
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool { // Check specific destination IP first
return common.HookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData) if rules, exists := m.outgoingRules[dstIP]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
// Check IPv4 unspecified address
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
// Check IPv6 unspecified address
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
return false
} }
// filterInbound implements filtering logic for incoming packets. // filterInbound implements filtering logic for incoming packets.
@@ -1044,19 +899,15 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
// TODO: pass fragments of routed packets to forwarder // TODO: pass fragments of routed packets to forwarder
if fragment { if fragment {
if d.decoded[0] == layers.LayerTypeIPv4 { m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v", srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
} else {
m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP)
}
return false return false
} }
// TODO: optimize port DNAT by caching matched rules in conntrack // TODO: optimize port DNAT by caching matched rules in conntrack
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated { if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
// Re-decode after port DNAT translation to update port information // Re-decode after port DNAT translation to update port information
if err := d.decodePacket(packetData); err != nil { if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err) m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
return true return true
} }
@@ -1065,7 +916,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
if translated := m.translateInboundReverse(packetData, d); translated { if translated := m.translateInboundReverse(packetData, d); translated {
// Re-decode after translation to get original addresses // Re-decode after translation to get original addresses
if err := d.decodePacket(packetData); err != nil { if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err) m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err)
return true return true
} }
@@ -1197,48 +1048,6 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
return true return true
} }
// icmpv6EchoFields extracts the echo identifier from an ICMPv6 packet and maps
// the ICMPv6 type code to an ICMPv4TypeCode so the ICMP conntrack can handle
// both families uniformly. The echo ID is in the first two payload bytes.
func icmpv6EchoFields(d *decoder) (id uint16, tc layers.ICMPv4TypeCode) {
if len(d.icmp6.Payload) >= 2 {
id = uint16(d.icmp6.Payload[0])<<8 | uint16(d.icmp6.Payload[1])
}
// Map ICMPv6 echo types to ICMPv4 equivalents for unified tracking.
switch d.icmp6.TypeCode.Type() {
case layers.ICMPv6TypeEchoRequest:
tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0)
case layers.ICMPv6TypeEchoReply:
tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoReply, 0)
default:
tc = layers.CreateICMPv4TypeCode(d.icmp6.TypeCode.Type(), d.icmp6.TypeCode.Code())
}
return id, tc
}
// protoLayerMatches checks if a packet's protocol layer matches a rule's expected
// protocol layer. ICMPv4 and ICMPv6 are treated as equivalent when matching
// ICMP rules since management sends a single ICMP rule for both families.
func protoLayerMatches(ruleLayer, packetLayer gopacket.LayerType) bool {
if ruleLayer == packetLayer {
return true
}
if ruleLayer == layers.LayerTypeICMPv4 && packetLayer == layers.LayerTypeICMPv6 {
return true
}
if ruleLayer == layers.LayerTypeICMPv6 && packetLayer == layers.LayerTypeICMPv4 {
return true
}
return false
}
func ipLayerFromPrefix(p netip.Prefix) gopacket.LayerType {
if p.Addr().Is6() {
return layers.LayerTypeIPv6
}
return layers.LayerTypeIPv4
}
func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType { func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType {
switch proto { switch proto {
case firewall.ProtocolTCP: case firewall.ProtocolTCP:
@@ -1262,10 +1071,8 @@ func getProtocolFromPacket(d *decoder) nftypes.Protocol {
return nftypes.TCP return nftypes.TCP
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
return nftypes.UDP return nftypes.UDP
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return nftypes.ICMP return nftypes.ICMP
case layers.LayerTypeICMPv6:
return nftypes.ICMPv6
default: default:
return nftypes.ProtocolUnknown return nftypes.ProtocolUnknown
} }
@@ -1286,7 +1093,7 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
// It returns true, false if the packet is valid and not a fragment. // It returns true, false if the packet is valid and not a fragment.
// It returns true, true if the packet is a fragment and valid. // It returns true, true if the packet is a fragment and valid.
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) { func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
if err := d.decodePacket(packetData); err != nil { if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Trace1("couldn't decode packet, err: %s", err) m.logger.Trace1("couldn't decode packet, err: %s", err)
return false, false return false, false
} }
@@ -1299,21 +1106,10 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
} }
// Fragments are also valid // Fragments are also valid
if l == 1 { if l == 1 && d.decoded[0] == layers.LayerTypeIPv4 {
switch d.decoded[0] { ip4 := d.ip4
case layers.LayerTypeIPv4: if ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0 {
if d.ip4.Flags&layers.IPv4MoreFragments != 0 || d.ip4.FragOffset != 0 { return true, true
return true, true
}
case layers.LayerTypeIPv6:
// IPv6 uses Fragment extension header (NextHeader=44). If gopacket
// only decoded the IPv6 layer, the transport is in a fragment.
// TODO: handle non-Fragment extension headers (HopByHop, Routing,
// DestOpts) by walking the chain. gopacket's parser does not
// support them as DecodingLayers; today we drop such packets.
if d.ip6.NextHeader == layers.IPProtocolIPv6Fragment {
return true, true
}
} }
} }
@@ -1351,35 +1147,21 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr,
size, size,
) )
case layers.LayerTypeICMPv6: // TODO: ICMPv6
id, _ := icmpv6EchoFields(d)
return m.icmpTracker.IsValidInbound(
srcIP,
dstIP,
id,
d.icmp6.TypeCode.Type(),
size,
)
} }
return false return false
} }
// isSpecialICMP returns true if the packet is a special ICMP error packet that should be allowed. // isSpecialICMP returns true if the packet is a special ICMP packet that should be allowed
func (m *Manager) isSpecialICMP(d *decoder) bool { func (m *Manager) isSpecialICMP(d *decoder) bool {
switch d.decoded[1] { if d.decoded[1] != layers.LayerTypeICMPv4 {
case layers.LayerTypeICMPv4: return false
icmpType := d.icmp4.TypeCode.Type()
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
icmpType == layers.ICMPv4TypeTimeExceeded
case layers.LayerTypeICMPv6:
icmpType := d.icmp6.TypeCode.Type()
return icmpType == layers.ICMPv6TypeDestinationUnreachable ||
icmpType == layers.ICMPv6TypePacketTooBig ||
icmpType == layers.ICMPv6TypeTimeExceeded ||
icmpType == layers.ICMPv6TypeParameterProblem
} }
return false
icmpType := d.icmp4.TypeCode.Type()
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
icmpType == layers.ICMPv4TypeTimeExceeded
} }
func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte) ([]byte, bool) { func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte) ([]byte, bool) {
@@ -1436,7 +1218,7 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
return rule.mgmtId, rule.drop, true return rule.mgmtId, rule.drop, true
} }
if !protoLayerMatches(rule.protoLayer, payloadLayer) { if payloadLayer != rule.protoLayer {
continue continue
} }
@@ -1446,6 +1228,12 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
return rule.mgmtId, rule.drop, true return rule.mgmtId, rule.drop, true
} }
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
// if rule has UDP hook (and if we are here we match this rule)
// we ignore rule.drop and call this hook
if rule.udpHook != nil {
return rule.mgmtId, rule.udpHook(packetData), true
}
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) { if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
return rule.mgmtId, rule.drop, true return rule.mgmtId, rule.drop, true
} }
@@ -1471,7 +1259,8 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.Lay
} }
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool { func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool {
if rule.protoLayer != layerTypeAll && !protoLayerMatches(rule.protoLayer, protoLayer) { // TODO: handle ipv6 vs ipv4 icmp rules
if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer {
return false return false
} }
@@ -1503,14 +1292,65 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
return sourceMatched return sourceMatched
} }
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove. // AddUDPPacketHook calls hook when UDP packet from given direction matched
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) { //
common.SetHook(&m.udpHookOut, ip, dPort, hook) // Hook function returns flag which indicates should be the matched package dropped or not
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
r := PeerRule{
id: uuid.New().String(),
ip: ip,
protoLayer: layers.LayerTypeUDP,
dPort: &firewall.Port{Values: []uint16{dPort}},
ipLayer: layers.LayerTypeIPv6,
udpHook: hook,
}
if ip.Is4() {
r.ipLayer = layers.LayerTypeIPv4
}
m.mutex.Lock()
if in {
// Incoming UDP hooks are stored in allow rules map
if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip] = make(map[string]PeerRule)
}
m.incomingRules[r.ip][r.id] = r
} else {
if _, ok := m.outgoingRules[r.ip]; !ok {
m.outgoingRules[r.ip] = make(map[string]PeerRule)
}
m.outgoingRules[r.ip][r.id] = r
}
m.mutex.Unlock()
return r.id
} }
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove. // RemovePacketHook removes packet hook by given ID
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) { func (m *Manager) RemovePacketHook(hookID string) error {
common.SetHook(&m.tcpHookOut, ip, dPort, hook) m.mutex.Lock()
defer m.mutex.Unlock()
// Check incoming hooks (stored in allow rules)
for _, arr := range m.incomingRules {
for _, r := range arr {
if r.id == hookID {
delete(arr, r.id)
return nil
}
}
}
// Check outgoing hooks
for _, arr := range m.outgoingRules {
for _, r := range arr {
if r.id == hookID {
delete(arr, r.id)
return nil
}
}
}
return fmt.Errorf("hook with given id not found")
} }
// SetLogLevel sets the log level for the firewall manager // SetLogLevel sets the log level for the firewall manager
@@ -1532,14 +1372,13 @@ func (m *Manager) EnableRouting() error {
return nil return nil
} }
rules, err := m.blockInvalidRouted(m.wgIface) rule, err := m.blockInvalidRouted(m.wgIface)
// Persist whatever was installed even on partial failure, so DisableRouting
// can clean it up later.
m.blockRules = rules
if err != nil { if err != nil {
return fmt.Errorf("block invalid routed: %w", err) return fmt.Errorf("block invalid routed: %w", err)
} }
m.blockRule = rule
return nil return nil
} }
@@ -1555,16 +1394,9 @@ func (m *Manager) DisableRouting() error {
m.routingEnabled.Store(false) m.routingEnabled.Store(false)
m.nativeRouter.Store(false) m.nativeRouter.Store(false)
var merr *multierror.Error // don't stop forwarder if in use by netstack
for _, rule := range m.blockRules {
if err := m.deleteRouteRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete block rule: %w", err))
}
}
m.blockRules = nil
if m.netstack && m.localForwarding { if m.netstack && m.localForwarding {
return nberrors.FormatErrorOrNil(merr) return nil
} }
fwder.Stop() fwder.Stop()
@@ -1572,7 +1404,14 @@ func (m *Manager) DisableRouting() error {
log.Debug("forwarder stopped") log.Debug("forwarder stopped")
return nberrors.FormatErrorOrNil(merr) if m.blockRule != nil {
if err := m.deleteRouteRule(m.blockRule); err != nil {
return fmt.Errorf("delete block rule: %w", err)
}
m.blockRule = nil
}
return nil
} }
// RegisterNetstackService registers a service as listening on the netstack for the given protocol and port // RegisterNetstackService registers a service as listening on the netstack for the given protocol and port
@@ -1626,8 +1465,7 @@ func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool {
} }
// traffic to our other local interfaces (not NetBird IP) - always forward // traffic to our other local interfaces (not NetBird IP) - always forward
addr := m.wgIface.Address() if dstIP != m.wgIface.Address().IP {
if dstIP != addr.IP && (!addr.IPv6.IsValid() || dstIP != addr.IPv6) {
return true return true
} }

View File

@@ -1023,8 +1023,7 @@ func BenchmarkMSSClamping(b *testing.B) {
}() }()
manager.mssClampEnabled = true manager.mssClampEnabled = true
manager.mssClampValueIPv4 = 1240 manager.mssClampValue = 1240
manager.mssClampValueIPv6 = 1220
srcIP := net.ParseIP("100.64.0.2") srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8") dstIP := net.ParseIP("8.8.8.8")
@@ -1089,8 +1088,7 @@ func BenchmarkMSSClampingOverhead(b *testing.B) {
manager.mssClampEnabled = sc.enabled manager.mssClampEnabled = sc.enabled
if sc.enabled { if sc.enabled {
manager.mssClampValueIPv4 = 1240 manager.mssClampValue = 1240
manager.mssClampValueIPv6 = 1220
} }
srcIP := net.ParseIP("100.64.0.2") srcIP := net.ParseIP("100.64.0.2")
@@ -1143,8 +1141,7 @@ func BenchmarkMSSClampingMemory(b *testing.B) {
}() }()
manager.mssClampEnabled = true manager.mssClampEnabled = true
manager.mssClampValueIPv4 = 1240 manager.mssClampValue = 1240
manager.mssClampValueIPv6 = 1220
srcIP := net.ParseIP("100.64.0.2") srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8") dstIP := net.ParseIP("8.8.8.8")

Some files were not shown because too many files have changed in this diff Show More