mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-09 10:19:55 +00:00
Compare commits
1 Commits
debug-logs
...
vnc-server
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b754df1171 |
@@ -1,130 +0,0 @@
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
## Ideas & Feature Requests
|
||||
|
||||
Use this category for feature requests, enhancements, integrations, and product ideas.
|
||||
|
||||
NetBird uses community traction in discussions — upvotes, replies, affected users, and use-case detail — as an input when deciding what should become a maintainer-curated issue or roadmap item. A clear problem statement is more useful than a solution-only request.
|
||||
|
||||
Please search first and add your use case to an existing discussion when one already exists.
|
||||
|
||||
- type: checkboxes
|
||||
id: preflight
|
||||
attributes:
|
||||
label: Before posting
|
||||
options:
|
||||
- label: I searched existing discussions and issues for similar requests.
|
||||
required: true
|
||||
- label: I checked the documentation to confirm this is not already supported.
|
||||
required: true
|
||||
- label: This is a product idea or enhancement request, not a support question.
|
||||
required: true
|
||||
- label: I removed or anonymized sensitive details from examples and screenshots.
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: area
|
||||
attributes:
|
||||
label: Product area
|
||||
description: Select every area this request touches.
|
||||
multiple: true
|
||||
options:
|
||||
- Client / Agent
|
||||
- CLI
|
||||
- Desktop UI
|
||||
- Mobile app
|
||||
- Dashboard / Admin UI
|
||||
- Management service / API
|
||||
- Signal service
|
||||
- Relay
|
||||
- DNS
|
||||
- Routes / Exit nodes
|
||||
- NetBird SSH
|
||||
- Access control policies
|
||||
- Posture checks
|
||||
- Identity provider / SSO
|
||||
- Self-hosting / Deployment
|
||||
- Kubernetes / Operator
|
||||
- Terraform / Automation
|
||||
- Documentation
|
||||
- Other / not sure
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: problem
|
||||
attributes:
|
||||
label: Problem or use case
|
||||
description: What are you trying to accomplish, and what is difficult or impossible today?
|
||||
placeholder: |
|
||||
As a ...
|
||||
I want to ...
|
||||
Because ...
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: proposal
|
||||
attributes:
|
||||
label: Proposed solution
|
||||
description: Describe the behavior, workflow, API, UI, or integration you would like to see.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: alternatives
|
||||
attributes:
|
||||
label: Alternatives or workarounds considered
|
||||
description: What have you tried today? Why is the current workaround not enough?
|
||||
|
||||
- type: textarea
|
||||
id: impact
|
||||
attributes:
|
||||
label: Community impact and priority
|
||||
description: Help us understand who benefits and how urgent this is.
|
||||
placeholder: |
|
||||
- Number of users/teams/peers affected:
|
||||
- Deployment type: Cloud / self-hosted / both
|
||||
- Frequency: daily / weekly / occasional
|
||||
- Blocking production adoption? yes/no
|
||||
- Related comments, discussions, or customer requests:
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: examples
|
||||
attributes:
|
||||
label: Examples from other tools or products
|
||||
description: If another tool solves this well, link or describe the behavior.
|
||||
|
||||
- type: textarea
|
||||
id: security
|
||||
attributes:
|
||||
label: Security, privacy, and compatibility considerations
|
||||
description: Note any access-control, audit, data retention, network, platform, or backward-compatibility concerns.
|
||||
|
||||
- type: textarea
|
||||
id: implementation
|
||||
attributes:
|
||||
label: Implementation ideas
|
||||
description: Optional. If you are familiar with the codebase or API, share possible implementation notes.
|
||||
|
||||
- type: dropdown
|
||||
id: contribution
|
||||
attributes:
|
||||
label: Are you willing to help?
|
||||
options:
|
||||
- Yes, I can submit a PR if the approach is accepted.
|
||||
- Yes, I can test or validate a proposed implementation.
|
||||
- Yes, I can provide more use-case details.
|
||||
- Not at this time.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: additional-context
|
||||
attributes:
|
||||
label: Additional context
|
||||
description: Add screenshots, diagrams, links, or anything else that helps explain the request.
|
||||
237
.github/DISCUSSION_TEMPLATE/issue-triage.yml
vendored
237
.github/DISCUSSION_TEMPLATE/issue-triage.yml
vendored
@@ -1,237 +0,0 @@
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
## Issue Triage
|
||||
|
||||
Use this category for reproducible bugs and regressions in NetBird.
|
||||
|
||||
The more context you include, the faster we can validate and act on your report. If you're not sure whether something is a bug, **Q&A / Support** is a good starting point — we can always move the conversation here once we've confirmed it's a product issue.
|
||||
|
||||
Intermittent issues are useful too. Include the trigger, frequency, timing, and any logs or debug evidence you have, and we'll work from there.
|
||||
|
||||
Please don't include secrets, tokens, private keys, internal hostnames, or public IPs. Security vulnerabilities should be reported through the repository security policy rather than a public discussion.
|
||||
|
||||
- type: checkboxes
|
||||
id: preflight
|
||||
attributes:
|
||||
label: Before posting
|
||||
options:
|
||||
- label: I searched existing discussions and issues, including closed ones, and checked the relevant docs.
|
||||
required: true
|
||||
- label: I believe this is a product bug rather than a configuration or setup question.
|
||||
required: true
|
||||
- label: I can reproduce this issue, or for intermittent issues I've included trigger, frequency, and timing details below.
|
||||
required: true
|
||||
- label: I removed or anonymized sensitive data from logs, screenshots, and configuration.
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: area
|
||||
attributes:
|
||||
label: Affected area
|
||||
description: Select every area this report touches.
|
||||
multiple: true
|
||||
options:
|
||||
- Client / Agent
|
||||
- Reverse Proxy
|
||||
- CLI
|
||||
- Desktop UI
|
||||
- Mobile app
|
||||
- Peer connectivity
|
||||
- DNS
|
||||
- Routes / Exit nodes
|
||||
- NetBird SSH
|
||||
- Relay / Signal / NAT traversal
|
||||
- Login / Authentication / IdP
|
||||
- Dashboard / Admin UI
|
||||
- Management service / API
|
||||
- Access control policies / Posture checks
|
||||
- Self-hosting / Deployment
|
||||
- Kubernetes / Operator
|
||||
- Documentation
|
||||
- Other / not sure
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: deployment
|
||||
attributes:
|
||||
label: Deployment type
|
||||
options:
|
||||
- NetBird Cloud
|
||||
- Self-hosted - quickstart script
|
||||
- Self-hosted - advanced/custom deployment
|
||||
- Local development build
|
||||
- Not sure / environment I do not fully control
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: platform
|
||||
attributes:
|
||||
label: Operating system or environment
|
||||
description: Select every environment involved in the reproduction.
|
||||
multiple: true
|
||||
options:
|
||||
- Linux
|
||||
- macOS
|
||||
- Windows
|
||||
- Android
|
||||
- iOS
|
||||
- FreeBSD
|
||||
- OpenWRT
|
||||
- Docker
|
||||
- Kubernetes
|
||||
- Synology
|
||||
- Browser
|
||||
- Other / not sure
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: version
|
||||
attributes:
|
||||
label: NetBird version and upgrade status
|
||||
description: Run `netbird version` where applicable. For self-hosted deployments, include management, signal, relay, and dashboard versions if available. If you cannot test on a current/supported version, explain why.
|
||||
placeholder: |
|
||||
Example:
|
||||
- Client: 0.30.2
|
||||
- Management: 0.30.2
|
||||
- Signal: 0.30.2
|
||||
- Relay: 0.30.2
|
||||
- Dashboard: 0.30.2
|
||||
- Upgrade status: reproduced on current version / cannot upgrade because ...
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: regression
|
||||
attributes:
|
||||
label: Did this work before?
|
||||
options:
|
||||
- Yes, this worked before
|
||||
- No, this never worked
|
||||
- Not sure
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: regression-details
|
||||
attributes:
|
||||
label: Regression details
|
||||
description: If this worked before, include the last known working version, first known broken version, and any recent upgrade, configuration, network, or IdP changes.
|
||||
placeholder: |
|
||||
- Last known working version:
|
||||
- First known broken version:
|
||||
- Recent changes:
|
||||
|
||||
- type: textarea
|
||||
id: summary
|
||||
attributes:
|
||||
label: Summary
|
||||
description: Briefly describe the reproducible bug.
|
||||
placeholder: What is broken?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: current-behavior
|
||||
attributes:
|
||||
label: Current behavior
|
||||
description: What happens now? Include exact errors, timeouts, UI messages, or failed commands when possible.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
attributes:
|
||||
label: Expected behavior
|
||||
description: What did you expect to happen instead?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
attributes:
|
||||
label: Steps to reproduce
|
||||
description: Provide the smallest set of steps that reliably reproduces the bug. If the issue is intermittent, include the trigger, frequency, timing, and relevant timestamps.
|
||||
placeholder: |
|
||||
1. Configure ...
|
||||
2. Run ...
|
||||
3. Observe ...
|
||||
|
||||
For intermittent issues:
|
||||
- Trigger:
|
||||
- Frequency:
|
||||
- Timing/timestamps:
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: environment
|
||||
attributes:
|
||||
label: Environment and topology
|
||||
description: Include the relevant topology and software involved in the reproduction. For UI/docs-only reports, write `N/A` if this does not apply. Use `None`, `Unknown`, or `N/A` where appropriate.
|
||||
placeholder: |
|
||||
- Peer A:
|
||||
- Peer B:
|
||||
- Same LAN or different networks:
|
||||
- NAT/CGNAT/corporate firewall/mobile network:
|
||||
- Other VPN software:
|
||||
- Firewall, DNS, or endpoint security software:
|
||||
- Routes, DNS, policies, posture checks, or SSH rules involved:
|
||||
- IdP, reverse proxy, or browser involved:
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: self-hosted-details
|
||||
attributes:
|
||||
label: Self-hosted details, if available
|
||||
description: Optional. If you use self-hosting and have access to these details, include them. If you do not administer the environment, provide what you know and say what you cannot access.
|
||||
placeholder: |
|
||||
- Deployment method: quickstart / Docker Compose / Helm / operator / custom
|
||||
- Management/signal/relay/dashboard versions:
|
||||
- Reverse proxy:
|
||||
- IdP/provider:
|
||||
- STUN/TURN/coturn/relay details:
|
||||
- Relevant component logs:
|
||||
|
||||
- type: textarea
|
||||
id: logs
|
||||
attributes:
|
||||
label: Logs, status output, or debug evidence
|
||||
description: |
|
||||
For client, connectivity, DNS, route, relay/signal, or self-hosted reports, logs are essential — please include anonymized output from `netbird status -dA`, or a debug bundle via `netbird debug for 1m -AS -U`. Debug bundles are automatically deleted after 30 days.
|
||||
|
||||
For UI, dashboard, or documentation reports, leave the pre-filled `N/A`.
|
||||
value: "N/A"
|
||||
render: shell
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: related-reports
|
||||
attributes:
|
||||
label: Related issues or discussions
|
||||
description: Optional. Link similar reports you found while searching, if any.
|
||||
placeholder: |
|
||||
- Related issue/discussion:
|
||||
- Why this may be the same or different:
|
||||
|
||||
- type: textarea
|
||||
id: impact
|
||||
attributes:
|
||||
label: Impact
|
||||
description: Optional. Help us understand priority. How many users, peers, environments, or workflows are affected? Is there a workaround?
|
||||
placeholder: |
|
||||
- Affected users/peers:
|
||||
- Business or production impact:
|
||||
- Workaround available:
|
||||
|
||||
- type: textarea
|
||||
id: additional-context
|
||||
attributes:
|
||||
label: Additional context
|
||||
description: Add links to related discussions, issues, docs, screenshots, recordings, or anything else that may help validation.
|
||||
146
.github/DISCUSSION_TEMPLATE/q-a-support.yml
vendored
146
.github/DISCUSSION_TEMPLATE/q-a-support.yml
vendored
@@ -1,146 +0,0 @@
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
## Q&A / Support
|
||||
|
||||
Use this category for questions about configuration, setup, self-hosted deployments, troubleshooting, and general NetBird usage.
|
||||
|
||||
This is community support and does not provide an SLA. For NetBird Cloud support, use the official support channel linked from the issue creation page. Please do not post secrets, tokens, private keys, internal hostnames, or public IPs unless you intentionally want them public.
|
||||
|
||||
If your question turns into a reproducible product defect, DevRel or a maintainer may ask you to open or move the conversation to Issue Triage.
|
||||
|
||||
- type: checkboxes
|
||||
id: preflight
|
||||
attributes:
|
||||
label: Before posting
|
||||
options:
|
||||
- label: I searched existing discussions and issues for similar questions.
|
||||
required: true
|
||||
- label: I reviewed the relevant NetBird documentation or troubleshooting guide.
|
||||
required: true
|
||||
- label: I removed or anonymized sensitive data from logs, screenshots, and configuration.
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: topic
|
||||
attributes:
|
||||
label: Topic
|
||||
multiple: true
|
||||
options:
|
||||
- Getting started
|
||||
- Self-hosting
|
||||
- Client / Agent
|
||||
- CLI
|
||||
- Desktop UI
|
||||
- Mobile app
|
||||
- Dashboard / Admin UI
|
||||
- DNS
|
||||
- Routes / Exit nodes
|
||||
- NetBird SSH
|
||||
- Relay
|
||||
- Access control policies
|
||||
- Posture checks
|
||||
- Identity provider / SSO
|
||||
- API
|
||||
- Kubernetes / Operator
|
||||
- Terraform / Automation
|
||||
- Documentation
|
||||
- Other / not sure
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: deployment
|
||||
attributes:
|
||||
label: Deployment type
|
||||
options:
|
||||
- NetBird Cloud
|
||||
- Self-hosted - quickstart script
|
||||
- Self-hosted - advanced/custom deployment
|
||||
- Local development build
|
||||
- Not sure
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: platform
|
||||
attributes:
|
||||
label: Operating system or environment
|
||||
multiple: true
|
||||
options:
|
||||
- Linux
|
||||
- macOS
|
||||
- Windows
|
||||
- Android
|
||||
- iOS
|
||||
- FreeBSD
|
||||
- OpenWRT
|
||||
- Docker
|
||||
- Kubernetes
|
||||
- Synology
|
||||
- Browser
|
||||
- Other / not sure
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: version
|
||||
attributes:
|
||||
label: NetBird version
|
||||
description: Run `netbird version` where applicable. For self-hosted deployments, include component versions if relevant.
|
||||
placeholder: "Example: client 0.30.2, management 0.30.2"
|
||||
|
||||
- type: textarea
|
||||
id: question
|
||||
attributes:
|
||||
label: Question
|
||||
description: What are you trying to understand or accomplish?
|
||||
placeholder: Describe your question clearly.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: goal
|
||||
attributes:
|
||||
label: Desired outcome
|
||||
description: What would a successful answer help you do?
|
||||
placeholder: |
|
||||
I want to configure ...
|
||||
I expected ...
|
||||
I need help deciding ...
|
||||
|
||||
- type: textarea
|
||||
id: attempted
|
||||
attributes:
|
||||
label: What have you tried?
|
||||
description: Include commands, documentation links, configuration attempts, or troubleshooting steps already tried.
|
||||
placeholder: |
|
||||
- Read ...
|
||||
- Ran ...
|
||||
- Changed ...
|
||||
- Observed ...
|
||||
|
||||
- type: textarea
|
||||
id: environment
|
||||
attributes:
|
||||
label: Relevant environment details
|
||||
description: Include redacted topology, IdP/provider, reverse proxy, firewall, DNS, route, policy, or self-hosted setup details that may affect the answer.
|
||||
placeholder: |
|
||||
- Deployment:
|
||||
- Components involved:
|
||||
- Network/topology:
|
||||
- Related config:
|
||||
|
||||
- type: textarea
|
||||
id: logs
|
||||
attributes:
|
||||
label: Logs or output
|
||||
description: Optional. Include anonymized logs, command output, screenshots, or `netbird status -dA` if relevant.
|
||||
render: shell
|
||||
|
||||
- type: textarea
|
||||
id: additional-context
|
||||
attributes:
|
||||
label: Additional context
|
||||
description: Add links, diagrams, screenshots, or other details that may help the community answer.
|
||||
71
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
Normal file
71
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
Normal file
@@ -0,0 +1,71 @@
|
||||
---
|
||||
name: Bug/Issue report
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: ['triage-needed']
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the problem**
|
||||
|
||||
A clear and concise description of what the problem is.
|
||||
|
||||
**To Reproduce**
|
||||
|
||||
Steps to reproduce the behavior:
|
||||
1. Go to '...'
|
||||
2. Click on '....'
|
||||
3. Scroll down to '....'
|
||||
4. See error
|
||||
|
||||
**Expected behavior**
|
||||
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Are you using NetBird Cloud?**
|
||||
|
||||
Please specify whether you use NetBird Cloud or self-host NetBird's control plane.
|
||||
|
||||
**NetBird version**
|
||||
|
||||
`netbird version`
|
||||
|
||||
**Is any other VPN software installed?**
|
||||
|
||||
If yes, which one?
|
||||
|
||||
**Debug output**
|
||||
|
||||
To help us resolve the problem, please attach the following anonymized status output
|
||||
|
||||
netbird status -dA
|
||||
|
||||
Create and upload a debug bundle, and share the returned file key:
|
||||
|
||||
netbird debug for 1m -AS -U
|
||||
|
||||
*Uploaded files are automatically deleted after 30 days.*
|
||||
|
||||
|
||||
Alternatively, create the file only and attach it here manually:
|
||||
|
||||
netbird debug for 1m -AS
|
||||
|
||||
|
||||
**Screenshots**
|
||||
|
||||
If applicable, add screenshots to help explain your problem.
|
||||
|
||||
**Additional context**
|
||||
|
||||
Add any other context about the problem here.
|
||||
|
||||
**Have you tried these troubleshooting steps?**
|
||||
- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable)
|
||||
- [ ] Checked for newer NetBird versions
|
||||
- [ ] Searched for similar issues on GitHub (including closed ones)
|
||||
- [ ] Restarted the NetBird client
|
||||
- [ ] Disabled other VPN software
|
||||
- [ ] Checked firewall settings
|
||||
|
||||
26
.github/ISSUE_TEMPLATE/config.yml
vendored
26
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,26 +1,14 @@
|
||||
blank_issues_enabled: false
|
||||
blank_issues_enabled: true
|
||||
contact_links:
|
||||
- name: Start an Issue Triage discussion
|
||||
url: https://github.com/netbirdio/netbird/discussions/new?category=issue-triage
|
||||
about: Report a bug, regression, or unexpected behavior so DevRel can validate it before it becomes an issue.
|
||||
- name: Propose an idea or feature request
|
||||
url: https://github.com/netbirdio/netbird/discussions/new?category=ideas-feature-requests
|
||||
about: Share feature requests, enhancements, and integration ideas for community feedback and prioritization.
|
||||
- name: Ask a Q&A / Support question
|
||||
url: https://github.com/netbirdio/netbird/discussions/new?category=q-a-support
|
||||
about: Get help with setup, configuration, self-hosting, troubleshooting, and general usage.
|
||||
- name: Security vulnerability disclosure
|
||||
url: https://github.com/netbirdio/netbird/security/policy
|
||||
about: Please do not report security vulnerabilities in public issues or discussions.
|
||||
- name: Community Support Forum
|
||||
- name: Community Support
|
||||
url: https://forum.netbird.io/
|
||||
about: Community support forum.
|
||||
about: Community support forum
|
||||
- name: Cloud Support
|
||||
url: https://docs.netbird.io/help/report-bug-issues
|
||||
about: Contact NetBird for Cloud support.
|
||||
- name: Client / Connection Troubleshooting
|
||||
about: Contact us for support
|
||||
- name: Client/Connection Troubleshooting
|
||||
url: https://docs.netbird.io/help/troubleshooting-client
|
||||
about: See the client troubleshooting guide for common connectivity issues.
|
||||
about: See our client troubleshooting guide for help addressing common issues
|
||||
- name: Self-host Troubleshooting
|
||||
url: https://docs.netbird.io/selfhosted/troubleshooting
|
||||
about: See the self-host troubleshooting guide for common deployment issues.
|
||||
about: See our self-host troubleshooting guide for help addressing common issues
|
||||
|
||||
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: ''
|
||||
labels: ['feature-request']
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context or screenshots about the feature request here.
|
||||
128
.github/ISSUE_TEMPLATE/validated_issue.yml
vendored
128
.github/ISSUE_TEMPLATE/validated_issue.yml
vendored
@@ -1,128 +0,0 @@
|
||||
name: Validated issue
|
||||
description: Maintainer/DevRel only. Create an issue after a discussion has been validated or for internally validated work.
|
||||
title: "[Validated]: "
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
## Discussion-first issue policy
|
||||
|
||||
Issues are maintainer-curated work items. Community reports and feature requests should start in [Discussions](https://github.com/netbirdio/netbird/discussions) so DevRel can validate, reproduce, and route them before engineering time is committed.
|
||||
|
||||
Use this form when:
|
||||
- A discussion has been validated and should become actionable work.
|
||||
- A maintainer is opening internally validated work that can bypass the discussion-first flow.
|
||||
|
||||
Issues opened without a relevant validated discussion or maintainer context may be closed and redirected to Discussions.
|
||||
|
||||
- type: checkboxes
|
||||
id: validation-checks
|
||||
attributes:
|
||||
label: Validation checklist
|
||||
options:
|
||||
- label: This issue is linked to a validated discussion, or it is being opened directly by a maintainer.
|
||||
required: true
|
||||
- label: The report has enough context for engineering to act on it without re-triaging from scratch.
|
||||
required: true
|
||||
- label: Sensitive data, secrets, private keys, internal hostnames, and public IPs have been removed or intentionally disclosed.
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: issue-type
|
||||
attributes:
|
||||
label: Issue type
|
||||
options:
|
||||
- Bug / Regression
|
||||
- Feature / Enhancement
|
||||
- Documentation
|
||||
- Maintenance / Refactor
|
||||
- Cross-repository coordination
|
||||
- Other
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: source-discussion
|
||||
attributes:
|
||||
label: Source discussion
|
||||
description: Link the GitHub Discussion that was validated. Maintainers bypassing the flow can write "Maintainer-created" and explain why below.
|
||||
placeholder: https://github.com/netbirdio/netbird/discussions/1234
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: validation-owner
|
||||
attributes:
|
||||
label: Validation owner
|
||||
description: GitHub handle of the DevRel team member or maintainer who validated this work.
|
||||
placeholder: "@username"
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: target-repository
|
||||
attributes:
|
||||
label: Target repository
|
||||
description: Where should the implementation work happen?
|
||||
options:
|
||||
- netbirdio/netbird
|
||||
- netbirdio/dashboard
|
||||
- netbirdio/kubernetes-operator
|
||||
- netbirdio/docs
|
||||
- Multiple repositories
|
||||
- Unknown / needs routing
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: summary
|
||||
attributes:
|
||||
label: Summary
|
||||
description: Concise description of the validated work.
|
||||
placeholder: What needs to be fixed, changed, documented, or built?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: evidence
|
||||
attributes:
|
||||
label: Validation evidence
|
||||
description: For bugs, include reproduction status, affected versions, logs, and environment. For features, include community traction, affected users, and alignment notes.
|
||||
placeholder: |
|
||||
- Reproduced by:
|
||||
- Affected versions / platforms:
|
||||
- Community signal:
|
||||
- Related logs or screenshots:
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: scope
|
||||
attributes:
|
||||
label: Proposed scope
|
||||
description: Describe what is in scope and, if helpful, what is explicitly out of scope.
|
||||
placeholder: |
|
||||
In scope:
|
||||
- ...
|
||||
|
||||
Out of scope:
|
||||
- ...
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: acceptance-criteria
|
||||
attributes:
|
||||
label: Acceptance criteria
|
||||
description: What must be true for this issue to be closed?
|
||||
placeholder: |
|
||||
- [ ] ...
|
||||
- [ ] ...
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: additional-context
|
||||
attributes:
|
||||
label: Additional context
|
||||
description: Links to related PRs, docs, issues in other repositories, roadmap items, or implementation notes.
|
||||
307
.github/workflows/release.yml
vendored
307
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.1.4"
|
||||
SIGN_PIPE_VER: "v0.1.2"
|
||||
GORELEASER_VER: "v2.14.3"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
@@ -114,13 +114,7 @@ jobs:
|
||||
retention-days: 30
|
||||
|
||||
release:
|
||||
runs-on: ubuntu-24.04-8-core
|
||||
outputs:
|
||||
release_artifact_url: ${{ steps.upload_release.outputs.artifact-url }}
|
||||
linux_packages_artifact_url: ${{ steps.upload_linux_packages.outputs.artifact-url }}
|
||||
windows_packages_artifact_url: ${{ steps.upload_windows_packages.outputs.artifact-url }}
|
||||
macos_packages_artifact_url: ${{ steps.upload_macos_packages.outputs.artifact-url }}
|
||||
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
|
||||
runs-on: ubuntu-latest-m
|
||||
env:
|
||||
flags: ""
|
||||
steps:
|
||||
@@ -219,13 +213,10 @@ jobs:
|
||||
if: always()
|
||||
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||
- name: Tag and push images (amd64 only)
|
||||
id: tag_and_push_images
|
||||
if: |
|
||||
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) ||
|
||||
(github.event_name == 'push' && github.ref == 'refs/heads/main')
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
resolve_tags() {
|
||||
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
|
||||
echo "pr-${{ github.event.pull_request.number }}"
|
||||
@@ -234,17 +225,6 @@ jobs:
|
||||
fi
|
||||
}
|
||||
|
||||
ghcr_package_url() {
|
||||
local image="$1" package encoded_package
|
||||
package="${image#ghcr.io/}"
|
||||
package="${package#*/}"
|
||||
package="${package%%:*}"
|
||||
encoded_package="${package//\//%2F}"
|
||||
echo "https://github.com/orgs/netbirdio/packages/container/package/${encoded_package}"
|
||||
}
|
||||
|
||||
image_refs=()
|
||||
|
||||
tag_and_push() {
|
||||
local src="$1" img_name tag dst
|
||||
img_name="${src%%:*}"
|
||||
@@ -253,56 +233,35 @@ jobs:
|
||||
echo "Tagging ${src} -> ${dst}"
|
||||
docker tag "$src" "$dst"
|
||||
docker push "$dst"
|
||||
image_refs+=("$dst")
|
||||
done
|
||||
}
|
||||
|
||||
cat > /tmp/goreleaser-artifacts.json <<'JSON'
|
||||
${{ steps.goreleaser.outputs.artifacts }}
|
||||
JSON
|
||||
export -f tag_and_push resolve_tags
|
||||
|
||||
mapfile -t src_images < <(
|
||||
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name | select(startswith("ghcr.io/"))' /tmp/goreleaser-artifacts.json
|
||||
)
|
||||
|
||||
for src in "${src_images[@]}"; do
|
||||
tag_and_push "$src"
|
||||
done
|
||||
|
||||
{
|
||||
echo "images_markdown<<EOF"
|
||||
if [[ ${#image_refs[@]} -eq 0 ]]; then
|
||||
echo "_No GHCR images were pushed._"
|
||||
else
|
||||
printf '%s\n' "${image_refs[@]}" | sort -u | while read -r image; do
|
||||
printf -- '- [`%s`](%s)\n' "$image" "$(ghcr_package_url "$image")"
|
||||
done
|
||||
fi
|
||||
echo "EOF"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
|
||||
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
|
||||
grep '^ghcr.io/' | while read -r SRC; do
|
||||
tag_and_push "$SRC"
|
||||
done
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release
|
||||
path: dist/
|
||||
retention-days: 7
|
||||
- name: upload linux packages
|
||||
id: upload_linux_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: linux-packages
|
||||
path: dist/netbird_linux**
|
||||
retention-days: 7
|
||||
- name: upload windows packages
|
||||
id: upload_windows_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: windows-packages
|
||||
path: dist/netbird_windows**
|
||||
retention-days: 7
|
||||
- name: upload macos packages
|
||||
id: upload_macos_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: macos-packages
|
||||
@@ -311,8 +270,6 @@ jobs:
|
||||
|
||||
release_ui:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
|
||||
steps:
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
@@ -403,7 +360,6 @@ jobs:
|
||||
if: always()
|
||||
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release_ui
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-ui
|
||||
@@ -412,8 +368,6 @@ jobs:
|
||||
|
||||
release_ui_darwin:
|
||||
runs-on: macos-latest
|
||||
outputs:
|
||||
release_ui_darwin_artifact_url: ${{ steps.upload_release_ui_darwin.outputs.artifact-url }}
|
||||
steps:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
@@ -448,258 +402,15 @@ jobs:
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release_ui_darwin
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-ui-darwin
|
||||
path: dist/
|
||||
retention-days: 3
|
||||
|
||||
test_windows_installer:
|
||||
name: "Windows Installer / Build Test"
|
||||
runs-on: windows-2022
|
||||
needs: [release, release_ui]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- arch: amd64
|
||||
wintun_arch: amd64
|
||||
- arch: arm64
|
||||
wintun_arch: arm64
|
||||
defaults:
|
||||
run:
|
||||
shell: powershell
|
||||
env:
|
||||
PackageWorkdir: netbird_windows_${{ matrix.arch }}
|
||||
downloadPath: '${{ github.workspace }}\temp'
|
||||
steps:
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Add 7-Zip to PATH
|
||||
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
|
||||
- name: Download release artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: release
|
||||
path: release
|
||||
|
||||
- name: Download UI release artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: release-ui
|
||||
path: release-ui
|
||||
|
||||
- name: Stage binaries into dist
|
||||
run: |
|
||||
$workdir = "dist\${{ env.PackageWorkdir }}"
|
||||
New-Item -ItemType Directory -Force -Path $workdir | Out-Null
|
||||
$client = Get-ChildItem -Recurse -Path release -Filter "netbird_*_windows_${{ matrix.arch }}.tar.gz" | Select-Object -First 1
|
||||
$ui = Get-ChildItem -Recurse -Path release-ui -Filter "netbird-ui-windows_*_windows_${{ matrix.arch }}.tar.gz" | Select-Object -First 1
|
||||
if (-not $client) { Write-Host "::error::client tarball not found for ${{ matrix.arch }}"; exit 1 }
|
||||
if (-not $ui) { Write-Host "::error::ui tarball not found for ${{ matrix.arch }}"; exit 1 }
|
||||
Write-Host "Client: $($client.FullName)"
|
||||
Write-Host "UI: $($ui.FullName)"
|
||||
tar -zvxf $client.FullName -C $workdir
|
||||
tar -zvxf $ui.FullName -C $workdir
|
||||
Get-ChildItem $workdir
|
||||
|
||||
- name: Download wintun
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-wintun
|
||||
with:
|
||||
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
||||
file-name: wintun.zip
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
|
||||
|
||||
- name: Decompress wintun files
|
||||
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
|
||||
- name: Move wintun.dll into dist
|
||||
run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
||||
|
||||
- name: Download Mesa3D (amd64 only)
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-mesa3d
|
||||
if: matrix.arch == 'amd64'
|
||||
with:
|
||||
file-url: https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z
|
||||
file-name: mesa3d.7z
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9'
|
||||
|
||||
- name: Extract Mesa3D driver (amd64 only)
|
||||
if: matrix.arch == 'amd64'
|
||||
run: 7z x -o"${{ env.downloadPath }}" "${{ env.downloadPath }}/mesa3d.7z"
|
||||
|
||||
- name: Move opengl32.dll into dist (amd64 only)
|
||||
if: matrix.arch == 'amd64'
|
||||
run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
||||
|
||||
- name: Download EnVar plugin for NSIS
|
||||
uses: carlosperate/download-file-action@v2
|
||||
with:
|
||||
file-url: https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip
|
||||
file-name: envar_plugin.zip
|
||||
location: ${{ github.workspace }}
|
||||
|
||||
- name: Extract EnVar plugin
|
||||
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/envar_plugin.zip"
|
||||
|
||||
- name: Download ShellExecAsUser plugin for NSIS (amd64 only)
|
||||
uses: carlosperate/download-file-action@v2
|
||||
if: matrix.arch == 'amd64'
|
||||
with:
|
||||
file-url: https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z
|
||||
file-name: ShellExecAsUser_amd64-Unicode.7z
|
||||
location: ${{ github.workspace }}
|
||||
|
||||
- name: Extract ShellExecAsUser plugin (amd64 only)
|
||||
if: matrix.arch == 'amd64'
|
||||
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z"
|
||||
|
||||
- name: Build NSIS installer
|
||||
uses: joncloud/makensis-action@v3.3
|
||||
with:
|
||||
additional-plugin-paths: ${{ github.workspace }}/NSIS_Plugins/Plugins
|
||||
script-file: client/installer.nsis
|
||||
arguments: "/V4 /DARCH=${{ matrix.arch }}"
|
||||
env:
|
||||
APPVER: ${{ steps.semver_parser.outputs.major }}.${{ steps.semver_parser.outputs.minor }}.${{ steps.semver_parser.outputs.patch }}.${{ github.run_id }}
|
||||
|
||||
- name: Rename NSIS installer
|
||||
run: mv netbird-installer.exe netbird_installer_test_windows_${{ matrix.arch }}.exe
|
||||
|
||||
- name: Install WiX
|
||||
run: |
|
||||
dotnet tool install --global wix --version 6.0.2
|
||||
wix extension add WixToolset.Util.wixext/6.0.2
|
||||
|
||||
- name: Build MSI installer
|
||||
env:
|
||||
NETBIRD_VERSION: "${{ steps.semver_parser.outputs.fullversion }}"
|
||||
run: wix build -arch ${{ matrix.arch == 'amd64' && 'x64' || 'arm64' }} -ext WixToolset.Util.wixext -o netbird_installer_test_windows_${{ matrix.arch }}.msi .\client\netbird.wxs -d ProcessorArchitecture=${{ matrix.arch == 'amd64' && 'x64' || 'arm64' }} -d ArchSuffix=${{ matrix.arch }}
|
||||
|
||||
- name: Upload installer artifacts
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: windows-installer-test-${{ matrix.arch }}
|
||||
path: |
|
||||
netbird_installer_test_windows_${{ matrix.arch }}.exe
|
||||
netbird_installer_test_windows_${{ matrix.arch }}.msi
|
||||
retention-days: 3
|
||||
|
||||
comment_release_artifacts:
|
||||
name: Comment release artifacts
|
||||
runs-on: ubuntu-latest
|
||||
needs: [release, release_ui, release_ui_darwin]
|
||||
if: ${{ always() && github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository }}
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Create or update PR comment
|
||||
uses: actions/github-script@v7
|
||||
env:
|
||||
RELEASE_RESULT: ${{ needs.release.result }}
|
||||
RELEASE_UI_RESULT: ${{ needs.release_ui.result }}
|
||||
RELEASE_UI_DARWIN_RESULT: ${{ needs.release_ui_darwin.result }}
|
||||
RELEASE_ARTIFACT_URL: ${{ needs.release.outputs.release_artifact_url }}
|
||||
LINUX_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.linux_packages_artifact_url }}
|
||||
WINDOWS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.windows_packages_artifact_url }}
|
||||
MACOS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.macos_packages_artifact_url }}
|
||||
RELEASE_UI_ARTIFACT_URL: ${{ needs.release_ui.outputs.release_ui_artifact_url }}
|
||||
RELEASE_UI_DARWIN_ARTIFACT_URL: ${{ needs.release_ui_darwin.outputs.release_ui_darwin_artifact_url }}
|
||||
GHCR_IMAGES_MARKDOWN: ${{ needs.release.outputs.ghcr_images }}
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const marker = '<!-- netbird-release-artifacts -->';
|
||||
const { owner, repo } = context.repo;
|
||||
const issue_number = context.payload.pull_request.number;
|
||||
const runUrl = `${context.serverUrl}/${owner}/${repo}/actions/runs/${context.runId}`;
|
||||
const shortSha = context.payload.pull_request.head.sha.slice(0, 7);
|
||||
|
||||
const artifactCell = (url, result) => {
|
||||
if (url) return `[Download](${url})`;
|
||||
return result && result !== 'success' ? `_Not available (${result})_` : '_Not available_';
|
||||
};
|
||||
|
||||
const artifacts = [
|
||||
['All release artifacts', process.env.RELEASE_ARTIFACT_URL, process.env.RELEASE_RESULT],
|
||||
['Linux packages', process.env.LINUX_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
|
||||
['Windows packages', process.env.WINDOWS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
|
||||
['macOS packages', process.env.MACOS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
|
||||
['UI artifacts', process.env.RELEASE_UI_ARTIFACT_URL, process.env.RELEASE_UI_RESULT],
|
||||
['UI macOS artifacts', process.env.RELEASE_UI_DARWIN_ARTIFACT_URL, process.env.RELEASE_UI_DARWIN_RESULT],
|
||||
];
|
||||
|
||||
const artifactRows = artifacts
|
||||
.map(([name, url, result]) => `| ${name} | ${artifactCell(url, result)} |`)
|
||||
.join('\n');
|
||||
|
||||
const ghcrImages = (process.env.GHCR_IMAGES_MARKDOWN || '').trim() || '_No GHCR images were pushed._';
|
||||
|
||||
const body = [
|
||||
marker,
|
||||
'## Release artifacts',
|
||||
'',
|
||||
`Built for PR head \`${shortSha}\` in [workflow run #${process.env.GITHUB_RUN_NUMBER}](${runUrl}).`,
|
||||
'',
|
||||
'| Artifact | Link |',
|
||||
'| --- | --- |',
|
||||
artifactRows,
|
||||
'',
|
||||
'### GHCR images (amd64)',
|
||||
ghcrImages,
|
||||
'',
|
||||
'_This comment is updated by the Release workflow. Artifact links expire according to the workflow retention policy._',
|
||||
].join('\n');
|
||||
|
||||
const comments = await github.paginate(github.rest.issues.listComments, {
|
||||
owner,
|
||||
repo,
|
||||
issue_number,
|
||||
per_page: 100,
|
||||
});
|
||||
|
||||
const previous = comments.find(comment =>
|
||||
comment.user?.type === 'Bot' && comment.body?.includes(marker)
|
||||
);
|
||||
|
||||
if (previous) {
|
||||
await github.rest.issues.updateComment({
|
||||
owner,
|
||||
repo,
|
||||
comment_id: previous.id,
|
||||
body,
|
||||
});
|
||||
core.info(`Updated release artifacts comment ${previous.id}`);
|
||||
} else {
|
||||
const { data } = await github.rest.issues.createComment({
|
||||
owner,
|
||||
repo,
|
||||
issue_number,
|
||||
body,
|
||||
});
|
||||
core.info(`Created release artifacts comment ${data.id}`);
|
||||
}
|
||||
|
||||
trigger_signer:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [release, release_ui, release_ui_darwin, test_windows_installer]
|
||||
needs: [release, release_ui, release_ui_darwin]
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- name: Trigger binaries sign pipelines
|
||||
|
||||
28
.github/workflows/sync-tag.yml
vendored
28
.github/workflows/sync-tag.yml
vendored
@@ -9,8 +9,6 @@ concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
# Receiving workflows (cloud sync-tag, mobile bump-netbird) expect the short
|
||||
# tag form (e.g. v0.30.0), not refs/tags/v0.30.0 — github.ref_name, not github.ref.
|
||||
jobs:
|
||||
trigger_sync_tag:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -22,30 +20,4 @@ jobs:
|
||||
ref: main
|
||||
repo: ${{ secrets.UPSTREAM_REPO }}
|
||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
||||
|
||||
trigger_android_bump:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
|
||||
steps:
|
||||
- name: Trigger android-client submodule bump
|
||||
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
|
||||
with:
|
||||
workflow: bump-netbird.yml
|
||||
ref: main
|
||||
repo: netbirdio/android-client
|
||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
||||
|
||||
trigger_ios_bump:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
|
||||
steps:
|
||||
- name: Trigger ios-client submodule bump
|
||||
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
|
||||
with:
|
||||
workflow: bump-netbird.yml
|
||||
ref: main
|
||||
repo: netbirdio/ios-client
|
||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
||||
@@ -58,11 +58,6 @@ linters:
|
||||
govet:
|
||||
enable:
|
||||
- nilness
|
||||
disable:
|
||||
# The inline analyzer flags x/exp/maps Clone/Clear with //go:fix inline
|
||||
# directives but cannot perform the rewrite due to generic type
|
||||
# parameter inference limitations in the Go inliner.
|
||||
- inline
|
||||
enable-all: false
|
||||
revive:
|
||||
rules:
|
||||
|
||||
@@ -17,7 +17,6 @@ ENV \
|
||||
NETBIRD_BIN="/usr/local/bin/netbird" \
|
||||
NB_LOG_FILE="console,/var/log/netbird/client.log" \
|
||||
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
|
||||
NB_ENABLE_CAPTURE="false" \
|
||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
|
||||
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||
|
||||
@@ -23,7 +23,6 @@ ENV \
|
||||
NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \
|
||||
NB_LOG_FILE="console,/var/lib/netbird/client.log" \
|
||||
NB_DISABLE_DNS="true" \
|
||||
NB_ENABLE_CAPTURE="false" \
|
||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
|
||||
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||
|
||||
@@ -1,196 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/util/capture"
|
||||
)
|
||||
|
||||
var captureCmd = &cobra.Command{
|
||||
Use: "capture",
|
||||
Short: "Capture packets on the WireGuard interface",
|
||||
Long: `Captures decrypted packets flowing through the WireGuard interface.
|
||||
|
||||
Default output is human-readable text. Use --pcap or --output for pcap binary.
|
||||
Requires --enable-capture to be set at service install or reconfigure time.
|
||||
|
||||
Examples:
|
||||
netbird debug capture
|
||||
netbird debug capture host 100.64.0.1 and port 443
|
||||
netbird debug capture tcp
|
||||
netbird debug capture icmp
|
||||
netbird debug capture src host 10.0.0.1 and dst port 80
|
||||
netbird debug capture -o capture.pcap
|
||||
netbird debug capture --pcap | tshark -r -
|
||||
netbird debug capture --pcap | tcpdump -r - -n`,
|
||||
Args: cobra.ArbitraryArgs,
|
||||
RunE: runCapture,
|
||||
}
|
||||
|
||||
func init() {
|
||||
debugCmd.AddCommand(captureCmd)
|
||||
|
||||
captureCmd.Flags().Bool("pcap", false, "Force pcap binary output (default when --output is set)")
|
||||
captureCmd.Flags().BoolP("verbose", "v", false, "Show seq/ack, TTL, window, total length")
|
||||
captureCmd.Flags().Bool("ascii", false, "Print payload as ASCII after each packet (useful for HTTP)")
|
||||
captureCmd.Flags().Uint32("snap-len", 0, "Max bytes per packet (0 = full)")
|
||||
captureCmd.Flags().DurationP("duration", "d", 0, "Capture duration (0 = until interrupted)")
|
||||
captureCmd.Flags().StringP("output", "o", "", "Write pcap to file instead of stdout")
|
||||
}
|
||||
|
||||
func runCapture(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
cmd.PrintErrf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
req, err := buildCaptureRequest(cmd, args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer cancel()
|
||||
|
||||
stream, err := client.StartCapture(ctx, req)
|
||||
if err != nil {
|
||||
return handleCaptureError(err)
|
||||
}
|
||||
|
||||
// First Recv is the empty acceptance message from the server. If the
|
||||
// device is unavailable (kernel WG, not connected, capture disabled),
|
||||
// the server returns an error instead.
|
||||
if _, err := stream.Recv(); err != nil {
|
||||
return handleCaptureError(err)
|
||||
}
|
||||
|
||||
out, cleanup, err := captureOutput(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if req.TextOutput {
|
||||
cmd.PrintErrf("Capturing packets... Press Ctrl+C to stop.\n")
|
||||
} else {
|
||||
cmd.PrintErrf("Capturing packets (pcap)... Press Ctrl+C to stop.\n")
|
||||
}
|
||||
|
||||
streamErr := streamCapture(ctx, cmd, stream, out)
|
||||
cleanupErr := cleanup()
|
||||
if streamErr != nil {
|
||||
return streamErr
|
||||
}
|
||||
return cleanupErr
|
||||
}
|
||||
|
||||
func buildCaptureRequest(cmd *cobra.Command, args []string) (*proto.StartCaptureRequest, error) {
|
||||
req := &proto.StartCaptureRequest{}
|
||||
|
||||
if len(args) > 0 {
|
||||
expr := strings.Join(args, " ")
|
||||
if _, err := capture.ParseFilter(expr); err != nil {
|
||||
return nil, fmt.Errorf("invalid filter: %w", err)
|
||||
}
|
||||
req.FilterExpr = expr
|
||||
}
|
||||
|
||||
if snap, _ := cmd.Flags().GetUint32("snap-len"); snap > 0 {
|
||||
req.SnapLen = snap
|
||||
}
|
||||
if d, _ := cmd.Flags().GetDuration("duration"); d != 0 {
|
||||
if d < 0 {
|
||||
return nil, fmt.Errorf("duration must not be negative")
|
||||
}
|
||||
req.Duration = durationpb.New(d)
|
||||
}
|
||||
req.Verbose, _ = cmd.Flags().GetBool("verbose")
|
||||
req.Ascii, _ = cmd.Flags().GetBool("ascii")
|
||||
|
||||
outPath, _ := cmd.Flags().GetString("output")
|
||||
forcePcap, _ := cmd.Flags().GetBool("pcap")
|
||||
req.TextOutput = !forcePcap && outPath == ""
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func streamCapture(ctx context.Context, cmd *cobra.Command, stream proto.DaemonService_StartCaptureClient, out io.Writer) error {
|
||||
for {
|
||||
pkt, err := stream.Recv()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
cmd.PrintErrf("\nCapture stopped.\n")
|
||||
return nil //nolint:nilerr // user interrupted
|
||||
}
|
||||
if err == io.EOF {
|
||||
cmd.PrintErrf("\nCapture finished.\n")
|
||||
return nil
|
||||
}
|
||||
return handleCaptureError(err)
|
||||
}
|
||||
if _, err := out.Write(pkt.GetData()); err != nil {
|
||||
return fmt.Errorf("write output: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// captureOutput returns the writer for capture data and a cleanup function
|
||||
// that finalizes the file. Errors from the cleanup must be propagated.
|
||||
func captureOutput(cmd *cobra.Command) (io.Writer, func() error, error) {
|
||||
outPath, _ := cmd.Flags().GetString("output")
|
||||
if outPath == "" {
|
||||
return os.Stdout, func() error { return nil }, nil
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp(filepath.Dir(outPath), filepath.Base(outPath)+".*.tmp")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create output file: %w", err)
|
||||
}
|
||||
tmpPath := f.Name()
|
||||
return f, func() error {
|
||||
var merr *multierror.Error
|
||||
if err := f.Close(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("close output file: %w", err))
|
||||
}
|
||||
fi, statErr := os.Stat(tmpPath)
|
||||
if statErr != nil || fi.Size() == 0 {
|
||||
if rmErr := os.Remove(tmpPath); rmErr != nil && !os.IsNotExist(rmErr) {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove empty output file: %w", rmErr))
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
if err := os.Rename(tmpPath, outPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("rename output file: %w", err))
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
cmd.PrintErrf("Wrote %s\n", outPath)
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func handleCaptureError(err error) error {
|
||||
if s, ok := status.FromError(err); ok {
|
||||
return fmt.Errorf("%s", s.Message())
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
@@ -240,50 +239,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
}()
|
||||
}
|
||||
|
||||
captureStarted := false
|
||||
if wantCapture, _ := cmd.Flags().GetBool("capture"); wantCapture {
|
||||
captureTimeout := duration + 30*time.Second
|
||||
const maxBundleCapture = 10 * time.Minute
|
||||
if captureTimeout > maxBundleCapture {
|
||||
captureTimeout = maxBundleCapture
|
||||
}
|
||||
_, err := client.StartBundleCapture(cmd.Context(), &proto.StartBundleCaptureRequest{
|
||||
Timeout: durationpb.New(captureTimeout),
|
||||
})
|
||||
if err != nil {
|
||||
cmd.PrintErrf("Failed to start packet capture: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
captureStarted = true
|
||||
cmd.Println("Packet capture started.")
|
||||
// Safety: always stop on exit, even if the normal stop below runs too.
|
||||
defer func() {
|
||||
if captureStarted {
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to stop packet capture: %v\n", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
||||
return waitErr
|
||||
}
|
||||
cmd.Println("\nDuration completed")
|
||||
|
||||
if captureStarted {
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to stop packet capture: %v\n", err)
|
||||
} else {
|
||||
captureStarted = false
|
||||
cmd.Println("Packet capture stopped.")
|
||||
}
|
||||
}
|
||||
|
||||
if cpuProfilingStarted {
|
||||
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
|
||||
@@ -456,5 +416,4 @@ func init() {
|
||||
forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
|
||||
forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
|
||||
forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
|
||||
forCmd.Flags().Bool("capture", false, "Capture packets during the debug duration and include in bundle")
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
@@ -24,7 +23,6 @@ import (
|
||||
|
||||
func init() {
|
||||
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
loginCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
|
||||
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
|
||||
}
|
||||
@@ -258,7 +256,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
||||
}
|
||||
|
||||
func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error {
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser, showQR)
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||
|
||||
resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
@@ -326,7 +324,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||
}
|
||||
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser, showQR)
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
||||
|
||||
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
|
||||
if err != nil {
|
||||
@@ -336,7 +334,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
||||
return &tokenInfo, nil
|
||||
}
|
||||
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser, showQR bool) {
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) {
|
||||
var codeMsg string
|
||||
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||
@@ -350,12 +348,6 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
||||
verificationURIComplete + " " + codeMsg)
|
||||
}
|
||||
|
||||
if showQR {
|
||||
if f, ok := cmd.OutOrStdout().(*os.File); ok && term.IsTerminal(int(f.Fd())) {
|
||||
printQRCode(f, verificationURIComplete)
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Println("")
|
||||
|
||||
if !noBrowser {
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/mdp/qrterminal/v3"
|
||||
)
|
||||
|
||||
// printQRCode prints a QR code for the given URL to the writer.
|
||||
// Called only when the user explicitly requests QR output via --qr.
|
||||
func printQRCode(w io.Writer, url string) {
|
||||
if url == "" {
|
||||
return
|
||||
}
|
||||
qrterminal.GenerateWithConfig(url, qrterminal.Config{
|
||||
Level: qrterminal.M,
|
||||
Writer: w,
|
||||
HalfBlocks: true,
|
||||
BlackChar: qrterminal.BLACK_BLACK,
|
||||
WhiteChar: qrterminal.WHITE_WHITE,
|
||||
BlackWhiteChar: qrterminal.BLACK_WHITE,
|
||||
WhiteBlackChar: qrterminal.WHITE_BLACK,
|
||||
QuietZone: qrterminal.QUIET_ZONE,
|
||||
})
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPrintQRCode_EmptyURL(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
printQRCode(&buf, "")
|
||||
|
||||
if buf.Len() != 0 {
|
||||
t.Error("expected no output for empty URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrintQRCode_WritesOutput(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
printQRCode(&buf, "https://example.com/auth")
|
||||
|
||||
if buf.Len() == 0 {
|
||||
t.Error("expected QR code output for non-empty URL")
|
||||
}
|
||||
}
|
||||
@@ -75,7 +75,6 @@ var (
|
||||
mtu uint16
|
||||
profilesDisabled bool
|
||||
updateSettingsDisabled bool
|
||||
captureEnabled bool
|
||||
networksDisabled bool
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
@@ -152,6 +151,7 @@ func init() {
|
||||
rootCmd.AddCommand(logoutCmd)
|
||||
rootCmd.AddCommand(versionCmd)
|
||||
rootCmd.AddCommand(sshCmd)
|
||||
rootCmd.AddCommand(vncCmd)
|
||||
rootCmd.AddCommand(networksCMD)
|
||||
rootCmd.AddCommand(forwardingRulesCmd)
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
|
||||
@@ -44,7 +44,6 @@ func init() {
|
||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
|
||||
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
|
||||
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
|
||||
serviceCmd.PersistentFlags().BoolVar(&captureEnabled, "enable-capture", false, "Enables packet capture via 'netbird debug capture'. To persist, use: netbird service install --enable-capture")
|
||||
serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks")
|
||||
|
||||
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
||||
|
||||
@@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error {
|
||||
}
|
||||
}
|
||||
|
||||
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, captureEnabled, networksDisabled)
|
||||
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, networksDisabled)
|
||||
if err := serverInstance.Start(); err != nil {
|
||||
log.Fatalf("failed to start daemon: %v", err)
|
||||
}
|
||||
|
||||
@@ -59,10 +59,6 @@ func buildServiceArguments() []string {
|
||||
args = append(args, "--disable-update-settings")
|
||||
}
|
||||
|
||||
if captureEnabled {
|
||||
args = append(args, "--enable-capture")
|
||||
}
|
||||
|
||||
if networksDisabled {
|
||||
args = append(args, "--disable-networks")
|
||||
}
|
||||
|
||||
@@ -28,7 +28,6 @@ type serviceParams struct {
|
||||
LogFiles []string `json:"log_files,omitempty"`
|
||||
DisableProfiles bool `json:"disable_profiles,omitempty"`
|
||||
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
|
||||
EnableCapture bool `json:"enable_capture,omitempty"`
|
||||
DisableNetworks bool `json:"disable_networks,omitempty"`
|
||||
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
|
||||
}
|
||||
@@ -80,7 +79,6 @@ func currentServiceParams() *serviceParams {
|
||||
LogFiles: logFiles,
|
||||
DisableProfiles: profilesDisabled,
|
||||
DisableUpdateSettings: updateSettingsDisabled,
|
||||
EnableCapture: captureEnabled,
|
||||
DisableNetworks: networksDisabled,
|
||||
}
|
||||
|
||||
@@ -146,10 +144,6 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
|
||||
updateSettingsDisabled = params.DisableUpdateSettings
|
||||
}
|
||||
|
||||
if !serviceCmd.PersistentFlags().Changed("enable-capture") {
|
||||
captureEnabled = params.EnableCapture
|
||||
}
|
||||
|
||||
if !serviceCmd.PersistentFlags().Changed("disable-networks") {
|
||||
networksDisabled = params.DisableNetworks
|
||||
}
|
||||
|
||||
@@ -535,7 +535,6 @@ func fieldToGlobalVar(field string) string {
|
||||
"LogFiles": "logFiles",
|
||||
"DisableProfiles": "profilesDisabled",
|
||||
"DisableUpdateSettings": "updateSettingsDisabled",
|
||||
"EnableCapture": "captureEnabled",
|
||||
"DisableNetworks": "networksDisabled",
|
||||
"ServiceEnvVars": "serviceEnvVars",
|
||||
}
|
||||
|
||||
@@ -36,7 +36,10 @@ const (
|
||||
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
|
||||
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
|
||||
disableSSHAuthFlag = "disable-ssh-auth"
|
||||
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
|
||||
jwtCacheTTLFlag = "jwt-cache-ttl"
|
||||
|
||||
// Alias for backward compatibility.
|
||||
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -61,7 +64,7 @@ var (
|
||||
enableSSHLocalPortForward bool
|
||||
enableSSHRemotePortForward bool
|
||||
disableSSHAuth bool
|
||||
sshJWTCacheTTL int
|
||||
jwtCacheTTL int
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -71,7 +74,9 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
|
||||
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
|
||||
upCmd.PersistentFlags().IntVar(&jwtCacheTTL, jwtCacheTTLFlag, 0, "JWT token cache TTL in seconds (0=disabled)")
|
||||
upCmd.PersistentFlags().IntVar(&jwtCacheTTL, sshJWTCacheTTLFlag, 0, "JWT token cache TTL in seconds (alias for --jwt-cache-ttl)")
|
||||
_ = upCmd.PersistentFlags().MarkDeprecated(sshJWTCacheTTLFlag, "use --jwt-cache-ttl instead")
|
||||
|
||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
|
||||
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
|
||||
|
||||
@@ -135,7 +135,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -160,7 +160,7 @@ func startClientDaemon(
|
||||
s := grpc.NewServer()
|
||||
|
||||
server := client.New(ctx,
|
||||
"", "", false, false, false, false)
|
||||
"", "", false, false, false)
|
||||
if err := server.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -39,9 +39,6 @@ const (
|
||||
noBrowserFlag = "no-browser"
|
||||
noBrowserDesc = "do not open the browser for SSO login"
|
||||
|
||||
showQRFlag = "qr"
|
||||
showQRDesc = "show QR code for the SSO login URL (useful for headless machines without browser access)"
|
||||
|
||||
profileNameFlag = "profile"
|
||||
profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used."
|
||||
)
|
||||
@@ -51,7 +48,6 @@ var (
|
||||
dnsLabels []string
|
||||
dnsLabelsValidated domain.List
|
||||
noBrowser bool
|
||||
showQR bool
|
||||
profileName string
|
||||
configPath string
|
||||
|
||||
@@ -84,7 +80,6 @@ func init() {
|
||||
)
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
upCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
|
||||
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ")
|
||||
|
||||
@@ -361,6 +356,9 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
req.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
req.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
req.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
@@ -376,9 +374,12 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
req.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||
req.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||
if cmd.Flag(disableVNCAuthFlag).Changed {
|
||||
req.DisableVNCAuth = &disableVNCAuth
|
||||
}
|
||||
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
jwtCacheTTL32 := int32(jwtCacheTTL)
|
||||
req.SshJWTCacheTTL = &jwtCacheTTL32
|
||||
}
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
@@ -463,6 +464,9 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
ic.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
ic.EnableSSHRoot = &enableSSHRoot
|
||||
@@ -484,8 +488,12 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
ic.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
|
||||
if cmd.Flag(disableVNCAuthFlag).Changed {
|
||||
ic.DisableVNCAuth = &disableVNCAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
ic.SSHJWTCacheTTL = &jwtCacheTTL
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
@@ -587,6 +595,9 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
loginRequest.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
loginRequest.EnableSSHRoot = &enableSSHRoot
|
||||
@@ -608,9 +619,13 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
loginRequest.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||
if cmd.Flag(disableVNCAuthFlag).Changed {
|
||||
loginRequest.DisableVNCAuth = &disableVNCAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
jwtCacheTTL32 := int32(jwtCacheTTL)
|
||||
loginRequest.SshJWTCacheTTL = &jwtCacheTTL32
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
|
||||
271
client/cmd/vnc.go
Normal file
271
client/cmd/vnc.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"os/user"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var (
|
||||
vncUsername string
|
||||
vncHost string
|
||||
vncMode string
|
||||
vncListen string
|
||||
vncNoBrowser bool
|
||||
vncNoCache bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
vncCmd.PersistentFlags().StringVar(&vncUsername, "user", "", "OS username for session mode")
|
||||
vncCmd.PersistentFlags().StringVar(&vncMode, "mode", "attach", "Connection mode: attach (view current display) or session (virtual desktop)")
|
||||
vncCmd.PersistentFlags().StringVar(&vncListen, "listen", "", "Start local VNC proxy on this address (e.g., :5900) for external VNC viewers")
|
||||
vncCmd.PersistentFlags().BoolVar(&vncNoBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
vncCmd.PersistentFlags().BoolVar(&vncNoCache, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||
}
|
||||
|
||||
var vncCmd = &cobra.Command{
|
||||
Use: "vnc [flags] [user@]host",
|
||||
Short: "Connect to a NetBird peer via VNC",
|
||||
Long: `Connect to a NetBird peer using VNC with JWT-based authentication.
|
||||
The target peer must have the VNC server enabled.
|
||||
|
||||
Two modes are available:
|
||||
- attach: view the current physical display (remote support)
|
||||
- session: start a virtual desktop as the specified user (passwordless login)
|
||||
|
||||
Use --listen to start a local proxy for external VNC viewers:
|
||||
netbird vnc --listen :5900 peer-hostname
|
||||
vncviewer localhost:5900
|
||||
|
||||
Examples:
|
||||
netbird vnc peer-hostname
|
||||
netbird vnc --mode session --user alice peer-hostname
|
||||
netbird vnc --listen :5900 peer-hostname`,
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: vncFn,
|
||||
}
|
||||
|
||||
func vncFn(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
logOutput := "console"
|
||||
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||
logOutput = firstLogFile
|
||||
}
|
||||
if err := util.InitLog(logLevel, logOutput); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
if err := parseVNCHostArg(args[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||
vncCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := runVNC(vncCtx, cmd); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-sig:
|
||||
cancel()
|
||||
<-vncCtx.Done()
|
||||
return nil
|
||||
case err := <-errCh:
|
||||
return err
|
||||
case <-vncCtx.Done():
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseVNCHostArg(arg string) error {
|
||||
if strings.Contains(arg, "@") {
|
||||
parts := strings.SplitN(arg, "@", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return fmt.Errorf("invalid user@host format")
|
||||
}
|
||||
if vncUsername == "" {
|
||||
vncUsername = parts[0]
|
||||
}
|
||||
vncHost = parts[1]
|
||||
if vncMode == "attach" {
|
||||
vncMode = "session"
|
||||
}
|
||||
} else {
|
||||
vncHost = arg
|
||||
}
|
||||
|
||||
if vncMode == "session" && vncUsername == "" {
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||
vncUsername = sudoUser
|
||||
} else if currentUser, err := user.Current(); err == nil {
|
||||
vncUsername = currentUser.Username
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runVNC(ctx context.Context, cmd *cobra.Command) error {
|
||||
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
|
||||
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to daemon: %w", err)
|
||||
}
|
||||
defer func() { _ = grpcConn.Close() }()
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(grpcConn)
|
||||
|
||||
if vncMode == "session" {
|
||||
cmd.Printf("Connecting to %s@%s [session mode]...\n", vncUsername, vncHost)
|
||||
} else {
|
||||
cmd.Printf("Connecting to %s [attach mode]...\n", vncHost)
|
||||
}
|
||||
|
||||
// Obtain JWT token. If the daemon has no SSO configured, proceed without one
|
||||
// (the server will accept unauthenticated connections if --disable-vnc-auth is set).
|
||||
var jwtToken string
|
||||
hint := profilemanager.GetLoginHint()
|
||||
var browserOpener func(string) error
|
||||
if !vncNoBrowser {
|
||||
browserOpener = util.OpenBrowser
|
||||
}
|
||||
|
||||
token, err := nbssh.RequestJWTToken(ctx, daemonClient, nil, cmd.ErrOrStderr(), !vncNoCache, hint, browserOpener)
|
||||
if err != nil {
|
||||
log.Debugf("JWT authentication unavailable, connecting without token: %v", err)
|
||||
} else {
|
||||
jwtToken = token
|
||||
log.Debug("JWT authentication successful")
|
||||
}
|
||||
|
||||
// Connect to the VNC server on the standard port (5900). The peer's firewall
|
||||
// DNATs 5900 -> 25900 (internal), so both ports work on the overlay network.
|
||||
vncAddr := net.JoinHostPort(vncHost, "5900")
|
||||
vncConn, err := net.DialTimeout("tcp", vncAddr, vncDialTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to VNC at %s: %w", vncAddr, err)
|
||||
}
|
||||
defer vncConn.Close()
|
||||
|
||||
// Send session header with mode, username, and JWT.
|
||||
if err := sendVNCHeader(vncConn, vncMode, vncUsername, jwtToken); err != nil {
|
||||
return fmt.Errorf("send VNC header: %w", err)
|
||||
}
|
||||
|
||||
cmd.Printf("VNC connected to %s\n", vncHost)
|
||||
|
||||
if vncListen != "" {
|
||||
return runVNCLocalProxy(ctx, cmd, vncConn)
|
||||
}
|
||||
|
||||
// No --listen flag: inform the user they need to use --listen for external viewers.
|
||||
cmd.Printf("VNC tunnel established. Use --listen :5900 to proxy for local VNC viewers.\n")
|
||||
cmd.Printf("Press Ctrl+C to disconnect.\n")
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
const vncDialTimeout = 15 * time.Second
|
||||
|
||||
// sendVNCHeader writes the NetBird VNC session header.
|
||||
func sendVNCHeader(conn net.Conn, mode, username, jwt string) error {
|
||||
var modeByte byte
|
||||
if mode == "session" {
|
||||
modeByte = 1
|
||||
}
|
||||
|
||||
usernameBytes := []byte(username)
|
||||
jwtBytes := []byte(jwt)
|
||||
hdr := make([]byte, 3+len(usernameBytes)+2+len(jwtBytes))
|
||||
hdr[0] = modeByte
|
||||
binary.BigEndian.PutUint16(hdr[1:3], uint16(len(usernameBytes)))
|
||||
off := 3
|
||||
copy(hdr[off:], usernameBytes)
|
||||
off += len(usernameBytes)
|
||||
binary.BigEndian.PutUint16(hdr[off:off+2], uint16(len(jwtBytes)))
|
||||
off += 2
|
||||
copy(hdr[off:], jwtBytes)
|
||||
|
||||
_, err := conn.Write(hdr)
|
||||
return err
|
||||
}
|
||||
|
||||
// runVNCLocalProxy listens on the given address and proxies incoming
|
||||
// connections to the already-established VNC tunnel.
|
||||
func runVNCLocalProxy(ctx context.Context, cmd *cobra.Command, vncConn net.Conn) error {
|
||||
listener, err := net.Listen("tcp", vncListen)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on %s: %w", vncListen, err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
cmd.Printf("VNC proxy listening on %s - connect with your VNC viewer\n", listener.Addr())
|
||||
cmd.Printf("Press Ctrl+C to stop.\n")
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
listener.Close()
|
||||
}()
|
||||
|
||||
// Accept a single viewer connection. VNC is single-session: the RFB
|
||||
// handshake completes on vncConn for the first viewer, so subsequent
|
||||
// viewers would get a mid-stream connection. The loop handles transient
|
||||
// accept errors until a valid connection arrives.
|
||||
for {
|
||||
clientConn, err := listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
log.Debugf("accept VNC proxy client: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
cmd.Printf("VNC viewer connected from %s\n", clientConn.RemoteAddr())
|
||||
|
||||
// Bidirectional copy.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
io.Copy(vncConn, clientConn)
|
||||
close(done)
|
||||
}()
|
||||
io.Copy(clientConn, vncConn)
|
||||
<-done
|
||||
clientConn.Close()
|
||||
|
||||
cmd.Printf("VNC viewer disconnected\n")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
62
client/cmd/vnc_agent.go
Normal file
62
client/cmd/vnc_agent.go
Normal file
@@ -0,0 +1,62 @@
|
||||
//go:build windows
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
var vncAgentPort string
|
||||
|
||||
func init() {
|
||||
vncAgentCmd.Flags().StringVar(&vncAgentPort, "port", "15900", "Port for the VNC agent to listen on")
|
||||
rootCmd.AddCommand(vncAgentCmd)
|
||||
}
|
||||
|
||||
// vncAgentCmd runs a VNC server in the current user session, listening on
|
||||
// localhost. It is spawned by the NetBird service (Session 0) via
|
||||
// CreateProcessAsUser into the interactive console session.
|
||||
var vncAgentCmd = &cobra.Command{
|
||||
Use: "vnc-agent",
|
||||
Short: "Run VNC capture agent (internal, spawned by service)",
|
||||
Hidden: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Agent's stderr is piped to the service which relogs it.
|
||||
// Use JSON format with caller info for structured parsing.
|
||||
log.SetReportCaller(true)
|
||||
log.SetFormatter(&log.JSONFormatter{})
|
||||
log.SetOutput(os.Stderr)
|
||||
|
||||
sessionID := vncserver.GetCurrentSessionID()
|
||||
log.Infof("VNC agent starting on 127.0.0.1:%s (session %d)", vncAgentPort, sessionID)
|
||||
|
||||
capturer := vncserver.NewDesktopCapturer()
|
||||
injector := vncserver.NewWindowsInputInjector()
|
||||
srv := vncserver.New(capturer, injector, "")
|
||||
// Auth is handled by the service. The agent verifies a token on each
|
||||
// connection to ensure only the service process can connect.
|
||||
// The token is passed via environment variable to avoid exposing it
|
||||
// in the process command line (visible via tasklist/wmic).
|
||||
srv.SetDisableAuth(true)
|
||||
srv.SetAgentToken(os.Getenv("NB_VNC_AGENT_TOKEN"))
|
||||
|
||||
port, err := netip.ParseAddrPort("127.0.0.1:" + vncAgentPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
loopback := netip.PrefixFrom(netip.AddrFrom4([4]byte{127, 0, 0, 0}), 8)
|
||||
if err := srv.Start(cmd.Context(), port, loopback); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
<-cmd.Context().Done()
|
||||
return srv.Stop()
|
||||
},
|
||||
}
|
||||
16
client/cmd/vnc_flags.go
Normal file
16
client/cmd/vnc_flags.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package cmd
|
||||
|
||||
const (
|
||||
serverVNCAllowedFlag = "allow-server-vnc"
|
||||
disableVNCAuthFlag = "disable-vnc-auth"
|
||||
)
|
||||
|
||||
var (
|
||||
serverVNCAllowed bool
|
||||
disableVNCAuth bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&serverVNCAllowed, serverVNCAllowedFlag, false, "Allow embedded VNC server on peer")
|
||||
upCmd.PersistentFlags().BoolVar(&disableVNCAuth, disableVNCAuthFlag, false, "Disable JWT authentication for VNC")
|
||||
}
|
||||
229
client/cmd/vnc_recordings.go
Normal file
229
client/cmd/vnc_recordings.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"crypto/ecdh"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var vncRecDir string
|
||||
|
||||
func init() {
|
||||
vncRecPlayCmd.Flags().StringVar(&vncRecDir, "dir", "", "Recording directory (default: auto-detect)")
|
||||
vncRecListCmd.Flags().StringVar(&vncRecDir, "dir", "", "Recording directory (default: auto-detect)")
|
||||
vncRecCmd.AddCommand(vncRecListCmd)
|
||||
vncRecCmd.AddCommand(vncRecPlayCmd)
|
||||
vncRecCmd.AddCommand(vncRecKeygenCmd)
|
||||
vncCmd.AddCommand(vncRecCmd)
|
||||
}
|
||||
|
||||
var vncRecCmd = &cobra.Command{
|
||||
Use: "rec",
|
||||
Short: "Manage VNC session recordings",
|
||||
}
|
||||
|
||||
var vncRecKeygenCmd = &cobra.Command{
|
||||
Use: "keygen",
|
||||
Short: "Generate an X25519 keypair for recording encryption",
|
||||
Long: `Generates an X25519 keypair. Put the public key in management settings
|
||||
(Session Recording > Encryption Key). Keep the private key safe for decrypting recordings.`,
|
||||
RunE: vncRecKeygenFn,
|
||||
}
|
||||
|
||||
var vncRecListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List VNC session recordings",
|
||||
RunE: vncRecListFn,
|
||||
}
|
||||
|
||||
var vncRecPlayCmd = &cobra.Command{
|
||||
Use: "play <file-or-name>",
|
||||
Short: "Open a VNC recording in the browser",
|
||||
Long: `Opens a browser-based player with playback controls:
|
||||
play/pause, seek, speed (0.25x to 8x), keyboard shortcuts.
|
||||
|
||||
Examples:
|
||||
netbird vnc rec play last
|
||||
netbird vnc rec play 20260416-104433_vnc.rec`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: vncRecPlayFn,
|
||||
}
|
||||
|
||||
|
||||
func vncRecListFn(cmd *cobra.Command, _ []string) error {
|
||||
dir, err := resolveVNCRecDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read recording dir %s: %w", dir, err)
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 0, 2, ' ', 0)
|
||||
fmt.Fprintln(w, "FILE\tSIZE\tDIMENSIONS\tUSER\tREMOTE\tMODE\tDATE")
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".rec") {
|
||||
continue
|
||||
}
|
||||
filePath := filepath.Join(dir, entry.Name())
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
header, err := vncserver.ReadRecordingHeader(filePath)
|
||||
if err != nil {
|
||||
fmt.Fprintf(w, "%s\t%s\t?\t?\t?\t?\t?\n", entry.Name(), vncFormatSize(info.Size()))
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "%s\t%s\t%dx%d\t%s\t%s\t%s\t%s\n",
|
||||
entry.Name(),
|
||||
vncFormatSize(info.Size()),
|
||||
header.Width, header.Height,
|
||||
header.Meta.User,
|
||||
header.Meta.RemoteAddr,
|
||||
header.Meta.Mode,
|
||||
header.StartTime.Format("2006-01-02 15:04:05"),
|
||||
)
|
||||
}
|
||||
|
||||
return w.Flush()
|
||||
}
|
||||
|
||||
func vncRecPlayFn(cmd *cobra.Command, args []string) error {
|
||||
filePath, err := resolveVNCRecFile(args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
header, err := vncserver.ReadRecordingHeader(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read recording: %w", err)
|
||||
}
|
||||
|
||||
cmd.Printf("Recording: %s (%dx%d)\n", filepath.Base(filePath), header.Width, header.Height)
|
||||
|
||||
url, err := vncserver.ServeWebPlayer(filePath, "localhost:0")
|
||||
if err != nil {
|
||||
return fmt.Errorf("start web player: %w", err)
|
||||
}
|
||||
cmd.Printf("Player: %s\n", url)
|
||||
if err := util.OpenBrowser(url); err != nil {
|
||||
cmd.Printf("Open %s in your browser\n", url)
|
||||
}
|
||||
cmd.Printf("Press Ctrl+C to stop.\n")
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sig
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
func vncRecKeygenFn(cmd *cobra.Command, _ []string) error {
|
||||
priv, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate key: %w", err)
|
||||
}
|
||||
|
||||
privB64 := base64.StdEncoding.EncodeToString(priv.Bytes())
|
||||
pubB64 := base64.StdEncoding.EncodeToString(priv.PublicKey().Bytes())
|
||||
|
||||
cmd.Printf("Private key (keep secret, for decrypting recordings):\n %s\n\n", privB64)
|
||||
cmd.Printf("Public key (paste into management Settings > Session Recording > Encryption Key):\n %s\n", pubB64)
|
||||
return nil
|
||||
}
|
||||
|
||||
func vncFormatSize(size int64) string {
|
||||
switch {
|
||||
case size >= 1<<20:
|
||||
return fmt.Sprintf("%.1fM", float64(size)/float64(1<<20))
|
||||
case size >= 1<<10:
|
||||
return fmt.Sprintf("%.1fK", float64(size)/float64(1<<10))
|
||||
default:
|
||||
return fmt.Sprintf("%dB", size)
|
||||
}
|
||||
}
|
||||
|
||||
func resolveVNCRecDir() (string, error) {
|
||||
if vncRecDir != "" {
|
||||
return vncRecDir, nil
|
||||
}
|
||||
candidates := []string{
|
||||
"/var/lib/netbird/recordings/vnc",
|
||||
filepath.Join(os.Getenv("HOME"), ".netbird/recordings/vnc"),
|
||||
}
|
||||
for _, dir := range candidates {
|
||||
if fi, err := os.Stat(dir); err == nil && fi.IsDir() {
|
||||
return dir, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no VNC recording directory found; use --dir to specify")
|
||||
}
|
||||
|
||||
func resolveVNCRecFile(arg string) (string, error) {
|
||||
if strings.Contains(arg, "/") || strings.Contains(arg, string(os.PathSeparator)) {
|
||||
return arg, nil
|
||||
}
|
||||
|
||||
dir, err := resolveVNCRecDir()
|
||||
if err != nil && arg != "last" {
|
||||
return arg, nil
|
||||
}
|
||||
|
||||
if arg == "last" {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return findLatestRec(dir)
|
||||
}
|
||||
|
||||
full := filepath.Join(dir, arg)
|
||||
if _, err := os.Stat(full); err == nil {
|
||||
return full, nil
|
||||
}
|
||||
return arg, nil
|
||||
}
|
||||
|
||||
func findLatestRec(dir string) (string, error) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read dir: %w", err)
|
||||
}
|
||||
|
||||
var latest string
|
||||
var latestTime time.Time
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".rec") {
|
||||
continue
|
||||
}
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if info.ModTime().After(latestTime) {
|
||||
latestTime = info.ModTime()
|
||||
latest = filepath.Join(dir, entry.Name())
|
||||
}
|
||||
}
|
||||
if latest == "" {
|
||||
return "", fmt.Errorf("no recordings found in %s", dir)
|
||||
}
|
||||
return latest, nil
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/util/capture"
|
||||
)
|
||||
|
||||
// CaptureOptions configures a packet capture session.
|
||||
type CaptureOptions struct {
|
||||
// Output receives pcap-formatted data. Nil disables pcap output.
|
||||
Output io.Writer
|
||||
// TextOutput receives human-readable packet summaries. Nil disables text output.
|
||||
TextOutput io.Writer
|
||||
// Filter is a BPF-like filter expression (e.g. "host 10.0.0.1 and tcp port 443").
|
||||
// Empty captures all packets.
|
||||
Filter string
|
||||
// Verbose adds seq/ack, TTL, window, and total length to text output.
|
||||
Verbose bool
|
||||
// ASCII dumps transport payload as printable ASCII after each packet line.
|
||||
ASCII bool
|
||||
}
|
||||
|
||||
// CaptureStats reports capture session counters.
|
||||
type CaptureStats struct {
|
||||
Packets int64
|
||||
Bytes int64
|
||||
Dropped int64
|
||||
}
|
||||
|
||||
// CaptureSession represents an active packet capture. Call Stop to end the
|
||||
// capture and flush buffered packets.
|
||||
type CaptureSession struct {
|
||||
sess *capture.Session
|
||||
engine *internal.Engine
|
||||
}
|
||||
|
||||
// Stop ends the capture, flushes remaining packets, and detaches from the device.
|
||||
// Safe to call multiple times.
|
||||
func (cs *CaptureSession) Stop() {
|
||||
if cs.engine != nil {
|
||||
_ = cs.engine.SetCapture(nil)
|
||||
cs.engine = nil
|
||||
}
|
||||
if cs.sess != nil {
|
||||
cs.sess.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// Stats returns current capture counters.
|
||||
func (cs *CaptureSession) Stats() CaptureStats {
|
||||
s := cs.sess.Stats()
|
||||
return CaptureStats{
|
||||
Packets: s.Packets,
|
||||
Bytes: s.Bytes,
|
||||
Dropped: s.Dropped,
|
||||
}
|
||||
}
|
||||
|
||||
// Done returns a channel that is closed when the capture's writer goroutine
|
||||
// has fully exited and all buffered packets have been flushed.
|
||||
func (cs *CaptureSession) Done() <-chan struct{} {
|
||||
return cs.sess.Done()
|
||||
}
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util/capture"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -66,7 +65,7 @@ type Options struct {
|
||||
PrivateKey string
|
||||
// ManagementURL overrides the default management server URL
|
||||
ManagementURL string
|
||||
// PreSharedKey is the pre-shared key for the tunnel interface
|
||||
// PreSharedKey is the pre-shared key for the WireGuard interface
|
||||
PreSharedKey string
|
||||
// LogOutput is the output destination for logs (defaults to os.Stderr if nil)
|
||||
LogOutput io.Writer
|
||||
@@ -82,9 +81,9 @@ type Options struct {
|
||||
DisableClientRoutes bool
|
||||
// BlockInbound blocks all inbound connections from peers
|
||||
BlockInbound bool
|
||||
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
|
||||
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
|
||||
WireguardPort *int
|
||||
// MTU is the MTU for the tunnel interface.
|
||||
// MTU is the MTU for the WireGuard interface.
|
||||
// Valid values are in the range 576..8192 bytes.
|
||||
// If non-nil, this value overrides any value stored in the config file.
|
||||
// If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280.
|
||||
@@ -470,52 +469,6 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
|
||||
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
||||
}
|
||||
|
||||
// StartCapture begins capturing packets on this client's tunnel device.
|
||||
// Only one capture can be active at a time; starting a new one stops the previous.
|
||||
// Call StopCapture (or CaptureSession.Stop) to end it.
|
||||
func (c *Client) StartCapture(opts CaptureOptions) (*CaptureSession, error) {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var matcher capture.Matcher
|
||||
if opts.Filter != "" {
|
||||
m, err := capture.ParseFilter(opts.Filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse filter: %w", err)
|
||||
}
|
||||
matcher = m
|
||||
}
|
||||
|
||||
sess, err := capture.NewSession(capture.Options{
|
||||
Output: opts.Output,
|
||||
TextOutput: opts.TextOutput,
|
||||
Matcher: matcher,
|
||||
Verbose: opts.Verbose,
|
||||
ASCII: opts.ASCII,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create capture session: %w", err)
|
||||
}
|
||||
|
||||
if err := engine.SetCapture(sess); err != nil {
|
||||
sess.Stop()
|
||||
return nil, fmt.Errorf("set capture: %w", err)
|
||||
}
|
||||
|
||||
return &CaptureSession{sess: sess, engine: engine}, nil
|
||||
}
|
||||
|
||||
// StopCapture stops the active capture session if one is running.
|
||||
func (c *Client) StopCapture() error {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return engine.SetCapture(nil)
|
||||
}
|
||||
|
||||
// getEngine safely retrieves the engine from the client with proper locking.
|
||||
// Returns ErrClientNotStarted if the client is not started.
|
||||
// Returns ErrEngineNotStarted if the engine is not available.
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
// Package firewalld integrates with the firewalld daemon so NetBird can place
|
||||
// its wg interface into firewalld's "trusted" zone. This is required because
|
||||
// firewalld's nftables chains are created with NFT_CHAIN_OWNER on recent
|
||||
// versions, which returns EPERM to any other process that tries to insert
|
||||
// rules into them. The workaround mirrors what Tailscale does: let firewalld
|
||||
// itself add the accept rules to its own chains by trusting the interface.
|
||||
package firewalld
|
||||
|
||||
// TrustedZone is the firewalld zone name used for interfaces whose traffic
|
||||
// should bypass firewalld filtering.
|
||||
const TrustedZone = "trusted"
|
||||
@@ -1,260 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
package firewalld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
dbusDest = "org.fedoraproject.FirewallD1"
|
||||
dbusPath = "/org/fedoraproject/FirewallD1"
|
||||
dbusRootIface = "org.fedoraproject.FirewallD1"
|
||||
dbusZoneIface = "org.fedoraproject.FirewallD1.zone"
|
||||
|
||||
errZoneAlreadySet = "ZONE_ALREADY_SET"
|
||||
errAlreadyEnabled = "ALREADY_ENABLED"
|
||||
errUnknownIface = "UNKNOWN_INTERFACE"
|
||||
errNotEnabled = "NOT_ENABLED"
|
||||
|
||||
// callTimeout bounds each individual DBus or firewall-cmd invocation.
|
||||
// A fresh context is created for each call so a slow DBus probe can't
|
||||
// exhaust the deadline before the firewall-cmd fallback gets to run.
|
||||
callTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
errDBusUnavailable = errors.New("firewalld dbus unavailable")
|
||||
|
||||
// trustLogOnce ensures the "added to trusted zone" message is logged at
|
||||
// Info level only for the first successful add per process; repeat adds
|
||||
// from other init paths are quieter.
|
||||
trustLogOnce sync.Once
|
||||
|
||||
parentCtxMu sync.RWMutex
|
||||
parentCtx context.Context = context.Background()
|
||||
)
|
||||
|
||||
// SetParentContext installs a parent context whose cancellation aborts any
|
||||
// in-flight TrustInterface call. It does not affect UntrustInterface, which
|
||||
// always uses a fresh Background-rooted timeout so cleanup can still run
|
||||
// during engine shutdown when the engine context is already cancelled.
|
||||
func SetParentContext(ctx context.Context) {
|
||||
parentCtxMu.Lock()
|
||||
parentCtx = ctx
|
||||
parentCtxMu.Unlock()
|
||||
}
|
||||
|
||||
func getParentContext() context.Context {
|
||||
parentCtxMu.RLock()
|
||||
defer parentCtxMu.RUnlock()
|
||||
return parentCtx
|
||||
}
|
||||
|
||||
// TrustInterface places iface into firewalld's trusted zone if firewalld is
|
||||
// running. It is idempotent and best-effort: errors are returned so callers
|
||||
// can log, but a non-running firewalld is not an error. Only the first
|
||||
// successful call per process logs at Info. Respects the parent context set
|
||||
// via SetParentContext so startup-time cancellation unblocks it.
|
||||
func TrustInterface(iface string) error {
|
||||
parent := getParentContext()
|
||||
if !isRunning(parent) {
|
||||
return nil
|
||||
}
|
||||
if err := addTrusted(parent, iface); err != nil {
|
||||
return fmt.Errorf("add %s to firewalld trusted zone: %w", iface, err)
|
||||
}
|
||||
trustLogOnce.Do(func() {
|
||||
log.Infof("added %s to firewalld trusted zone", iface)
|
||||
})
|
||||
log.Debugf("firewalld: ensured %s is in trusted zone", iface)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UntrustInterface removes iface from firewalld's trusted zone if firewalld
|
||||
// is running. Idempotent. Uses a Background-rooted timeout so it still runs
|
||||
// during shutdown after the engine context has been cancelled.
|
||||
func UntrustInterface(iface string) error {
|
||||
if !isRunning(context.Background()) {
|
||||
return nil
|
||||
}
|
||||
if err := removeTrusted(context.Background(), iface); err != nil {
|
||||
return fmt.Errorf("remove %s from firewalld trusted zone: %w", iface, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newCallContext(parent context.Context) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(parent, callTimeout)
|
||||
}
|
||||
|
||||
func isRunning(parent context.Context) bool {
|
||||
ctx, cancel := newCallContext(parent)
|
||||
ok, err := isRunningDBus(ctx)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return ok
|
||||
}
|
||||
if errors.Is(err, errDBusUnavailable) || errors.Is(err, context.DeadlineExceeded) {
|
||||
ctx, cancel = newCallContext(parent)
|
||||
defer cancel()
|
||||
return isRunningCLI(ctx)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func addTrusted(parent context.Context, iface string) error {
|
||||
ctx, cancel := newCallContext(parent)
|
||||
err := addDBus(ctx, iface)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if !errors.Is(err, errDBusUnavailable) {
|
||||
log.Debugf("firewalld: dbus add failed, falling back to firewall-cmd: %v", err)
|
||||
}
|
||||
ctx, cancel = newCallContext(parent)
|
||||
defer cancel()
|
||||
return addCLI(ctx, iface)
|
||||
}
|
||||
|
||||
func removeTrusted(parent context.Context, iface string) error {
|
||||
ctx, cancel := newCallContext(parent)
|
||||
err := removeDBus(ctx, iface)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if !errors.Is(err, errDBusUnavailable) {
|
||||
log.Debugf("firewalld: dbus remove failed, falling back to firewall-cmd: %v", err)
|
||||
}
|
||||
ctx, cancel = newCallContext(parent)
|
||||
defer cancel()
|
||||
return removeCLI(ctx, iface)
|
||||
}
|
||||
|
||||
func isRunningDBus(ctx context.Context) (bool, error) {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||
}
|
||||
obj := conn.Object(dbusDest, dbusPath)
|
||||
|
||||
var zone string
|
||||
if err := obj.CallWithContext(ctx, dbusRootIface+".getDefaultZone", 0).Store(&zone); err != nil {
|
||||
return false, fmt.Errorf("firewalld getDefaultZone: %w", err)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func isRunningCLI(ctx context.Context) bool {
|
||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||
return false
|
||||
}
|
||||
return exec.CommandContext(ctx, "firewall-cmd", "--state").Run() == nil
|
||||
}
|
||||
|
||||
func addDBus(ctx context.Context, iface string) error {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||
}
|
||||
obj := conn.Object(dbusDest, dbusPath)
|
||||
|
||||
call := obj.CallWithContext(ctx, dbusZoneIface+".addInterface", 0, TrustedZone, iface)
|
||||
if call.Err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if dbusErrContains(call.Err, errAlreadyEnabled) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if dbusErrContains(call.Err, errZoneAlreadySet) {
|
||||
move := obj.CallWithContext(ctx, dbusZoneIface+".changeZoneOfInterface", 0, TrustedZone, iface)
|
||||
if move.Err != nil {
|
||||
return fmt.Errorf("firewalld changeZoneOfInterface: %w", move.Err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("firewalld addInterface: %w", call.Err)
|
||||
}
|
||||
|
||||
func removeDBus(ctx context.Context, iface string) error {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||
}
|
||||
obj := conn.Object(dbusDest, dbusPath)
|
||||
|
||||
call := obj.CallWithContext(ctx, dbusZoneIface+".removeInterface", 0, TrustedZone, iface)
|
||||
if call.Err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if dbusErrContains(call.Err, errUnknownIface) || dbusErrContains(call.Err, errNotEnabled) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("firewalld removeInterface: %w", call.Err)
|
||||
}
|
||||
|
||||
func addCLI(ctx context.Context, iface string) error {
|
||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||
return fmt.Errorf("firewall-cmd not available: %w", err)
|
||||
}
|
||||
|
||||
// --change-interface (no --permanent) binds the interface for the
|
||||
// current runtime only; we do not want membership to persist across
|
||||
// reboots because netbird re-asserts it on every startup.
|
||||
out, err := exec.CommandContext(ctx,
|
||||
"firewall-cmd", "--zone="+TrustedZone, "--change-interface="+iface,
|
||||
).CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("firewall-cmd change-interface: %w: %s", err, strings.TrimSpace(string(out)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeCLI(ctx context.Context, iface string) error {
|
||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||
return fmt.Errorf("firewall-cmd not available: %w", err)
|
||||
}
|
||||
|
||||
out, err := exec.CommandContext(ctx,
|
||||
"firewall-cmd", "--zone="+TrustedZone, "--remove-interface="+iface,
|
||||
).CombinedOutput()
|
||||
if err != nil {
|
||||
msg := strings.TrimSpace(string(out))
|
||||
if strings.Contains(msg, errUnknownIface) || strings.Contains(msg, errNotEnabled) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("firewall-cmd remove-interface: %w: %s", err, msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func dbusErrContains(err error, code string) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var de dbus.Error
|
||||
if errors.As(err, &de) {
|
||||
for _, b := range de.Body {
|
||||
if s, ok := b.(string); ok && strings.Contains(s, code) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Contains(err.Error(), code)
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
package firewalld
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
)
|
||||
|
||||
func TestDBusErrContains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
code string
|
||||
want bool
|
||||
}{
|
||||
{"nil error", nil, errZoneAlreadySet, false},
|
||||
{"plain error match", errors.New("ZONE_ALREADY_SET: wt0"), errZoneAlreadySet, true},
|
||||
{"plain error miss", errors.New("something else"), errZoneAlreadySet, false},
|
||||
{
|
||||
"dbus.Error body match",
|
||||
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"ZONE_ALREADY_SET: wt0"}},
|
||||
errZoneAlreadySet,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"dbus.Error body miss",
|
||||
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"INVALID_INTERFACE"}},
|
||||
errAlreadyEnabled,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"dbus.Error non-string body falls back to Error()",
|
||||
dbus.Error{Name: "x", Body: []any{123}},
|
||||
"x",
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := dbusErrContains(tc.err, tc.code)
|
||||
if got != tc.want {
|
||||
t.Fatalf("dbusErrContains(%v, %q) = %v; want %v", tc.err, tc.code, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
//go:build !linux
|
||||
|
||||
package firewalld
|
||||
|
||||
import "context"
|
||||
|
||||
// SetParentContext is a no-op on non-Linux platforms because firewalld only
|
||||
// runs on Linux.
|
||||
func SetParentContext(context.Context) {
|
||||
// intentionally empty: firewalld is a Linux-only daemon
|
||||
}
|
||||
|
||||
// TrustInterface is a no-op on non-Linux platforms because firewalld only
|
||||
// runs on Linux.
|
||||
func TrustInterface(string) error {
|
||||
// intentionally empty: firewalld is a Linux-only daemon
|
||||
return nil
|
||||
}
|
||||
|
||||
// UntrustInterface is a no-op on non-Linux platforms because firewalld only
|
||||
// runs on Linux.
|
||||
func UntrustInterface(string) error {
|
||||
// intentionally empty: firewalld is a Linux-only daemon
|
||||
return nil
|
||||
}
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@@ -87,12 +86,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
|
||||
}
|
||||
|
||||
// Trust after all fatal init steps so a later failure doesn't leave the
|
||||
// interface in firewalld's trusted zone without a corresponding Close.
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
// persist early to ensure cleanup of chains
|
||||
go func() {
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
@@ -198,12 +191,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
||||
}
|
||||
|
||||
// Appending to merr intentionally blocks DeleteState below so ShutdownState
|
||||
// stays persisted and the crash-recovery path retries firewalld cleanup.
|
||||
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
// attempt to delete state only if all other operations succeeded
|
||||
if merr == nil {
|
||||
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||
@@ -230,11 +217,6 @@ func (m *Manager) AllowNetbird() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||
}
|
||||
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@@ -218,10 +217,6 @@ func (m *Manager) AllowNetbird() error {
|
||||
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
||||
}
|
||||
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||
@@ -41,8 +40,6 @@ const (
|
||||
chainNameForward = "FORWARD"
|
||||
chainNameMangleForward = "netbird-mangle-forward"
|
||||
|
||||
firewalldTableName = "firewalld"
|
||||
|
||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||
userDataAcceptInputRule = "inputaccept"
|
||||
@@ -136,10 +133,6 @@ func (r *router) Reset() error {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
||||
}
|
||||
|
||||
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.removeNatPreroutingRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
||||
}
|
||||
@@ -287,10 +280,6 @@ func (r *router) createContainers() error {
|
||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||
}
|
||||
|
||||
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
log.Errorf("failed to refresh rules: %s", err)
|
||||
}
|
||||
@@ -1330,13 +1319,6 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip firewalld-owned chains. Firewalld creates its chains with the
|
||||
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
|
||||
// We delegate acceptance to firewalld by trusting the interface instead.
|
||||
if chain.Table.Name == firewalldTableName {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip all iptables-managed tables in the ip family
|
||||
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
||||
return false
|
||||
|
||||
@@ -3,9 +3,6 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -19,9 +16,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.Close(stateManager)
|
||||
}
|
||||
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to untrust interface in firewalld: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -30,8 +24,5 @@ func (m *Manager) AllowNetbird() error {
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AllowNetbird()
|
||||
}
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
Name() string
|
||||
SetFilter(device.PacketFilter) error
|
||||
Address() wgaddr.Address
|
||||
GetWGDevice() *wgdevice.Device
|
||||
|
||||
@@ -115,13 +115,12 @@ type Manager struct {
|
||||
|
||||
localipmanager *localIPManager
|
||||
|
||||
udpTracker *conntrack.UDPTracker
|
||||
icmpTracker *conntrack.ICMPTracker
|
||||
tcpTracker *conntrack.TCPTracker
|
||||
forwarder atomic.Pointer[forwarder.Forwarder]
|
||||
pendingCapture atomic.Pointer[forwarder.PacketCapture]
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
udpTracker *conntrack.UDPTracker
|
||||
icmpTracker *conntrack.ICMPTracker
|
||||
tcpTracker *conntrack.TCPTracker
|
||||
forwarder atomic.Pointer[forwarder.Forwarder]
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
|
||||
blockRule firewall.Rule
|
||||
|
||||
@@ -352,19 +351,6 @@ func (m *Manager) determineRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetPacketCapture sets or clears packet capture on the forwarder endpoint.
|
||||
// This captures outbound response packets that bypass the FilteredDevice in netstack mode.
|
||||
func (m *Manager) SetPacketCapture(pc forwarder.PacketCapture) {
|
||||
if pc == nil {
|
||||
m.pendingCapture.Store(nil)
|
||||
} else {
|
||||
m.pendingCapture.Store(&pc)
|
||||
}
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.SetCapture(pc)
|
||||
}
|
||||
}
|
||||
|
||||
// initForwarder initializes the forwarder, it disables routing on errors
|
||||
func (m *Manager) initForwarder() error {
|
||||
if m.forwarder.Load() != nil {
|
||||
@@ -386,11 +372,6 @@ func (m *Manager) initForwarder() error {
|
||||
|
||||
m.forwarder.Store(forwarder)
|
||||
|
||||
// Re-load after store: a concurrent SetPacketCapture may have seen forwarder as nil and only updated pendingCapture.
|
||||
if pc := m.pendingCapture.Load(); pc != nil {
|
||||
forwarder.SetCapture(*pc)
|
||||
}
|
||||
|
||||
log.Debug("forwarder initialized")
|
||||
|
||||
return nil
|
||||
@@ -633,7 +614,6 @@ func (m *Manager) resetState() {
|
||||
}
|
||||
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.SetCapture(nil)
|
||||
fwder.Stop()
|
||||
}
|
||||
|
||||
|
||||
@@ -31,20 +31,12 @@ var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
type IFaceMock struct {
|
||||
NameFunc func() string
|
||||
SetFilterFunc func(device.PacketFilter) error
|
||||
AddressFunc func() wgaddr.Address
|
||||
GetWGDeviceFunc func() *wgdevice.Device
|
||||
GetDeviceFunc func() *device.FilteredDevice
|
||||
}
|
||||
|
||||
func (i *IFaceMock) Name() string {
|
||||
if i.NameFunc == nil {
|
||||
return "wgtest"
|
||||
}
|
||||
return i.NameFunc()
|
||||
}
|
||||
|
||||
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
|
||||
if i.GetWGDeviceFunc == nil {
|
||||
return nil
|
||||
|
||||
@@ -12,19 +12,12 @@ import (
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
// PacketCapture captures raw packets for debugging. Implementations must be
|
||||
// safe for concurrent use and must not block.
|
||||
type PacketCapture interface {
|
||||
Offer(data []byte, outbound bool)
|
||||
}
|
||||
|
||||
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
|
||||
type endpoint struct {
|
||||
logger *nblog.Logger
|
||||
dispatcher stack.NetworkDispatcher
|
||||
device *wgdevice.Device
|
||||
mtu atomic.Uint32
|
||||
capture atomic.Pointer[PacketCapture]
|
||||
}
|
||||
|
||||
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||
@@ -61,17 +54,13 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error)
|
||||
continue
|
||||
}
|
||||
|
||||
pktBytes := data.AsSlice()
|
||||
|
||||
// Send the packet through WireGuard
|
||||
address := netHeader.DestinationAddress()
|
||||
if err := e.device.CreateOutboundPacket(pktBytes, address.AsSlice()); err != nil {
|
||||
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
|
||||
if err != nil {
|
||||
e.logger.Error1("CreateOutboundPacket: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if pc := e.capture.Load(); pc != nil {
|
||||
(*pc).Offer(pktBytes, true)
|
||||
}
|
||||
written++
|
||||
}
|
||||
|
||||
|
||||
@@ -139,16 +139,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// SetCapture sets or clears the packet capture on the forwarder endpoint.
|
||||
// This captures outbound packets that bypass the FilteredDevice (netstack forwarding).
|
||||
func (f *Forwarder) SetCapture(pc PacketCapture) {
|
||||
if pc == nil {
|
||||
f.endpoint.capture.Store(nil)
|
||||
return
|
||||
}
|
||||
f.endpoint.capture.Store(&pc)
|
||||
}
|
||||
|
||||
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||
if len(payload) < header.IPv4MinimumSize {
|
||||
return fmt.Errorf("packet too small: %d bytes", len(payload))
|
||||
|
||||
@@ -270,9 +270,5 @@ func (f *Forwarder) injectICMPReply(id stack.TransportEndpointID, icmpPayload []
|
||||
return 0
|
||||
}
|
||||
|
||||
if pc := f.endpoint.capture.Load(); pc != nil {
|
||||
(*pc).Offer(fullPacket, true)
|
||||
}
|
||||
|
||||
return len(fullPacket)
|
||||
}
|
||||
|
||||
@@ -239,12 +239,8 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
|
||||
ipv6Count++
|
||||
}
|
||||
|
||||
// Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The
|
||||
// routing-correctness checks above are the real assertions; the counts
|
||||
// are a sanity bound to catch a totally silent path.
|
||||
minDelivered := packetsPerFamily * 80 / 100
|
||||
assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold")
|
||||
assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold")
|
||||
assert.Equal(t, packetsPerFamily, ipv4Count)
|
||||
assert.Equal(t, packetsPerFamily, ipv6Count)
|
||||
}
|
||||
|
||||
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
|
||||
|
||||
@@ -3,7 +3,6 @@ package device
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
@@ -29,20 +28,11 @@ type PacketFilter interface {
|
||||
SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
|
||||
}
|
||||
|
||||
// PacketCapture captures raw packets for debugging. Implementations must be
|
||||
// safe for concurrent use and must not block.
|
||||
type PacketCapture interface {
|
||||
// Offer submits a packet for capture. outbound is true for packets
|
||||
// leaving the host (Read path), false for packets arriving (Write path).
|
||||
Offer(data []byte, outbound bool)
|
||||
}
|
||||
|
||||
// FilteredDevice to override Read or Write of packets
|
||||
type FilteredDevice struct {
|
||||
tun.Device
|
||||
|
||||
filter PacketFilter
|
||||
capture atomic.Pointer[PacketCapture]
|
||||
mutex sync.RWMutex
|
||||
closeOnce sync.Once
|
||||
}
|
||||
@@ -73,25 +63,20 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
||||
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
d.mutex.RLock()
|
||||
filter := d.filter
|
||||
d.mutex.RUnlock()
|
||||
|
||||
if filter != nil {
|
||||
for i := 0; i < n; i++ {
|
||||
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||
bufs = append(bufs[:i], bufs[i+1:]...)
|
||||
sizes = append(sizes[:i], sizes[i+1:]...)
|
||||
n--
|
||||
i--
|
||||
}
|
||||
}
|
||||
if filter == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if pc := d.capture.Load(); pc != nil {
|
||||
for i := 0; i < n; i++ {
|
||||
(*pc).Offer(bufs[i][offset:offset+sizes[i]], true)
|
||||
for i := 0; i < n; i++ {
|
||||
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||
bufs = append(bufs[:i], bufs[i+1:]...)
|
||||
sizes = append(sizes[:i], sizes[i+1:]...)
|
||||
n--
|
||||
i--
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,13 +85,6 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
||||
|
||||
// Write wraps write method with filtering feature
|
||||
func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
// Capture before filtering so dropped packets are still visible in captures.
|
||||
if pc := d.capture.Load(); pc != nil {
|
||||
for _, buf := range bufs {
|
||||
(*pc).Offer(buf[offset:], false)
|
||||
}
|
||||
}
|
||||
|
||||
d.mutex.RLock()
|
||||
filter := d.filter
|
||||
d.mutex.RUnlock()
|
||||
@@ -118,10 +96,9 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
filteredBufs := make([][]byte, 0, len(bufs))
|
||||
dropped := 0
|
||||
for _, buf := range bufs {
|
||||
if filter.FilterInbound(buf[offset:], len(buf)) {
|
||||
dropped++
|
||||
} else {
|
||||
if !filter.FilterInbound(buf[offset:], len(buf)) {
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
dropped++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,14 +113,3 @@ func (d *FilteredDevice) SetFilter(filter PacketFilter) {
|
||||
d.filter = filter
|
||||
d.mutex.Unlock()
|
||||
}
|
||||
|
||||
// SetCapture sets or clears the packet capture sink. Pass nil to disable.
|
||||
// Uses atomic store so the hot path (Read/Write) is a single pointer load
|
||||
// with no locking overhead when capture is off.
|
||||
func (d *FilteredDevice) SetCapture(pc PacketCapture) {
|
||||
if pc == nil {
|
||||
d.capture.Store(nil)
|
||||
return
|
||||
}
|
||||
d.capture.Store(&pc)
|
||||
}
|
||||
|
||||
@@ -158,7 +158,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
if n != 1 {
|
||||
if n != 0 {
|
||||
t.Errorf("expected n=1, got %d", n)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -201,18 +201,7 @@ Pop $0
|
||||
|
||||
Function .onInit
|
||||
StrCpy $INSTDIR "${INSTALL_DIR}"
|
||||
; Default autostart to enabled so silent installs (/S) match the interactive default
|
||||
StrCpy $AutostartEnabled "1"
|
||||
|
||||
; Pre-0.70.1 installers ran without SetRegView, so their uninstall keys live
|
||||
; in the 32-bit view. Fall back to it so upgrades still find them.
|
||||
SetRegView 64
|
||||
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString"
|
||||
${If} $R0 == ""
|
||||
SetRegView 32
|
||||
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString"
|
||||
SetRegView 64
|
||||
${EndIf}
|
||||
${If} $R0 != ""
|
||||
# if silent install jump to uninstall step
|
||||
IfSilent uninstall
|
||||
@@ -225,10 +214,6 @@ ${If} $R0 != ""
|
||||
|
||||
${EndIf}
|
||||
FunctionEnd
|
||||
|
||||
Function un.onInit
|
||||
SetRegView 64
|
||||
FunctionEnd
|
||||
######################################################################
|
||||
Section -MainProgram
|
||||
${INSTALL_TYPE}
|
||||
@@ -243,7 +228,6 @@ Section -MainProgram
|
||||
!else
|
||||
File /r "..\\dist\\netbird_windows_amd64\\"
|
||||
!endif
|
||||
File "..\\client\\ui\\assets\\netbird.png"
|
||||
SectionEnd
|
||||
######################################################################
|
||||
|
||||
@@ -263,11 +247,9 @@ WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
|
||||
; Create autostart registry entry based on checkbox
|
||||
DetailPrint "Autostart enabled: $AutostartEnabled"
|
||||
${If} $AutostartEnabled == "1"
|
||||
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
|
||||
WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe"
|
||||
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
|
||||
${Else}
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DetailPrint "Autostart not enabled by user"
|
||||
${EndIf}
|
||||
@@ -301,8 +283,6 @@ ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
||||
|
||||
; Remove autostart registry entry
|
||||
DetailPrint "Removing autostart registry entry if exists..."
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
|
||||
; Handle data deletion based on checkbox
|
||||
@@ -341,7 +321,6 @@ DetailPrint "Removing registry keys..."
|
||||
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
|
||||
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
|
||||
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
|
||||
DeleteRegKey HKCU "Software\Classes\AppUserModelId\${APP_NAME}"
|
||||
|
||||
DetailPrint "Removing application directory from PATH..."
|
||||
EnVar::SetHKLM
|
||||
|
||||
@@ -315,6 +315,7 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||
a.config.RosenpassEnabled,
|
||||
a.config.RosenpassPermissive,
|
||||
a.config.ServerSSHAllowed,
|
||||
a.config.ServerVNCAllowed,
|
||||
a.config.DisableClientRoutes,
|
||||
a.config.DisableServerRoutes,
|
||||
a.config.DisableDNS,
|
||||
@@ -327,6 +328,7 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||
a.config.EnableSSHLocalPortForwarding,
|
||||
a.config.EnableSSHRemotePortForwarding,
|
||||
a.config.DisableSSHAuth,
|
||||
a.config.DisableVNCAuth,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -333,10 +333,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
c.statusRecorder.MarkSignalConnected()
|
||||
|
||||
relayURLs, token := parseRelayInfo(loginResp)
|
||||
if override, ok := peer.OverrideRelayURLs(); ok {
|
||||
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
|
||||
relayURLs = override
|
||||
}
|
||||
peerConfig := loginResp.GetPeerConfig()
|
||||
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
|
||||
@@ -550,11 +546,13 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
||||
RosenpassEnabled: config.RosenpassEnabled,
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
ServerVNCAllowed: config.ServerVNCAllowed != nil && *config.ServerVNCAllowed,
|
||||
EnableSSHRoot: config.EnableSSHRoot,
|
||||
EnableSSHSFTP: config.EnableSSHSFTP,
|
||||
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
|
||||
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
|
||||
DisableSSHAuth: config.DisableSSHAuth,
|
||||
DisableVNCAuth: config.DisableVNCAuth,
|
||||
DNSRouteInterval: config.DNSRouteInterval,
|
||||
|
||||
DisableClientRoutes: config.DisableClientRoutes,
|
||||
@@ -631,6 +629,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.RosenpassEnabled,
|
||||
config.RosenpassPermissive,
|
||||
config.ServerSSHAllowed,
|
||||
config.ServerVNCAllowed,
|
||||
config.DisableClientRoutes,
|
||||
config.DisableServerRoutes,
|
||||
config.DisableDNS,
|
||||
@@ -643,6 +642,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
config.DisableVNCAuth,
|
||||
)
|
||||
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
@@ -62,7 +61,6 @@ allocs.prof: Allocations profiling information.
|
||||
threadcreate.prof: Thread creation profiling information.
|
||||
cpu.prof: CPU profiling information.
|
||||
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
|
||||
capture.pcap: Packet capture in pcap format. Only present when capture was running during bundle collection. Omitted from anonymized bundles because it contains raw decrypted packet data.
|
||||
|
||||
|
||||
Anonymization Process
|
||||
@@ -236,7 +234,6 @@ type BundleGenerator struct {
|
||||
logPath string
|
||||
tempDir string
|
||||
cpuProfile []byte
|
||||
capturePath string
|
||||
refreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
clientMetrics MetricsExporter
|
||||
|
||||
@@ -260,8 +257,7 @@ type GeneratorDependencies struct {
|
||||
LogPath string
|
||||
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
|
||||
CPUProfile []byte
|
||||
CapturePath string
|
||||
RefreshStatus func()
|
||||
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
ClientMetrics MetricsExporter
|
||||
}
|
||||
|
||||
@@ -281,7 +277,6 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
||||
logPath: deps.LogPath,
|
||||
tempDir: deps.TempDir,
|
||||
cpuProfile: deps.CPUProfile,
|
||||
capturePath: deps.CapturePath,
|
||||
refreshStatus: deps.RefreshStatus,
|
||||
clientMetrics: deps.ClientMetrics,
|
||||
|
||||
@@ -351,10 +346,6 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("failed to add CPU profile to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addCaptureFile(); err != nil {
|
||||
log.Errorf("failed to add capture file to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addStackTrace(); err != nil {
|
||||
log.Errorf("failed to add stack trace to debug bundle: %v", err)
|
||||
}
|
||||
@@ -584,9 +575,6 @@ func isSensitiveEnvVar(key string) bool {
|
||||
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
|
||||
configContent.WriteString("NetBird Client Configuration:\n\n")
|
||||
|
||||
if key, err := wgtypes.ParseKey(g.internalConfig.PrivateKey); err == nil {
|
||||
configContent.WriteString(fmt.Sprintf("PublicKey: %s\n", key.PublicKey().String()))
|
||||
}
|
||||
configContent.WriteString(fmt.Sprintf("WgIface: %s\n", g.internalConfig.WgIface))
|
||||
configContent.WriteString(fmt.Sprintf("WgPort: %d\n", g.internalConfig.WgPort))
|
||||
if g.internalConfig.NetworkMonitor != nil {
|
||||
@@ -611,12 +599,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
||||
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
|
||||
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
|
||||
}
|
||||
if g.internalConfig.DisableSSHAuth != nil {
|
||||
configContent.WriteString(fmt.Sprintf("DisableSSHAuth: %v\n", *g.internalConfig.DisableSSHAuth))
|
||||
}
|
||||
if g.internalConfig.SSHJWTCacheTTL != nil {
|
||||
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
|
||||
}
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
||||
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||
@@ -643,7 +625,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
||||
}
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
|
||||
configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU))
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addProf() (err error) {
|
||||
@@ -688,29 +669,6 @@ func (g *BundleGenerator) addCPUProfile() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addCaptureFile() error {
|
||||
if g.capturePath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if g.anonymize {
|
||||
log.Info("skipping capture file in anonymized bundle (contains raw packet data)")
|
||||
return nil
|
||||
}
|
||||
|
||||
f, err := os.Open(g.capturePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open capture file: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := g.addFileToZip(f, "capture.pcap"); err != nil {
|
||||
return fmt.Errorf("add capture file to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addStackTrace() error {
|
||||
buf := make([]byte, 5242880) // 5 MB buffer
|
||||
n := runtime.Stack(buf, true)
|
||||
|
||||
@@ -5,21 +5,16 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/configs"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -476,8 +471,8 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
|
||||
anonymize: false,
|
||||
input: map[string]any{
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"HOME": "/root",
|
||||
"PATH": "/usr/bin",
|
||||
"HOME": "/root",
|
||||
"PATH": "/usr/bin",
|
||||
"NB_LOG_LEVEL": "debug",
|
||||
},
|
||||
},
|
||||
@@ -494,9 +489,9 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
|
||||
anonymize: false,
|
||||
input: map[string]any{
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"NB_SETUP_KEY": "abc123",
|
||||
"NB_API_TOKEN": "tok_xyz",
|
||||
"NB_LOG_LEVEL": "info",
|
||||
"NB_SETUP_KEY": "abc123",
|
||||
"NB_API_TOKEN": "tok_xyz",
|
||||
"NB_LOG_LEVEL": "info",
|
||||
},
|
||||
},
|
||||
check: func(t *testing.T, params map[string]any) {
|
||||
@@ -771,127 +766,3 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
|
||||
assert.Contains(t, anonNftables, "chain input {")
|
||||
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
|
||||
}
|
||||
|
||||
// TestAddConfig_AllFieldsCovered uses reflection to ensure every field in
|
||||
// profilemanager.Config is either rendered in the debug bundle or explicitly
|
||||
// excluded. When a new field is added to Config, this test fails until the
|
||||
// developer either dumps it in addConfig/addCommonConfigFields or adds it to
|
||||
// the excluded set with a justification.
|
||||
func TestAddConfig_AllFieldsCovered(t *testing.T) {
|
||||
excluded := map[string]string{
|
||||
"PrivateKey": "sensitive: WireGuard private key",
|
||||
"PreSharedKey": "sensitive: WireGuard pre-shared key",
|
||||
"SSHKey": "sensitive: SSH private key",
|
||||
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
|
||||
}
|
||||
|
||||
mURL, _ := url.Parse("https://api.example.com:443")
|
||||
aURL, _ := url.Parse("https://admin.example.com:443")
|
||||
bTrue := true
|
||||
iVal := 42
|
||||
cfg := &profilemanager.Config{
|
||||
PrivateKey: "priv",
|
||||
PreSharedKey: "psk",
|
||||
ManagementURL: mURL,
|
||||
AdminURL: aURL,
|
||||
WgIface: "wt0",
|
||||
WgPort: 51820,
|
||||
NetworkMonitor: &bTrue,
|
||||
IFaceBlackList: []string{"eth0"},
|
||||
DisableIPv6Discovery: true,
|
||||
RosenpassEnabled: true,
|
||||
RosenpassPermissive: true,
|
||||
ServerSSHAllowed: &bTrue,
|
||||
EnableSSHRoot: &bTrue,
|
||||
EnableSSHSFTP: &bTrue,
|
||||
EnableSSHLocalPortForwarding: &bTrue,
|
||||
EnableSSHRemotePortForwarding: &bTrue,
|
||||
DisableSSHAuth: &bTrue,
|
||||
SSHJWTCacheTTL: &iVal,
|
||||
DisableClientRoutes: true,
|
||||
DisableServerRoutes: true,
|
||||
DisableDNS: true,
|
||||
DisableFirewall: true,
|
||||
BlockLANAccess: true,
|
||||
BlockInbound: true,
|
||||
DisableNotifications: &bTrue,
|
||||
DNSLabels: domain.List{},
|
||||
SSHKey: "sshkey",
|
||||
NATExternalIPs: []string{"1.2.3.4"},
|
||||
CustomDNSAddress: "1.1.1.1:53",
|
||||
DisableAutoConnect: true,
|
||||
DNSRouteInterval: 5 * time.Second,
|
||||
ClientCertPath: "/tmp/cert",
|
||||
ClientCertKeyPath: "/tmp/key",
|
||||
LazyConnectionEnabled: true,
|
||||
MTU: 1280,
|
||||
}
|
||||
|
||||
for _, anonymize := range []bool{false, true} {
|
||||
t.Run("anonymize="+map[bool]string{true: "true", false: "false"}[anonymize], func(t *testing.T) {
|
||||
g := &BundleGenerator{
|
||||
anonymizer: newAnonymizerForTest(),
|
||||
internalConfig: cfg,
|
||||
anonymize: anonymize,
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
g.addCommonConfigFields(&sb)
|
||||
rendered := sb.String() + renderAddConfigSpecific(g)
|
||||
|
||||
val := reflect.ValueOf(cfg).Elem()
|
||||
typ := val.Type()
|
||||
var missing []string
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
name := typ.Field(i).Name
|
||||
if _, ok := excluded[name]; ok {
|
||||
continue
|
||||
}
|
||||
if !strings.Contains(rendered, name+":") {
|
||||
missing = append(missing, name)
|
||||
}
|
||||
}
|
||||
if len(missing) > 0 {
|
||||
t.Fatalf("Config field(s) not present in debug bundle output: %v\n"+
|
||||
"Either render the field in addCommonConfigFields/addConfig, "+
|
||||
"or add it to the excluded map with a justification.", missing)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// renderAddConfigSpecific renders the fields handled by the anonymize/non-anonymize
|
||||
// branches in addConfig (ManagementURL, AdminURL, NATExternalIPs, CustomDNSAddress).
|
||||
// addCommonConfigFields covers the rest. Keeping this in the test mirrors the
|
||||
// production shape without needing to write an actual zip.
|
||||
func renderAddConfigSpecific(g *BundleGenerator) string {
|
||||
var sb strings.Builder
|
||||
if g.anonymize {
|
||||
if g.internalConfig.ManagementURL != nil {
|
||||
sb.WriteString("ManagementURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.ManagementURL.String()) + "\n")
|
||||
}
|
||||
if g.internalConfig.AdminURL != nil {
|
||||
sb.WriteString("AdminURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.AdminURL.String()) + "\n")
|
||||
}
|
||||
sb.WriteString("NATExternalIPs: x\n")
|
||||
if g.internalConfig.CustomDNSAddress != "" {
|
||||
sb.WriteString("CustomDNSAddress: " + g.anonymizer.AnonymizeString(g.internalConfig.CustomDNSAddress) + "\n")
|
||||
}
|
||||
} else {
|
||||
if g.internalConfig.ManagementURL != nil {
|
||||
sb.WriteString("ManagementURL: " + g.internalConfig.ManagementURL.String() + "\n")
|
||||
}
|
||||
if g.internalConfig.AdminURL != nil {
|
||||
sb.WriteString("AdminURL: " + g.internalConfig.AdminURL.String() + "\n")
|
||||
}
|
||||
sb.WriteString("NATExternalIPs: x\n")
|
||||
if g.internalConfig.CustomDNSAddress != "" {
|
||||
sb.WriteString("CustomDNSAddress: " + g.internalConfig.CustomDNSAddress + "\n")
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func newAnonymizerForTest() *anonymize.Anonymizer {
|
||||
return anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
}
|
||||
|
||||
@@ -3,12 +3,10 @@ package debug
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -21,10 +19,8 @@ func TestUpload(t *testing.T) {
|
||||
t.Skip("Skipping upload test on docker ci")
|
||||
}
|
||||
testDir := t.TempDir()
|
||||
addr := reserveLoopbackPort(t)
|
||||
testURL := "http://" + addr
|
||||
testURL := "http://localhost:8080"
|
||||
t.Setenv("SERVER_URL", testURL)
|
||||
t.Setenv("SERVER_ADDRESS", addr)
|
||||
t.Setenv("STORE_DIR", testDir)
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
@@ -37,7 +33,6 @@ func TestUpload(t *testing.T) {
|
||||
t.Errorf("Failed to stop server: %v", err)
|
||||
}
|
||||
})
|
||||
waitForServer(t, addr)
|
||||
|
||||
file := filepath.Join(t.TempDir(), "tmpfile")
|
||||
fileContent := []byte("test file content")
|
||||
@@ -52,30 +47,3 @@ func TestUpload(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fileContent, createdFileContent)
|
||||
}
|
||||
|
||||
// reserveLoopbackPort binds an ephemeral port on loopback to learn a free
|
||||
// address, then releases it so the server under test can rebind. The close/
|
||||
// rebind window is racy in theory; on loopback with a kernel-assigned port
|
||||
// it's essentially never contended in practice.
|
||||
func reserveLoopbackPort(t *testing.T) string {
|
||||
t.Helper()
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := l.Addr().String()
|
||||
require.NoError(t, l.Close())
|
||||
return addr
|
||||
}
|
||||
|
||||
func waitForServer(t *testing.T, addr string) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(5 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
|
||||
if err == nil {
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("server did not start listening on %s in time", addr)
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
|
||||
const (
|
||||
defaultResolvConfPath = "/etc/resolv.conf"
|
||||
nsswitchConfPath = "/etc/nsswitch.conf"
|
||||
)
|
||||
|
||||
type resolvConf struct {
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -195,12 +192,6 @@ func (c *HandlerChain) logHandlers() {
|
||||
}
|
||||
|
||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
c.dispatch(w, r, math.MaxInt)
|
||||
}
|
||||
|
||||
// dispatch routes a DNS request through the chain, skipping handlers with
|
||||
// priority > maxPriority. Shared by ServeDNS and ResolveInternal.
|
||||
func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) {
|
||||
if len(r.Question) == 0 {
|
||||
return
|
||||
}
|
||||
@@ -225,9 +216,6 @@ func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority in
|
||||
|
||||
// Try handlers in priority order
|
||||
for _, entry := range handlers {
|
||||
if entry.Priority > maxPriority {
|
||||
continue
|
||||
}
|
||||
if !c.isHandlerMatch(qname, entry) {
|
||||
continue
|
||||
}
|
||||
@@ -285,55 +273,6 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
|
||||
cw.response.Len(), meta, time.Since(startTime))
|
||||
}
|
||||
|
||||
// ResolveInternal runs an in-process DNS query against the chain, skipping any
|
||||
// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt
|
||||
// cache refresher) that must bypass themselves to avoid loops. Honors ctx
|
||||
// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own
|
||||
// (bounded by the invoked handler's internal timeout).
|
||||
func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) {
|
||||
if len(r.Question) == 0 {
|
||||
return nil, fmt.Errorf("empty question")
|
||||
}
|
||||
|
||||
base := &internalResponseWriter{}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
c.dispatch(base, r, maxPriority)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
// Prefer a completed response if dispatch finished concurrently with cancellation.
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
if base.response == nil || base.response.Rcode == dns.RcodeRefused {
|
||||
return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d",
|
||||
strings.ToLower(r.Question[0].Name), maxPriority)
|
||||
}
|
||||
return base.response, nil
|
||||
}
|
||||
|
||||
// HasRootHandlerAtOrBelow reports whether any "." handler is registered at
|
||||
// priority ≤ maxPriority.
|
||||
func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
for _, h := range c.handlers {
|
||||
if h.Pattern == "." && h.Priority <= maxPriority {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
switch {
|
||||
case entry.Pattern == ".":
|
||||
@@ -352,36 +291,3 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// internalResponseWriter captures a dns.Msg for in-process chain queries.
|
||||
type internalResponseWriter struct {
|
||||
response *dns.Msg
|
||||
}
|
||||
|
||||
func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil }
|
||||
func (w *internalResponseWriter) LocalAddr() net.Addr { return nil }
|
||||
func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||
|
||||
// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg
|
||||
// still surface their answer to ResolveInternal.
|
||||
func (w *internalResponseWriter) Write(p []byte) (int, error) {
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(p); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w.response = msg
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *internalResponseWriter) Close() error { return nil }
|
||||
func (w *internalResponseWriter) TsigStatus() error { return nil }
|
||||
|
||||
// TsigTimersOnly is part of dns.ResponseWriter.
|
||||
func (w *internalResponseWriter) TsigTimersOnly(bool) {
|
||||
// no-op: in-process queries carry no TSIG state.
|
||||
}
|
||||
|
||||
// Hijack is part of dns.ResponseWriter.
|
||||
func (w *internalResponseWriter) Hijack() {
|
||||
// no-op: in-process queries have no underlying connection to hand off.
|
||||
}
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
package dns_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
@@ -1046,163 +1042,3 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// answeringHandler writes a fixed A record to ack the query. Used to verify
|
||||
// which handler ResolveInternal dispatches to.
|
||||
type answeringHandler struct {
|
||||
name string
|
||||
ip string
|
||||
}
|
||||
|
||||
func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
resp.Answer = []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP(h.ip).To4(),
|
||||
}}
|
||||
_ = w.WriteMsg(resp)
|
||||
}
|
||||
|
||||
func (h *answeringHandler) String() string { return h.name }
|
||||
|
||||
func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
|
||||
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
||||
low := &answeringHandler{name: "low", ip: "10.0.0.2"}
|
||||
|
||||
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
||||
chain.AddHandler("example.com.", low, nbdns.PriorityUpstream)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, 1, len(resp.Answer))
|
||||
a, ok := resp.Answer[0].(*dns.A)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream")
|
||||
}
|
||||
|
||||
func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
||||
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
_, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||
assert.Error(t, err, "no handler at or below maxPriority should error")
|
||||
}
|
||||
|
||||
// rawWriteHandler packs a response and calls ResponseWriter.Write directly
|
||||
// (instead of WriteMsg), exercising the internalResponseWriter.Write path.
|
||||
type rawWriteHandler struct {
|
||||
ip string
|
||||
}
|
||||
|
||||
func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
resp.Answer = []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP(h.ip).To4(),
|
||||
}}
|
||||
packed, err := resp.Pack()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = w.Write(packed)
|
||||
}
|
||||
|
||||
func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||
assert.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.Answer, 1)
|
||||
a, ok := resp.Answer[0].(*dns.A)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer")
|
||||
}
|
||||
|
||||
func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
_, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// hangingHandler blocks indefinitely until closed, simulating a wedged upstream.
|
||||
type hangingHandler struct {
|
||||
block chan struct{}
|
||||
}
|
||||
|
||||
func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
<-h.block
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
_ = w.WriteMsg(resp)
|
||||
}
|
||||
|
||||
func (h *hangingHandler) String() string { return "hangingHandler" }
|
||||
|
||||
func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
h := &hangingHandler{block: make(chan struct{})}
|
||||
defer close(h.block)
|
||||
|
||||
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
_, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline")
|
||||
}
|
||||
|
||||
func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
h := &answeringHandler{name: "h", ip: "10.0.0.1"}
|
||||
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain")
|
||||
|
||||
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count")
|
||||
|
||||
chain.AddHandler(".", h, nbdns.PriorityMgmtCache)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded")
|
||||
|
||||
chain.AddHandler(".", h, nbdns.PriorityDefault)
|
||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included")
|
||||
|
||||
chain.RemoveHandler(".", nbdns.PriorityDefault)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
||||
|
||||
// Primary nsgroup case: root handler lands at PriorityUpstream.
|
||||
chain.AddHandler(".", h, nbdns.PriorityUpstream)
|
||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included")
|
||||
chain.RemoveHandler(".", nbdns.PriorityUpstream)
|
||||
|
||||
// Fallback case: original /etc/resolv.conf entries land at PriorityFallback.
|
||||
chain.AddHandler(".", h, nbdns.PriorityFallback)
|
||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included")
|
||||
chain.RemoveHandler(".", nbdns.PriorityFallback)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
||||
}
|
||||
|
||||
@@ -46,12 +46,12 @@ type restoreHostManager interface {
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface string) (hostManager, error) {
|
||||
osManager, reason, err := getOSDNSManagerType()
|
||||
osManager, err := getOSDNSManagerType()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get os dns manager type: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("System DNS manager discovered: %s (%s)", osManager, reason)
|
||||
log.Infof("System DNS manager discovered: %s", osManager)
|
||||
mgr, err := newHostManagerFromType(wgInterface, osManager)
|
||||
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
|
||||
if err != nil {
|
||||
@@ -74,49 +74,17 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor
|
||||
}
|
||||
}
|
||||
|
||||
func getOSDNSManagerType() (osManagerType, string, error) {
|
||||
resolved := isSystemdResolvedRunning()
|
||||
nss := isLibnssResolveUsed()
|
||||
stub := checkStub()
|
||||
|
||||
// Prefer systemd-resolved whenever it owns libc resolution, regardless of
|
||||
// who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups
|
||||
// that go through nss-resolve, and in foreign mode they can loop back
|
||||
// through resolved as an upstream.
|
||||
if resolved && (nss || stub) {
|
||||
return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil
|
||||
}
|
||||
|
||||
mgr, reason, rejected, err := scanResolvConfHeader()
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
if reason != "" {
|
||||
return mgr, reason, nil
|
||||
}
|
||||
|
||||
fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub)
|
||||
if len(rejected) > 0 {
|
||||
fallback += "; rejected: " + strings.Join(rejected, ", ")
|
||||
}
|
||||
return fileManager, fallback, nil
|
||||
}
|
||||
|
||||
// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the
|
||||
// matching manager. If reason is empty the caller should pick file mode and
|
||||
// use rejected for diagnostics.
|
||||
func scanResolvConfHeader() (osManagerType, string, []string, error) {
|
||||
func getOSDNSManagerType() (osManagerType, error) {
|
||||
file, err := os.Open(defaultResolvConfPath)
|
||||
if err != nil {
|
||||
return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
||||
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
||||
}
|
||||
defer func() {
|
||||
if cerr := file.Close(); cerr != nil {
|
||||
log.Errorf("close file %s: %s", defaultResolvConfPath, cerr)
|
||||
if err := file.Close(); err != nil {
|
||||
log.Errorf("close file %s: %s", defaultResolvConfPath, err)
|
||||
}
|
||||
}()
|
||||
|
||||
var rejected []string
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
text := scanner.Text()
|
||||
@@ -124,48 +92,41 @@ func scanResolvConfHeader() (osManagerType, string, []string, error) {
|
||||
continue
|
||||
}
|
||||
if text[0] != '#' {
|
||||
break
|
||||
return fileManager, nil
|
||||
}
|
||||
if mgr, reason, rej := matchResolvConfHeader(text); reason != "" {
|
||||
return mgr, reason, nil, nil
|
||||
} else if rej != "" {
|
||||
rejected = append(rejected, rej)
|
||||
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
||||
return netbirdManager, nil
|
||||
}
|
||||
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||
return networkManager, nil
|
||||
}
|
||||
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
|
||||
if checkStub() {
|
||||
return systemdManager, nil
|
||||
} else {
|
||||
return fileManager, nil
|
||||
}
|
||||
}
|
||||
if strings.Contains(text, "resolvconf") {
|
||||
if isSystemdResolveConfMode() {
|
||||
return systemdManager, nil
|
||||
}
|
||||
|
||||
return resolvConfManager, nil
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||
return 0, "", nil, fmt.Errorf("scan: %w", err)
|
||||
return 0, fmt.Errorf("scan: %w", err)
|
||||
}
|
||||
return 0, "", rejected, nil
|
||||
|
||||
return fileManager, nil
|
||||
}
|
||||
|
||||
// matchResolvConfHeader inspects a single comment line. Returns either a
|
||||
// definitive (manager, reason) or a non-empty rejected diagnostic.
|
||||
func matchResolvConfHeader(text string) (osManagerType, string, string) {
|
||||
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
||||
return netbirdManager, "netbird-managed resolv.conf header detected", ""
|
||||
}
|
||||
if strings.Contains(text, "NetworkManager") {
|
||||
if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||
return networkManager, "NetworkManager header + supported version on dbus", ""
|
||||
}
|
||||
return 0, "", "NetworkManager header (no dbus or unsupported version)"
|
||||
}
|
||||
if strings.Contains(text, "resolvconf") {
|
||||
if isSystemdResolveConfMode() {
|
||||
return systemdManager, "resolvconf header in systemd-resolved compatibility mode", ""
|
||||
}
|
||||
return resolvConfManager, "resolvconf header detected", ""
|
||||
}
|
||||
return 0, "", ""
|
||||
}
|
||||
|
||||
// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed
|
||||
// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping
|
||||
// into file mode while resolved is active.
|
||||
// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager.
|
||||
func checkStub() bool {
|
||||
rConf, err := parseDefaultResolvConf()
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err)
|
||||
log.Warnf("failed to parse resolv conf: %s", err)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -178,36 +139,3 @@ func checkStub() bool {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isLibnssResolveUsed reports whether nss-resolve is listed before dns on
|
||||
// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are
|
||||
// delegated to systemd-resolved regardless of /etc/resolv.conf.
|
||||
func isLibnssResolveUsed() bool {
|
||||
bs, err := os.ReadFile(nsswitchConfPath)
|
||||
if err != nil {
|
||||
log.Debugf("read %s: %v", nsswitchConfPath, err)
|
||||
return false
|
||||
}
|
||||
return parseNsswitchResolveAhead(bs)
|
||||
}
|
||||
|
||||
func parseNsswitchResolveAhead(data []byte) bool {
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
if i := strings.IndexByte(line, '#'); i >= 0 {
|
||||
line = line[:i]
|
||||
}
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 || fields[0] != "hosts:" {
|
||||
continue
|
||||
}
|
||||
for _, module := range fields[1:] {
|
||||
switch module {
|
||||
case "dns":
|
||||
return false
|
||||
case "resolve":
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseNsswitchResolveAhead(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "resolve before dns with action token",
|
||||
in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "dns before resolve",
|
||||
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "debian default with only dns",
|
||||
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "neither resolve nor dns",
|
||||
in: "hosts: files myhostname\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "no hosts line",
|
||||
in: "passwd: files systemd\ngroup: files systemd\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
in: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "comments and blank lines ignored",
|
||||
in: "# comment\n\n# another\nhosts: resolve dns\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "trailing inline comment",
|
||||
in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "hosts token must be the first field",
|
||||
in: " hosts: resolve dns\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "other db line mentioning resolve is ignored",
|
||||
in: "networks: resolve\nhosts: dns\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "only resolve, no dns",
|
||||
in: "hosts: files resolve\n",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want {
|
||||
t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,83 +2,40 @@ package mgmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
const (
|
||||
dnsTimeout = 5 * time.Second
|
||||
defaultTTL = 300 * time.Second
|
||||
refreshBackoff = 30 * time.Second
|
||||
const dnsTimeout = 5 * time.Second
|
||||
|
||||
// envMgmtCacheTTL overrides defaultTTL for integration/dev testing.
|
||||
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
|
||||
)
|
||||
|
||||
// ChainResolver lets the cache refresh stale entries through the DNS handler
|
||||
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
|
||||
// system resolver.
|
||||
type ChainResolver interface {
|
||||
ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error)
|
||||
HasRootHandlerAtOrBelow(maxPriority int) bool
|
||||
}
|
||||
|
||||
// cachedRecord holds DNS records plus timestamps used for TTL refresh.
|
||||
// records and cachedAt are set at construction and treated as immutable;
|
||||
// lastFailedRefresh and consecFailures are mutable and must be accessed under
|
||||
// Resolver.mutex.
|
||||
type cachedRecord struct {
|
||||
records []dns.RR
|
||||
cachedAt time.Time
|
||||
lastFailedRefresh time.Time
|
||||
consecFailures int
|
||||
}
|
||||
|
||||
// Resolver caches critical NetBird infrastructure domains.
|
||||
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
|
||||
// Resolver caches critical NetBird infrastructure domains
|
||||
type Resolver struct {
|
||||
records map[dns.Question]*cachedRecord
|
||||
records map[dns.Question][]dns.RR
|
||||
mgmtDomain *domain.Domain
|
||||
serverDomains *dnsconfig.ServerDomains
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
chain ChainResolver
|
||||
chainMaxPriority int
|
||||
refreshGroup singleflight.Group
|
||||
|
||||
// refreshing tracks questions whose refresh is running via the OS
|
||||
// fallback path. A ServeDNS hit for a question in this map indicates
|
||||
// the OS resolver routed the recursive query back to us (loop). Only
|
||||
// the OS path arms this so chain-path refreshes don't produce false
|
||||
// positives. The atomic bool is CAS-flipped once per refresh to
|
||||
// throttle the warning log.
|
||||
refreshing map[dns.Question]*atomic.Bool
|
||||
|
||||
cacheTTL time.Duration
|
||||
type ipsResponse struct {
|
||||
ips []netip.Addr
|
||||
err error
|
||||
}
|
||||
|
||||
// NewResolver creates a new management domains cache resolver.
|
||||
func NewResolver() *Resolver {
|
||||
return &Resolver{
|
||||
records: make(map[dns.Question]*cachedRecord),
|
||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||
cacheTTL: resolveCacheTTL(),
|
||||
records: make(map[dns.Question][]dns.RR),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,19 +44,7 @@ func (m *Resolver) String() string {
|
||||
return "MgmtCacheResolver"
|
||||
}
|
||||
|
||||
// SetChainResolver wires the handler chain used to refresh stale cache entries.
|
||||
// maxPriority caps which handlers may answer refresh queries (typically
|
||||
// PriorityUpstream, so upstream/default/fallback handlers are consulted and
|
||||
// mgmt/route/local handlers are skipped).
|
||||
func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
|
||||
m.mutex.Lock()
|
||||
m.chain = chain
|
||||
m.chainMaxPriority = maxPriority
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// ServeDNS serves cached A/AAAA records. Stale entries are returned
|
||||
// immediately and refreshed asynchronously (stale-while-revalidate).
|
||||
// ServeDNS implements dns.Handler interface.
|
||||
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
m.continueToNext(w, r)
|
||||
@@ -115,14 +60,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
|
||||
m.mutex.RLock()
|
||||
cached, found := m.records[question]
|
||||
inflight := m.refreshing[question]
|
||||
var shouldRefresh bool
|
||||
if found {
|
||||
stale := time.Since(cached.cachedAt) > m.cacheTTL
|
||||
inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff
|
||||
shouldRefresh = stale && !inBackoff
|
||||
}
|
||||
records, found := m.records[question]
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if !found {
|
||||
@@ -130,23 +68,12 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
if inflight != nil && inflight.CompareAndSwap(false, true) {
|
||||
log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)",
|
||||
question.Name)
|
||||
}
|
||||
|
||||
// Skip scheduling a refresh goroutine if one is already inflight for
|
||||
// this question; singleflight would dedup anyway but skipping avoids
|
||||
// a parked goroutine per stale hit under bursty load.
|
||||
if shouldRefresh && inflight == nil {
|
||||
m.scheduleRefresh(question, cached)
|
||||
}
|
||||
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
resp.Authoritative = false
|
||||
resp.RecursionAvailable = true
|
||||
resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt))
|
||||
|
||||
resp.Answer = append(resp.Answer, records...)
|
||||
|
||||
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
|
||||
|
||||
@@ -171,260 +98,101 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
}
|
||||
|
||||
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
||||
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
||||
// entry for that qtype.
|
||||
// AddDomain manually adds a domain to cache by resolving it.
|
||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName)
|
||||
|
||||
if errA != nil && errAAAA != nil {
|
||||
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
||||
ips, err := lookupIPWithExtraTimeout(ctx, d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(aRecords) == 0 && len(aaaaRecords) == 0 {
|
||||
if err := errors.Join(errA, errAAAA); err != nil {
|
||||
return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err)
|
||||
var aRecords, aaaaRecords []dns.RR
|
||||
for _, ip := range ips {
|
||||
if ip.Is4() {
|
||||
rr := &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dnsName,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300,
|
||||
},
|
||||
A: ip.AsSlice(),
|
||||
}
|
||||
aRecords = append(aRecords, rr)
|
||||
} else if ip.Is6() {
|
||||
rr := &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dnsName,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300,
|
||||
},
|
||||
AAAA: ip.AsSlice(),
|
||||
}
|
||||
aaaaRecords = append(aaaaRecords, rr)
|
||||
}
|
||||
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now)
|
||||
m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now)
|
||||
if len(aRecords) > 0 {
|
||||
aQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
m.records[aQuestion] = aRecords
|
||||
}
|
||||
|
||||
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
||||
if len(aaaaRecords) > 0 {
|
||||
aaaaQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeAAAA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
m.records[aaaaQuestion] = aaaaRecords
|
||||
}
|
||||
|
||||
m.mutex.Unlock()
|
||||
|
||||
log.Debugf("added domain=%s with %d A records and %d AAAA records",
|
||||
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyFamilyRecords writes records, evicts on NODATA, leaves the cache
|
||||
// untouched on error. Caller holds m.mutex.
|
||||
func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) {
|
||||
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
||||
switch {
|
||||
case len(records) > 0:
|
||||
m.records[q] = &cachedRecord{records: records, cachedAt: now}
|
||||
case err == nil:
|
||||
delete(m.records, q)
|
||||
}
|
||||
}
|
||||
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
|
||||
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
|
||||
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
|
||||
resultChan := make(chan *ipsResponse, 1)
|
||||
|
||||
// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per
|
||||
// unique in-flight key; bursty stale hits share its channel. expected is the
|
||||
// cachedRecord pointer observed by the caller; the refresh only mutates the
|
||||
// cache if that pointer is still the one stored, so a stale in-flight refresh
|
||||
// can't clobber a newer entry written by AddDomain or a competing refresh.
|
||||
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
|
||||
key := question.Name + "|" + dns.TypeToString[question.Qtype]
|
||||
_ = m.refreshGroup.DoChan(key, func() (any, error) {
|
||||
return nil, m.refreshQuestion(question, expected)
|
||||
})
|
||||
}
|
||||
|
||||
// refreshQuestion replaces the cached records on success, or marks the entry
|
||||
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
|
||||
// a resolver loop by spotting a query for this same question arriving on us.
|
||||
// expected pins the cache entry observed at schedule time; mutations only apply
|
||||
// if m.records[question] still points at it.
|
||||
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
|
||||
if err != nil {
|
||||
m.markRefreshFailed(question, expected)
|
||||
return fmt.Errorf("parse domain: %w", err)
|
||||
}
|
||||
|
||||
records, err := m.lookupRecords(ctx, d, question)
|
||||
if err != nil {
|
||||
fails := m.markRefreshFailed(question, expected)
|
||||
logf := log.Warnf
|
||||
if fails == 0 || fails > 1 {
|
||||
logf = log.Debugf
|
||||
go func() {
|
||||
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
|
||||
resultChan <- &ipsResponse{
|
||||
err: err,
|
||||
ips: ips,
|
||||
}
|
||||
logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype], err, fails)
|
||||
return err
|
||||
}()
|
||||
|
||||
var resp *ipsResponse
|
||||
|
||||
select {
|
||||
case <-time.After(dnsTimeout + time.Millisecond*500):
|
||||
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
|
||||
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case resp = <-resultChan:
|
||||
}
|
||||
|
||||
// NOERROR/NODATA: family gone upstream, evict so we stop serving stale.
|
||||
if len(records) == 0 {
|
||||
m.mutex.Lock()
|
||||
if m.records[question] == expected {
|
||||
delete(m.records, question)
|
||||
m.mutex.Unlock()
|
||||
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
if resp.err != nil {
|
||||
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
m.mutex.Lock()
|
||||
if m.records[question] != expected {
|
||||
m.mutex.Unlock()
|
||||
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
m.records[question] = &cachedRecord{records: records, cachedAt: now}
|
||||
m.mutex.Unlock()
|
||||
|
||||
log.Infof("refreshed mgmt cache domain=%s type=%s",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Resolver) markRefreshing(question dns.Question) {
|
||||
m.mutex.Lock()
|
||||
m.refreshing[question] = &atomic.Bool{}
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (m *Resolver) clearRefreshing(question dns.Question) {
|
||||
m.mutex.Lock()
|
||||
delete(m.refreshing, question)
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// markRefreshFailed arms the backoff and returns the new consecutive-failure
|
||||
// count so callers can downgrade subsequent failure logs to debug.
|
||||
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
c, ok := m.records[question]
|
||||
if !ok || c != expected {
|
||||
return 0
|
||||
}
|
||||
c.lastFailedRefresh = time.Now()
|
||||
c.consecFailures++
|
||||
return c.consecFailures
|
||||
}
|
||||
|
||||
// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
|
||||
// callers tell records, NODATA (nil err, no records), and failure apart.
|
||||
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
|
||||
m.mutex.RLock()
|
||||
chain := m.chain
|
||||
maxPriority := m.chainMaxPriority
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
||||
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
|
||||
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: drop once every supported OS registers a fallback resolver. Safe
|
||||
// today: no root handler at priority ≤ PriorityUpstream means NetBird is
|
||||
// not the system resolver, so net.DefaultResolver will not loop back.
|
||||
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
|
||||
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
|
||||
return
|
||||
}
|
||||
|
||||
// lookupRecords resolves a single record type via chain or OS. The OS branch
|
||||
// arms the loop detector for the duration of its call so that ServeDNS can
|
||||
// spot the OS resolver routing the recursive query back to us.
|
||||
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
|
||||
m.mutex.RLock()
|
||||
chain := m.chain
|
||||
maxPriority := m.chainMaxPriority
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
||||
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
|
||||
}
|
||||
|
||||
// TODO: drop once every supported OS registers a fallback resolver.
|
||||
m.markRefreshing(q)
|
||||
defer m.clearRefreshing(q)
|
||||
|
||||
return m.osLookup(ctx, d, q.Name, q.Qtype)
|
||||
}
|
||||
|
||||
// lookupViaChain resolves via the handler chain and rewrites each RR to use
|
||||
// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache
|
||||
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
|
||||
func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) {
|
||||
msg := &dns.Msg{}
|
||||
msg.SetQuestion(dnsName, qtype)
|
||||
msg.RecursionDesired = true
|
||||
|
||||
resp, err := chain.ResolveInternal(ctx, msg, maxPriority)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("chain resolve: %w", err)
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, fmt.Errorf("chain resolve returned nil response")
|
||||
}
|
||||
if resp.Rcode != dns.RcodeSuccess {
|
||||
return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode])
|
||||
}
|
||||
|
||||
ttl := uint32(m.cacheTTL.Seconds())
|
||||
owners := cnameOwners(dnsName, resp.Answer)
|
||||
var filtered []dns.RR
|
||||
for _, rr := range resp.Answer {
|
||||
h := rr.Header()
|
||||
if h.Class != dns.ClassINET || h.Rrtype != qtype {
|
||||
continue
|
||||
}
|
||||
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
|
||||
continue
|
||||
}
|
||||
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
|
||||
filtered = append(filtered, cp)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
// osLookup resolves a single family via net.DefaultResolver using resutil,
|
||||
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
|
||||
// returns (nil, nil).
|
||||
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
|
||||
network := resutil.NetworkForQtype(qtype)
|
||||
if network == "" {
|
||||
return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype])
|
||||
}
|
||||
|
||||
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
||||
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
||||
|
||||
result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
|
||||
if result.Rcode == dns.RcodeSuccess {
|
||||
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
|
||||
}
|
||||
|
||||
if result.Err != nil {
|
||||
return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err)
|
||||
}
|
||||
return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode])
|
||||
}
|
||||
|
||||
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
|
||||
// so downstream resolvers don't cache an answer for longer than we will.
|
||||
func (m *Resolver) responseTTL(cachedAt time.Time) uint32 {
|
||||
remaining := m.cacheTTL - time.Since(cachedAt)
|
||||
if remaining <= 0 {
|
||||
return 0
|
||||
}
|
||||
return uint32((remaining + time.Second - 1) / time.Second)
|
||||
return resp.ips, nil
|
||||
}
|
||||
|
||||
// PopulateFromConfig extracts and caches domains from the client configuration.
|
||||
@@ -456,12 +224,19 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}
|
||||
delete(m.records, qA)
|
||||
delete(m.records, qAAAA)
|
||||
delete(m.refreshing, qA)
|
||||
delete(m.refreshing, qAAAA)
|
||||
aQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
delete(m.records, aQuestion)
|
||||
|
||||
aaaaQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeAAAA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
delete(m.records, aaaaQuestion)
|
||||
|
||||
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||
return nil
|
||||
@@ -619,73 +394,3 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
|
||||
|
||||
return domains
|
||||
}
|
||||
|
||||
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
|
||||
// A/AAAA records return nil.
|
||||
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
|
||||
switch r := rr.(type) {
|
||||
case *dns.A:
|
||||
cp := *r
|
||||
cp.Hdr.Name = owner
|
||||
cp.Hdr.Ttl = ttl
|
||||
cp.A = slices.Clone(r.A)
|
||||
return &cp
|
||||
case *dns.AAAA:
|
||||
cp := *r
|
||||
cp.Hdr.Name = owner
|
||||
cp.Hdr.Ttl = ttl
|
||||
cp.AAAA = slices.Clone(r.AAAA)
|
||||
return &cp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
|
||||
// stamping ttl so the response shares no memory with the cached slice.
|
||||
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
|
||||
out := make([]dns.RR, 0, len(records))
|
||||
for _, rr := range records {
|
||||
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
|
||||
out = append(out, cp)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
|
||||
// in answer, iterating until fixed point so out-of-order chains resolve.
|
||||
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
|
||||
owners := map[string]bool{dnsName: true}
|
||||
for {
|
||||
added := false
|
||||
for _, rr := range answer {
|
||||
cname, ok := rr.(*dns.CNAME)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
|
||||
if !owners[name] {
|
||||
continue
|
||||
}
|
||||
target := strings.ToLower(dns.Fqdn(cname.Target))
|
||||
if !owners[target] {
|
||||
owners[target] = true
|
||||
added = true
|
||||
}
|
||||
}
|
||||
if !added {
|
||||
return owners
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// resolveCacheTTL reads the cache TTL override env var; invalid or empty
|
||||
// values fall back to defaultTTL. Called once per Resolver from NewResolver.
|
||||
func resolveCacheTTL() time.Duration {
|
||||
if v := os.Getenv(envMgmtCacheTTL); v != "" {
|
||||
if d, err := time.ParseDuration(v); err == nil && d > 0 {
|
||||
return d
|
||||
}
|
||||
}
|
||||
return defaultTTL
|
||||
}
|
||||
|
||||
@@ -1,408 +0,0 @@
|
||||
package mgmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
type fakeChain struct {
|
||||
mu sync.Mutex
|
||||
calls map[string]int
|
||||
answers map[string][]dns.RR
|
||||
err error
|
||||
hasRoot bool
|
||||
onLookup func()
|
||||
}
|
||||
|
||||
func newFakeChain() *fakeChain {
|
||||
return &fakeChain{
|
||||
calls: map[string]int{},
|
||||
answers: map[string][]dns.RR{},
|
||||
hasRoot: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.hasRoot
|
||||
}
|
||||
|
||||
func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) {
|
||||
f.mu.Lock()
|
||||
q := msg.Question[0]
|
||||
key := q.Name + "|" + dns.TypeToString[q.Qtype]
|
||||
f.calls[key]++
|
||||
answers := f.answers[key]
|
||||
err := f.err
|
||||
onLookup := f.onLookup
|
||||
f.mu.Unlock()
|
||||
|
||||
if onLookup != nil {
|
||||
onLookup()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(msg)
|
||||
resp.Answer = answers
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
key := name + "|" + dns.TypeToString[qtype]
|
||||
hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60}
|
||||
switch qtype {
|
||||
case dns.TypeA:
|
||||
f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}}
|
||||
case dns.TypeAAAA:
|
||||
f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeChain) callCount(name string, qtype uint16) int {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.calls[name+"|"+dns.TypeToString[qtype]]
|
||||
}
|
||||
|
||||
// waitFor polls the predicate until it returns true or the deadline passes.
|
||||
func waitFor(t *testing.T, d time.Duration, fn func() bool) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(d)
|
||||
for time.Now().Before(deadline) {
|
||||
if fn() {
|
||||
return
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("condition not met within %s", d)
|
||||
}
|
||||
|
||||
func queryA(t *testing.T, r *Resolver, name string) *dns.Msg {
|
||||
t.Helper()
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(name, dns.TypeA)
|
||||
w := &test.MockResponseWriter{}
|
||||
r.ServeDNS(w, msg)
|
||||
return w.GetLastResponse()
|
||||
}
|
||||
|
||||
func firstA(t *testing.T, resp *dns.Msg) string {
|
||||
t.Helper()
|
||||
require.NotNil(t, resp)
|
||||
require.Greater(t, len(resp.Answer), 0, "expected at least one answer")
|
||||
a, ok := resp.Answer[0].(*dns.A)
|
||||
require.True(t, ok, "expected A record")
|
||||
return a.A.String()
|
||||
}
|
||||
|
||||
func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
|
||||
// Same cached entry age, different cacheTTL values: the shorter TTL must
|
||||
// trigger a background refresh, the longer one must not. Proves that the
|
||||
// per-Resolver cacheTTL field actually drives the stale decision.
|
||||
cachedAt := time.Now().Add(-100 * time.Millisecond)
|
||||
|
||||
newRec := func() *cachedRecord {
|
||||
return &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: cachedAt,
|
||||
}
|
||||
}
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
|
||||
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r.cacheTTL = 10 * time.Millisecond
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
r.records[q] = newRec()
|
||||
|
||||
resp := queryA(t, r, q.Name)
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
|
||||
|
||||
waitFor(t, time.Second, func() bool {
|
||||
return chain.callCount(q.Name, dns.TypeA) >= 1
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r.cacheTTL = time.Hour
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
r.records[q] = newRec()
|
||||
|
||||
resp := queryA(t, r, q.Name)
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp))
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh")
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(), // fresh
|
||||
}
|
||||
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp))
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh")
|
||||
}
|
||||
|
||||
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now().Add(-2 * defaultTTL), // stale
|
||||
}
|
||||
|
||||
// First query: serves stale immediately.
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
|
||||
|
||||
waitFor(t, time.Second, func() bool {
|
||||
return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1
|
||||
})
|
||||
|
||||
// Next query should now return the refreshed IP.
|
||||
waitFor(t, time.Second, func() bool {
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2"
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
|
||||
var inflight atomic.Int32
|
||||
var maxInflight atomic.Int32
|
||||
chain.onLookup = func() {
|
||||
cur := inflight.Add(1)
|
||||
defer inflight.Add(-1)
|
||||
for {
|
||||
prev := maxInflight.Load()
|
||||
if cur <= prev || maxInflight.CompareAndSwap(prev, cur) {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide
|
||||
}
|
||||
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now().Add(-2 * defaultTTL),
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
waitFor(t, 2*time.Second, func() bool {
|
||||
return inflight.Load() == 0
|
||||
})
|
||||
|
||||
calls := chain.callCount("mgmt.example.com.", dns.TypeA)
|
||||
assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls)
|
||||
assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently")
|
||||
}
|
||||
|
||||
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now().Add(-2 * defaultTTL),
|
||||
}
|
||||
|
||||
// First stale hit triggers a refresh attempt that fails.
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails")
|
||||
|
||||
waitFor(t, time.Second, func() bool {
|
||||
return chain.callCount("mgmt.example.com.", dns.TypeA) == 1
|
||||
})
|
||||
waitFor(t, time.Second, func() bool {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
c, ok := r.records[q]
|
||||
return ok && !c.lastFailedRefresh.IsZero()
|
||||
})
|
||||
|
||||
// Subsequent stale hits within backoff window should not schedule more refreshes.
|
||||
for i := 0; i < 10; i++ {
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes")
|
||||
}
|
||||
|
||||
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.hasRoot = false
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
// With hasRoot=false the chain must not be consulted. Use a short
|
||||
// deadline so the OS fallback returns quickly without waiting on a
|
||||
// real network call in CI.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
_, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.")
|
||||
|
||||
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA),
|
||||
"chain must not be used when no root handler is registered at the bound priority")
|
||||
}
|
||||
|
||||
func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
|
||||
// ServeDNS being invoked for a question while a refresh for that question
|
||||
// is inflight indicates a resolver loop (OS resolver sent the recursive
|
||||
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
|
||||
r := NewResolver()
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Simulate an inflight refresh.
|
||||
r.markRefreshing(q)
|
||||
defer r.clearRefreshing(q)
|
||||
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries")
|
||||
|
||||
r.mutex.RLock()
|
||||
inflight := r.refreshing[q]
|
||||
r.mutex.RUnlock()
|
||||
require.NotNil(t, inflight)
|
||||
assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed")
|
||||
}
|
||||
|
||||
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(),
|
||||
}
|
||||
|
||||
r.markRefreshing(q)
|
||||
defer r.clearRefreshing(q)
|
||||
|
||||
// Multiple ServeDNS calls during the same refresh must not re-set the flag
|
||||
// (CompareAndSwap from false -> true returns true only on the first call).
|
||||
for range 5 {
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
}
|
||||
|
||||
r.mutex.RLock()
|
||||
inflight := r.refreshing[q]
|
||||
r.mutex.RUnlock()
|
||||
assert.True(t, inflight.Load())
|
||||
}
|
||||
|
||||
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
|
||||
r := NewResolver()
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(),
|
||||
}
|
||||
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
|
||||
r.mutex.RLock()
|
||||
_, ok := r.refreshing[q]
|
||||
r.mutex.RUnlock()
|
||||
assert.False(t, ok, "no refresh inflight means no loop tracking")
|
||||
}
|
||||
|
||||
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com")))
|
||||
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.2", firstA(t, resp))
|
||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA))
|
||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA))
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -24,60 +23,6 @@ func TestResolver_NewResolver(t *testing.T) {
|
||||
assert.False(t, resolver.MatchSubdomains())
|
||||
}
|
||||
|
||||
func TestResolveCacheTTL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
want time.Duration
|
||||
}{
|
||||
{"unset falls back to default", "", defaultTTL},
|
||||
{"valid duration", "45s", 45 * time.Second},
|
||||
{"valid minutes", "2m", 2 * time.Minute},
|
||||
{"malformed falls back to default", "not-a-duration", defaultTTL},
|
||||
{"zero falls back to default", "0s", defaultTTL},
|
||||
{"negative falls back to default", "-5s", defaultTTL},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(envMgmtCacheTTL, tc.value)
|
||||
got := resolveCacheTTL()
|
||||
assert.Equal(t, tc.want, got, "parsed TTL should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
|
||||
t.Setenv(envMgmtCacheTTL, "7s")
|
||||
r := NewResolver()
|
||||
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
|
||||
}
|
||||
|
||||
func TestResolver_ResponseTTL(t *testing.T) {
|
||||
now := time.Now()
|
||||
tests := []struct {
|
||||
name string
|
||||
cacheTTL time.Duration
|
||||
cachedAt time.Time
|
||||
wantMin uint32
|
||||
wantMax uint32
|
||||
}{
|
||||
{"fresh entry returns full TTL", 60 * time.Second, now, 59, 60},
|
||||
{"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31},
|
||||
{"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0},
|
||||
{"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := &Resolver{cacheTTL: tc.cacheTTL}
|
||||
got := r.responseTTL(tc.cachedAt)
|
||||
assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin")
|
||||
assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolver_ExtractDomainFromURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -212,7 +212,6 @@ func newDefaultServer(
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
|
||||
mgmtCacheResolver := mgmt.NewResolver()
|
||||
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
|
||||
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
|
||||
@@ -26,9 +26,7 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
@@ -69,7 +67,6 @@ import (
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
sProto "github.com/netbirdio/netbird/shared/signal/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/util/capture"
|
||||
)
|
||||
|
||||
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
||||
@@ -120,11 +117,13 @@ type EngineConfig struct {
|
||||
RosenpassPermissive bool
|
||||
|
||||
ServerSSHAllowed bool
|
||||
ServerVNCAllowed bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
EnableSSHRemotePortForwarding *bool
|
||||
DisableSSHAuth *bool
|
||||
DisableVNCAuth *bool
|
||||
|
||||
DNSRouteInterval time.Duration
|
||||
|
||||
@@ -201,6 +200,7 @@ type Engine struct {
|
||||
networkMonitor *networkmonitor.NetworkMonitor
|
||||
|
||||
sshServer sshServer
|
||||
vncSrv vncServer
|
||||
|
||||
statusRecorder *peer.Status
|
||||
|
||||
@@ -220,8 +220,6 @@ type Engine struct {
|
||||
portForwardManager *portforward.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
|
||||
afpacketCapture *capture.AFPacketCapture
|
||||
|
||||
// Sync response persistence (protected by syncRespMux)
|
||||
syncRespMux sync.RWMutex
|
||||
persistSyncResponse bool
|
||||
@@ -316,6 +314,10 @@ func (e *Engine) Stop() error {
|
||||
log.Warnf("failed to stop SSH server: %v", err)
|
||||
}
|
||||
|
||||
if err := e.stopVNCServer(); err != nil {
|
||||
log.Warnf("failed to stop VNC server: %v", err)
|
||||
}
|
||||
|
||||
e.cleanupSSHConfig()
|
||||
|
||||
if e.ingressGatewayMgr != nil {
|
||||
@@ -575,7 +577,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.connMgr.Start(e.ctx)
|
||||
|
||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||
e.srWatcher.Start(peer.IsForceRelayed())
|
||||
e.srWatcher.Start()
|
||||
|
||||
e.receiveSignalEvents()
|
||||
e.receiveManagementEvents()
|
||||
@@ -609,8 +611,6 @@ func (e *Engine) createFirewall() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
firewalld.SetParentContext(e.ctx)
|
||||
|
||||
var err error
|
||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
||||
if err != nil {
|
||||
@@ -948,12 +948,7 @@ func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
||||
return fmt.Errorf("update relay token: %w", err)
|
||||
}
|
||||
|
||||
urls := update.Urls
|
||||
if override, ok := peer.OverrideRelayURLs(); ok {
|
||||
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
|
||||
urls = override
|
||||
}
|
||||
e.relayManager.UpdateServerURLs(urls)
|
||||
e.relayManager.UpdateServerURLs(update.Urls)
|
||||
|
||||
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
||||
// We can ignore all errors because the guard will manage the reconnection retries.
|
||||
@@ -1010,6 +1005,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
@@ -1022,6 +1018,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
e.config.DisableVNCAuth,
|
||||
)
|
||||
|
||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||
@@ -1049,6 +1046,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := e.updateVNC(conf.GetSshConfig()); err != nil {
|
||||
log.Warnf("failed handling VNC server setup: %v", err)
|
||||
}
|
||||
|
||||
state := e.statusRecorder.GetLocalPeerState()
|
||||
state.IP = e.wgInterface.Address().String()
|
||||
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
||||
@@ -1151,6 +1152,7 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
@@ -1163,6 +1165,7 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
e.config.DisableVNCAuth,
|
||||
)
|
||||
|
||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||
@@ -1337,6 +1340,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
}
|
||||
|
||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||
|
||||
// VNC auth: use dedicated VNCAuth if present.
|
||||
if vncAuth := networkMap.GetVncAuth(); vncAuth != nil {
|
||||
e.updateVNCServerAuth(vncAuth)
|
||||
}
|
||||
}
|
||||
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
@@ -1707,11 +1715,6 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
}
|
||||
|
||||
func (e *Engine) close() {
|
||||
if e.afpacketCapture != nil {
|
||||
e.afpacketCapture.Stop()
|
||||
e.afpacketCapture = nil
|
||||
}
|
||||
|
||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||
|
||||
if e.wgInterface != nil {
|
||||
@@ -1751,6 +1754,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
@@ -1763,6 +1767,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
e.config.DisableVNCAuth,
|
||||
)
|
||||
|
||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||
@@ -2177,62 +2182,6 @@ func (e *Engine) Address() (netip.Addr, error) {
|
||||
return e.wgInterface.Address().IP, nil
|
||||
}
|
||||
|
||||
// SetCapture sets or clears packet capture on the WireGuard device.
|
||||
// On userspace WireGuard, it taps the FilteredDevice directly.
|
||||
// On kernel WireGuard (Linux), it falls back to AF_PACKET raw socket capture.
|
||||
// Pass nil to disable capture.
|
||||
func (e *Engine) SetCapture(pc device.PacketCapture) error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
intf := e.wgInterface
|
||||
if intf == nil {
|
||||
return errors.New("wireguard interface not initialized")
|
||||
}
|
||||
|
||||
if e.afpacketCapture != nil {
|
||||
e.afpacketCapture.Stop()
|
||||
e.afpacketCapture = nil
|
||||
}
|
||||
|
||||
dev := intf.GetDevice()
|
||||
if dev != nil {
|
||||
dev.SetCapture(pc)
|
||||
e.setForwarderCapture(pc)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Kernel mode: no FilteredDevice. Use AF_PACKET on Linux.
|
||||
if pc == nil {
|
||||
return nil
|
||||
}
|
||||
sess, ok := pc.(*capture.Session)
|
||||
if !ok {
|
||||
return errors.New("filtered device not available and AF_PACKET requires *capture.Session")
|
||||
}
|
||||
|
||||
afc := capture.NewAFPacketCapture(intf.Name(), sess)
|
||||
if err := afc.Start(); err != nil {
|
||||
return fmt.Errorf("start AF_PACKET capture on %s: %w", intf.Name(), err)
|
||||
}
|
||||
e.afpacketCapture = afc
|
||||
return nil
|
||||
}
|
||||
|
||||
// setForwarderCapture propagates capture to the USP filter's forwarder endpoint.
|
||||
// This captures outbound response packets that bypass the FilteredDevice in netstack mode.
|
||||
func (e *Engine) setForwarderCapture(pc device.PacketCapture) {
|
||||
if e.firewall == nil {
|
||||
return
|
||||
}
|
||||
type forwarderCapturer interface {
|
||||
SetPacketCapture(pc forwarder.PacketCapture)
|
||||
}
|
||||
if fc, ok := e.firewall.(forwarderCapturer); ok {
|
||||
fc.SetPacketCapture(pc)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {
|
||||
if e.firewall == nil {
|
||||
log.Warn("firewall is disabled, not updating forwarding rules")
|
||||
@@ -2454,8 +2403,6 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
|
||||
}
|
||||
}
|
||||
|
||||
relayIP := decodeRelayIP(msg.GetBody().GetRelayServerIP())
|
||||
|
||||
offerAnswer := peer.OfferAnswer{
|
||||
IceCredentials: peer.IceCredentials{
|
||||
UFrag: remoteCred.UFrag,
|
||||
@@ -2466,23 +2413,7 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
|
||||
RosenpassPubKey: rosenpassPubKey,
|
||||
RosenpassAddr: rosenpassAddr,
|
||||
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
|
||||
RelaySrvIP: relayIP,
|
||||
SessionID: sessionID,
|
||||
}
|
||||
return &offerAnswer, nil
|
||||
}
|
||||
|
||||
// decodeRelayIP decodes the proto relayServerIP bytes (4 or 16) into a
|
||||
// netip.Addr. Returns the zero value for empty input and logs a warning
|
||||
// for malformed payloads.
|
||||
func decodeRelayIP(b []byte) netip.Addr {
|
||||
if len(b) == 0 {
|
||||
return netip.Addr{}
|
||||
}
|
||||
ip, ok := netip.AddrFromSlice(b)
|
||||
if !ok {
|
||||
log.Warnf("invalid relayServerIP in signal message (%d bytes), ignoring", len(b))
|
||||
return netip.Addr{}
|
||||
}
|
||||
return ip.Unmap()
|
||||
}
|
||||
|
||||
@@ -1671,7 +1671,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
309
client/internal/engine_vnc.go
Normal file
309
client/internal/engine_vnc.go
Normal file
@@ -0,0 +1,309 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
const envVNCForceRecording = "NB_VNC_FORCE_RECORDING"
|
||||
|
||||
const (
|
||||
vncExternalPort uint16 = 5900
|
||||
vncInternalPort uint16 = 25900
|
||||
)
|
||||
|
||||
type vncServer interface {
|
||||
Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error
|
||||
Stop() error
|
||||
}
|
||||
|
||||
func (e *Engine) setupVNCPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, vncExternalPort, vncInternalPort); err != nil {
|
||||
return fmt.Errorf("add VNC port redirection: %w", err)
|
||||
}
|
||||
log.Infof("VNC port redirection: %s:%d -> %s:%d", localAddr, vncExternalPort, localAddr, vncInternalPort)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) cleanupVNCPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, vncExternalPort, vncInternalPort); err != nil {
|
||||
return fmt.Errorf("remove VNC port redirection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateVNC handles starting/stopping the VNC server based on the config flag.
|
||||
// sshConf provides the JWT identity provider config (shared with SSH).
|
||||
func (e *Engine) updateVNC(sshConf *mgmProto.SSHConfig) error {
|
||||
if !e.config.ServerVNCAllowed {
|
||||
if e.vncSrv != nil {
|
||||
log.Info("VNC server disabled, stopping")
|
||||
}
|
||||
return e.stopVNCServer()
|
||||
}
|
||||
|
||||
if e.config.BlockInbound {
|
||||
log.Info("VNC server disabled because inbound connections are blocked")
|
||||
return e.stopVNCServer()
|
||||
}
|
||||
|
||||
if e.vncSrv != nil {
|
||||
// Update JWT config on existing server in case management sent new config.
|
||||
e.updateVNCServerJWT(sshConf)
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.startVNCServer(sshConf)
|
||||
}
|
||||
|
||||
func (e *Engine) startVNCServer(sshConf *mgmProto.SSHConfig) error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wg interface not initialized")
|
||||
}
|
||||
|
||||
capturer, injector := newPlatformVNC()
|
||||
if capturer == nil || injector == nil {
|
||||
log.Debug("VNC server not supported on this platform")
|
||||
return nil
|
||||
}
|
||||
|
||||
netbirdIP := e.wgInterface.Address().IP
|
||||
|
||||
srv := vncserver.New(capturer, injector, "")
|
||||
if vncNeedsServiceMode() {
|
||||
log.Info("VNC: running in Session 0, enabling service mode (agent proxy)")
|
||||
srv.SetServiceMode(true)
|
||||
}
|
||||
|
||||
// Configure VNC authentication.
|
||||
if e.config.DisableVNCAuth != nil && *e.config.DisableVNCAuth {
|
||||
log.Info("VNC: authentication disabled by config")
|
||||
srv.SetDisableAuth(true)
|
||||
} else if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
|
||||
audiences := protoJWT.GetAudiences()
|
||||
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
|
||||
audiences = []string{protoJWT.GetAudience()}
|
||||
}
|
||||
srv.SetJWTConfig(&vncserver.JWTConfig{
|
||||
Issuer: protoJWT.GetIssuer(),
|
||||
Audiences: audiences,
|
||||
KeysLocation: protoJWT.GetKeysLocation(),
|
||||
MaxTokenAge: protoJWT.GetMaxTokenAge(),
|
||||
})
|
||||
log.Debugf("VNC: JWT authentication configured (issuer=%s)", protoJWT.GetIssuer())
|
||||
}
|
||||
|
||||
e.configureVNCRecording(srv, sshConf)
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
srv.SetNetstackNet(netstackNet)
|
||||
}
|
||||
|
||||
listenAddr := netip.AddrPortFrom(netbirdIP, vncInternalPort)
|
||||
network := e.wgInterface.Address().Network
|
||||
if err := srv.Start(e.ctx, listenAddr, network); err != nil {
|
||||
return fmt.Errorf("start VNC server: %w", err)
|
||||
}
|
||||
|
||||
e.vncSrv = srv
|
||||
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.RegisterNetstackService(nftypes.TCP, vncInternalPort)
|
||||
log.Debugf("registered VNC service for TCP:%d", vncInternalPort)
|
||||
}
|
||||
|
||||
if err := e.setupVNCPortRedirection(); err != nil {
|
||||
log.Warnf("setup VNC port redirection: %v", err)
|
||||
}
|
||||
|
||||
log.Info("VNC server enabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
// configureVNCRecording enables session recording on the VNC server from the
|
||||
// management-supplied settings. The env var NB_VNC_FORCE_RECORDING overrides
|
||||
// the API for local development: when set, recording is always enabled and
|
||||
// writes into that directory. Otherwise recordings go next to the state file
|
||||
// under vnc-recordings/.
|
||||
func (e *Engine) configureVNCRecording(srv *vncserver.Server, sshConf *mgmProto.SSHConfig) {
|
||||
recDir := os.Getenv(envVNCForceRecording)
|
||||
apiEnabled := sshConf.GetEnableRecording()
|
||||
|
||||
if recDir == "" && !apiEnabled {
|
||||
log.Debugf("VNC recording disabled (env=%q, api=%v)", recDir, apiEnabled)
|
||||
return
|
||||
}
|
||||
|
||||
if recDir == "" {
|
||||
base := e.defaultRecordingBase()
|
||||
if base == "" {
|
||||
log.Warn("VNC recording requested by management but no state directory is available")
|
||||
return
|
||||
}
|
||||
recDir = filepath.Join(base, "vnc-recordings")
|
||||
} else {
|
||||
recDir = filepath.Join(recDir, "vnc")
|
||||
}
|
||||
|
||||
srv.SetRecordingDir(recDir)
|
||||
log.Infof("VNC recording enabled (dir=%s, source=%s)", recDir, recordingSource(apiEnabled))
|
||||
|
||||
encKey := string(sshConf.GetRecordingEncryptionKey())
|
||||
if encKey == "" {
|
||||
encKey = os.Getenv("NB_VNC_RECORDING_ENCRYPTION_KEY")
|
||||
}
|
||||
if encKey != "" {
|
||||
srv.SetRecordingEncryptionKey(encKey)
|
||||
log.Info("VNC recording encryption enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) defaultRecordingBase() string {
|
||||
if e.stateManager == nil {
|
||||
return ""
|
||||
}
|
||||
p := e.stateManager.FilePath()
|
||||
if p == "" {
|
||||
return ""
|
||||
}
|
||||
return filepath.Dir(p)
|
||||
}
|
||||
|
||||
func recordingSource(api bool) string {
|
||||
if api {
|
||||
return "management"
|
||||
}
|
||||
return "env"
|
||||
}
|
||||
|
||||
// updateVNCServerJWT configures the JWT validation for the VNC server using
|
||||
// the same JWT config as SSH (same identity provider).
|
||||
func (e *Engine) updateVNCServerJWT(sshConf *mgmProto.SSHConfig) {
|
||||
if e.vncSrv == nil {
|
||||
return
|
||||
}
|
||||
|
||||
vncSrv, ok := e.vncSrv.(*vncserver.Server)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if e.config.DisableVNCAuth != nil && *e.config.DisableVNCAuth {
|
||||
vncSrv.SetDisableAuth(true)
|
||||
return
|
||||
}
|
||||
|
||||
protoJWT := sshConf.GetJwtConfig()
|
||||
if protoJWT == nil {
|
||||
return
|
||||
}
|
||||
|
||||
audiences := protoJWT.GetAudiences()
|
||||
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
|
||||
audiences = []string{protoJWT.GetAudience()}
|
||||
}
|
||||
|
||||
vncSrv.SetJWTConfig(&vncserver.JWTConfig{
|
||||
Issuer: protoJWT.GetIssuer(),
|
||||
Audiences: audiences,
|
||||
KeysLocation: protoJWT.GetKeysLocation(),
|
||||
MaxTokenAge: protoJWT.GetMaxTokenAge(),
|
||||
})
|
||||
}
|
||||
|
||||
// updateVNCServerAuth updates VNC fine-grained access control from management.
|
||||
func (e *Engine) updateVNCServerAuth(vncAuth *mgmProto.VNCAuth) {
|
||||
if vncAuth == nil || e.vncSrv == nil {
|
||||
return
|
||||
}
|
||||
|
||||
vncSrv, ok := e.vncSrv.(*vncserver.Server)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
protoUsers := vncAuth.GetAuthorizedUsers()
|
||||
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
|
||||
for i, hash := range protoUsers {
|
||||
if len(hash) != 16 {
|
||||
log.Warnf("invalid VNC auth hash length %d, expected 16", len(hash))
|
||||
return
|
||||
}
|
||||
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
|
||||
}
|
||||
|
||||
machineUsers := make(map[string][]uint32)
|
||||
for osUser, indexes := range vncAuth.GetMachineUsers() {
|
||||
machineUsers[osUser] = indexes.GetIndexes()
|
||||
}
|
||||
|
||||
vncSrv.UpdateVNCAuth(&sshauth.Config{
|
||||
UserIDClaim: vncAuth.GetUserIDClaim(),
|
||||
AuthorizedUsers: authorizedUsers,
|
||||
MachineUsers: machineUsers,
|
||||
})
|
||||
}
|
||||
|
||||
// GetVNCServerStatus returns whether the VNC server is running.
|
||||
func (e *Engine) GetVNCServerStatus() bool {
|
||||
return e.vncSrv != nil
|
||||
}
|
||||
|
||||
func (e *Engine) stopVNCServer() error {
|
||||
if e.vncSrv == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.cleanupVNCPortRedirection(); err != nil {
|
||||
log.Warnf("cleanup VNC port redirection: %v", err)
|
||||
}
|
||||
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.UnregisterNetstackService(nftypes.TCP, vncInternalPort)
|
||||
}
|
||||
|
||||
log.Info("stopping VNC server")
|
||||
err := e.vncSrv.Stop()
|
||||
e.vncSrv = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("stop VNC server: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
23
client/internal/engine_vnc_darwin.go
Normal file
23
client/internal/engine_vnc_darwin.go
Normal file
@@ -0,0 +1,23 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
|
||||
capturer := vncserver.NewMacPoller()
|
||||
injector, err := vncserver.NewMacInputInjector()
|
||||
if err != nil {
|
||||
log.Debugf("VNC: macOS input injector: %v", err)
|
||||
return capturer, &vncserver.StubInputInjector{}
|
||||
}
|
||||
return capturer, injector
|
||||
}
|
||||
|
||||
func vncNeedsServiceMode() bool {
|
||||
return false
|
||||
}
|
||||
13
client/internal/engine_vnc_stub.go
Normal file
13
client/internal/engine_vnc_stub.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build !windows && !darwin && !freebsd && !(linux && !android)
|
||||
|
||||
package internal
|
||||
|
||||
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func vncNeedsServiceMode() bool {
|
||||
return false
|
||||
}
|
||||
13
client/internal/engine_vnc_windows.go
Normal file
13
client/internal/engine_vnc_windows.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build windows
|
||||
|
||||
package internal
|
||||
|
||||
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
|
||||
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector()
|
||||
}
|
||||
|
||||
func vncNeedsServiceMode() bool {
|
||||
return vncserver.GetCurrentSessionID() == 0
|
||||
}
|
||||
23
client/internal/engine_vnc_x11.go
Normal file
23
client/internal/engine_vnc_x11.go
Normal file
@@ -0,0 +1,23 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
|
||||
capturer := vncserver.NewX11Poller("")
|
||||
injector, err := vncserver.NewX11InputInjector("")
|
||||
if err != nil {
|
||||
log.Debugf("VNC: X11 input injector: %v", err)
|
||||
return capturer, &vncserver.StubInputInjector{}
|
||||
}
|
||||
return capturer, injector
|
||||
}
|
||||
|
||||
func vncNeedsServiceMode() bool {
|
||||
return false
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package activity
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -17,6 +18,10 @@ import (
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
func isBindListenerPlatform() bool {
|
||||
return runtime.GOOS == "windows" || runtime.GOOS == "js"
|
||||
}
|
||||
|
||||
// mockEndpointManager implements device.EndpointManager for testing
|
||||
type mockEndpointManager struct {
|
||||
endpoints map[netip.Addr]net.Conn
|
||||
@@ -176,6 +181,10 @@ func TestBindListener_Close(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestManager_BindMode(t *testing.T) {
|
||||
if !isBindListenerPlatform() {
|
||||
t.Skip("BindListener only used on Windows/JS platforms")
|
||||
}
|
||||
|
||||
mockEndpointMgr := newMockEndpointManager()
|
||||
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||
|
||||
@@ -217,6 +226,10 @@ func TestManager_BindMode(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestManager_BindMode_MultiplePeers(t *testing.T) {
|
||||
if !isBindListenerPlatform() {
|
||||
t.Skip("BindListener only used on Windows/JS platforms")
|
||||
}
|
||||
|
||||
mockEndpointMgr := newMockEndpointManager()
|
||||
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||
|
||||
|
||||
@@ -4,12 +4,14 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
@@ -73,6 +75,16 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error)
|
||||
return NewUDPListener(m.wgIface, peerCfg)
|
||||
}
|
||||
|
||||
// BindListener is used on Windows, JS, and netstack platforms:
|
||||
// - JS: Cannot listen to UDP sockets
|
||||
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
|
||||
// gateway points to, preventing them from reaching the loopback interface.
|
||||
// - Netstack: Allows multiple instances on the same host without port conflicts.
|
||||
// BindListener bypasses these issues by passing data directly through the bind.
|
||||
if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() {
|
||||
return NewUDPListener(m.wgIface, peerCfg)
|
||||
}
|
||||
|
||||
provider, ok := m.wgIface.(bindProvider)
|
||||
if !ok {
|
||||
return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider")
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn/activity"
|
||||
@@ -90,8 +91,8 @@ func (m *Manager) UpdateRouteHAMap(haMap route.HAMap) {
|
||||
m.routesMu.Lock()
|
||||
defer m.routesMu.Unlock()
|
||||
|
||||
clear(m.peerToHAGroups)
|
||||
clear(m.haGroupToPeers)
|
||||
maps.Clear(m.peerToHAGroups)
|
||||
maps.Clear(m.haGroupToPeers)
|
||||
|
||||
for haUniqueID, routes := range haMap {
|
||||
var peers []string
|
||||
|
||||
@@ -3,6 +3,8 @@ package store
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
@@ -28,7 +30,7 @@ func (m *Memory) StoreEvent(event *types.Event) {
|
||||
func (m *Memory) Close() {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
clear(m.events)
|
||||
maps.Clear(m.events)
|
||||
}
|
||||
|
||||
func (m *Memory) GetEvents() []*types.Event {
|
||||
|
||||
@@ -185,20 +185,17 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
|
||||
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
|
||||
|
||||
forceRelay := IsForceRelayed()
|
||||
if !forceRelay {
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
|
||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages)
|
||||
|
||||
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
|
||||
if !forceRelay {
|
||||
if !isForceRelayed() {
|
||||
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
|
||||
}
|
||||
|
||||
@@ -254,9 +251,7 @@ func (conn *Conn) Close(signalToRemote bool) {
|
||||
conn.wgWatcherCancel()
|
||||
}
|
||||
conn.workerRelay.CloseConn()
|
||||
if conn.workerICE != nil {
|
||||
conn.workerICE.Close()
|
||||
}
|
||||
conn.workerICE.Close()
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
err := conn.wgProxyRelay.CloseConn()
|
||||
@@ -299,9 +294,7 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) {
|
||||
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
||||
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
|
||||
conn.dumpState.RemoteCandidate()
|
||||
if conn.workerICE != nil {
|
||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||
}
|
||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||
}
|
||||
|
||||
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
|
||||
@@ -719,35 +712,33 @@ func (conn *Conn) evalStatus() ConnStatus {
|
||||
return StatusConnecting
|
||||
}
|
||||
|
||||
// isConnectedOnAllWay evaluates the overall connection status based on ICE and Relay transports.
|
||||
//
|
||||
// The result is a tri-state:
|
||||
// - ConnStatusConnected: all available transports are up
|
||||
// - ConnStatusPartiallyConnected: relay is up but ICE is still pending/reconnecting
|
||||
// - ConnStatusDisconnected: no working transport
|
||||
func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) {
|
||||
func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
// would be better to protect this with a mutex, but it could cause deadlock with Close function
|
||||
|
||||
defer func() {
|
||||
if status == guard.ConnStatusDisconnected {
|
||||
if !connected {
|
||||
conn.logTraceConnState()
|
||||
}
|
||||
}()
|
||||
|
||||
iceWorkerCreated := conn.workerICE != nil
|
||||
|
||||
var iceInProgress bool
|
||||
if iceWorkerCreated {
|
||||
iceInProgress = conn.workerICE.InProgress()
|
||||
// For JS platform: only relay connection is supported
|
||||
if runtime.GOOS == "js" {
|
||||
return conn.statusRelay.Get() == worker.StatusConnected
|
||||
}
|
||||
|
||||
return evalConnStatus(connStatusInputs{
|
||||
forceRelay: IsForceRelayed(),
|
||||
peerUsesRelay: conn.workerRelay.IsRelayConnectionSupportedWithPeer(),
|
||||
relayConnected: conn.statusRelay.Get() == worker.StatusConnected,
|
||||
remoteSupportsICE: conn.handshaker.RemoteICESupported(),
|
||||
iceWorkerCreated: iceWorkerCreated,
|
||||
iceStatusConnecting: conn.statusICE.Get() != worker.StatusDisconnected,
|
||||
iceInProgress: iceInProgress,
|
||||
})
|
||||
// For non-JS platforms: check ICE connection status
|
||||
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||
return false
|
||||
}
|
||||
|
||||
// If relay is supported with peer, it must also be connected
|
||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||
if conn.statusRelay.Get() == worker.StatusDisconnected {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
|
||||
@@ -935,43 +926,3 @@ func isController(config ConnConfig) bool {
|
||||
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
||||
return remoteRosenpassPubKey != nil
|
||||
}
|
||||
|
||||
func evalConnStatus(in connStatusInputs) guard.ConnStatus {
|
||||
// "Relay up and needed" — the peer uses relay and the transport is connected.
|
||||
relayUsedAndUp := in.peerUsesRelay && in.relayConnected
|
||||
|
||||
// Force-relay mode: ICE never runs. Relay is the only transport and must be up.
|
||||
if in.forceRelay {
|
||||
return boolToConnStatus(relayUsedAndUp)
|
||||
}
|
||||
|
||||
// Remote peer doesn't support ICE, or we haven't created the worker yet:
|
||||
// relay is the only possible transport.
|
||||
if !in.remoteSupportsICE || !in.iceWorkerCreated {
|
||||
return boolToConnStatus(relayUsedAndUp)
|
||||
}
|
||||
|
||||
// ICE counts as "up" when the status is anything other than Disconnected, OR
|
||||
// when a negotiation is currently in progress (so we don't spam offers while one is in flight).
|
||||
iceUp := in.iceStatusConnecting || in.iceInProgress
|
||||
|
||||
// Relay side is acceptable if the peer doesn't rely on relay, or relay is connected.
|
||||
relayOK := !in.peerUsesRelay || in.relayConnected
|
||||
|
||||
switch {
|
||||
case iceUp && relayOK:
|
||||
return guard.ConnStatusConnected
|
||||
case relayUsedAndUp:
|
||||
// Relay is up but ICE is down — partially connected.
|
||||
return guard.ConnStatusPartiallyConnected
|
||||
default:
|
||||
return guard.ConnStatusDisconnected
|
||||
}
|
||||
}
|
||||
|
||||
func boolToConnStatus(connected bool) guard.ConnStatus {
|
||||
if connected {
|
||||
return guard.ConnStatusConnected
|
||||
}
|
||||
return guard.ConnStatusDisconnected
|
||||
}
|
||||
|
||||
@@ -13,20 +13,6 @@ const (
|
||||
StatusConnected
|
||||
)
|
||||
|
||||
// connStatusInputs is the primitive-valued snapshot of the state that drives the
|
||||
// tri-state connection classification. Extracted so the decision logic can be unit-tested
|
||||
// without constructing full Worker/Handshaker objects.
|
||||
type connStatusInputs struct {
|
||||
forceRelay bool // NB_FORCE_RELAY or JS/WASM
|
||||
peerUsesRelay bool // remote peer advertises relay support AND local has relay
|
||||
relayConnected bool // statusRelay reports Connected (independent of whether peer uses relay)
|
||||
remoteSupportsICE bool // remote peer sent ICE credentials
|
||||
iceWorkerCreated bool // local WorkerICE exists (false in force-relay mode)
|
||||
iceStatusConnecting bool // statusICE is anything other than Disconnected
|
||||
iceInProgress bool // a negotiation is currently in flight
|
||||
}
|
||||
|
||||
|
||||
// ConnStatus describe the status of a peer's connection
|
||||
type ConnStatus int32
|
||||
|
||||
|
||||
@@ -1,201 +0,0 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
)
|
||||
|
||||
func TestEvalConnStatus_ForceRelay(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in connStatusInputs
|
||||
want guard.ConnStatus
|
||||
}{
|
||||
{
|
||||
name: "force relay, peer uses relay, relay up",
|
||||
in: connStatusInputs{
|
||||
forceRelay: true,
|
||||
peerUsesRelay: true,
|
||||
relayConnected: true,
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "force relay, peer uses relay, relay down",
|
||||
in: connStatusInputs{
|
||||
forceRelay: true,
|
||||
peerUsesRelay: true,
|
||||
relayConnected: false,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "force relay, peer does NOT use relay - disconnected forever",
|
||||
in: connStatusInputs{
|
||||
forceRelay: true,
|
||||
peerUsesRelay: false,
|
||||
relayConnected: true,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := evalConnStatus(tc.in); got != tc.want {
|
||||
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvalConnStatus_ICEUnavailable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in connStatusInputs
|
||||
want guard.ConnStatus
|
||||
}{
|
||||
{
|
||||
name: "remote does not support ICE, peer uses relay, relay up",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: true,
|
||||
relayConnected: true,
|
||||
remoteSupportsICE: false,
|
||||
iceWorkerCreated: true,
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "remote does not support ICE, peer uses relay, relay down",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: true,
|
||||
relayConnected: false,
|
||||
remoteSupportsICE: false,
|
||||
iceWorkerCreated: true,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "ICE worker not yet created, relay up",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: true,
|
||||
relayConnected: true,
|
||||
remoteSupportsICE: true,
|
||||
iceWorkerCreated: false,
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "remote does not support ICE, peer does not use relay",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: false,
|
||||
relayConnected: false,
|
||||
remoteSupportsICE: false,
|
||||
iceWorkerCreated: true,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := evalConnStatus(tc.in); got != tc.want {
|
||||
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvalConnStatus_FullyAvailable(t *testing.T) {
|
||||
base := connStatusInputs{
|
||||
remoteSupportsICE: true,
|
||||
iceWorkerCreated: true,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutator func(*connStatusInputs)
|
||||
want guard.ConnStatus
|
||||
}{
|
||||
{
|
||||
name: "ICE connected, relay connected, peer uses relay",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = true
|
||||
in.relayConnected = true
|
||||
in.iceStatusConnecting = true
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE connected, peer does NOT use relay",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.relayConnected = false
|
||||
in.iceStatusConnecting = true
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE InProgress only, peer does NOT use relay",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = true
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE down, relay up, peer uses relay -> partial",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = true
|
||||
in.relayConnected = true
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = false
|
||||
},
|
||||
want: guard.ConnStatusPartiallyConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE down, peer does NOT use relay -> disconnected",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.relayConnected = false
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = false
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "ICE up, peer uses relay but relay down -> partial (relay required, ICE ignored)",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = true
|
||||
in.relayConnected = false
|
||||
in.iceStatusConnecting = true
|
||||
},
|
||||
// relayOK = false (peer uses relay but it's down), iceUp = true
|
||||
// first switch arm fails (relayOK false), relayUsedAndUp = false (relay down),
|
||||
// falls into default: Disconnected.
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "ICE down, relay up but peer does not use relay -> disconnected",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.relayConnected = true // not actually used since peer doesn't rely on it
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = false
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
in := base
|
||||
tc.mutator(&in)
|
||||
if got := evalConnStatus(in); got != tc.want {
|
||||
t.Fatalf("evalConnStatus = %v, want %v (inputs: %+v)", got, tc.want, in)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,38 +7,12 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
|
||||
EnvKeyNBHomeRelayServers = "NB_HOME_RELAY_SERVERS"
|
||||
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
|
||||
)
|
||||
|
||||
func IsForceRelayed() bool {
|
||||
func isForceRelayed() bool {
|
||||
if runtime.GOOS == "js" {
|
||||
return true
|
||||
}
|
||||
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
|
||||
}
|
||||
|
||||
// OverrideRelayURLs returns the relay server URL list set in
|
||||
// NB_HOME_RELAY_SERVERS (comma-separated) and a boolean indicating whether
|
||||
// the override is active. When the env var is unset, the boolean is false
|
||||
// and the caller should keep the list received from the management server.
|
||||
// Intended for lab/debug scenarios where a peer must pin to a specific home
|
||||
// relay regardless of what management offers.
|
||||
func OverrideRelayURLs() ([]string, bool) {
|
||||
raw := os.Getenv(EnvKeyNBHomeRelayServers)
|
||||
if raw == "" {
|
||||
return nil, false
|
||||
}
|
||||
parts := strings.Split(raw, ",")
|
||||
urls := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
urls = append(urls, p)
|
||||
}
|
||||
}
|
||||
if len(urls) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
return urls, true
|
||||
}
|
||||
|
||||
@@ -8,19 +8,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ConnStatus represents the connection state as seen by the guard.
|
||||
type ConnStatus int
|
||||
|
||||
const (
|
||||
// ConnStatusDisconnected means neither ICE nor Relay is connected.
|
||||
ConnStatusDisconnected ConnStatus = iota
|
||||
// ConnStatusPartiallyConnected means Relay is connected but ICE is not.
|
||||
ConnStatusPartiallyConnected
|
||||
// ConnStatusConnected means all required connections are established.
|
||||
ConnStatusConnected
|
||||
)
|
||||
|
||||
type connStatusFunc func() ConnStatus
|
||||
type isConnectedFunc func() bool
|
||||
|
||||
// Guard is responsible for the reconnection logic.
|
||||
// It will trigger to send an offer to the peer then has connection issues.
|
||||
@@ -32,14 +20,14 @@ type connStatusFunc func() ConnStatus
|
||||
// - ICE candidate changes
|
||||
type Guard struct {
|
||||
log *log.Entry
|
||||
isConnectedOnAllWay connStatusFunc
|
||||
isConnectedOnAllWay isConnectedFunc
|
||||
timeout time.Duration
|
||||
srWatcher *SRWatcher
|
||||
relayedConnDisconnected chan struct{}
|
||||
iCEConnDisconnected chan struct{}
|
||||
}
|
||||
|
||||
func NewGuard(log *log.Entry, isConnectedFn connStatusFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||
return &Guard{
|
||||
log: log,
|
||||
isConnectedOnAllWay: isConnectedFn,
|
||||
@@ -69,17 +57,8 @@ func (g *Guard) SetICEConnDisconnected() {
|
||||
}
|
||||
}
|
||||
|
||||
// reconnectLoopWithRetry periodically checks the connection status and sends offers to re-establish connectivity.
|
||||
//
|
||||
// Behavior depends on the connection state reported by isConnectedOnAllWay:
|
||||
// - Connected: no action, the peer is fully reachable.
|
||||
// - Disconnected (neither ICE nor Relay): retries aggressively with exponential backoff (800ms doubling
|
||||
// up to timeout), never gives up. This ensures rapid recovery when the peer has no connectivity at all.
|
||||
// - PartiallyConnected (Relay up, ICE not): retries up to 3 times with exponential backoff, then switches
|
||||
// to one attempt per hour. This limits signaling traffic when relay already provides connectivity.
|
||||
//
|
||||
// External events (relay/ICE disconnect, signal/relay reconnect, candidate changes) reset the retry
|
||||
// counter and backoff ticker, giving ICE a fresh chance after network conditions change.
|
||||
// reconnectLoopWithRetry periodically check the connection status.
|
||||
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
|
||||
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
||||
srReconnectedChan := g.srWatcher.NewListener()
|
||||
defer g.srWatcher.RemoveListener(srReconnectedChan)
|
||||
@@ -89,47 +68,36 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
||||
|
||||
tickerChannel := ticker.C
|
||||
|
||||
iceState := &iceRetryState{log: g.log}
|
||||
defer iceState.reset()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-tickerChannel:
|
||||
switch g.isConnectedOnAllWay() {
|
||||
case ConnStatusConnected:
|
||||
// all good, nothing to do
|
||||
case ConnStatusDisconnected:
|
||||
callback()
|
||||
case ConnStatusPartiallyConnected:
|
||||
if iceState.shouldRetry() {
|
||||
callback()
|
||||
} else {
|
||||
iceState.enterHourlyMode()
|
||||
ticker.Stop()
|
||||
tickerChannel = iceState.hourlyC()
|
||||
}
|
||||
case t := <-tickerChannel:
|
||||
if t.IsZero() {
|
||||
g.log.Infof("retry timed out, stop periodic offer sending")
|
||||
// after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop
|
||||
tickerChannel = make(<-chan time.Time)
|
||||
continue
|
||||
}
|
||||
|
||||
if !g.isConnectedOnAllWay() {
|
||||
callback()
|
||||
}
|
||||
case <-g.relayedConnDisconnected:
|
||||
g.log.Debugf("Relay connection changed, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-g.iCEConnDisconnected:
|
||||
g.log.Debugf("ICE connection changed, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-srReconnectedChan:
|
||||
g.log.Debugf("has network changes, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-ctx.Done():
|
||||
g.log.Debugf("context is done, stop reconnect loop")
|
||||
@@ -152,7 +120,7 @@ func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker {
|
||||
return backoff.NewTicker(bo)
|
||||
}
|
||||
|
||||
func (g *Guard) newReconnectTicker(ctx context.Context) *backoff.Ticker {
|
||||
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
|
||||
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 800 * time.Millisecond,
|
||||
RandomizationFactor: 0.1,
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
package guard
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxICERetries is the maximum number of ICE offer attempts when relay is connected
|
||||
maxICERetries = 3
|
||||
// iceRetryInterval is the periodic retry interval after ICE retries are exhausted
|
||||
iceRetryInterval = 1 * time.Hour
|
||||
)
|
||||
|
||||
// iceRetryState tracks the limited ICE retry attempts when relay is already connected.
|
||||
// After maxICERetries attempts it switches to a periodic hourly retry.
|
||||
type iceRetryState struct {
|
||||
log *log.Entry
|
||||
retries int
|
||||
hourly *time.Ticker
|
||||
}
|
||||
|
||||
func (s *iceRetryState) reset() {
|
||||
s.retries = 0
|
||||
if s.hourly != nil {
|
||||
s.hourly.Stop()
|
||||
s.hourly = nil
|
||||
}
|
||||
}
|
||||
|
||||
// shouldRetry reports whether the caller should send another ICE offer on this tick.
|
||||
// Returns false when the per-cycle retry budget is exhausted and the caller must switch
|
||||
// to the hourly ticker via enterHourlyMode + hourlyC.
|
||||
func (s *iceRetryState) shouldRetry() bool {
|
||||
if s.hourly != nil {
|
||||
s.log.Debugf("hourly ICE retry attempt")
|
||||
return true
|
||||
}
|
||||
|
||||
s.retries++
|
||||
if s.retries <= maxICERetries {
|
||||
s.log.Debugf("ICE retry attempt %d/%d", s.retries, maxICERetries)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// enterHourlyMode starts the hourly retry ticker. Must be called after shouldRetry returns false.
|
||||
func (s *iceRetryState) enterHourlyMode() {
|
||||
s.log.Infof("ICE retries exhausted (%d/%d), switching to hourly retry", maxICERetries, maxICERetries)
|
||||
s.hourly = time.NewTicker(iceRetryInterval)
|
||||
}
|
||||
|
||||
func (s *iceRetryState) hourlyC() <-chan time.Time {
|
||||
if s.hourly == nil {
|
||||
return nil
|
||||
}
|
||||
return s.hourly.C
|
||||
}
|
||||
@@ -1,103 +0,0 @@
|
||||
package guard
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func newTestRetryState() *iceRetryState {
|
||||
return &iceRetryState{log: log.NewEntry(log.StandardLogger())}
|
||||
}
|
||||
|
||||
func TestICERetryState_AllowsInitialBudget(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
|
||||
for i := 1; i <= maxICERetries; i++ {
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned false on attempt %d, want true (budget = %d)", i, maxICERetries)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ExhaustsAfterBudget(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
|
||||
for i := 0; i < maxICERetries; i++ {
|
||||
_ = s.shouldRetry()
|
||||
}
|
||||
|
||||
if s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned true after budget exhausted, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_HourlyCNilBeforeEnterHourlyMode(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
|
||||
if s.hourlyC() != nil {
|
||||
t.Fatalf("hourlyC returned non-nil channel before enterHourlyMode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_EnterHourlyModeArmsTicker(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
for i := 0; i < maxICERetries+1; i++ {
|
||||
_ = s.shouldRetry()
|
||||
}
|
||||
|
||||
s.enterHourlyMode()
|
||||
defer s.reset()
|
||||
|
||||
if s.hourlyC() == nil {
|
||||
t.Fatalf("hourlyC returned nil after enterHourlyMode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ShouldRetryTrueInHourlyMode(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
s.enterHourlyMode()
|
||||
defer s.reset()
|
||||
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned false in hourly mode, want true")
|
||||
}
|
||||
|
||||
// Subsequent calls also return true — we keep retrying on each hourly tick.
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("second shouldRetry returned false in hourly mode, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ResetRestoresBudget(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
for i := 0; i < maxICERetries+1; i++ {
|
||||
_ = s.shouldRetry()
|
||||
}
|
||||
s.enterHourlyMode()
|
||||
|
||||
s.reset()
|
||||
|
||||
if s.hourlyC() != nil {
|
||||
t.Fatalf("hourlyC returned non-nil channel after reset")
|
||||
}
|
||||
if s.retries != 0 {
|
||||
t.Fatalf("retries = %d after reset, want 0", s.retries)
|
||||
}
|
||||
|
||||
for i := 1; i <= maxICERetries; i++ {
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned false on attempt %d after reset, want true", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ResetIsIdempotent(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
s.reset()
|
||||
s.reset() // second call must not panic or re-stop a nil ticker
|
||||
|
||||
if s.hourlyC() != nil {
|
||||
t.Fatalf("hourlyC non-nil after double reset")
|
||||
}
|
||||
}
|
||||
@@ -39,7 +39,7 @@ func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscove
|
||||
return srw
|
||||
}
|
||||
|
||||
func (w *SRWatcher) Start(disableICEMonitor bool) {
|
||||
func (w *SRWatcher) Start() {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
@@ -50,10 +50,8 @@ func (w *SRWatcher) Start(disableICEMonitor bool) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w.cancelIceMonitor = cancel
|
||||
|
||||
if !disableICEMonitor {
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||
}
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
||||
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
||||
|
||||
|
||||
@@ -3,9 +3,7 @@ package peer
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -41,18 +39,10 @@ type OfferAnswer struct {
|
||||
|
||||
// relay server address
|
||||
RelaySrvAddress string
|
||||
// RelaySrvIP is the IP the remote peer is connected to on its
|
||||
// relay server. Used as a dial target if DNS for RelaySrvAddress
|
||||
// fails. Zero value if the peer did not advertise an IP.
|
||||
RelaySrvIP netip.Addr
|
||||
// SessionID is the unique identifier of the session, used to discard old messages
|
||||
SessionID *ICESessionID
|
||||
}
|
||||
|
||||
func (o *OfferAnswer) hasICECredentials() bool {
|
||||
return o.IceCredentials.UFrag != "" && o.IceCredentials.Pwd != ""
|
||||
}
|
||||
|
||||
type Handshaker struct {
|
||||
mu sync.Mutex
|
||||
log *log.Entry
|
||||
@@ -69,10 +59,6 @@ type Handshaker struct {
|
||||
relayListener *AsyncOfferListener
|
||||
iceListener func(remoteOfferAnswer *OfferAnswer)
|
||||
|
||||
// remoteICESupported tracks whether the remote peer includes ICE credentials in its offers/answers.
|
||||
// When false, the local side skips ICE listener dispatch and suppresses ICE credentials in responses.
|
||||
remoteICESupported atomic.Bool
|
||||
|
||||
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
||||
remoteOffersCh chan OfferAnswer
|
||||
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
|
||||
@@ -80,7 +66,7 @@ type Handshaker struct {
|
||||
}
|
||||
|
||||
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker {
|
||||
h := &Handshaker{
|
||||
return &Handshaker{
|
||||
log: log,
|
||||
config: config,
|
||||
signaler: signaler,
|
||||
@@ -90,13 +76,6 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
|
||||
remoteOffersCh: make(chan OfferAnswer),
|
||||
remoteAnswerCh: make(chan OfferAnswer),
|
||||
}
|
||||
// assume remote supports ICE until we learn otherwise from received offers
|
||||
h.remoteICESupported.Store(ice != nil)
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *Handshaker) RemoteICESupported() bool {
|
||||
return h.remoteICESupported.Load()
|
||||
}
|
||||
|
||||
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||
@@ -111,20 +90,18 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case remoteOfferAnswer := <-h.remoteOffersCh:
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
|
||||
// Record signaling received for reconnection attempts
|
||||
if h.metricsStages != nil {
|
||||
h.metricsStages.RecordSignalingReceived()
|
||||
}
|
||||
|
||||
h.updateRemoteICEState(&remoteOfferAnswer)
|
||||
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil && h.RemoteICESupported() {
|
||||
if h.iceListener != nil {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
@@ -133,20 +110,18 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
|
||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
|
||||
// Record signaling received for reconnection attempts
|
||||
if h.metricsStages != nil {
|
||||
h.metricsStages.RecordSignalingReceived()
|
||||
}
|
||||
|
||||
h.updateRemoteICEState(&remoteOfferAnswer)
|
||||
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil && h.RemoteICESupported() {
|
||||
if h.iceListener != nil {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
@@ -208,39 +183,20 @@ func (h *Handshaker) sendAnswer() error {
|
||||
}
|
||||
|
||||
func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
||||
uFrag, pwd := h.ice.GetLocalUserCredentials()
|
||||
sid := h.ice.SessionID()
|
||||
answer := OfferAnswer{
|
||||
IceCredentials: IceCredentials{uFrag, pwd},
|
||||
WgListenPort: h.config.LocalWgPort,
|
||||
Version: version.NetbirdVersion(),
|
||||
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
|
||||
RosenpassAddr: h.config.RosenpassConfig.Addr,
|
||||
SessionID: &sid,
|
||||
}
|
||||
|
||||
if h.ice != nil && h.RemoteICESupported() {
|
||||
uFrag, pwd := h.ice.GetLocalUserCredentials()
|
||||
sid := h.ice.SessionID()
|
||||
answer.IceCredentials = IceCredentials{uFrag, pwd}
|
||||
answer.SessionID = &sid
|
||||
}
|
||||
|
||||
if addr, ip, err := h.relay.RelayInstanceAddress(); err == nil {
|
||||
if addr, err := h.relay.RelayInstanceAddress(); err == nil {
|
||||
answer.RelaySrvAddress = addr
|
||||
answer.RelaySrvIP = ip
|
||||
}
|
||||
|
||||
return answer
|
||||
}
|
||||
|
||||
func (h *Handshaker) updateRemoteICEState(offer *OfferAnswer) {
|
||||
hasICE := offer.hasICECredentials()
|
||||
prev := h.remoteICESupported.Swap(hasICE)
|
||||
if prev != hasICE {
|
||||
if hasICE {
|
||||
h.log.Infof("remote peer started sending ICE credentials")
|
||||
} else {
|
||||
h.log.Infof("remote peer stopped sending ICE credentials")
|
||||
if h.ice != nil {
|
||||
h.ice.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
type mocListener struct {
|
||||
lastState int
|
||||
wg sync.WaitGroup
|
||||
peersWg sync.WaitGroup
|
||||
peers int
|
||||
}
|
||||
|
||||
@@ -34,7 +33,6 @@ func (l *mocListener) OnAddressChanged(host, addr string) {
|
||||
}
|
||||
func (l *mocListener) OnPeersListChanged(size int) {
|
||||
l.peers = size
|
||||
l.peersWg.Done()
|
||||
}
|
||||
|
||||
func (l *mocListener) setWaiter() {
|
||||
@@ -45,14 +43,6 @@ func (l *mocListener) wait() {
|
||||
l.wg.Wait()
|
||||
}
|
||||
|
||||
func (l *mocListener) setPeersWaiter() {
|
||||
l.peersWg.Add(1)
|
||||
}
|
||||
|
||||
func (l *mocListener) waitPeers() {
|
||||
l.peersWg.Wait()
|
||||
}
|
||||
|
||||
func Test_notifier_serverState(t *testing.T) {
|
||||
|
||||
type scenario struct {
|
||||
@@ -82,13 +72,11 @@ func Test_notifier_serverState(t *testing.T) {
|
||||
func Test_notifier_SetListener(t *testing.T) {
|
||||
listener := &mocListener{}
|
||||
listener.setWaiter()
|
||||
listener.setPeersWaiter()
|
||||
|
||||
n := newNotifier()
|
||||
n.lastNotification = stateConnecting
|
||||
n.setListener(listener)
|
||||
listener.wait()
|
||||
listener.waitPeers()
|
||||
if listener.lastState != n.lastNotification {
|
||||
t.Errorf("invalid state: %d, expected: %d", listener.lastState, n.lastNotification)
|
||||
}
|
||||
@@ -97,14 +85,9 @@ func Test_notifier_SetListener(t *testing.T) {
|
||||
func Test_notifier_RemoveListener(t *testing.T) {
|
||||
listener := &mocListener{}
|
||||
listener.setWaiter()
|
||||
listener.setPeersWaiter()
|
||||
n := newNotifier()
|
||||
n.lastNotification = stateConnecting
|
||||
n.setListener(listener)
|
||||
// setListener replays cached state on a goroutine; wait for both the state
|
||||
// and peers callbacks to finish so we don't race on listener.peers.
|
||||
listener.wait()
|
||||
listener.waitPeers()
|
||||
n.removeListener()
|
||||
n.peerListChanged(1)
|
||||
|
||||
|
||||
@@ -46,27 +46,23 @@ func (s *Signaler) Ready() bool {
|
||||
|
||||
// SignalOfferAnswer signals either an offer or an answer to remote peer
|
||||
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
|
||||
var sessionIDBytes []byte
|
||||
if offerAnswer.SessionID != nil {
|
||||
var err error
|
||||
sessionIDBytes, err = offerAnswer.SessionID.Bytes()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get session ID bytes: %v", err)
|
||||
}
|
||||
sessionIDBytes, err := offerAnswer.SessionID.Bytes()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get session ID bytes: %v", err)
|
||||
}
|
||||
msg, err := signal.MarshalCredential(s.wgPrivateKey, remoteKey, signal.CredentialPayload{
|
||||
Type: bodyType,
|
||||
WgListenPort: offerAnswer.WgListenPort,
|
||||
Credential: &signal.Credential{
|
||||
msg, err := signal.MarshalCredential(
|
||||
s.wgPrivateKey,
|
||||
offerAnswer.WgListenPort,
|
||||
remoteKey,
|
||||
&signal.Credential{
|
||||
UFrag: offerAnswer.IceCredentials.UFrag,
|
||||
Pwd: offerAnswer.IceCredentials.Pwd,
|
||||
},
|
||||
RosenpassPubKey: offerAnswer.RosenpassPubKey,
|
||||
RosenpassAddr: offerAnswer.RosenpassAddr,
|
||||
RelaySrvAddress: offerAnswer.RelaySrvAddress,
|
||||
RelaySrvIP: offerAnswer.RelaySrvIP,
|
||||
SessionID: sessionIDBytes,
|
||||
})
|
||||
bodyType,
|
||||
offerAnswer.RosenpassPubKey,
|
||||
offerAnswer.RosenpassAddr,
|
||||
offerAnswer.RelaySrvAddress,
|
||||
sessionIDBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -320,10 +320,10 @@ func (d *Status) RemovePeer(peerPubKey string) error {
|
||||
// UpdatePeerState updates peer status
|
||||
func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
if !ok {
|
||||
d.mux.Unlock()
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
@@ -343,29 +343,23 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
|
||||
d.peers[receivedState.PubKey] = peerState
|
||||
|
||||
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
|
||||
// when we close the connection we will not notify the router manager
|
||||
notifyRouter := receivedState.ConnStatus == StatusIdle
|
||||
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
|
||||
numPeers := d.numOfPeers()
|
||||
|
||||
d.mux.Unlock()
|
||||
|
||||
if notifyList {
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||
d.notifyPeerListChanged()
|
||||
}
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
|
||||
// when we close the connection we will not notify the router manager
|
||||
if receivedState.ConnStatus == StatusIdle {
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.ResID) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[peer]
|
||||
if !ok {
|
||||
d.mux.Unlock()
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
@@ -377,20 +371,17 @@ func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.R
|
||||
d.routeIDLookup.AddRemoteRouteID(resourceId, pref)
|
||||
}
|
||||
|
||||
numPeers := d.numOfPeers()
|
||||
d.mux.Unlock()
|
||||
|
||||
// todo: consider to make sense of this notification or not
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) RemovePeerStateRoute(peer string, route string) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[peer]
|
||||
if !ok {
|
||||
d.mux.Unlock()
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
@@ -402,11 +393,8 @@ func (d *Status) RemovePeerStateRoute(peer string, route string) error {
|
||||
d.routeIDLookup.RemoveRemoteRouteID(pref)
|
||||
}
|
||||
|
||||
numPeers := d.numOfPeers()
|
||||
d.mux.Unlock()
|
||||
|
||||
// todo: consider to make sense of this notification or not
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -422,10 +410,10 @@ func (d *Status) CheckRoutes(ip netip.Addr) ([]byte, bool) {
|
||||
|
||||
func (d *Status) UpdatePeerICEState(receivedState State) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
if !ok {
|
||||
d.mux.Unlock()
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
@@ -443,28 +431,22 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
|
||||
|
||||
d.peers[receivedState.PubKey] = peerState
|
||||
|
||||
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
|
||||
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
|
||||
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
|
||||
numPeers := d.numOfPeers()
|
||||
|
||||
d.mux.Unlock()
|
||||
|
||||
if notifyList {
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||
d.notifyPeerListChanged()
|
||||
}
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
|
||||
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
if !ok {
|
||||
d.mux.Unlock()
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
@@ -479,28 +461,22 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
||||
|
||||
d.peers[receivedState.PubKey] = peerState
|
||||
|
||||
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
|
||||
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
|
||||
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
|
||||
numPeers := d.numOfPeers()
|
||||
|
||||
d.mux.Unlock()
|
||||
|
||||
if notifyList {
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||
d.notifyPeerListChanged()
|
||||
}
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
|
||||
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
if !ok {
|
||||
d.mux.Unlock()
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
@@ -514,28 +490,22 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
|
||||
|
||||
d.peers[receivedState.PubKey] = peerState
|
||||
|
||||
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
|
||||
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
|
||||
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
|
||||
numPeers := d.numOfPeers()
|
||||
|
||||
d.mux.Unlock()
|
||||
|
||||
if notifyList {
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||
d.notifyPeerListChanged()
|
||||
}
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
|
||||
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
if !ok {
|
||||
d.mux.Unlock()
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
@@ -552,18 +522,12 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
||||
|
||||
d.peers[receivedState.PubKey] = peerState
|
||||
|
||||
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
|
||||
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
|
||||
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
|
||||
numPeers := d.numOfPeers()
|
||||
|
||||
d.mux.Unlock()
|
||||
|
||||
if notifyList {
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||
d.notifyPeerListChanged()
|
||||
}
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
|
||||
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -630,33 +594,17 @@ func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) erro
|
||||
// FinishPeerListModifications this event invoke the notification
|
||||
func (d *Status) FinishPeerListModifications() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
if !d.peerListChangedForNotification {
|
||||
d.mux.Unlock()
|
||||
return
|
||||
}
|
||||
d.peerListChangedForNotification = false
|
||||
|
||||
numPeers := d.numOfPeers()
|
||||
d.notifyPeerListChanged()
|
||||
|
||||
// snapshot per-peer router state to deliver after the lock is released
|
||||
type routerDispatch struct {
|
||||
peerID string
|
||||
snapshot map[string]RouterState
|
||||
}
|
||||
dispatches := make([]routerDispatch, 0, len(d.peers))
|
||||
for key := range d.peers {
|
||||
snapshot := d.snapshotRouterPeersLocked(key, true)
|
||||
if snapshot != nil {
|
||||
dispatches = append(dispatches, routerDispatch{peerID: key, snapshot: snapshot})
|
||||
}
|
||||
}
|
||||
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
for _, rd := range dispatches {
|
||||
d.dispatchRouterPeers(rd.peerID, rd.snapshot)
|
||||
d.notifyPeerStateChangeListeners(key)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -707,12 +655,10 @@ func (d *Status) GetLocalPeerState() LocalPeerState {
|
||||
// UpdateLocalPeerState updates local peer status
|
||||
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
||||
d.mux.Lock()
|
||||
d.localPeer = localPeerState
|
||||
fqdn := d.localPeer.FQDN
|
||||
ip := d.localPeer.IP
|
||||
d.mux.Unlock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
d.notifier.localAddressChanged(fqdn, ip)
|
||||
d.localPeer = localPeerState
|
||||
d.notifyAddressChanged()
|
||||
}
|
||||
|
||||
// AddLocalPeerStateRoute adds a route to the local peer state
|
||||
@@ -775,36 +721,30 @@ func (d *Status) CleanLocalPeerStateRoutes() {
|
||||
// CleanLocalPeerState cleans local peer status
|
||||
func (d *Status) CleanLocalPeerState() {
|
||||
d.mux.Lock()
|
||||
d.localPeer = LocalPeerState{}
|
||||
fqdn := d.localPeer.FQDN
|
||||
ip := d.localPeer.IP
|
||||
d.mux.Unlock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
d.notifier.localAddressChanged(fqdn, ip)
|
||||
d.localPeer = LocalPeerState{}
|
||||
d.notifyAddressChanged()
|
||||
}
|
||||
|
||||
// MarkManagementDisconnected sets ManagementState to disconnected
|
||||
func (d *Status) MarkManagementDisconnected(err error) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
defer d.onConnectionChanged()
|
||||
|
||||
d.managementState = false
|
||||
d.managementError = err
|
||||
mgm := d.managementState
|
||||
sig := d.signalState
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
}
|
||||
|
||||
// MarkManagementConnected sets ManagementState to connected
|
||||
func (d *Status) MarkManagementConnected() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
defer d.onConnectionChanged()
|
||||
|
||||
d.managementState = true
|
||||
d.managementError = nil
|
||||
mgm := d.managementState
|
||||
sig := d.signalState
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
}
|
||||
|
||||
// UpdateSignalAddress update the address of the signal server
|
||||
@@ -838,25 +778,21 @@ func (d *Status) UpdateLazyConnection(enabled bool) {
|
||||
// MarkSignalDisconnected sets SignalState to disconnected
|
||||
func (d *Status) MarkSignalDisconnected(err error) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
defer d.onConnectionChanged()
|
||||
|
||||
d.signalState = false
|
||||
d.signalError = err
|
||||
mgm := d.managementState
|
||||
sig := d.signalState
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
}
|
||||
|
||||
// MarkSignalConnected sets SignalState to connected
|
||||
func (d *Status) MarkSignalConnected() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
defer d.onConnectionChanged()
|
||||
|
||||
d.signalState = true
|
||||
d.signalError = nil
|
||||
mgm := d.managementState
|
||||
sig := d.signalState
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
}
|
||||
|
||||
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
|
||||
@@ -983,7 +919,7 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
|
||||
// if the server connection is not established then we will use the general address
|
||||
// in case of connection we will use the instance specific address
|
||||
instanceAddr, _, err := d.relayMgr.RelayInstanceAddress()
|
||||
instanceAddr, err := d.relayMgr.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
// TODO add their status
|
||||
for _, r := range d.relayMgr.ServerURLs() {
|
||||
@@ -1076,17 +1012,18 @@ func (d *Status) RemoveConnectionListener() {
|
||||
d.notifier.removeListener()
|
||||
}
|
||||
|
||||
// snapshotRouterPeersLocked builds the RouterState map for a peer's subscribers.
|
||||
// Caller MUST hold d.mux. Returns nil when there are no subscribers for peerID
|
||||
// or when notify is false. The snapshot is consumed later by dispatchRouterPeers
|
||||
// outside the lock so the channel send cannot stall any d.mux holder.
|
||||
func (d *Status) snapshotRouterPeersLocked(peerID string, notify bool) map[string]RouterState {
|
||||
if !notify {
|
||||
return nil
|
||||
}
|
||||
if _, ok := d.changeNotify[peerID]; !ok {
|
||||
return nil
|
||||
func (d *Status) onConnectionChanged() {
|
||||
d.notifier.updateServerStates(d.managementState, d.signalState)
|
||||
}
|
||||
|
||||
// notifyPeerStateChangeListeners notifies route manager about the change in peer state
|
||||
func (d *Status) notifyPeerStateChangeListeners(peerID string) {
|
||||
subs, ok := d.changeNotify[peerID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// collect the relevant data for router peers
|
||||
routerPeers := make(map[string]RouterState, len(d.changeNotify))
|
||||
for pid := range d.changeNotify {
|
||||
s, ok := d.peers[pid]
|
||||
@@ -1094,35 +1031,13 @@ func (d *Status) snapshotRouterPeersLocked(peerID string, notify bool) map[strin
|
||||
log.Warnf("router peer not found in peers list: %s", pid)
|
||||
continue
|
||||
}
|
||||
|
||||
routerPeers[pid] = RouterState{
|
||||
Status: s.ConnStatus,
|
||||
Relayed: s.Relayed,
|
||||
Latency: s.Latency,
|
||||
}
|
||||
}
|
||||
return routerPeers
|
||||
}
|
||||
|
||||
// dispatchRouterPeers delivers a previously snapshotted router-state map to
|
||||
// the peer's subscribers. Caller MUST NOT hold d.mux. The method takes a
|
||||
// fresh, short read of d.changeNotify under the lock to grab subscriber
|
||||
// channels, then sends outside the lock so a slow consumer cannot block other
|
||||
// d.mux holders. The send itself stays blocking (only short-circuited by the
|
||||
// subscriber's context) so peer state transitions are not silently dropped.
|
||||
func (d *Status) dispatchRouterPeers(peerID string, routerPeers map[string]RouterState) {
|
||||
if routerPeers == nil {
|
||||
return
|
||||
}
|
||||
|
||||
d.mux.Lock()
|
||||
subsMap, ok := d.changeNotify[peerID]
|
||||
subs := make([]*StatusChangeSubscription, 0, len(subsMap))
|
||||
if ok {
|
||||
for _, sub := range subsMap {
|
||||
subs = append(subs, sub)
|
||||
}
|
||||
}
|
||||
d.mux.Unlock()
|
||||
|
||||
for _, sub := range subs {
|
||||
select {
|
||||
@@ -1132,6 +1047,14 @@ func (d *Status) dispatchRouterPeers(peerID string, routerPeers map[string]Route
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Status) notifyPeerListChanged() {
|
||||
d.notifier.peerListChanged(d.numOfPeers())
|
||||
}
|
||||
|
||||
func (d *Status) notifyAddressChanged() {
|
||||
d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP)
|
||||
}
|
||||
|
||||
func (d *Status) numOfPeers() int {
|
||||
return len(d.peers) + len(d.offlinePeers)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
@@ -54,19 +53,15 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
w.relaySupportedOnRemotePeer.Store(true)
|
||||
|
||||
// the relayManager will return with error in case if the connection has lost with relay server
|
||||
currentRelayAddress, _, err := w.relayManager.RelayInstanceAddress()
|
||||
currentRelayAddress, err := w.relayManager.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to handle new offer: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
|
||||
var serverIP netip.Addr
|
||||
if srv == remoteOfferAnswer.RelaySrvAddress {
|
||||
serverIP = remoteOfferAnswer.RelaySrvIP
|
||||
}
|
||||
|
||||
relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key, serverIP)
|
||||
relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key)
|
||||
if err != nil {
|
||||
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
|
||||
w.log.Debugf("handled offer by reusing existing relay connection")
|
||||
@@ -95,7 +90,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
})
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) RelayInstanceAddress() (string, netip.Addr, error) {
|
||||
func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
|
||||
return w.relayManager.RelayInstanceAddress()
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -178,12 +177,7 @@ func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
dst := net.IPv4zero
|
||||
if runtime.GOOS == "linux" {
|
||||
// go-netroute v0.4.0 rejects unspecified destinations client-side on Linux.
|
||||
dst = net.IPv4(0, 0, 0, 1)
|
||||
}
|
||||
_, gateway, localIP, err = router.Route(dst)
|
||||
_, gateway, localIP, err = router.Route(net.IPv4zero)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -202,12 +196,7 @@ func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
dst := net.IPv6zero
|
||||
if runtime.GOOS == "linux" {
|
||||
// ::2
|
||||
dst = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}
|
||||
}
|
||||
_, gateway, localIP, err = router.Route(dst)
|
||||
_, gateway, localIP, err = router.Route(net.IPv6zero)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
@@ -64,11 +64,13 @@ type ConfigInput struct {
|
||||
StateFilePath string
|
||||
PreSharedKey *string
|
||||
ServerSSHAllowed *bool
|
||||
ServerVNCAllowed *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
EnableSSHRemotePortForwarding *bool
|
||||
DisableSSHAuth *bool
|
||||
DisableVNCAuth *bool
|
||||
SSHJWTCacheTTL *int
|
||||
NATExternalIPs []string
|
||||
CustomDNSAddress []byte
|
||||
@@ -114,11 +116,13 @@ type Config struct {
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed *bool
|
||||
ServerVNCAllowed *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
EnableSSHRemotePortForwarding *bool
|
||||
DisableSSHAuth *bool
|
||||
DisableVNCAuth *bool
|
||||
SSHJWTCacheTTL *int
|
||||
|
||||
DisableClientRoutes bool
|
||||
@@ -415,6 +419,21 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.ServerVNCAllowed != nil {
|
||||
if config.ServerVNCAllowed == nil || *input.ServerVNCAllowed != *config.ServerVNCAllowed {
|
||||
if *input.ServerVNCAllowed {
|
||||
log.Infof("enabling VNC server")
|
||||
} else {
|
||||
log.Infof("disabling VNC server")
|
||||
}
|
||||
config.ServerVNCAllowed = input.ServerVNCAllowed
|
||||
updated = true
|
||||
}
|
||||
} else if config.ServerVNCAllowed == nil {
|
||||
config.ServerVNCAllowed = util.True()
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
|
||||
if *input.EnableSSHRoot {
|
||||
log.Infof("enabling SSH root login")
|
||||
@@ -465,6 +484,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableVNCAuth != nil && input.DisableVNCAuth != config.DisableVNCAuth {
|
||||
if *input.DisableVNCAuth {
|
||||
log.Infof("disabling VNC authentication")
|
||||
} else {
|
||||
log.Infof("enabling VNC authentication")
|
||||
}
|
||||
config.DisableVNCAuth = input.DisableVNCAuth
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
|
||||
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
|
||||
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
|
||||
|
||||
@@ -89,16 +89,8 @@ func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
|
||||
return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec)
|
||||
}
|
||||
|
||||
reused := false
|
||||
if err := r.addScopedDefault(unspec, nexthop); err != nil {
|
||||
if !errors.Is(err, unix.EEXIST) {
|
||||
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
|
||||
}
|
||||
// macOS installs its own RTF_IFSCOPE defaults for primary service
|
||||
// selection on multi-NIC setups, so a route on this ifindex can
|
||||
// already exist before we try. Binding to it via IP[V6]_BOUND_IF
|
||||
// still produces the scoped lookup we need.
|
||||
reused = true
|
||||
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
|
||||
}
|
||||
|
||||
af := unix.AF_INET
|
||||
@@ -110,11 +102,7 @@ func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
|
||||
if nexthop.IP.IsValid() {
|
||||
via = nexthop.IP.String()
|
||||
}
|
||||
verb := "installed"
|
||||
if reused {
|
||||
verb = "reused existing"
|
||||
}
|
||||
log.Infof("%s scoped default route via %s on %s for %s", verb, via, nexthop.Intf.Name, afOf(unspec))
|
||||
log.Infof("installed scoped default route via %s on %s for %s", via, nexthop.Intf.Name, afOf(unspec))
|
||||
return true, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -342,22 +342,6 @@ func GetNextHop(ip netip.Addr) (Nexthop, error) {
|
||||
if err != nil {
|
||||
return Nexthop{}, fmt.Errorf("new netroute: %w", err)
|
||||
}
|
||||
|
||||
// go-netroute v0.4.0 rejects unspecified destinations on Linux with a hard
|
||||
// client-side check. Substitute the lowest non-loopback address so the
|
||||
// lookup falls through to the default route (::1 / 127.0.0.1 would match
|
||||
// loopback, ::/0.0.0.0 are unspec). BSD/Windows pass the query straight to
|
||||
// the kernel and need no substitution.
|
||||
if runtime.GOOS == "linux" && ip.IsUnspecified() {
|
||||
if ip.Is6() {
|
||||
// ::2
|
||||
ip = netip.AddrFrom16([16]byte{15: 2})
|
||||
} else {
|
||||
// 0.0.0.1
|
||||
ip = netip.AddrFrom4([4]byte{0, 0, 0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
|
||||
if err != nil {
|
||||
log.Debugf("Failed to get route for %s: %v", ip, err)
|
||||
|
||||
@@ -354,13 +354,9 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
|
||||
require.NoError(t, err, "Should be able to get IPv4 default route")
|
||||
t.Logf("Initial IPv4 next hop: %s", initialNextHopV4)
|
||||
|
||||
if testCase.prefix.Addr().Is6() && !testCase.expectError {
|
||||
ensureIPv6DefaultRoute(t)
|
||||
}
|
||||
|
||||
initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified())
|
||||
if testCase.prefix.Addr().Is6() &&
|
||||
initialNextHopV6.Intf != nil && strings.HasPrefix(initialNextHopV6.Intf.Name, "utun") {
|
||||
(errors.Is(err, vars.ErrRouteNotFound) || initialNextHopV6.Intf != nil && strings.HasPrefix(initialNextHopV6.Intf.Name, "utun")) {
|
||||
t.Skip("Skipping test as no ipv6 default route is available")
|
||||
}
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os/exec"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
|
||||
// interface so route lookups for global IPv6 prefixes resolve in environments
|
||||
// without v6 connectivity. If a default already exists it is left alone.
|
||||
func ensureIPv6DefaultRoute(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
out, err := exec.Command("route", "-6", "add", "default", "-iface", "lo0").CombinedOutput()
|
||||
if err != nil {
|
||||
// Existing default; nothing to install or clean up.
|
||||
if bytes.Contains(out, []byte("route already in table")) {
|
||||
return
|
||||
}
|
||||
t.Skipf("install IPv6 fallback default route: %v: %s", err, out)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if out, err := exec.Command("route", "-6", "delete", "default").CombinedOutput(); err != nil {
|
||||
t.Logf("delete IPv6 fallback default route: %v: %s", err, out)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
// ensureIPv6DefaultRoute installs a low-preference IPv6 default route via the
|
||||
// loopback interface so route lookups for global IPv6 prefixes resolve in
|
||||
// environments without v6 connectivity. Any pre-existing default route wins
|
||||
// because of its lower metric.
|
||||
func ensureIPv6DefaultRoute(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
lo, err := netlink.LinkByName("lo")
|
||||
require.NoError(t, err, "find loopback interface")
|
||||
|
||||
route := &netlink.Route{
|
||||
Dst: &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)},
|
||||
LinkIndex: lo.Attrs().Index,
|
||||
Priority: 1 << 20,
|
||||
}
|
||||
if err := netlink.RouteAdd(route); err != nil {
|
||||
if errors.Is(err, syscall.EEXIST) {
|
||||
return
|
||||
}
|
||||
t.Skipf("install IPv6 fallback default route: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) {
|
||||
t.Logf("delete IPv6 fallback default route: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os/exec"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const loopbackIfaceWindows = "Loopback Pseudo-Interface 1"
|
||||
|
||||
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
|
||||
// interface so route lookups for global IPv6 prefixes resolve in environments
|
||||
// without v6 connectivity. If a default already exists it is left alone.
|
||||
func ensureIPv6DefaultRoute(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
script := `New-NetRoute -DestinationPrefix "::/0" -InterfaceAlias "` + loopbackIfaceWindows + `" -RouteMetric 9999 -PolicyStore ActiveStore -ErrorAction Stop`
|
||||
out, err := exec.Command("powershell", "-Command", script).CombinedOutput()
|
||||
if err != nil {
|
||||
// Existing default; nothing to install or clean up.
|
||||
if bytes.Contains(out, []byte("already exists")) {
|
||||
return
|
||||
}
|
||||
t.Skipf("install IPv6 fallback default route: %v: %s", err, out)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
script := `Remove-NetRoute -DestinationPrefix "::/0" -InterfaceAlias "` + loopbackIfaceWindows + `" -Confirm:$false -ErrorAction Stop`
|
||||
if out, err := exec.Command("powershell", "-Command", script).CombinedOutput(); err != nil {
|
||||
t.Logf("delete IPv6 fallback default route: %v: %s", err, out)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -43,8 +44,8 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
|
||||
if rs.selectedRoutes == nil {
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
}
|
||||
clear(rs.deselectedRoutes)
|
||||
clear(rs.selectedRoutes)
|
||||
maps.Clear(rs.deselectedRoutes)
|
||||
maps.Clear(rs.selectedRoutes)
|
||||
for _, r := range allRoutes {
|
||||
rs.deselectedRoutes[r] = struct{}{}
|
||||
}
|
||||
@@ -77,8 +78,8 @@ func (rs *RouteSelector) SelectAllRoutes() {
|
||||
if rs.selectedRoutes == nil {
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
}
|
||||
clear(rs.deselectedRoutes)
|
||||
clear(rs.selectedRoutes)
|
||||
maps.Clear(rs.deselectedRoutes)
|
||||
maps.Clear(rs.selectedRoutes)
|
||||
}
|
||||
|
||||
// DeselectRoutes removes specific routes from the selection.
|
||||
@@ -115,8 +116,8 @@ func (rs *RouteSelector) DeselectAllRoutes() {
|
||||
if rs.selectedRoutes == nil {
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
}
|
||||
clear(rs.deselectedRoutes)
|
||||
clear(rs.selectedRoutes)
|
||||
maps.Clear(rs.deselectedRoutes)
|
||||
maps.Clear(rs.selectedRoutes)
|
||||
}
|
||||
|
||||
// IsSelected checks if a specific route is selected.
|
||||
|
||||
@@ -2,358 +2,217 @@
|
||||
|
||||
package sleep
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -framework IOKit -framework CoreFoundation
|
||||
#include <IOKit/pwr_mgt/IOPMLib.h>
|
||||
#include <IOKit/IOMessage.h>
|
||||
#include <CoreFoundation/CoreFoundation.h>
|
||||
|
||||
extern void sleepCallbackBridge();
|
||||
extern void poweredOnCallbackBridge();
|
||||
extern void suspendedCallbackBridge();
|
||||
extern void resumedCallbackBridge();
|
||||
|
||||
|
||||
// C global variables for IOKit state
|
||||
static IONotificationPortRef g_notifyPortRef = NULL;
|
||||
static io_object_t g_notifierObject = 0;
|
||||
static io_object_t g_generalInterestNotifier = 0;
|
||||
static io_connect_t g_rootPort = 0;
|
||||
static CFRunLoopRef g_runLoop = NULL;
|
||||
|
||||
static void sleepCallback(void* refCon, io_service_t service, natural_t messageType, void* messageArgument) {
|
||||
switch (messageType) {
|
||||
case kIOMessageSystemWillSleep:
|
||||
sleepCallbackBridge();
|
||||
IOAllowPowerChange(g_rootPort, (long)messageArgument);
|
||||
break;
|
||||
case kIOMessageSystemHasPoweredOn:
|
||||
poweredOnCallbackBridge();
|
||||
break;
|
||||
case kIOMessageServiceIsSuspended:
|
||||
suspendedCallbackBridge();
|
||||
break;
|
||||
case kIOMessageServiceIsResumed:
|
||||
resumedCallbackBridge();
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static void registerNotifications() {
|
||||
g_rootPort = IORegisterForSystemPower(
|
||||
NULL,
|
||||
&g_notifyPortRef,
|
||||
(IOServiceInterestCallback)sleepCallback,
|
||||
&g_notifierObject
|
||||
);
|
||||
|
||||
if (g_rootPort == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
CFRunLoopAddSource(CFRunLoopGetCurrent(),
|
||||
IONotificationPortGetRunLoopSource(g_notifyPortRef),
|
||||
kCFRunLoopCommonModes);
|
||||
|
||||
g_runLoop = CFRunLoopGetCurrent();
|
||||
CFRunLoopRun();
|
||||
}
|
||||
|
||||
static void unregisterNotifications() {
|
||||
CFRunLoopRemoveSource(g_runLoop,
|
||||
IONotificationPortGetRunLoopSource(g_notifyPortRef),
|
||||
kCFRunLoopCommonModes);
|
||||
|
||||
IODeregisterForSystemPower(&g_notifierObject);
|
||||
IOServiceClose(g_rootPort);
|
||||
IONotificationPortDestroy(g_notifyPortRef);
|
||||
CFRunLoopStop(g_runLoop);
|
||||
|
||||
g_notifyPortRef = NULL;
|
||||
g_notifierObject = 0;
|
||||
g_rootPort = 0;
|
||||
g_runLoop = NULL;
|
||||
}
|
||||
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// IOKit message types from IOKit/IOMessage.h.
|
||||
const (
|
||||
kIOMessageCanSystemSleep uintptr = 0xe0000270
|
||||
kIOMessageSystemWillSleep uintptr = 0xe0000280
|
||||
kIOMessageSystemHasPoweredOn uintptr = 0xe0000300
|
||||
)
|
||||
|
||||
var (
|
||||
ioKit iokitFuncs
|
||||
cf cfFuncs
|
||||
cfCommonModes uintptr
|
||||
|
||||
libInitOnce sync.Once
|
||||
libInitErr error
|
||||
|
||||
// callbackThunk is the single C-callable trampoline registered with IOKit.
|
||||
callbackThunk uintptr
|
||||
|
||||
serviceRegistry = make(map[*Detector]struct{})
|
||||
serviceRegistryMu sync.Mutex
|
||||
session *runLoopSession
|
||||
|
||||
// lifecycleMu serializes Register/Deregister so a new registration can't
|
||||
// start a second runloop while a previous teardown is still pending.
|
||||
lifecycleMu sync.Mutex
|
||||
)
|
||||
|
||||
// iokitFuncs holds IOKit symbols resolved once at init.
|
||||
type iokitFuncs struct {
|
||||
IORegisterForSystemPower func(refcon uintptr, portRef *uintptr, callback uintptr, notifier *uintptr) uintptr
|
||||
IODeregisterForSystemPower func(notifier *uintptr) int32
|
||||
IOAllowPowerChange func(kernelPort uintptr, notificationID uintptr) int32
|
||||
IOServiceClose func(connect uintptr) int32
|
||||
IONotificationPortGetRunLoopSource func(port uintptr) uintptr
|
||||
IONotificationPortDestroy func(port uintptr)
|
||||
}
|
||||
|
||||
// cfFuncs holds CoreFoundation symbols resolved once at init.
|
||||
type cfFuncs struct {
|
||||
CFRunLoopGetCurrent func() uintptr
|
||||
CFRunLoopRun func()
|
||||
CFRunLoopStop func(rl uintptr)
|
||||
CFRunLoopAddSource func(rl, source, mode uintptr)
|
||||
CFRunLoopRemoveSource func(rl, source, mode uintptr)
|
||||
}
|
||||
|
||||
// runLoopSession bundles the handles owned by one CFRunLoop lifetime. A nil
|
||||
// session means no runloop is active and the next Register must start one.
|
||||
type runLoopSession struct {
|
||||
rl uintptr
|
||||
port uintptr
|
||||
notifier uintptr
|
||||
rp uintptr
|
||||
}
|
||||
|
||||
// detectorSnapshot pins a detector's callback and done channel so dispatch
|
||||
// runs with values valid at snapshot time, even if a concurrent
|
||||
// Deregister/Register rewrites the detector's fields.
|
||||
type detectorSnapshot struct {
|
||||
detector *Detector
|
||||
callback func(event EventType)
|
||||
done <-chan struct{}
|
||||
}
|
||||
|
||||
// Detector delivers sleep and wake events to a registered callback.
|
||||
type Detector struct {
|
||||
callback func(event EventType)
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// Register installs callback for power events. The first registration starts
|
||||
// the CFRunLoop on a dedicated OS-locked thread and blocks until IOKit
|
||||
// registration succeeds or fails; subsequent registrations just add to the
|
||||
// dispatch set.
|
||||
func (d *Detector) Register(callback func(event EventType)) error {
|
||||
lifecycleMu.Lock()
|
||||
defer lifecycleMu.Unlock()
|
||||
//export sleepCallbackBridge
|
||||
func sleepCallbackBridge() {
|
||||
log.Info("sleepCallbackBridge event triggered")
|
||||
|
||||
serviceRegistryMu.Lock()
|
||||
defer serviceRegistryMu.Unlock()
|
||||
|
||||
for svc := range serviceRegistry {
|
||||
svc.triggerCallback(EventTypeSleep)
|
||||
}
|
||||
}
|
||||
|
||||
//export resumedCallbackBridge
|
||||
func resumedCallbackBridge() {
|
||||
log.Info("resumedCallbackBridge event triggered")
|
||||
}
|
||||
|
||||
//export suspendedCallbackBridge
|
||||
func suspendedCallbackBridge() {
|
||||
log.Info("suspendedCallbackBridge event triggered")
|
||||
}
|
||||
|
||||
//export poweredOnCallbackBridge
|
||||
func poweredOnCallbackBridge() {
|
||||
log.Info("poweredOnCallbackBridge event triggered")
|
||||
serviceRegistryMu.Lock()
|
||||
defer serviceRegistryMu.Unlock()
|
||||
|
||||
for svc := range serviceRegistry {
|
||||
svc.triggerCallback(EventTypeWakeUp)
|
||||
}
|
||||
}
|
||||
|
||||
type Detector struct {
|
||||
callback func(event EventType)
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewDetector() (*Detector, error) {
|
||||
return &Detector{}, nil
|
||||
}
|
||||
|
||||
func (d *Detector) Register(callback func(event EventType)) error {
|
||||
serviceRegistryMu.Lock()
|
||||
defer serviceRegistryMu.Unlock()
|
||||
|
||||
if _, exists := serviceRegistry[d]; exists {
|
||||
serviceRegistryMu.Unlock()
|
||||
return fmt.Errorf("detector service already registered")
|
||||
}
|
||||
d.callback = callback
|
||||
d.done = make(chan struct{})
|
||||
serviceRegistry[d] = struct{}{}
|
||||
needSetup := session == nil
|
||||
serviceRegistryMu.Unlock()
|
||||
|
||||
if !needSetup {
|
||||
d.callback = callback
|
||||
|
||||
d.ctx, d.cancel = context.WithCancel(context.Background())
|
||||
|
||||
if len(serviceRegistry) > 0 {
|
||||
serviceRegistry[d] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go runRunLoop(errCh)
|
||||
if err := <-errCh; err != nil {
|
||||
serviceRegistryMu.Lock()
|
||||
delete(serviceRegistry, d)
|
||||
close(d.done)
|
||||
d.done = nil
|
||||
serviceRegistryMu.Unlock()
|
||||
return err
|
||||
}
|
||||
serviceRegistry[d] = struct{}{}
|
||||
|
||||
// CFRunLoop must run on a single fixed OS thread
|
||||
go func() {
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
|
||||
C.registerNotifications()
|
||||
}()
|
||||
|
||||
log.Info("sleep detection service started on macOS")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deregister removes the detector. When the last detector leaves, IOKit
|
||||
// notifications are torn down and the runloop is stopped.
|
||||
// Deregister removes the detector. When the last detector is removed, IOKit registration is torn down
|
||||
// and the runloop is stopped and cleaned up.
|
||||
func (d *Detector) Deregister() error {
|
||||
lifecycleMu.Lock()
|
||||
defer lifecycleMu.Unlock()
|
||||
|
||||
serviceRegistryMu.Lock()
|
||||
if _, exists := serviceRegistry[d]; !exists {
|
||||
serviceRegistryMu.Unlock()
|
||||
defer serviceRegistryMu.Unlock()
|
||||
_, exists := serviceRegistry[d]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
close(d.done)
|
||||
|
||||
// cancel and remove this detector
|
||||
d.cancel()
|
||||
delete(serviceRegistry, d)
|
||||
|
||||
// If other Detectors still exist, leave IOKit running
|
||||
if len(serviceRegistry) > 0 {
|
||||
serviceRegistryMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
sess := session
|
||||
serviceRegistryMu.Unlock()
|
||||
|
||||
log.Info("sleep detection service stopping (deregister)")
|
||||
|
||||
if sess == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if sess.rl != 0 && sess.port != 0 {
|
||||
source := ioKit.IONotificationPortGetRunLoopSource(sess.port)
|
||||
cf.CFRunLoopRemoveSource(sess.rl, source, cfCommonModes)
|
||||
}
|
||||
if sess.notifier != 0 {
|
||||
n := sess.notifier
|
||||
ioKit.IODeregisterForSystemPower(&n)
|
||||
}
|
||||
|
||||
// Clear session only after IODeregisterForSystemPower returns so any
|
||||
// in-flight powerCallback can still look up session.rp to ack sleep.
|
||||
serviceRegistryMu.Lock()
|
||||
session = nil
|
||||
serviceRegistryMu.Unlock()
|
||||
|
||||
if sess.rp != 0 {
|
||||
ioKit.IOServiceClose(sess.rp)
|
||||
}
|
||||
if sess.port != 0 {
|
||||
ioKit.IONotificationPortDestroy(sess.port)
|
||||
}
|
||||
if sess.rl != 0 {
|
||||
cf.CFRunLoopStop(sess.rl)
|
||||
}
|
||||
// Deregister IOKit notifications, stop runloop, and free resources
|
||||
C.unregisterNotifications()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Detector) triggerCallback(event EventType, cb func(event EventType), done <-chan struct{}) {
|
||||
if cb == nil || done == nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
func (d *Detector) triggerCallback(event EventType) {
|
||||
doneChan := make(chan struct{})
|
||||
|
||||
timeout := time.NewTimer(500 * time.Millisecond)
|
||||
defer timeout.Stop()
|
||||
|
||||
go func() {
|
||||
defer close(doneChan)
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Errorf("panic in sleep callback: %v", r)
|
||||
}
|
||||
}()
|
||||
cb := d.callback
|
||||
go func(callback func(event EventType)) {
|
||||
log.Info("sleep detection event fired")
|
||||
cb(event)
|
||||
}()
|
||||
callback(event)
|
||||
close(doneChan)
|
||||
}(cb)
|
||||
|
||||
select {
|
||||
case <-doneChan:
|
||||
case <-done:
|
||||
case <-d.ctx.Done():
|
||||
case <-timeout.C:
|
||||
log.Warn("sleep callback timed out")
|
||||
log.Warnf("sleep callback timed out")
|
||||
}
|
||||
}
|
||||
|
||||
// NewDetector initializes IOKit/CoreFoundation bindings and returns a Detector.
|
||||
func NewDetector() (*Detector, error) {
|
||||
if err := initLibs(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Detector{}, nil
|
||||
}
|
||||
|
||||
func initLibs() error {
|
||||
libInitOnce.Do(func() {
|
||||
iokit, err := purego.Dlopen("/System/Library/Frameworks/IOKit.framework/IOKit", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
libInitErr = fmt.Errorf("dlopen IOKit: %w", err)
|
||||
return
|
||||
}
|
||||
cfLib, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
libInitErr = fmt.Errorf("dlopen CoreFoundation: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
purego.RegisterLibFunc(&ioKit.IORegisterForSystemPower, iokit, "IORegisterForSystemPower")
|
||||
purego.RegisterLibFunc(&ioKit.IODeregisterForSystemPower, iokit, "IODeregisterForSystemPower")
|
||||
purego.RegisterLibFunc(&ioKit.IOAllowPowerChange, iokit, "IOAllowPowerChange")
|
||||
purego.RegisterLibFunc(&ioKit.IOServiceClose, iokit, "IOServiceClose")
|
||||
purego.RegisterLibFunc(&ioKit.IONotificationPortGetRunLoopSource, iokit, "IONotificationPortGetRunLoopSource")
|
||||
purego.RegisterLibFunc(&ioKit.IONotificationPortDestroy, iokit, "IONotificationPortDestroy")
|
||||
|
||||
purego.RegisterLibFunc(&cf.CFRunLoopGetCurrent, cfLib, "CFRunLoopGetCurrent")
|
||||
purego.RegisterLibFunc(&cf.CFRunLoopRun, cfLib, "CFRunLoopRun")
|
||||
purego.RegisterLibFunc(&cf.CFRunLoopStop, cfLib, "CFRunLoopStop")
|
||||
purego.RegisterLibFunc(&cf.CFRunLoopAddSource, cfLib, "CFRunLoopAddSource")
|
||||
purego.RegisterLibFunc(&cf.CFRunLoopRemoveSource, cfLib, "CFRunLoopRemoveSource")
|
||||
|
||||
modeAddr, err := purego.Dlsym(cfLib, "kCFRunLoopCommonModes")
|
||||
if err != nil {
|
||||
libInitErr = fmt.Errorf("dlsym kCFRunLoopCommonModes: %w", err)
|
||||
return
|
||||
}
|
||||
// Launder the uintptr-to-pointer conversion through a Go variable so
|
||||
// go vet's unsafeptr analyzer doesn't flag a system-library global.
|
||||
cfCommonModes = **(**uintptr)(unsafe.Pointer(&modeAddr))
|
||||
|
||||
// NewCallback slots are a finite, non-reclaimable resource, so register
|
||||
// a single thunk that dispatches to the current Detector set.
|
||||
callbackThunk = purego.NewCallback(powerCallback)
|
||||
})
|
||||
return libInitErr
|
||||
}
|
||||
|
||||
// powerCallback is the IOServiceInterestCallback trampoline, invoked on the
|
||||
// runloop thread. A Go panic crossing the purego boundary has undefined
|
||||
// behavior, so contain it here.
|
||||
func powerCallback(refcon, service, messageType, messageArgument uintptr) uintptr {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Errorf("panic in sleep powerCallback: %v", r)
|
||||
}
|
||||
}()
|
||||
switch messageType {
|
||||
case kIOMessageCanSystemSleep:
|
||||
// Not acknowledging forces a 30s IOKit timeout before idle sleep.
|
||||
allowPowerChange(messageArgument)
|
||||
case kIOMessageSystemWillSleep:
|
||||
dispatchEvent(EventTypeSleep)
|
||||
allowPowerChange(messageArgument)
|
||||
case kIOMessageSystemHasPoweredOn:
|
||||
dispatchEvent(EventTypeWakeUp)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func allowPowerChange(messageArgument uintptr) {
|
||||
serviceRegistryMu.Lock()
|
||||
var port uintptr
|
||||
if session != nil {
|
||||
port = session.rp
|
||||
}
|
||||
serviceRegistryMu.Unlock()
|
||||
if port != 0 {
|
||||
ioKit.IOAllowPowerChange(port, messageArgument)
|
||||
}
|
||||
}
|
||||
|
||||
func dispatchEvent(event EventType) {
|
||||
serviceRegistryMu.Lock()
|
||||
snaps := make([]detectorSnapshot, 0, len(serviceRegistry))
|
||||
for d := range serviceRegistry {
|
||||
snaps = append(snaps, detectorSnapshot{
|
||||
detector: d,
|
||||
callback: d.callback,
|
||||
done: d.done,
|
||||
})
|
||||
}
|
||||
serviceRegistryMu.Unlock()
|
||||
|
||||
for _, s := range snaps {
|
||||
s.detector.triggerCallback(event, s.callback, s.done)
|
||||
}
|
||||
}
|
||||
|
||||
// runRunLoop owns the OS-locked thread that CFRunLoop is pinned to. Setup
|
||||
// result is reported on errCh so Register can surface failures synchronously.
|
||||
func runRunLoop(errCh chan<- error) {
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
|
||||
sess, err := setupSession()
|
||||
if err == nil {
|
||||
serviceRegistryMu.Lock()
|
||||
session = sess
|
||||
serviceRegistryMu.Unlock()
|
||||
}
|
||||
errCh <- err
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Errorf("panic in sleep runloop: %v", r)
|
||||
}
|
||||
}()
|
||||
cf.CFRunLoopRun()
|
||||
}
|
||||
|
||||
// setupSession performs the IOKit registration on the current thread. Panics
|
||||
// are converted to errors so runRunLoop never leaves errCh unsent.
|
||||
func setupSession() (s *runLoopSession, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic during runloop setup: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
var portRef, notifier uintptr
|
||||
rp := ioKit.IORegisterForSystemPower(0, &portRef, callbackThunk, ¬ifier)
|
||||
if rp == 0 {
|
||||
return nil, fmt.Errorf("IORegisterForSystemPower returned zero")
|
||||
}
|
||||
|
||||
rl := cf.CFRunLoopGetCurrent()
|
||||
source := ioKit.IONotificationPortGetRunLoopSource(portRef)
|
||||
cf.CFRunLoopAddSource(rl, source, cfCommonModes)
|
||||
|
||||
return &runLoopSession{rl: rl, port: portRef, notifier: notifier, rp: rp}, nil
|
||||
}
|
||||
|
||||
@@ -74,6 +74,14 @@ func New(filePath string) *Manager {
|
||||
}
|
||||
}
|
||||
|
||||
// FilePath returns the path of the underlying state file.
|
||||
func (m *Manager) FilePath() string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
return m.filePath
|
||||
}
|
||||
|
||||
// Start starts the state manager periodic save routine
|
||||
func (m *Manager) Start() {
|
||||
if m == nil {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -27,10 +28,6 @@ func NewWGIfaceMonitor() *WGIfaceMonitor {
|
||||
|
||||
// Start begins monitoring the WireGuard interface.
|
||||
// It relies on the provided context cancellation to stop.
|
||||
//
|
||||
// On Linux the watcher is event-driven (RTNLGRP_LINK netlink subscription)
|
||||
// to avoid the allocation churn of repeatedly dumping the kernel link
|
||||
// table; on other platforms it falls back to a low-frequency poll.
|
||||
func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) {
|
||||
defer close(m.done)
|
||||
|
||||
@@ -59,7 +56,31 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
|
||||
|
||||
log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex)
|
||||
|
||||
return watchInterface(ctx, ifaceName, expectedIndex)
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Infof("Interface monitor: stopped for %s", ifaceName)
|
||||
return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err())
|
||||
case <-ticker.C:
|
||||
currentIndex, err := getInterfaceIndex(ifaceName)
|
||||
if err != nil {
|
||||
// Interface was deleted
|
||||
log.Infof("Interface monitor: %s deleted", ifaceName)
|
||||
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
|
||||
}
|
||||
|
||||
// Check if interface index changed (interface was recreated)
|
||||
if currentIndex != expectedIndex {
|
||||
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
|
||||
ifaceName, expectedIndex, currentIndex)
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// getInterfaceIndex returns the index of a network interface by name.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user