mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-06 17:08:53 +00:00
Merge branch 'main' into reduce-embed-wg-pool
This commit is contained in:
130
.github/DISCUSSION_TEMPLATE/ideas-feature-requests.yml
vendored
Normal file
130
.github/DISCUSSION_TEMPLATE/ideas-feature-requests.yml
vendored
Normal file
@@ -0,0 +1,130 @@
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
## Ideas & Feature Requests
|
||||
|
||||
Use this category for feature requests, enhancements, integrations, and product ideas.
|
||||
|
||||
NetBird uses community traction in discussions — upvotes, replies, affected users, and use-case detail — as an input when deciding what should become a maintainer-curated issue or roadmap item. A clear problem statement is more useful than a solution-only request.
|
||||
|
||||
Please search first and add your use case to an existing discussion when one already exists.
|
||||
|
||||
- type: checkboxes
|
||||
id: preflight
|
||||
attributes:
|
||||
label: Before posting
|
||||
options:
|
||||
- label: I searched existing discussions and issues for similar requests.
|
||||
required: true
|
||||
- label: I checked the documentation to confirm this is not already supported.
|
||||
required: true
|
||||
- label: This is a product idea or enhancement request, not a support question.
|
||||
required: true
|
||||
- label: I removed or anonymized sensitive details from examples and screenshots.
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: area
|
||||
attributes:
|
||||
label: Product area
|
||||
description: Select every area this request touches.
|
||||
multiple: true
|
||||
options:
|
||||
- Client / Agent
|
||||
- CLI
|
||||
- Desktop UI
|
||||
- Mobile app
|
||||
- Dashboard / Admin UI
|
||||
- Management service / API
|
||||
- Signal service
|
||||
- Relay
|
||||
- DNS
|
||||
- Routes / Exit nodes
|
||||
- NetBird SSH
|
||||
- Access control policies
|
||||
- Posture checks
|
||||
- Identity provider / SSO
|
||||
- Self-hosting / Deployment
|
||||
- Kubernetes / Operator
|
||||
- Terraform / Automation
|
||||
- Documentation
|
||||
- Other / not sure
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: problem
|
||||
attributes:
|
||||
label: Problem or use case
|
||||
description: What are you trying to accomplish, and what is difficult or impossible today?
|
||||
placeholder: |
|
||||
As a ...
|
||||
I want to ...
|
||||
Because ...
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: proposal
|
||||
attributes:
|
||||
label: Proposed solution
|
||||
description: Describe the behavior, workflow, API, UI, or integration you would like to see.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: alternatives
|
||||
attributes:
|
||||
label: Alternatives or workarounds considered
|
||||
description: What have you tried today? Why is the current workaround not enough?
|
||||
|
||||
- type: textarea
|
||||
id: impact
|
||||
attributes:
|
||||
label: Community impact and priority
|
||||
description: Help us understand who benefits and how urgent this is.
|
||||
placeholder: |
|
||||
- Number of users/teams/peers affected:
|
||||
- Deployment type: Cloud / self-hosted / both
|
||||
- Frequency: daily / weekly / occasional
|
||||
- Blocking production adoption? yes/no
|
||||
- Related comments, discussions, or customer requests:
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: examples
|
||||
attributes:
|
||||
label: Examples from other tools or products
|
||||
description: If another tool solves this well, link or describe the behavior.
|
||||
|
||||
- type: textarea
|
||||
id: security
|
||||
attributes:
|
||||
label: Security, privacy, and compatibility considerations
|
||||
description: Note any access-control, audit, data retention, network, platform, or backward-compatibility concerns.
|
||||
|
||||
- type: textarea
|
||||
id: implementation
|
||||
attributes:
|
||||
label: Implementation ideas
|
||||
description: Optional. If you are familiar with the codebase or API, share possible implementation notes.
|
||||
|
||||
- type: dropdown
|
||||
id: contribution
|
||||
attributes:
|
||||
label: Are you willing to help?
|
||||
options:
|
||||
- Yes, I can submit a PR if the approach is accepted.
|
||||
- Yes, I can test or validate a proposed implementation.
|
||||
- Yes, I can provide more use-case details.
|
||||
- Not at this time.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: additional-context
|
||||
attributes:
|
||||
label: Additional context
|
||||
description: Add screenshots, diagrams, links, or anything else that helps explain the request.
|
||||
237
.github/DISCUSSION_TEMPLATE/issue-triage.yml
vendored
Normal file
237
.github/DISCUSSION_TEMPLATE/issue-triage.yml
vendored
Normal file
@@ -0,0 +1,237 @@
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
## Issue Triage
|
||||
|
||||
Use this category for reproducible bugs and regressions in NetBird.
|
||||
|
||||
The more context you include, the faster we can validate and act on your report. If you're not sure whether something is a bug, **Q&A / Support** is a good starting point — we can always move the conversation here once we've confirmed it's a product issue.
|
||||
|
||||
Intermittent issues are useful too. Include the trigger, frequency, timing, and any logs or debug evidence you have, and we'll work from there.
|
||||
|
||||
Please don't include secrets, tokens, private keys, internal hostnames, or public IPs. Security vulnerabilities should be reported through the repository security policy rather than a public discussion.
|
||||
|
||||
- type: checkboxes
|
||||
id: preflight
|
||||
attributes:
|
||||
label: Before posting
|
||||
options:
|
||||
- label: I searched existing discussions and issues, including closed ones, and checked the relevant docs.
|
||||
required: true
|
||||
- label: I believe this is a product bug rather than a configuration or setup question.
|
||||
required: true
|
||||
- label: I can reproduce this issue, or for intermittent issues I've included trigger, frequency, and timing details below.
|
||||
required: true
|
||||
- label: I removed or anonymized sensitive data from logs, screenshots, and configuration.
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: area
|
||||
attributes:
|
||||
label: Affected area
|
||||
description: Select every area this report touches.
|
||||
multiple: true
|
||||
options:
|
||||
- Client / Agent
|
||||
- Reverse Proxy
|
||||
- CLI
|
||||
- Desktop UI
|
||||
- Mobile app
|
||||
- Peer connectivity
|
||||
- DNS
|
||||
- Routes / Exit nodes
|
||||
- NetBird SSH
|
||||
- Relay / Signal / NAT traversal
|
||||
- Login / Authentication / IdP
|
||||
- Dashboard / Admin UI
|
||||
- Management service / API
|
||||
- Access control policies / Posture checks
|
||||
- Self-hosting / Deployment
|
||||
- Kubernetes / Operator
|
||||
- Documentation
|
||||
- Other / not sure
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: deployment
|
||||
attributes:
|
||||
label: Deployment type
|
||||
options:
|
||||
- NetBird Cloud
|
||||
- Self-hosted - quickstart script
|
||||
- Self-hosted - advanced/custom deployment
|
||||
- Local development build
|
||||
- Not sure / environment I do not fully control
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: platform
|
||||
attributes:
|
||||
label: Operating system or environment
|
||||
description: Select every environment involved in the reproduction.
|
||||
multiple: true
|
||||
options:
|
||||
- Linux
|
||||
- macOS
|
||||
- Windows
|
||||
- Android
|
||||
- iOS
|
||||
- FreeBSD
|
||||
- OpenWRT
|
||||
- Docker
|
||||
- Kubernetes
|
||||
- Synology
|
||||
- Browser
|
||||
- Other / not sure
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: version
|
||||
attributes:
|
||||
label: NetBird version and upgrade status
|
||||
description: Run `netbird version` where applicable. For self-hosted deployments, include management, signal, relay, and dashboard versions if available. If you cannot test on a current/supported version, explain why.
|
||||
placeholder: |
|
||||
Example:
|
||||
- Client: 0.30.2
|
||||
- Management: 0.30.2
|
||||
- Signal: 0.30.2
|
||||
- Relay: 0.30.2
|
||||
- Dashboard: 0.30.2
|
||||
- Upgrade status: reproduced on current version / cannot upgrade because ...
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: regression
|
||||
attributes:
|
||||
label: Did this work before?
|
||||
options:
|
||||
- Yes, this worked before
|
||||
- No, this never worked
|
||||
- Not sure
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: regression-details
|
||||
attributes:
|
||||
label: Regression details
|
||||
description: If this worked before, include the last known working version, first known broken version, and any recent upgrade, configuration, network, or IdP changes.
|
||||
placeholder: |
|
||||
- Last known working version:
|
||||
- First known broken version:
|
||||
- Recent changes:
|
||||
|
||||
- type: textarea
|
||||
id: summary
|
||||
attributes:
|
||||
label: Summary
|
||||
description: Briefly describe the reproducible bug.
|
||||
placeholder: What is broken?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: current-behavior
|
||||
attributes:
|
||||
label: Current behavior
|
||||
description: What happens now? Include exact errors, timeouts, UI messages, or failed commands when possible.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
attributes:
|
||||
label: Expected behavior
|
||||
description: What did you expect to happen instead?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
attributes:
|
||||
label: Steps to reproduce
|
||||
description: Provide the smallest set of steps that reliably reproduces the bug. If the issue is intermittent, include the trigger, frequency, timing, and relevant timestamps.
|
||||
placeholder: |
|
||||
1. Configure ...
|
||||
2. Run ...
|
||||
3. Observe ...
|
||||
|
||||
For intermittent issues:
|
||||
- Trigger:
|
||||
- Frequency:
|
||||
- Timing/timestamps:
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: environment
|
||||
attributes:
|
||||
label: Environment and topology
|
||||
description: Include the relevant topology and software involved in the reproduction. For UI/docs-only reports, write `N/A` if this does not apply. Use `None`, `Unknown`, or `N/A` where appropriate.
|
||||
placeholder: |
|
||||
- Peer A:
|
||||
- Peer B:
|
||||
- Same LAN or different networks:
|
||||
- NAT/CGNAT/corporate firewall/mobile network:
|
||||
- Other VPN software:
|
||||
- Firewall, DNS, or endpoint security software:
|
||||
- Routes, DNS, policies, posture checks, or SSH rules involved:
|
||||
- IdP, reverse proxy, or browser involved:
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: self-hosted-details
|
||||
attributes:
|
||||
label: Self-hosted details, if available
|
||||
description: Optional. If you use self-hosting and have access to these details, include them. If you do not administer the environment, provide what you know and say what you cannot access.
|
||||
placeholder: |
|
||||
- Deployment method: quickstart / Docker Compose / Helm / operator / custom
|
||||
- Management/signal/relay/dashboard versions:
|
||||
- Reverse proxy:
|
||||
- IdP/provider:
|
||||
- STUN/TURN/coturn/relay details:
|
||||
- Relevant component logs:
|
||||
|
||||
- type: textarea
|
||||
id: logs
|
||||
attributes:
|
||||
label: Logs, status output, or debug evidence
|
||||
description: |
|
||||
For client, connectivity, DNS, route, relay/signal, or self-hosted reports, logs are essential — please include anonymized output from `netbird status -dA`, or a debug bundle via `netbird debug for 1m -AS -U`. Debug bundles are automatically deleted after 30 days.
|
||||
|
||||
For UI, dashboard, or documentation reports, leave the pre-filled `N/A`.
|
||||
value: "N/A"
|
||||
render: shell
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: related-reports
|
||||
attributes:
|
||||
label: Related issues or discussions
|
||||
description: Optional. Link similar reports you found while searching, if any.
|
||||
placeholder: |
|
||||
- Related issue/discussion:
|
||||
- Why this may be the same or different:
|
||||
|
||||
- type: textarea
|
||||
id: impact
|
||||
attributes:
|
||||
label: Impact
|
||||
description: Optional. Help us understand priority. How many users, peers, environments, or workflows are affected? Is there a workaround?
|
||||
placeholder: |
|
||||
- Affected users/peers:
|
||||
- Business or production impact:
|
||||
- Workaround available:
|
||||
|
||||
- type: textarea
|
||||
id: additional-context
|
||||
attributes:
|
||||
label: Additional context
|
||||
description: Add links to related discussions, issues, docs, screenshots, recordings, or anything else that may help validation.
|
||||
146
.github/DISCUSSION_TEMPLATE/q-a-support.yml
vendored
Normal file
146
.github/DISCUSSION_TEMPLATE/q-a-support.yml
vendored
Normal file
@@ -0,0 +1,146 @@
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
## Q&A / Support
|
||||
|
||||
Use this category for questions about configuration, setup, self-hosted deployments, troubleshooting, and general NetBird usage.
|
||||
|
||||
This is community support and does not provide an SLA. For NetBird Cloud support, use the official support channel linked from the issue creation page. Please do not post secrets, tokens, private keys, internal hostnames, or public IPs unless you intentionally want them public.
|
||||
|
||||
If your question turns into a reproducible product defect, DevRel or a maintainer may ask you to open or move the conversation to Issue Triage.
|
||||
|
||||
- type: checkboxes
|
||||
id: preflight
|
||||
attributes:
|
||||
label: Before posting
|
||||
options:
|
||||
- label: I searched existing discussions and issues for similar questions.
|
||||
required: true
|
||||
- label: I reviewed the relevant NetBird documentation or troubleshooting guide.
|
||||
required: true
|
||||
- label: I removed or anonymized sensitive data from logs, screenshots, and configuration.
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: topic
|
||||
attributes:
|
||||
label: Topic
|
||||
multiple: true
|
||||
options:
|
||||
- Getting started
|
||||
- Self-hosting
|
||||
- Client / Agent
|
||||
- CLI
|
||||
- Desktop UI
|
||||
- Mobile app
|
||||
- Dashboard / Admin UI
|
||||
- DNS
|
||||
- Routes / Exit nodes
|
||||
- NetBird SSH
|
||||
- Relay
|
||||
- Access control policies
|
||||
- Posture checks
|
||||
- Identity provider / SSO
|
||||
- API
|
||||
- Kubernetes / Operator
|
||||
- Terraform / Automation
|
||||
- Documentation
|
||||
- Other / not sure
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: deployment
|
||||
attributes:
|
||||
label: Deployment type
|
||||
options:
|
||||
- NetBird Cloud
|
||||
- Self-hosted - quickstart script
|
||||
- Self-hosted - advanced/custom deployment
|
||||
- Local development build
|
||||
- Not sure
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: platform
|
||||
attributes:
|
||||
label: Operating system or environment
|
||||
multiple: true
|
||||
options:
|
||||
- Linux
|
||||
- macOS
|
||||
- Windows
|
||||
- Android
|
||||
- iOS
|
||||
- FreeBSD
|
||||
- OpenWRT
|
||||
- Docker
|
||||
- Kubernetes
|
||||
- Synology
|
||||
- Browser
|
||||
- Other / not sure
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: version
|
||||
attributes:
|
||||
label: NetBird version
|
||||
description: Run `netbird version` where applicable. For self-hosted deployments, include component versions if relevant.
|
||||
placeholder: "Example: client 0.30.2, management 0.30.2"
|
||||
|
||||
- type: textarea
|
||||
id: question
|
||||
attributes:
|
||||
label: Question
|
||||
description: What are you trying to understand or accomplish?
|
||||
placeholder: Describe your question clearly.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: goal
|
||||
attributes:
|
||||
label: Desired outcome
|
||||
description: What would a successful answer help you do?
|
||||
placeholder: |
|
||||
I want to configure ...
|
||||
I expected ...
|
||||
I need help deciding ...
|
||||
|
||||
- type: textarea
|
||||
id: attempted
|
||||
attributes:
|
||||
label: What have you tried?
|
||||
description: Include commands, documentation links, configuration attempts, or troubleshooting steps already tried.
|
||||
placeholder: |
|
||||
- Read ...
|
||||
- Ran ...
|
||||
- Changed ...
|
||||
- Observed ...
|
||||
|
||||
- type: textarea
|
||||
id: environment
|
||||
attributes:
|
||||
label: Relevant environment details
|
||||
description: Include redacted topology, IdP/provider, reverse proxy, firewall, DNS, route, policy, or self-hosted setup details that may affect the answer.
|
||||
placeholder: |
|
||||
- Deployment:
|
||||
- Components involved:
|
||||
- Network/topology:
|
||||
- Related config:
|
||||
|
||||
- type: textarea
|
||||
id: logs
|
||||
attributes:
|
||||
label: Logs or output
|
||||
description: Optional. Include anonymized logs, command output, screenshots, or `netbird status -dA` if relevant.
|
||||
render: shell
|
||||
|
||||
- type: textarea
|
||||
id: additional-context
|
||||
attributes:
|
||||
label: Additional context
|
||||
description: Add links, diagrams, screenshots, or other details that may help the community answer.
|
||||
71
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
71
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
@@ -1,71 +0,0 @@
|
||||
---
|
||||
name: Bug/Issue report
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: ['triage-needed']
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the problem**
|
||||
|
||||
A clear and concise description of what the problem is.
|
||||
|
||||
**To Reproduce**
|
||||
|
||||
Steps to reproduce the behavior:
|
||||
1. Go to '...'
|
||||
2. Click on '....'
|
||||
3. Scroll down to '....'
|
||||
4. See error
|
||||
|
||||
**Expected behavior**
|
||||
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Are you using NetBird Cloud?**
|
||||
|
||||
Please specify whether you use NetBird Cloud or self-host NetBird's control plane.
|
||||
|
||||
**NetBird version**
|
||||
|
||||
`netbird version`
|
||||
|
||||
**Is any other VPN software installed?**
|
||||
|
||||
If yes, which one?
|
||||
|
||||
**Debug output**
|
||||
|
||||
To help us resolve the problem, please attach the following anonymized status output
|
||||
|
||||
netbird status -dA
|
||||
|
||||
Create and upload a debug bundle, and share the returned file key:
|
||||
|
||||
netbird debug for 1m -AS -U
|
||||
|
||||
*Uploaded files are automatically deleted after 30 days.*
|
||||
|
||||
|
||||
Alternatively, create the file only and attach it here manually:
|
||||
|
||||
netbird debug for 1m -AS
|
||||
|
||||
|
||||
**Screenshots**
|
||||
|
||||
If applicable, add screenshots to help explain your problem.
|
||||
|
||||
**Additional context**
|
||||
|
||||
Add any other context about the problem here.
|
||||
|
||||
**Have you tried these troubleshooting steps?**
|
||||
- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable)
|
||||
- [ ] Checked for newer NetBird versions
|
||||
- [ ] Searched for similar issues on GitHub (including closed ones)
|
||||
- [ ] Restarted the NetBird client
|
||||
- [ ] Disabled other VPN software
|
||||
- [ ] Checked firewall settings
|
||||
|
||||
26
.github/ISSUE_TEMPLATE/config.yml
vendored
26
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,14 +1,26 @@
|
||||
blank_issues_enabled: true
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: Community Support
|
||||
- name: Start an Issue Triage discussion
|
||||
url: https://github.com/netbirdio/netbird/discussions/new?category=issue-triage
|
||||
about: Report a bug, regression, or unexpected behavior so DevRel can validate it before it becomes an issue.
|
||||
- name: Propose an idea or feature request
|
||||
url: https://github.com/netbirdio/netbird/discussions/new?category=ideas-feature-requests
|
||||
about: Share feature requests, enhancements, and integration ideas for community feedback and prioritization.
|
||||
- name: Ask a Q&A / Support question
|
||||
url: https://github.com/netbirdio/netbird/discussions/new?category=q-a-support
|
||||
about: Get help with setup, configuration, self-hosting, troubleshooting, and general usage.
|
||||
- name: Security vulnerability disclosure
|
||||
url: https://github.com/netbirdio/netbird/security/policy
|
||||
about: Please do not report security vulnerabilities in public issues or discussions.
|
||||
- name: Community Support Forum
|
||||
url: https://forum.netbird.io/
|
||||
about: Community support forum
|
||||
about: Community support forum.
|
||||
- name: Cloud Support
|
||||
url: https://docs.netbird.io/help/report-bug-issues
|
||||
about: Contact us for support
|
||||
- name: Client/Connection Troubleshooting
|
||||
about: Contact NetBird for Cloud support.
|
||||
- name: Client / Connection Troubleshooting
|
||||
url: https://docs.netbird.io/help/troubleshooting-client
|
||||
about: See our client troubleshooting guide for help addressing common issues
|
||||
about: See the client troubleshooting guide for common connectivity issues.
|
||||
- name: Self-host Troubleshooting
|
||||
url: https://docs.netbird.io/selfhosted/troubleshooting
|
||||
about: See our self-host troubleshooting guide for help addressing common issues
|
||||
about: See the self-host troubleshooting guide for common deployment issues.
|
||||
|
||||
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
@@ -1,20 +0,0 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: ''
|
||||
labels: ['feature-request']
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context or screenshots about the feature request here.
|
||||
128
.github/ISSUE_TEMPLATE/validated_issue.yml
vendored
Normal file
128
.github/ISSUE_TEMPLATE/validated_issue.yml
vendored
Normal file
@@ -0,0 +1,128 @@
|
||||
name: Validated issue
|
||||
description: Maintainer/DevRel only. Create an issue after a discussion has been validated or for internally validated work.
|
||||
title: "[Validated]: "
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
## Discussion-first issue policy
|
||||
|
||||
Issues are maintainer-curated work items. Community reports and feature requests should start in [Discussions](https://github.com/netbirdio/netbird/discussions) so DevRel can validate, reproduce, and route them before engineering time is committed.
|
||||
|
||||
Use this form when:
|
||||
- A discussion has been validated and should become actionable work.
|
||||
- A maintainer is opening internally validated work that can bypass the discussion-first flow.
|
||||
|
||||
Issues opened without a relevant validated discussion or maintainer context may be closed and redirected to Discussions.
|
||||
|
||||
- type: checkboxes
|
||||
id: validation-checks
|
||||
attributes:
|
||||
label: Validation checklist
|
||||
options:
|
||||
- label: This issue is linked to a validated discussion, or it is being opened directly by a maintainer.
|
||||
required: true
|
||||
- label: The report has enough context for engineering to act on it without re-triaging from scratch.
|
||||
required: true
|
||||
- label: Sensitive data, secrets, private keys, internal hostnames, and public IPs have been removed or intentionally disclosed.
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: issue-type
|
||||
attributes:
|
||||
label: Issue type
|
||||
options:
|
||||
- Bug / Regression
|
||||
- Feature / Enhancement
|
||||
- Documentation
|
||||
- Maintenance / Refactor
|
||||
- Cross-repository coordination
|
||||
- Other
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: source-discussion
|
||||
attributes:
|
||||
label: Source discussion
|
||||
description: Link the GitHub Discussion that was validated. Maintainers bypassing the flow can write "Maintainer-created" and explain why below.
|
||||
placeholder: https://github.com/netbirdio/netbird/discussions/1234
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: validation-owner
|
||||
attributes:
|
||||
label: Validation owner
|
||||
description: GitHub handle of the DevRel team member or maintainer who validated this work.
|
||||
placeholder: "@username"
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: target-repository
|
||||
attributes:
|
||||
label: Target repository
|
||||
description: Where should the implementation work happen?
|
||||
options:
|
||||
- netbirdio/netbird
|
||||
- netbirdio/dashboard
|
||||
- netbirdio/kubernetes-operator
|
||||
- netbirdio/docs
|
||||
- Multiple repositories
|
||||
- Unknown / needs routing
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: summary
|
||||
attributes:
|
||||
label: Summary
|
||||
description: Concise description of the validated work.
|
||||
placeholder: What needs to be fixed, changed, documented, or built?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: evidence
|
||||
attributes:
|
||||
label: Validation evidence
|
||||
description: For bugs, include reproduction status, affected versions, logs, and environment. For features, include community traction, affected users, and alignment notes.
|
||||
placeholder: |
|
||||
- Reproduced by:
|
||||
- Affected versions / platforms:
|
||||
- Community signal:
|
||||
- Related logs or screenshots:
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: scope
|
||||
attributes:
|
||||
label: Proposed scope
|
||||
description: Describe what is in scope and, if helpful, what is explicitly out of scope.
|
||||
placeholder: |
|
||||
In scope:
|
||||
- ...
|
||||
|
||||
Out of scope:
|
||||
- ...
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: acceptance-criteria
|
||||
attributes:
|
||||
label: Acceptance criteria
|
||||
description: What must be true for this issue to be closed?
|
||||
placeholder: |
|
||||
- [ ] ...
|
||||
- [ ] ...
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: additional-context
|
||||
attributes:
|
||||
label: Additional context
|
||||
description: Links to related PRs, docs, issues in other repositories, roadmap items, or implementation notes.
|
||||
@@ -58,6 +58,11 @@ linters:
|
||||
govet:
|
||||
enable:
|
||||
- nilness
|
||||
disable:
|
||||
# The inline analyzer flags x/exp/maps Clone/Clear with //go:fix inline
|
||||
# directives but cannot perform the rewrite due to generic type
|
||||
# parameter inference limitations in the Go inliner.
|
||||
- inline
|
||||
enable-all: false
|
||||
revive:
|
||||
rules:
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
@@ -23,6 +24,7 @@ import (
|
||||
|
||||
func init() {
|
||||
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
loginCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
|
||||
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
|
||||
}
|
||||
@@ -256,7 +258,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
||||
}
|
||||
|
||||
func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error {
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser, showQR)
|
||||
|
||||
resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
@@ -324,7 +326,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||
}
|
||||
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser, showQR)
|
||||
|
||||
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
|
||||
if err != nil {
|
||||
@@ -334,7 +336,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
||||
return &tokenInfo, nil
|
||||
}
|
||||
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) {
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser, showQR bool) {
|
||||
var codeMsg string
|
||||
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||
@@ -348,6 +350,12 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
||||
verificationURIComplete + " " + codeMsg)
|
||||
}
|
||||
|
||||
if showQR {
|
||||
if f, ok := cmd.OutOrStdout().(*os.File); ok && term.IsTerminal(int(f.Fd())) {
|
||||
printQRCode(f, verificationURIComplete)
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Println("")
|
||||
|
||||
if !noBrowser {
|
||||
|
||||
25
client/cmd/qr.go
Normal file
25
client/cmd/qr.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/mdp/qrterminal/v3"
|
||||
)
|
||||
|
||||
// printQRCode prints a QR code for the given URL to the writer.
|
||||
// Called only when the user explicitly requests QR output via --qr.
|
||||
func printQRCode(w io.Writer, url string) {
|
||||
if url == "" {
|
||||
return
|
||||
}
|
||||
qrterminal.GenerateWithConfig(url, qrterminal.Config{
|
||||
Level: qrterminal.M,
|
||||
Writer: w,
|
||||
HalfBlocks: true,
|
||||
BlackChar: qrterminal.BLACK_BLACK,
|
||||
WhiteChar: qrterminal.WHITE_WHITE,
|
||||
BlackWhiteChar: qrterminal.BLACK_WHITE,
|
||||
WhiteBlackChar: qrterminal.WHITE_BLACK,
|
||||
QuietZone: qrterminal.QUIET_ZONE,
|
||||
})
|
||||
}
|
||||
26
client/cmd/qr_test.go
Normal file
26
client/cmd/qr_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPrintQRCode_EmptyURL(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
printQRCode(&buf, "")
|
||||
|
||||
if buf.Len() != 0 {
|
||||
t.Error("expected no output for empty URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrintQRCode_WritesOutput(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
printQRCode(&buf, "https://example.com/auth")
|
||||
|
||||
if buf.Len() == 0 {
|
||||
t.Error("expected QR code output for non-empty URL")
|
||||
}
|
||||
}
|
||||
@@ -39,6 +39,9 @@ const (
|
||||
noBrowserFlag = "no-browser"
|
||||
noBrowserDesc = "do not open the browser for SSO login"
|
||||
|
||||
showQRFlag = "qr"
|
||||
showQRDesc = "show QR code for the SSO login URL (useful for headless machines without browser access)"
|
||||
|
||||
profileNameFlag = "profile"
|
||||
profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used."
|
||||
)
|
||||
@@ -48,6 +51,7 @@ var (
|
||||
dnsLabels []string
|
||||
dnsLabelsValidated domain.List
|
||||
noBrowser bool
|
||||
showQR bool
|
||||
profileName string
|
||||
configPath string
|
||||
|
||||
@@ -80,6 +84,7 @@ func init() {
|
||||
)
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
upCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
|
||||
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ")
|
||||
|
||||
|
||||
@@ -2477,6 +2477,8 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
|
||||
}
|
||||
}
|
||||
|
||||
relayIP := decodeRelayIP(msg.GetBody().GetRelayServerIP())
|
||||
|
||||
offerAnswer := peer.OfferAnswer{
|
||||
IceCredentials: peer.IceCredentials{
|
||||
UFrag: remoteCred.UFrag,
|
||||
@@ -2487,7 +2489,23 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
|
||||
RosenpassPubKey: rosenpassPubKey,
|
||||
RosenpassAddr: rosenpassAddr,
|
||||
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
|
||||
RelaySrvIP: relayIP,
|
||||
SessionID: sessionID,
|
||||
}
|
||||
return &offerAnswer, nil
|
||||
}
|
||||
|
||||
// decodeRelayIP decodes the proto relayServerIP bytes (4 or 16) into a
|
||||
// netip.Addr. Returns the zero value for empty input and logs a warning
|
||||
// for malformed payloads.
|
||||
func decodeRelayIP(b []byte) netip.Addr {
|
||||
if len(b) == 0 {
|
||||
return netip.Addr{}
|
||||
}
|
||||
ip, ok := netip.AddrFromSlice(b)
|
||||
if !ok {
|
||||
log.Warnf("invalid relayServerIP in signal message (%d bytes), ignoring", len(b))
|
||||
return netip.Addr{}
|
||||
}
|
||||
return ip.Unmap()
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package peer
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
@@ -40,6 +41,10 @@ type OfferAnswer struct {
|
||||
|
||||
// relay server address
|
||||
RelaySrvAddress string
|
||||
// RelaySrvIP is the IP the remote peer is connected to on its
|
||||
// relay server. Used as a dial target if DNS for RelaySrvAddress
|
||||
// fails. Zero value if the peer did not advertise an IP.
|
||||
RelaySrvIP netip.Addr
|
||||
// SessionID is the unique identifier of the session, used to discard old messages
|
||||
SessionID *ICESessionID
|
||||
}
|
||||
@@ -217,8 +222,9 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
||||
answer.SessionID = &sid
|
||||
}
|
||||
|
||||
if addr, err := h.relay.RelayInstanceAddress(); err == nil {
|
||||
if addr, ip, err := h.relay.RelayInstanceAddress(); err == nil {
|
||||
answer.RelaySrvAddress = addr
|
||||
answer.RelaySrvIP = ip
|
||||
}
|
||||
|
||||
return answer
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
type mocListener struct {
|
||||
lastState int
|
||||
wg sync.WaitGroup
|
||||
peersWg sync.WaitGroup
|
||||
peers int
|
||||
}
|
||||
|
||||
@@ -33,6 +34,7 @@ func (l *mocListener) OnAddressChanged(host, addr string) {
|
||||
}
|
||||
func (l *mocListener) OnPeersListChanged(size int) {
|
||||
l.peers = size
|
||||
l.peersWg.Done()
|
||||
}
|
||||
|
||||
func (l *mocListener) setWaiter() {
|
||||
@@ -43,6 +45,14 @@ func (l *mocListener) wait() {
|
||||
l.wg.Wait()
|
||||
}
|
||||
|
||||
func (l *mocListener) setPeersWaiter() {
|
||||
l.peersWg.Add(1)
|
||||
}
|
||||
|
||||
func (l *mocListener) waitPeers() {
|
||||
l.peersWg.Wait()
|
||||
}
|
||||
|
||||
func Test_notifier_serverState(t *testing.T) {
|
||||
|
||||
type scenario struct {
|
||||
@@ -72,11 +82,13 @@ func Test_notifier_serverState(t *testing.T) {
|
||||
func Test_notifier_SetListener(t *testing.T) {
|
||||
listener := &mocListener{}
|
||||
listener.setWaiter()
|
||||
listener.setPeersWaiter()
|
||||
|
||||
n := newNotifier()
|
||||
n.lastNotification = stateConnecting
|
||||
n.setListener(listener)
|
||||
listener.wait()
|
||||
listener.waitPeers()
|
||||
if listener.lastState != n.lastNotification {
|
||||
t.Errorf("invalid state: %d, expected: %d", listener.lastState, n.lastNotification)
|
||||
}
|
||||
@@ -85,9 +97,14 @@ func Test_notifier_SetListener(t *testing.T) {
|
||||
func Test_notifier_RemoveListener(t *testing.T) {
|
||||
listener := &mocListener{}
|
||||
listener.setWaiter()
|
||||
listener.setPeersWaiter()
|
||||
n := newNotifier()
|
||||
n.lastNotification = stateConnecting
|
||||
n.setListener(listener)
|
||||
// setListener replays cached state on a goroutine; wait for both the state
|
||||
// and peers callbacks to finish so we don't race on listener.peers.
|
||||
listener.wait()
|
||||
listener.waitPeers()
|
||||
n.removeListener()
|
||||
n.peerListChanged(1)
|
||||
|
||||
|
||||
@@ -54,19 +54,19 @@ func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string,
|
||||
log.Warnf("failed to get session ID bytes: %v", err)
|
||||
}
|
||||
}
|
||||
msg, err := signal.MarshalCredential(
|
||||
s.wgPrivateKey,
|
||||
offerAnswer.WgListenPort,
|
||||
remoteKey,
|
||||
&signal.Credential{
|
||||
msg, err := signal.MarshalCredential(s.wgPrivateKey, remoteKey, signal.CredentialPayload{
|
||||
Type: bodyType,
|
||||
WgListenPort: offerAnswer.WgListenPort,
|
||||
Credential: &signal.Credential{
|
||||
UFrag: offerAnswer.IceCredentials.UFrag,
|
||||
Pwd: offerAnswer.IceCredentials.Pwd,
|
||||
},
|
||||
bodyType,
|
||||
offerAnswer.RosenpassPubKey,
|
||||
offerAnswer.RosenpassAddr,
|
||||
offerAnswer.RelaySrvAddress,
|
||||
sessionIDBytes)
|
||||
RosenpassPubKey: offerAnswer.RosenpassPubKey,
|
||||
RosenpassAddr: offerAnswer.RosenpassAddr,
|
||||
RelaySrvAddress: offerAnswer.RelaySrvAddress,
|
||||
RelaySrvIP: offerAnswer.RelaySrvIP,
|
||||
SessionID: sessionIDBytes,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -320,10 +320,10 @@ func (d *Status) RemovePeer(peerPubKey string) error {
|
||||
// UpdatePeerState updates peer status
|
||||
func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
if !ok {
|
||||
d.mux.Unlock()
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
@@ -343,23 +343,29 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
|
||||
d.peers[receivedState.PubKey] = peerState
|
||||
|
||||
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||
d.notifyPeerListChanged()
|
||||
}
|
||||
|
||||
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
|
||||
// when we close the connection we will not notify the router manager
|
||||
if receivedState.ConnStatus == StatusIdle {
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
notifyRouter := receivedState.ConnStatus == StatusIdle
|
||||
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
|
||||
numPeers := d.numOfPeers()
|
||||
|
||||
d.mux.Unlock()
|
||||
|
||||
if notifyList {
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
}
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.ResID) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[peer]
|
||||
if !ok {
|
||||
d.mux.Unlock()
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
@@ -371,17 +377,20 @@ func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.R
|
||||
d.routeIDLookup.AddRemoteRouteID(resourceId, pref)
|
||||
}
|
||||
|
||||
numPeers := d.numOfPeers()
|
||||
d.mux.Unlock()
|
||||
|
||||
// todo: consider to make sense of this notification or not
|
||||
d.notifyPeerListChanged()
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) RemovePeerStateRoute(peer string, route string) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[peer]
|
||||
if !ok {
|
||||
d.mux.Unlock()
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
@@ -393,8 +402,11 @@ func (d *Status) RemovePeerStateRoute(peer string, route string) error {
|
||||
d.routeIDLookup.RemoveRemoteRouteID(pref)
|
||||
}
|
||||
|
||||
numPeers := d.numOfPeers()
|
||||
d.mux.Unlock()
|
||||
|
||||
// todo: consider to make sense of this notification or not
|
||||
d.notifyPeerListChanged()
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -410,10 +422,10 @@ func (d *Status) CheckRoutes(ip netip.Addr) ([]byte, bool) {
|
||||
|
||||
func (d *Status) UpdatePeerICEState(receivedState State) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
if !ok {
|
||||
d.mux.Unlock()
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
@@ -431,22 +443,28 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
|
||||
|
||||
d.peers[receivedState.PubKey] = peerState
|
||||
|
||||
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||
d.notifyPeerListChanged()
|
||||
}
|
||||
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
|
||||
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
|
||||
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
|
||||
numPeers := d.numOfPeers()
|
||||
|
||||
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.mux.Unlock()
|
||||
|
||||
if notifyList {
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
}
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
if !ok {
|
||||
d.mux.Unlock()
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
@@ -461,22 +479,28 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
||||
|
||||
d.peers[receivedState.PubKey] = peerState
|
||||
|
||||
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||
d.notifyPeerListChanged()
|
||||
}
|
||||
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
|
||||
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
|
||||
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
|
||||
numPeers := d.numOfPeers()
|
||||
|
||||
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.mux.Unlock()
|
||||
|
||||
if notifyList {
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
}
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
if !ok {
|
||||
d.mux.Unlock()
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
@@ -490,22 +514,28 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
|
||||
|
||||
d.peers[receivedState.PubKey] = peerState
|
||||
|
||||
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||
d.notifyPeerListChanged()
|
||||
}
|
||||
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
|
||||
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
|
||||
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
|
||||
numPeers := d.numOfPeers()
|
||||
|
||||
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.mux.Unlock()
|
||||
|
||||
if notifyList {
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
}
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
if !ok {
|
||||
d.mux.Unlock()
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
@@ -522,12 +552,18 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
||||
|
||||
d.peers[receivedState.PubKey] = peerState
|
||||
|
||||
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||
d.notifyPeerListChanged()
|
||||
}
|
||||
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
|
||||
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
|
||||
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
|
||||
numPeers := d.numOfPeers()
|
||||
|
||||
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.mux.Unlock()
|
||||
|
||||
if notifyList {
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
}
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -594,17 +630,33 @@ func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) erro
|
||||
// FinishPeerListModifications this event invoke the notification
|
||||
func (d *Status) FinishPeerListModifications() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
if !d.peerListChangedForNotification {
|
||||
d.mux.Unlock()
|
||||
return
|
||||
}
|
||||
d.peerListChangedForNotification = false
|
||||
|
||||
d.notifyPeerListChanged()
|
||||
numPeers := d.numOfPeers()
|
||||
|
||||
// snapshot per-peer router state to deliver after the lock is released
|
||||
type routerDispatch struct {
|
||||
peerID string
|
||||
snapshot map[string]RouterState
|
||||
}
|
||||
dispatches := make([]routerDispatch, 0, len(d.peers))
|
||||
for key := range d.peers {
|
||||
d.notifyPeerStateChangeListeners(key)
|
||||
snapshot := d.snapshotRouterPeersLocked(key, true)
|
||||
if snapshot != nil {
|
||||
dispatches = append(dispatches, routerDispatch{peerID: key, snapshot: snapshot})
|
||||
}
|
||||
}
|
||||
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
for _, rd := range dispatches {
|
||||
d.dispatchRouterPeers(rd.peerID, rd.snapshot)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -655,10 +707,12 @@ func (d *Status) GetLocalPeerState() LocalPeerState {
|
||||
// UpdateLocalPeerState updates local peer status
|
||||
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
d.localPeer = localPeerState
|
||||
d.notifyAddressChanged()
|
||||
fqdn := d.localPeer.FQDN
|
||||
ip := d.localPeer.IP
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.localAddressChanged(fqdn, ip)
|
||||
}
|
||||
|
||||
// AddLocalPeerStateRoute adds a route to the local peer state
|
||||
@@ -721,30 +775,36 @@ func (d *Status) CleanLocalPeerStateRoutes() {
|
||||
// CleanLocalPeerState cleans local peer status
|
||||
func (d *Status) CleanLocalPeerState() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
d.localPeer = LocalPeerState{}
|
||||
d.notifyAddressChanged()
|
||||
fqdn := d.localPeer.FQDN
|
||||
ip := d.localPeer.IP
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.localAddressChanged(fqdn, ip)
|
||||
}
|
||||
|
||||
// MarkManagementDisconnected sets ManagementState to disconnected
|
||||
func (d *Status) MarkManagementDisconnected(err error) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
defer d.onConnectionChanged()
|
||||
|
||||
d.managementState = false
|
||||
d.managementError = err
|
||||
mgm := d.managementState
|
||||
sig := d.signalState
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
}
|
||||
|
||||
// MarkManagementConnected sets ManagementState to connected
|
||||
func (d *Status) MarkManagementConnected() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
defer d.onConnectionChanged()
|
||||
|
||||
d.managementState = true
|
||||
d.managementError = nil
|
||||
mgm := d.managementState
|
||||
sig := d.signalState
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
}
|
||||
|
||||
// UpdateSignalAddress update the address of the signal server
|
||||
@@ -778,21 +838,25 @@ func (d *Status) UpdateLazyConnection(enabled bool) {
|
||||
// MarkSignalDisconnected sets SignalState to disconnected
|
||||
func (d *Status) MarkSignalDisconnected(err error) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
defer d.onConnectionChanged()
|
||||
|
||||
d.signalState = false
|
||||
d.signalError = err
|
||||
mgm := d.managementState
|
||||
sig := d.signalState
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
}
|
||||
|
||||
// MarkSignalConnected sets SignalState to connected
|
||||
func (d *Status) MarkSignalConnected() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
defer d.onConnectionChanged()
|
||||
|
||||
d.signalState = true
|
||||
d.signalError = nil
|
||||
mgm := d.managementState
|
||||
sig := d.signalState
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
}
|
||||
|
||||
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
|
||||
@@ -919,7 +983,7 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
|
||||
// if the server connection is not established then we will use the general address
|
||||
// in case of connection we will use the instance specific address
|
||||
instanceAddr, err := d.relayMgr.RelayInstanceAddress()
|
||||
instanceAddr, _, err := d.relayMgr.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
// TODO add their status
|
||||
for _, r := range d.relayMgr.ServerURLs() {
|
||||
@@ -1012,18 +1076,17 @@ func (d *Status) RemoveConnectionListener() {
|
||||
d.notifier.removeListener()
|
||||
}
|
||||
|
||||
func (d *Status) onConnectionChanged() {
|
||||
d.notifier.updateServerStates(d.managementState, d.signalState)
|
||||
}
|
||||
|
||||
// notifyPeerStateChangeListeners notifies route manager about the change in peer state
|
||||
func (d *Status) notifyPeerStateChangeListeners(peerID string) {
|
||||
subs, ok := d.changeNotify[peerID]
|
||||
if !ok {
|
||||
return
|
||||
// snapshotRouterPeersLocked builds the RouterState map for a peer's subscribers.
|
||||
// Caller MUST hold d.mux. Returns nil when there are no subscribers for peerID
|
||||
// or when notify is false. The snapshot is consumed later by dispatchRouterPeers
|
||||
// outside the lock so the channel send cannot stall any d.mux holder.
|
||||
func (d *Status) snapshotRouterPeersLocked(peerID string, notify bool) map[string]RouterState {
|
||||
if !notify {
|
||||
return nil
|
||||
}
|
||||
if _, ok := d.changeNotify[peerID]; !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// collect the relevant data for router peers
|
||||
routerPeers := make(map[string]RouterState, len(d.changeNotify))
|
||||
for pid := range d.changeNotify {
|
||||
s, ok := d.peers[pid]
|
||||
@@ -1031,13 +1094,35 @@ func (d *Status) notifyPeerStateChangeListeners(peerID string) {
|
||||
log.Warnf("router peer not found in peers list: %s", pid)
|
||||
continue
|
||||
}
|
||||
|
||||
routerPeers[pid] = RouterState{
|
||||
Status: s.ConnStatus,
|
||||
Relayed: s.Relayed,
|
||||
Latency: s.Latency,
|
||||
}
|
||||
}
|
||||
return routerPeers
|
||||
}
|
||||
|
||||
// dispatchRouterPeers delivers a previously snapshotted router-state map to
|
||||
// the peer's subscribers. Caller MUST NOT hold d.mux. The method takes a
|
||||
// fresh, short read of d.changeNotify under the lock to grab subscriber
|
||||
// channels, then sends outside the lock so a slow consumer cannot block other
|
||||
// d.mux holders. The send itself stays blocking (only short-circuited by the
|
||||
// subscriber's context) so peer state transitions are not silently dropped.
|
||||
func (d *Status) dispatchRouterPeers(peerID string, routerPeers map[string]RouterState) {
|
||||
if routerPeers == nil {
|
||||
return
|
||||
}
|
||||
|
||||
d.mux.Lock()
|
||||
subsMap, ok := d.changeNotify[peerID]
|
||||
subs := make([]*StatusChangeSubscription, 0, len(subsMap))
|
||||
if ok {
|
||||
for _, sub := range subsMap {
|
||||
subs = append(subs, sub)
|
||||
}
|
||||
}
|
||||
d.mux.Unlock()
|
||||
|
||||
for _, sub := range subs {
|
||||
select {
|
||||
@@ -1047,14 +1132,6 @@ func (d *Status) notifyPeerStateChangeListeners(peerID string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Status) notifyPeerListChanged() {
|
||||
d.notifier.peerListChanged(d.numOfPeers())
|
||||
}
|
||||
|
||||
func (d *Status) notifyAddressChanged() {
|
||||
d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP)
|
||||
}
|
||||
|
||||
func (d *Status) numOfPeers() int {
|
||||
return len(d.peers) + len(d.offlinePeers)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
@@ -53,15 +54,19 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
w.relaySupportedOnRemotePeer.Store(true)
|
||||
|
||||
// the relayManager will return with error in case if the connection has lost with relay server
|
||||
currentRelayAddress, err := w.relayManager.RelayInstanceAddress()
|
||||
currentRelayAddress, _, err := w.relayManager.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to handle new offer: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
|
||||
var serverIP netip.Addr
|
||||
if srv == remoteOfferAnswer.RelaySrvAddress {
|
||||
serverIP = remoteOfferAnswer.RelaySrvIP
|
||||
}
|
||||
|
||||
relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key)
|
||||
relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key, serverIP)
|
||||
if err != nil {
|
||||
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
|
||||
w.log.Debugf("handled offer by reusing existing relay connection")
|
||||
@@ -90,7 +95,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
})
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
|
||||
func (w *WorkerRelay) RelayInstanceAddress() (string, netip.Addr, error) {
|
||||
return w.relayManager.RelayInstanceAddress()
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -177,7 +178,12 @@ func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
_, gateway, localIP, err = router.Route(net.IPv4zero)
|
||||
dst := net.IPv4zero
|
||||
if runtime.GOOS == "linux" {
|
||||
// go-netroute v0.4.0 rejects unspecified destinations client-side on Linux.
|
||||
dst = net.IPv4(0, 0, 0, 1)
|
||||
}
|
||||
_, gateway, localIP, err = router.Route(dst)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -196,7 +202,12 @@ func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
_, gateway, localIP, err = router.Route(net.IPv6zero)
|
||||
dst := net.IPv6zero
|
||||
if runtime.GOOS == "linux" {
|
||||
// ::2
|
||||
dst = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}
|
||||
}
|
||||
_, gateway, localIP, err = router.Route(dst)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
@@ -342,6 +342,22 @@ func GetNextHop(ip netip.Addr) (Nexthop, error) {
|
||||
if err != nil {
|
||||
return Nexthop{}, fmt.Errorf("new netroute: %w", err)
|
||||
}
|
||||
|
||||
// go-netroute v0.4.0 rejects unspecified destinations on Linux with a hard
|
||||
// client-side check. Substitute the lowest non-loopback address so the
|
||||
// lookup falls through to the default route (::1 / 127.0.0.1 would match
|
||||
// loopback, ::/0.0.0.0 are unspec). BSD/Windows pass the query straight to
|
||||
// the kernel and need no substitution.
|
||||
if runtime.GOOS == "linux" && ip.IsUnspecified() {
|
||||
if ip.Is6() {
|
||||
// ::2
|
||||
ip = netip.AddrFrom16([16]byte{15: 2})
|
||||
} else {
|
||||
// 0.0.0.1
|
||||
ip = netip.AddrFrom4([4]byte{0, 0, 0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
|
||||
if err != nil {
|
||||
log.Debugf("Failed to get route for %s: %v", ip, err)
|
||||
|
||||
@@ -354,9 +354,13 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
|
||||
require.NoError(t, err, "Should be able to get IPv4 default route")
|
||||
t.Logf("Initial IPv4 next hop: %s", initialNextHopV4)
|
||||
|
||||
if testCase.prefix.Addr().Is6() && !testCase.expectError {
|
||||
ensureIPv6DefaultRoute(t)
|
||||
}
|
||||
|
||||
initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified())
|
||||
if testCase.prefix.Addr().Is6() &&
|
||||
(errors.Is(err, vars.ErrRouteNotFound) || initialNextHopV6.Intf != nil && strings.HasPrefix(initialNextHopV6.Intf.Name, "utun")) {
|
||||
initialNextHopV6.Intf != nil && strings.HasPrefix(initialNextHopV6.Intf.Name, "utun") {
|
||||
t.Skip("Skipping test as no ipv6 default route is available")
|
||||
}
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
|
||||
30
client/internal/routemanager/systemops/v6route_bsd_test.go
Normal file
30
client/internal/routemanager/systemops/v6route_bsd_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os/exec"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
|
||||
// interface so route lookups for global IPv6 prefixes resolve in environments
|
||||
// without v6 connectivity. If a default already exists it is left alone.
|
||||
func ensureIPv6DefaultRoute(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
out, err := exec.Command("route", "-6", "add", "default", "-iface", "lo0").CombinedOutput()
|
||||
if err != nil {
|
||||
// Existing default; nothing to install or clean up.
|
||||
if bytes.Contains(out, []byte("route already in table")) {
|
||||
return
|
||||
}
|
||||
t.Skipf("install IPv6 fallback default route: %v: %s", err, out)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if out, err := exec.Command("route", "-6", "delete", "default").CombinedOutput(); err != nil {
|
||||
t.Logf("delete IPv6 fallback default route: %v: %s", err, out)
|
||||
}
|
||||
})
|
||||
}
|
||||
41
client/internal/routemanager/systemops/v6route_linux_test.go
Normal file
41
client/internal/routemanager/systemops/v6route_linux_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
// ensureIPv6DefaultRoute installs a low-preference IPv6 default route via the
|
||||
// loopback interface so route lookups for global IPv6 prefixes resolve in
|
||||
// environments without v6 connectivity. Any pre-existing default route wins
|
||||
// because of its lower metric.
|
||||
func ensureIPv6DefaultRoute(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
lo, err := netlink.LinkByName("lo")
|
||||
require.NoError(t, err, "find loopback interface")
|
||||
|
||||
route := &netlink.Route{
|
||||
Dst: &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)},
|
||||
LinkIndex: lo.Attrs().Index,
|
||||
Priority: 1 << 20,
|
||||
}
|
||||
if err := netlink.RouteAdd(route); err != nil {
|
||||
if errors.Is(err, syscall.EEXIST) {
|
||||
return
|
||||
}
|
||||
t.Skipf("install IPv6 fallback default route: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) {
|
||||
t.Logf("delete IPv6 fallback default route: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
//go:build windows
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os/exec"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const loopbackIfaceWindows = "Loopback Pseudo-Interface 1"
|
||||
|
||||
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
|
||||
// interface so route lookups for global IPv6 prefixes resolve in environments
|
||||
// without v6 connectivity. If a default already exists it is left alone.
|
||||
func ensureIPv6DefaultRoute(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
script := `New-NetRoute -DestinationPrefix "::/0" -InterfaceAlias "` + loopbackIfaceWindows + `" -RouteMetric 9999 -PolicyStore ActiveStore -ErrorAction Stop`
|
||||
out, err := exec.Command("powershell", "-Command", script).CombinedOutput()
|
||||
if err != nil {
|
||||
// Existing default; nothing to install or clean up.
|
||||
if bytes.Contains(out, []byte("already exists")) {
|
||||
return
|
||||
}
|
||||
t.Skipf("install IPv6 fallback default route: %v: %s", err, out)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
script := `Remove-NetRoute -DestinationPrefix "::/0" -InterfaceAlias "` + loopbackIfaceWindows + `" -Confirm:$false -ErrorAction Stop`
|
||||
if out, err := exec.Command("powershell", "-Command", script).CombinedOutput(); err != nil {
|
||||
t.Logf("delete IPv6 fallback default route: %v: %s", err, out)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -28,6 +27,10 @@ func NewWGIfaceMonitor() *WGIfaceMonitor {
|
||||
|
||||
// Start begins monitoring the WireGuard interface.
|
||||
// It relies on the provided context cancellation to stop.
|
||||
//
|
||||
// On Linux the watcher is event-driven (RTNLGRP_LINK netlink subscription)
|
||||
// to avoid the allocation churn of repeatedly dumping the kernel link
|
||||
// table; on other platforms it falls back to a low-frequency poll.
|
||||
func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) {
|
||||
defer close(m.done)
|
||||
|
||||
@@ -56,31 +59,7 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
|
||||
|
||||
log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex)
|
||||
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Infof("Interface monitor: stopped for %s", ifaceName)
|
||||
return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err())
|
||||
case <-ticker.C:
|
||||
currentIndex, err := getInterfaceIndex(ifaceName)
|
||||
if err != nil {
|
||||
// Interface was deleted
|
||||
log.Infof("Interface monitor: %s deleted", ifaceName)
|
||||
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
|
||||
}
|
||||
|
||||
// Check if interface index changed (interface was recreated)
|
||||
if currentIndex != expectedIndex {
|
||||
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
|
||||
ifaceName, expectedIndex, currentIndex)
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return watchInterface(ctx, ifaceName, expectedIndex)
|
||||
}
|
||||
|
||||
// getInterfaceIndex returns the index of a network interface by name.
|
||||
|
||||
134
client/internal/wg_iface_monitor_linux.go
Normal file
134
client/internal/wg_iface_monitor_linux.go
Normal file
@@ -0,0 +1,134 @@
|
||||
//go:build linux
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
// watchInterface uses an RTNLGRP_LINK netlink subscription to detect
|
||||
// deletion or recreation of the WireGuard interface.
|
||||
//
|
||||
// The previous implementation polled net.InterfaceByName every 2 s, which
|
||||
// on Linux issues syscall.NetlinkRIB(RTM_GETLINK, ...) and dumps the
|
||||
// entire kernel link table on every call. On hosts with many veth
|
||||
// interfaces (containers, bridges) the resulting allocation churn was on
|
||||
// the order of ~1 GB/day from this single ticker, which on small ARM
|
||||
// hosts manifested as a slow RSS climb (see netbirdio/netbird#3678).
|
||||
//
|
||||
// The event-driven version below allocates only when the kernel actually
|
||||
// publishes a link event for the tracked interface — typically zero
|
||||
// allocations between events.
|
||||
func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) {
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
|
||||
// Buffer the channel to absorb event bursts (e.g. when many veth
|
||||
// pairs are created/destroyed at once by container runtimes).
|
||||
linkChan := make(chan netlink.LinkUpdate, 32)
|
||||
if err := netlink.LinkSubscribe(linkChan, done); err != nil {
|
||||
// Return shouldRestart=true so the engine recovers monitoring
|
||||
// via triggerClientRestart instead of silently losing it for
|
||||
// the rest of the process lifetime.
|
||||
return true, fmt.Errorf("subscribe to link updates: %w", err)
|
||||
}
|
||||
|
||||
// Race window: the interface could have been deleted (or recreated)
|
||||
// between the initial getInterfaceIndex() in Start and LinkSubscribe
|
||||
// completing its handshake with the kernel. Re-check explicitly so we
|
||||
// do not block forever waiting for an event that already fired.
|
||||
if currentIndex, err := getInterfaceIndex(ifaceName); err != nil {
|
||||
log.Infof("Interface monitor: %s deleted before subscription completed", ifaceName)
|
||||
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
|
||||
} else if currentIndex != expectedIndex {
|
||||
log.Infof("Interface monitor: %s recreated (index changed from %d to %d) before subscription completed",
|
||||
ifaceName, expectedIndex, currentIndex)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Infof("Interface monitor: stopped for %s", ifaceName)
|
||||
return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err())
|
||||
|
||||
case update, ok := <-linkChan:
|
||||
if !ok {
|
||||
// The vishvananda/netlink subscription goroutine closes
|
||||
// the channel on receive errors. Signal the engine to
|
||||
// restart so monitoring is re-established instead of
|
||||
// silently ending.
|
||||
log.Warnf("Interface monitor: link subscription channel closed unexpectedly for %s", ifaceName)
|
||||
return true, fmt.Errorf("link subscription channel closed unexpectedly")
|
||||
}
|
||||
if restart, err := inspectLinkEvent(update, ifaceName, expectedIndex); restart {
|
||||
return true, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// inspectLinkEvent classifies a single netlink link update against the
|
||||
// tracked WireGuard interface. It returns (true, err) when the engine
|
||||
// should restart monitoring; (false, nil) means the event is unrelated
|
||||
// and the caller should keep waiting.
|
||||
//
|
||||
// The error component, when non-nil, describes the kernel-side reason
|
||||
// (deletion or rename); the recreation case returns (true, nil) since
|
||||
// no error condition is reported.
|
||||
func inspectLinkEvent(update netlink.LinkUpdate, ifaceName string, expectedIndex int) (bool, error) {
|
||||
eventIndex := int(update.Index)
|
||||
eventName := ""
|
||||
if attrs := update.Attrs(); attrs != nil {
|
||||
eventName = attrs.Name
|
||||
}
|
||||
|
||||
switch update.Header.Type {
|
||||
case syscall.RTM_DELLINK:
|
||||
return inspectDelLink(eventIndex, ifaceName, expectedIndex)
|
||||
case syscall.RTM_NEWLINK:
|
||||
return inspectNewLink(eventIndex, eventName, ifaceName, expectedIndex)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// inspectDelLink reports a restart when an RTM_DELLINK arrives for the
|
||||
// tracked interface index.
|
||||
func inspectDelLink(eventIndex int, ifaceName string, expectedIndex int) (bool, error) {
|
||||
if eventIndex != expectedIndex {
|
||||
return false, nil
|
||||
}
|
||||
log.Infof("Interface monitor: %s deleted", ifaceName)
|
||||
return true, fmt.Errorf("interface %s deleted", ifaceName)
|
||||
}
|
||||
|
||||
// inspectNewLink reports a restart when an RTM_NEWLINK either:
|
||||
//
|
||||
// 1. Introduces a link with our name at a different index (recreation
|
||||
// after a delete), or
|
||||
//
|
||||
// 2. Reports a link still at our index but with a different name
|
||||
// (in-place rename). The previous polling implementation caught
|
||||
// this implicitly because net.InterfaceByName(ifaceName) would
|
||||
// start failing; the event-driven version has to test it.
|
||||
//
|
||||
// Same name + same index is just a flag/state change on the existing
|
||||
// interface and is ignored.
|
||||
func inspectNewLink(eventIndex int, eventName, ifaceName string, expectedIndex int) (bool, error) {
|
||||
if eventName == ifaceName && eventIndex != expectedIndex {
|
||||
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
|
||||
ifaceName, expectedIndex, eventIndex)
|
||||
return true, nil
|
||||
}
|
||||
if eventIndex == expectedIndex && eventName != "" && eventName != ifaceName {
|
||||
log.Infof("Interface monitor: %s renamed to %s (index %d), restarting engine",
|
||||
ifaceName, eventName, expectedIndex)
|
||||
return true, fmt.Errorf("interface %s renamed to %s", ifaceName, eventName)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
56
client/internal/wg_iface_monitor_other.go
Normal file
56
client/internal/wg_iface_monitor_other.go
Normal file
@@ -0,0 +1,56 @@
|
||||
//go:build !linux
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// watchInterface polls net.InterfaceByName at a fixed interval to detect
|
||||
// deletion or recreation of the WireGuard interface.
|
||||
//
|
||||
// This is the fallback used on non-Linux desktop and server platforms
|
||||
// (darwin, windows, freebsd). It is also compiled on android and ios so
|
||||
// the package builds on every supported GOOS, but it is never reached
|
||||
// at runtime there because Start() in wg_iface_monitor.go exits early
|
||||
// on mobile platforms.
|
||||
//
|
||||
// The Linux build (see wg_iface_monitor_linux.go) uses an event-driven
|
||||
// RTNLGRP_LINK netlink subscription instead, because on Linux
|
||||
// net.InterfaceByName issues syscall.NetlinkRIB(RTM_GETLINK, ...) which
|
||||
// dumps the entire kernel link table on every call and produces
|
||||
// significant allocation churn (netbirdio/netbird#3678).
|
||||
//
|
||||
// Windows is also reported in #3678 as affected by RSS climb. A future
|
||||
// follow-up could implement an event-driven watcher there using
|
||||
// NotifyIpInterfaceChange from iphlpapi.
|
||||
func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) {
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Infof("Interface monitor: stopped for %s", ifaceName)
|
||||
return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err())
|
||||
case <-ticker.C:
|
||||
currentIndex, err := getInterfaceIndex(ifaceName)
|
||||
if err != nil {
|
||||
// Interface was deleted
|
||||
log.Infof("Interface monitor: %s deleted", ifaceName)
|
||||
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
|
||||
}
|
||||
|
||||
// Check if interface index changed (interface was recreated)
|
||||
if currentIndex != expectedIndex {
|
||||
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
|
||||
ifaceName, expectedIndex, currentIndex)
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -224,15 +224,20 @@ func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
|
||||
|
||||
func (m *Manager) writeSSHConfig(sshConfig string) error {
|
||||
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
|
||||
sshConfigPathTmp := sshConfigPath + ".tmp"
|
||||
|
||||
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
|
||||
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
|
||||
}
|
||||
|
||||
if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil {
|
||||
if err := writeFileWithTimeout(sshConfigPathTmp, []byte(sshConfig), 0644); err != nil {
|
||||
return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
|
||||
}
|
||||
|
||||
if err := os.Rename(sshConfigPathTmp, sshConfigPath); err != nil {
|
||||
return fmt.Errorf("rename ssh config %s -> %s: %w", sshConfigPathTmp, sshConfigPath, err)
|
||||
}
|
||||
|
||||
log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -13,11 +13,9 @@ import (
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
nbgrpc "github.com/netbirdio/netbird/client/grpc"
|
||||
"github.com/netbirdio/netbird/flow/proto"
|
||||
@@ -301,12 +299,11 @@ func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff
|
||||
}, ctx)
|
||||
}
|
||||
|
||||
// isContextDone reports whether the local context has been canceled or has
|
||||
// exceeded its deadline. It deliberately does not inspect gRPC status codes:
|
||||
// a server- or proxy-sent codes.Canceled / codes.DeadlineExceeded must not
|
||||
// short-circuit our retry loop, since retrying is the correct response when
|
||||
// the local context is still alive.
|
||||
func isContextDone(err error) bool {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return true
|
||||
}
|
||||
if s, ok := status.FromError(err); ok {
|
||||
return s.Code() == codes.Canceled || s.Code() == codes.DeadlineExceeded
|
||||
}
|
||||
return false
|
||||
return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
20
go.mod
20
go.mod
@@ -17,8 +17,8 @@ require (
|
||||
github.com/spf13/cobra v1.10.1
|
||||
github.com/spf13/pflag v1.0.9
|
||||
github.com/vishvananda/netlink v1.3.1
|
||||
golang.org/x/crypto v0.49.0
|
||||
golang.org/x/sys v0.42.0
|
||||
golang.org/x/crypto v0.50.0
|
||||
golang.org/x/sys v0.43.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
@@ -68,9 +68,10 @@ require (
|
||||
github.com/jackc/pgx/v5 v5.5.5
|
||||
github.com/libdns/route53 v1.5.0
|
||||
github.com/libp2p/go-nat v0.2.0
|
||||
github.com/libp2p/go-netroute v0.2.1
|
||||
github.com/libp2p/go-netroute v0.4.0
|
||||
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
|
||||
github.com/mdlayher/socket v0.5.1
|
||||
github.com/mdp/qrterminal/v3 v3.2.1
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
|
||||
@@ -117,11 +118,11 @@ require (
|
||||
goauthentik.io/api/v3 v3.2023051.3
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b
|
||||
golang.org/x/mobile v0.0.0-20251113184115-a159579294ab
|
||||
golang.org/x/mod v0.33.0
|
||||
golang.org/x/net v0.52.0
|
||||
golang.org/x/mod v0.34.0
|
||||
golang.org/x/net v0.53.0
|
||||
golang.org/x/oauth2 v0.36.0
|
||||
golang.org/x/sync v0.20.0
|
||||
golang.org/x/term v0.41.0
|
||||
golang.org/x/term v0.42.0
|
||||
golang.org/x/time v0.15.0
|
||||
google.golang.org/api v0.276.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
@@ -302,13 +303,14 @@ require (
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||
golang.org/x/image v0.33.0 // indirect
|
||||
golang.org/x/text v0.35.0 // indirect
|
||||
golang.org/x/tools v0.42.0 // indirect
|
||||
golang.org/x/text v0.36.0 // indirect
|
||||
golang.org/x/tools v0.43.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
|
||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
rsc.io/qr v0.2.0 // indirect
|
||||
)
|
||||
|
||||
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502
|
||||
@@ -321,8 +323,6 @@ replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-2023080111
|
||||
|
||||
replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51
|
||||
|
||||
replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944
|
||||
|
||||
replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0
|
||||
|
||||
replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0
|
||||
|
||||
36
go.sum
36
go.sum
@@ -395,6 +395,8 @@ github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA=
|
||||
github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q=
|
||||
github.com/libp2p/go-nat v0.2.0 h1:Tyz+bUFAYqGyJ/ppPPymMGbIgNRH+WqC5QrT5fKrrGk=
|
||||
github.com/libp2p/go-nat v0.2.0/go.mod h1:3MJr+GRpRkyT65EpVPBstXLvOlAPzUVlG6Pwg9ohLJk=
|
||||
github.com/libp2p/go-netroute v0.4.0 h1:sZZx9hyANYUx9PZyqcgE/E1GUG3iEtTZHUEvdtXT7/Q=
|
||||
github.com/libp2p/go-netroute v0.4.0/go.mod h1:Nkd5ShYgSMS5MUKy/MU2T57xFoOKvvLR92Lic48LEyA=
|
||||
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 h1:J56rFEfUTFT9j9CiRXhi1r8lUJ4W5idG3CiaBZGojNU=
|
||||
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81/go.mod h1:RD8ML/YdXctQ7qbcizZkw5mZ6l8Ogrl1dodBzVJduwI=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
@@ -415,6 +417,8 @@ github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0
|
||||
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o=
|
||||
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
|
||||
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
|
||||
github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFet4=
|
||||
github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU=
|
||||
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
|
||||
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
|
||||
github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
|
||||
@@ -451,8 +455,6 @@ github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U
|
||||
github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU=
|
||||
github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus=
|
||||
github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk=
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 h1:F3zS5fT9xzD1OFLfcdAE+3FfyiwjGukF1hvj0jErgs8=
|
||||
@@ -709,8 +711,8 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
||||
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
|
||||
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o=
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8=
|
||||
golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ=
|
||||
@@ -727,8 +729,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
|
||||
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
|
||||
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
|
||||
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
@@ -747,8 +749,8 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
||||
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
|
||||
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
|
||||
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
|
||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||
@@ -799,8 +801,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
@@ -813,8 +815,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
|
||||
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
|
||||
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
|
||||
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
|
||||
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
@@ -826,8 +828,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -841,8 +843,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
|
||||
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
|
||||
golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
|
||||
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
@@ -915,3 +917,5 @@ gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
|
||||
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
||||
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 h1:mGJaeA61P8dEHTqdvAgc70ZIV3QoUoJcXCRyyjO26OA=
|
||||
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q=
|
||||
rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY=
|
||||
rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=
|
||||
|
||||
@@ -89,21 +89,33 @@ func (p *Provider) ListConnectors(ctx context.Context) ([]*ConnectorConfig, erro
|
||||
}
|
||||
|
||||
// UpdateConnector updates an existing connector in Dex storage.
|
||||
// It merges incoming updates with existing values to prevent data loss on partial updates.
|
||||
// It overlays user-mutable config fields (issuer, clientID, clientSecret,
|
||||
// redirectURI) onto the stored connector config, and updates the connector name
|
||||
// when cfg.Name is set. Empty fields on cfg leave stored values unchanged, so
|
||||
// partial updates preserve create-time defaults such as scopes, claimMapping,
|
||||
// and userIDKey.
|
||||
func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error {
|
||||
if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) {
|
||||
oldCfg, err := p.parseStorageConnector(old)
|
||||
if err != nil {
|
||||
return storage.Connector{}, fmt.Errorf("failed to parse existing connector: %w", err)
|
||||
if cfg.Type != "" && cfg.Type != inferIdentityProviderType(old.Type, cfg.ID, nil) {
|
||||
return storage.Connector{}, errors.New("connector type change not allowed")
|
||||
}
|
||||
|
||||
mergeConnectorConfig(cfg, oldCfg)
|
||||
|
||||
storageConn, err := p.buildStorageConnector(cfg)
|
||||
configData, err := overlayConnectorConfig(old.Config, cfg)
|
||||
if err != nil {
|
||||
return storage.Connector{}, fmt.Errorf("failed to build connector: %w", err)
|
||||
return storage.Connector{}, fmt.Errorf("failed to overlay connector config: %w", err)
|
||||
}
|
||||
return storageConn, nil
|
||||
|
||||
name := cfg.Name
|
||||
if name == "" {
|
||||
name = old.Name
|
||||
}
|
||||
|
||||
return storage.Connector{
|
||||
ID: cfg.ID,
|
||||
Type: old.Type,
|
||||
Name: name,
|
||||
Config: configData,
|
||||
}, nil
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to update connector: %w", err)
|
||||
}
|
||||
@@ -112,23 +124,27 @@ func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) er
|
||||
return nil
|
||||
}
|
||||
|
||||
// mergeConnectorConfig preserves existing values for empty fields in the update.
|
||||
func mergeConnectorConfig(cfg, oldCfg *ConnectorConfig) {
|
||||
if cfg.ClientSecret == "" {
|
||||
cfg.ClientSecret = oldCfg.ClientSecret
|
||||
// overlayConnectorConfig writes only the user-mutable fields onto the existing
|
||||
// stored config, preserving every other field (scopes, claimMapping, userIDKey,
|
||||
// insecure flags, etc.). Empty fields on cfg leave the existing value alone.
|
||||
func overlayConnectorConfig(oldConfig []byte, cfg *ConnectorConfig) ([]byte, error) {
|
||||
var m map[string]any
|
||||
if err := decodeConnectorConfig(oldConfig, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg.RedirectURI == "" {
|
||||
cfg.RedirectURI = oldCfg.RedirectURI
|
||||
if cfg.Issuer != "" {
|
||||
m["issuer"] = cfg.Issuer
|
||||
}
|
||||
if cfg.Issuer == "" && cfg.Type == oldCfg.Type {
|
||||
cfg.Issuer = oldCfg.Issuer
|
||||
if cfg.ClientID != "" {
|
||||
m["clientID"] = cfg.ClientID
|
||||
}
|
||||
if cfg.ClientID == "" {
|
||||
cfg.ClientID = oldCfg.ClientID
|
||||
if cfg.ClientSecret != "" {
|
||||
m["clientSecret"] = cfg.ClientSecret
|
||||
}
|
||||
if cfg.Name == "" {
|
||||
cfg.Name = oldCfg.Name
|
||||
if cfg.RedirectURI != "" {
|
||||
m["redirectURI"] = cfg.RedirectURI
|
||||
}
|
||||
return encodeConnectorConfig(m)
|
||||
}
|
||||
|
||||
// DeleteConnector removes a connector from Dex storage.
|
||||
@@ -216,6 +232,10 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
|
||||
oidcConfig["getUserInfo"] = true
|
||||
case "entra":
|
||||
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
|
||||
// Use the Entra Object ID (oid) instead of the default OIDC sub claim.
|
||||
// Entra issues sub as a per-app pairwise identifier that does not match
|
||||
// the stable Object ID.
|
||||
oidcConfig["userIDKey"] = "oid"
|
||||
case "okta":
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
case "pocketid":
|
||||
|
||||
205
idp/dex/connector_test.go
Normal file
205
idp/dex/connector_test.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package dex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/sql"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestProvider(t *testing.T) (*Provider, func()) {
|
||||
t.Helper()
|
||||
tmpDir, err := os.MkdirTemp("", "dex-connector-test-*")
|
||||
require.NoError(t, err)
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
s, err := (&sql.SQLite3{File: filepath.Join(tmpDir, "dex.db")}).Open(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
return &Provider{storage: s, logger: logger}, func() {
|
||||
_ = s.Close()
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOIDCConnectorConfig_EntraSetsUserIDKey(t *testing.T) {
|
||||
cfg := &ConnectorConfig{
|
||||
ID: "entra-test",
|
||||
Name: "Entra",
|
||||
Type: "entra",
|
||||
Issuer: "https://login.microsoftonline.com/tid/v2.0",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
}
|
||||
data, err := buildOIDCConnectorConfig(cfg, "https://example.com/oauth2/callback")
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]any
|
||||
require.NoError(t, json.Unmarshal(data, &m))
|
||||
|
||||
assert.Equal(t, "oid", m["userIDKey"], "entra connectors must default userIDKey to oid")
|
||||
assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"])
|
||||
}
|
||||
|
||||
func TestBuildOIDCConnectorConfig_NonEntraDoesNotSetUserIDKey(t *testing.T) {
|
||||
// ensures the Entra userIDKey override does not leak into other OIDC providers,
|
||||
// which already use a stable sub claim.
|
||||
for _, typ := range []string{"oidc", "zitadel", "okta", "pocketid", "authentik", "keycloak", "adfs"} {
|
||||
t.Run(typ, func(t *testing.T) {
|
||||
data, err := buildOIDCConnectorConfig(&ConnectorConfig{Type: typ}, "https://example.com/oauth2/callback")
|
||||
require.NoError(t, err)
|
||||
var m map[string]any
|
||||
require.NoError(t, json.Unmarshal(data, &m))
|
||||
_, ok := m["userIDKey"]
|
||||
assert.False(t, ok, "%s connectors must not have userIDKey set", typ)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateConnector_PreservesCreateTimeDefaults(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
p, cleanup := newTestProvider(t)
|
||||
defer cleanup()
|
||||
|
||||
created, err := p.CreateConnector(ctx, &ConnectorConfig{
|
||||
ID: "entra-test",
|
||||
Name: "Entra",
|
||||
Type: "entra",
|
||||
Issuer: "https://login.microsoftonline.com/tid/v2.0",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "old-secret",
|
||||
RedirectURI: "https://example.com/oauth2/callback",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "entra-test", created.ID)
|
||||
|
||||
// Rotate only the client secret.
|
||||
err = p.UpdateConnector(ctx, &ConnectorConfig{
|
||||
ID: "entra-test",
|
||||
Type: "entra",
|
||||
ClientSecret: "new-secret",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := p.storage.GetConnector(ctx, "entra-test")
|
||||
require.NoError(t, err)
|
||||
var m map[string]any
|
||||
require.NoError(t, json.Unmarshal(conn.Config, &m))
|
||||
|
||||
assert.Equal(t, "new-secret", m["clientSecret"], "clientSecret should be rotated")
|
||||
assert.Equal(t, "client-id", m["clientID"], "clientID must survive (overlay should leave it alone)")
|
||||
assert.Equal(t, "https://login.microsoftonline.com/tid/v2.0", m["issuer"])
|
||||
assert.Equal(t, "oid", m["userIDKey"], "userIDKey must survive update")
|
||||
assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"], "claimMapping must survive update")
|
||||
}
|
||||
|
||||
func TestUpdateConnector_DoesNotAddUserIDKeyToExistingConnector(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
p, cleanup := newTestProvider(t)
|
||||
defer cleanup()
|
||||
|
||||
// Seed a connector directly into storage without userIDKey
|
||||
preFixConfig, err := json.Marshal(map[string]any{
|
||||
"issuer": "https://login.microsoftonline.com/tid/v2.0",
|
||||
"clientID": "client-id",
|
||||
"clientSecret": "old-secret",
|
||||
"redirectURI": "https://example.com/oauth2/callback",
|
||||
"scopes": []string{"openid", "profile", "email"},
|
||||
"claimMapping": map[string]string{"email": "preferred_username"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, p.storage.CreateConnector(ctx, storage.Connector{
|
||||
ID: "entra-prefix",
|
||||
Type: "oidc",
|
||||
Name: "Entra",
|
||||
Config: preFixConfig,
|
||||
}))
|
||||
|
||||
// Rotate client secret via UpdateConnector.
|
||||
err = p.UpdateConnector(ctx, &ConnectorConfig{
|
||||
ID: "entra-prefix",
|
||||
Type: "entra",
|
||||
ClientSecret: "new-secret",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := p.storage.GetConnector(ctx, "entra-prefix")
|
||||
require.NoError(t, err)
|
||||
var m map[string]any
|
||||
require.NoError(t, json.Unmarshal(conn.Config, &m))
|
||||
|
||||
assert.Equal(t, "new-secret", m["clientSecret"])
|
||||
_, has := m["userIDKey"]
|
||||
assert.False(t, has, "userIDKey must not be auto-added to a connector that did not have it before")
|
||||
}
|
||||
|
||||
func TestUpdateConnector_RejectsTypeChange(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
p, cleanup := newTestProvider(t)
|
||||
defer cleanup()
|
||||
|
||||
_, err := p.CreateConnector(ctx, &ConnectorConfig{
|
||||
ID: "entra-test",
|
||||
Name: "Entra",
|
||||
Type: "entra",
|
||||
Issuer: "https://login.microsoftonline.com/tid/v2.0",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "secret",
|
||||
RedirectURI: "https://example.com/oauth2/callback",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attempt to switch the connector to okta.
|
||||
err = p.UpdateConnector(ctx, &ConnectorConfig{
|
||||
ID: "entra-test",
|
||||
Type: "okta",
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "connector type change not allowed")
|
||||
|
||||
// stored connector type/config unchanged after the rejected update.
|
||||
conn, err := p.storage.GetConnector(ctx, "entra-test")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "oidc", conn.Type)
|
||||
var m map[string]any
|
||||
require.NoError(t, json.Unmarshal(conn.Config, &m))
|
||||
assert.Equal(t, "oid", m["userIDKey"])
|
||||
}
|
||||
|
||||
func TestUpdateConnector_AllowsSameTypeUpdate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
p, cleanup := newTestProvider(t)
|
||||
defer cleanup()
|
||||
|
||||
_, err := p.CreateConnector(ctx, &ConnectorConfig{
|
||||
ID: "entra-test",
|
||||
Name: "Entra",
|
||||
Type: "entra",
|
||||
Issuer: "https://login.microsoftonline.com/old/v2.0",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "secret",
|
||||
RedirectURI: "https://example.com/oauth2/callback",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = p.UpdateConnector(ctx, &ConnectorConfig{
|
||||
ID: "entra-test",
|
||||
Type: "entra",
|
||||
Issuer: "https://login.microsoftonline.com/new/v2.0",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := p.storage.GetConnector(ctx, "entra-test")
|
||||
require.NoError(t, err)
|
||||
var m map[string]any
|
||||
require.NoError(t, json.Unmarshal(conn.Config, &m))
|
||||
assert.Equal(t, "https://login.microsoftonline.com/new/v2.0", m["issuer"])
|
||||
}
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
|
||||
// Manager defines the interface for proxy operations
|
||||
type Manager interface {
|
||||
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error
|
||||
Disconnect(ctx context.Context, proxyID string) error
|
||||
Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error)
|
||||
Disconnect(ctx context.Context, proxyID, sessionID string) error
|
||||
Heartbeat(ctx context.Context, p *Proxy) error
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveClusters(ctx context.Context) ([]Cluster, error)
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
|
||||
@@ -13,7 +13,8 @@ import (
|
||||
// store defines the interface for proxy persistence operations
|
||||
type store interface {
|
||||
SaveProxy(ctx context.Context, p *proxy.Proxy) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
@@ -43,7 +44,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
|
||||
|
||||
// Connect registers a new proxy connection in the database.
|
||||
// capabilities may be nil for old proxies that do not report them.
|
||||
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error {
|
||||
func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
|
||||
now := time.Now()
|
||||
var caps proxy.Capabilities
|
||||
if capabilities != nil {
|
||||
@@ -51,6 +52,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
|
||||
}
|
||||
p := &proxy.Proxy{
|
||||
ID: proxyID,
|
||||
SessionID: sessionID,
|
||||
ClusterAddress: clusterAddress,
|
||||
IPAddress: ipAddress,
|
||||
LastSeen: now,
|
||||
@@ -61,48 +63,42 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
|
||||
|
||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"proxyID": proxyID,
|
||||
"sessionID": sessionID,
|
||||
"clusterAddress": clusterAddress,
|
||||
"ipAddress": ipAddress,
|
||||
}).Info("proxy connected")
|
||||
|
||||
return nil
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Disconnect marks a proxy as disconnected in the database
|
||||
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
|
||||
now := time.Now()
|
||||
p := &proxy.Proxy{
|
||||
ID: proxyID,
|
||||
Status: "disconnected",
|
||||
DisconnectedAt: &now,
|
||||
LastSeen: now,
|
||||
}
|
||||
|
||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
|
||||
// Disconnect marks a proxy as disconnected in the database.
|
||||
func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
|
||||
if err := m.store.DisconnectProxy(ctx, proxyID, sessionID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"proxyID": proxyID,
|
||||
"proxyID": proxyID,
|
||||
"sessionID": sessionID,
|
||||
}).Info("proxy disconnected")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Heartbeat updates the proxy's last seen timestamp
|
||||
func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
|
||||
// Heartbeat updates the proxy's last seen timestamp.
|
||||
func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
if err := m.store.UpdateProxyHeartbeat(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", p.ID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("updated heartbeat for proxy %s", proxyID)
|
||||
log.WithContext(ctx).Tracef("updated heartbeat for proxy %s session %s", p.ID, p.SessionID)
|
||||
m.metrics.IncrementProxyHeartbeatCount()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -93,31 +93,32 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
|
||||
}
|
||||
|
||||
// Connect mocks base method.
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error {
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
|
||||
ret0, _ := ret[0].(*Proxy)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Connect indicates an expected call of Connect.
|
||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
|
||||
}
|
||||
|
||||
// Disconnect mocks base method.
|
||||
func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error {
|
||||
func (m *MockManager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID)
|
||||
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID, sessionID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Disconnect indicates an expected call of Disconnect.
|
||||
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID, sessionID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID, sessionID)
|
||||
}
|
||||
|
||||
// GetActiveClusterAddresses mocks base method.
|
||||
@@ -151,17 +152,17 @@ func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Ca
|
||||
}
|
||||
|
||||
// Heartbeat mocks base method.
|
||||
func (m *MockManager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
func (m *MockManager) Heartbeat(ctx context.Context, p *Proxy) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID, clusterAddress, ipAddress)
|
||||
ret := m.ctrl.Call(m, "Heartbeat", ctx, p)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Heartbeat indicates an expected call of Heartbeat.
|
||||
func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) Heartbeat(ctx, p interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, p)
|
||||
}
|
||||
|
||||
// MockController is a mock of Controller interface.
|
||||
|
||||
@@ -18,12 +18,13 @@ type Capabilities struct {
|
||||
// Proxy represents a reverse proxy instance
|
||||
type Proxy struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)"`
|
||||
SessionID string `gorm:"type:varchar(36)"`
|
||||
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
|
||||
IPAddress string `gorm:"type:varchar(45)"`
|
||||
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
||||
ConnectedAt *time.Time
|
||||
DisconnectedAt *time.Time
|
||||
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
|
||||
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
|
||||
Capabilities Capabilities `gorm:"embedded"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
@@ -11,11 +11,14 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
"google.golang.org/grpc/codes"
|
||||
@@ -81,14 +84,44 @@ type ProxyServiceServer struct {
|
||||
// Store for PKCE verifiers
|
||||
pkceVerifierStore *PKCEVerifierStore
|
||||
|
||||
// tokenTTL is the lifetime of one-time tokens generated for proxy
|
||||
// authentication. Defaults to defaultProxyTokenTTL when zero.
|
||||
tokenTTL time.Duration
|
||||
|
||||
// snapshotBatchSize is the number of mappings per gRPC message during
|
||||
// initial snapshot delivery. Configurable via NB_PROXY_SNAPSHOT_BATCH_SIZE.
|
||||
snapshotBatchSize int
|
||||
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
const pkceVerifierTTL = 10 * time.Minute
|
||||
|
||||
const defaultProxyTokenTTL = 5 * time.Minute
|
||||
|
||||
const defaultSnapshotBatchSize = 500
|
||||
|
||||
func snapshotBatchSizeFromEnv() int {
|
||||
if v := os.Getenv("NB_PROXY_SNAPSHOT_BATCH_SIZE"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return defaultSnapshotBatchSize
|
||||
}
|
||||
|
||||
// proxyTokenTTL returns the configured token TTL or the default when unset.
|
||||
func (s *ProxyServiceServer) proxyTokenTTL() time.Duration {
|
||||
if s.tokenTTL > 0 {
|
||||
return s.tokenTTL
|
||||
}
|
||||
return defaultProxyTokenTTL
|
||||
}
|
||||
|
||||
// proxyConnection represents a connected proxy
|
||||
type proxyConnection struct {
|
||||
proxyID string
|
||||
sessionID string
|
||||
address string
|
||||
capabilities *proto.ProxyCapabilities
|
||||
stream proto.ProxyService_GetMappingUpdateServer
|
||||
@@ -108,6 +141,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
|
||||
peersManager: peersManager,
|
||||
usersManager: usersManager,
|
||||
proxyManager: proxyMgr,
|
||||
snapshotBatchSize: snapshotBatchSizeFromEnv(),
|
||||
cancel: cancel,
|
||||
}
|
||||
go s.cleanupStaleProxies(ctx)
|
||||
@@ -166,9 +200,22 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
||||
}
|
||||
|
||||
sessionID := uuid.NewString()
|
||||
|
||||
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
|
||||
oldConn := old.(*proxyConnection)
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"old_session_id": oldConn.sessionID,
|
||||
"new_session_id": sessionID,
|
||||
}).Info("Superseding existing proxy connection")
|
||||
oldConn.cancel()
|
||||
}
|
||||
|
||||
connCtx, cancel := context.WithCancel(ctx)
|
||||
conn := &proxyConnection{
|
||||
proxyID: proxyID,
|
||||
sessionID: sessionID,
|
||||
address: proxyAddress,
|
||||
capabilities: req.GetCapabilities(),
|
||||
stream: stream,
|
||||
@@ -177,79 +224,93 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||
}
|
||||
|
||||
// Register proxy in database with capabilities
|
||||
var caps *proxy.Capabilities
|
||||
if c := req.GetCapabilities(); c != nil {
|
||||
caps = &proxy.Capabilities{
|
||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||
RequireSubdomain: c.RequireSubdomain,
|
||||
SupportsCrowdsec: c.SupportsCrowdsec,
|
||||
SupportsCrowdsec: c.SupportsCrowdsec,
|
||||
}
|
||||
}
|
||||
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil {
|
||||
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
||||
s.connectedProxies.Delete(proxyID)
|
||||
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
|
||||
}
|
||||
cancel()
|
||||
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"address": proxyAddress,
|
||||
"cluster_addr": proxyAddress,
|
||||
"total_proxies": len(s.GetConnectedProxies()),
|
||||
}).Info("Proxy registered in cluster")
|
||||
defer func() {
|
||||
if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil {
|
||||
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
|
||||
}
|
||||
|
||||
s.connectedProxies.Delete(proxyID)
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
|
||||
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
log.Infof("Proxy %s disconnected", proxyID)
|
||||
}()
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||
}
|
||||
|
||||
if err := s.sendSnapshot(ctx, conn); err != nil {
|
||||
if s.connectedProxies.CompareAndDelete(proxyID, conn) {
|
||||
if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr)
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr)
|
||||
}
|
||||
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
|
||||
}
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
|
||||
// Start heartbeat goroutine
|
||||
go s.heartbeat(connCtx, proxyID, proxyAddress, peerInfo)
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"session_id": sessionID,
|
||||
"address": proxyAddress,
|
||||
"cluster_addr": proxyAddress,
|
||||
"total_proxies": len(s.GetConnectedProxies()),
|
||||
}).Info("Proxy registered in cluster")
|
||||
defer func() {
|
||||
if !s.connectedProxies.CompareAndDelete(proxyID, conn) {
|
||||
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", proxyID, sessionID)
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
|
||||
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
|
||||
}
|
||||
if err := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); err != nil {
|
||||
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
|
||||
}()
|
||||
|
||||
go s.heartbeat(connCtx, proxyRecord)
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
log.WithContext(ctx).Warnf("Failed to send update: %v", err)
|
||||
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
|
||||
case <-connCtx.Done():
|
||||
log.WithContext(ctx).Infof("Proxy %s context canceled", proxyID)
|
||||
return connCtx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeat updates the proxy's last_seen timestamp every minute
|
||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) {
|
||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := s.proxyManager.Heartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
|
||||
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err)
|
||||
if err := s.proxyManager.Heartbeat(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", p.ID, err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
log.WithContext(ctx).Infof("proxy %s heartbeat stopped: context canceled", p.ID)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -267,22 +328,27 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
return err
|
||||
}
|
||||
|
||||
// Send mappings in batches to reduce per-message gRPC overhead while
|
||||
// staying well within the default 4 MB message size limit.
|
||||
for i := 0; i < len(mappings); i += s.snapshotBatchSize {
|
||||
end := i + s.snapshotBatchSize
|
||||
if end > len(mappings) {
|
||||
end = len(mappings)
|
||||
}
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
Mapping: mappings[i:end],
|
||||
InitialSyncComplete: end == len(mappings),
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send snapshot batch: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(mappings) == 0 {
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
InitialSyncComplete: true,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send snapshot completion: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for i, m := range mappings {
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{m},
|
||||
InitialSyncComplete: i == len(mappings)-1,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send proxy mapping: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -300,13 +366,9 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
|
||||
continue
|
||||
}
|
||||
|
||||
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute)
|
||||
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL())
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"service": service.Name,
|
||||
"account": service.AccountID,
|
||||
}).WithError(err).Error("failed to generate auth token for snapshot")
|
||||
continue
|
||||
return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
|
||||
@@ -386,13 +448,16 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
|
||||
conn := value.(*proxyConnection)
|
||||
resp := s.perProxyMessage(update, conn.proxyID)
|
||||
if resp == nil {
|
||||
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
|
||||
conn.cancel()
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case conn.sendChan <- resp:
|
||||
log.Debugf("Sent service update to proxy server %s", conn.proxyID)
|
||||
default:
|
||||
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
|
||||
log.Warnf("Send channel full for proxy %s, disconnecting to force resync", conn.proxyID)
|
||||
conn.cancel()
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -472,13 +537,16 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
|
||||
}
|
||||
msg := s.perProxyMessage(updateResponse, proxyID)
|
||||
if msg == nil {
|
||||
log.WithContext(ctx).Warnf("Token generation failed for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr)
|
||||
conn.cancel()
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case conn.sendChan <- msg:
|
||||
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
||||
default:
|
||||
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
|
||||
log.WithContext(ctx).Warnf("Send channel full for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr)
|
||||
conn.cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -504,7 +572,8 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo
|
||||
// perProxyMessage returns a copy of update with a fresh one-time token for
|
||||
// create/update operations. For delete operations the original mapping is
|
||||
// used unchanged because proxies do not need to authenticate for removal.
|
||||
// Returns nil if token generation fails (the proxy should be skipped).
|
||||
// Returns nil if token generation fails; the caller must disconnect the
|
||||
// proxy so it can resync via a fresh snapshot on reconnect.
|
||||
func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse {
|
||||
resp := make([]*proto.ProxyMapping, 0, len(update.Mapping))
|
||||
for _, mapping := range update.Mapping {
|
||||
@@ -513,7 +582,7 @@ func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateRespo
|
||||
continue
|
||||
}
|
||||
|
||||
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute)
|
||||
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, s.proxyTokenTTL())
|
||||
if err != nil {
|
||||
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
|
||||
return nil
|
||||
|
||||
174
management/internals/shared/grpc/proxy_snapshot_test.go
Normal file
174
management/internals/shared/grpc/proxy_snapshot_test.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// recordingStream captures all messages sent via Send so tests can inspect
|
||||
// batching behaviour without a real gRPC transport.
|
||||
type recordingStream struct {
|
||||
grpc.ServerStream
|
||||
messages []*proto.GetMappingUpdateResponse
|
||||
}
|
||||
|
||||
func (s *recordingStream) Send(m *proto.GetMappingUpdateResponse) error {
|
||||
s.messages = append(s.messages, m)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *recordingStream) Context() context.Context { return context.Background() }
|
||||
func (s *recordingStream) SetHeader(metadata.MD) error { return nil }
|
||||
func (s *recordingStream) SendHeader(metadata.MD) error { return nil }
|
||||
func (s *recordingStream) SetTrailer(metadata.MD) {}
|
||||
func (s *recordingStream) SendMsg(any) error { return nil }
|
||||
func (s *recordingStream) RecvMsg(any) error { return nil }
|
||||
|
||||
// makeServices creates n enabled services assigned to the given cluster.
|
||||
func makeServices(n int, cluster string) []*rpservice.Service {
|
||||
services := make([]*rpservice.Service, n)
|
||||
for i := range n {
|
||||
services[i] = &rpservice.Service{
|
||||
ID: fmt.Sprintf("svc-%d", i),
|
||||
AccountID: "acct-1",
|
||||
Name: fmt.Sprintf("svc-%d", i),
|
||||
Domain: fmt.Sprintf("svc-%d.example.com", i),
|
||||
ProxyCluster: cluster,
|
||||
Enabled: true,
|
||||
Targets: []*rpservice.Target{
|
||||
{TargetType: rpservice.TargetTypeHost, TargetId: "host-1"},
|
||||
},
|
||||
}
|
||||
}
|
||||
return services
|
||||
}
|
||||
|
||||
func newSnapshotTestServer(t *testing.T, batchSize int) *ProxyServiceServer {
|
||||
t.Helper()
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: NewOneTimeTokenStore(context.Background(), testCacheStore(t)),
|
||||
snapshotBatchSize: batchSize,
|
||||
}
|
||||
s.SetProxyController(newTestProxyController())
|
||||
return s
|
||||
}
|
||||
|
||||
func TestSendSnapshot_BatchesMappings(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 3
|
||||
const totalServices = 7 // 3 + 3 + 1
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &recordingStream{}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
stream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshot(context.Background(), conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Expect ceil(7/3) = 3 messages
|
||||
require.Len(t, stream.messages, 3, "should send ceil(totalServices/batchSize) messages")
|
||||
|
||||
assert.Len(t, stream.messages[0].Mapping, 3)
|
||||
assert.False(t, stream.messages[0].InitialSyncComplete, "first batch should not be sync-complete")
|
||||
|
||||
assert.Len(t, stream.messages[1].Mapping, 3)
|
||||
assert.False(t, stream.messages[1].InitialSyncComplete, "middle batch should not be sync-complete")
|
||||
|
||||
assert.Len(t, stream.messages[2].Mapping, 1)
|
||||
assert.True(t, stream.messages[2].InitialSyncComplete, "last batch must be sync-complete")
|
||||
|
||||
// Verify all service IDs are present exactly once
|
||||
seen := make(map[string]bool)
|
||||
for _, msg := range stream.messages {
|
||||
for _, m := range msg.Mapping {
|
||||
assert.False(t, seen[m.Id], "duplicate service ID %s", m.Id)
|
||||
seen[m.Id] = true
|
||||
}
|
||||
}
|
||||
assert.Len(t, seen, totalServices)
|
||||
}
|
||||
|
||||
func TestSendSnapshot_ExactBatchMultiple(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 3
|
||||
const totalServices = 6 // exactly 2 batches
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &recordingStream{}
|
||||
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||
|
||||
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||
require.Len(t, stream.messages, 2)
|
||||
|
||||
assert.Len(t, stream.messages[0].Mapping, 3)
|
||||
assert.False(t, stream.messages[0].InitialSyncComplete)
|
||||
|
||||
assert.Len(t, stream.messages[1].Mapping, 3)
|
||||
assert.True(t, stream.messages[1].InitialSyncComplete)
|
||||
}
|
||||
|
||||
func TestSendSnapshot_SingleBatch(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 100
|
||||
const totalServices = 5
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &recordingStream{}
|
||||
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||
|
||||
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||
require.Len(t, stream.messages, 1, "all mappings should fit in one batch")
|
||||
assert.Len(t, stream.messages[0].Mapping, totalServices)
|
||||
assert.True(t, stream.messages[0].InitialSyncComplete)
|
||||
}
|
||||
|
||||
func TestSendSnapshot_EmptySnapshot(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil)
|
||||
|
||||
s := newSnapshotTestServer(t, 500)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &recordingStream{}
|
||||
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||
|
||||
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||
require.Len(t, stream.messages, 1, "empty snapshot must still send sync-complete")
|
||||
assert.Empty(t, stream.messages[0].Mapping)
|
||||
assert.True(t, stream.messages[0].InitialSyncComplete)
|
||||
}
|
||||
@@ -85,11 +85,14 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
|
||||
// registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities.
|
||||
func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse {
|
||||
ch := make(chan *proto.GetMappingUpdateResponse, 10)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
conn := &proxyConnection{
|
||||
proxyID: proxyID,
|
||||
address: clusterAddr,
|
||||
capabilities: caps,
|
||||
sendChan: ch,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
|
||||
|
||||
@@ -818,6 +818,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
if !addedByUser {
|
||||
opEvent.Meta["setup_key_name"] = peerAddConfig.SetupKeyName
|
||||
}
|
||||
if newPeer.Status != nil && newPeer.Status.RequiresApproval {
|
||||
opEvent.Meta["pending_approval"] = true
|
||||
}
|
||||
|
||||
if !temporary {
|
||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||
|
||||
@@ -5437,13 +5437,35 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy or creates a new entry if it doesn't exist
|
||||
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
// DisconnectProxy marks a proxy as disconnected only if the session ID matches.
|
||||
// This prevents a slow-to-close old session from overwriting a newer reconnection.
|
||||
func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
|
||||
now := time.Now()
|
||||
result := s.db.
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("id = ? AND session_id = ?", proxyID, sessionID).
|
||||
Updates(map[string]any{
|
||||
"status": "disconnected",
|
||||
"disconnected_at": now,
|
||||
"last_seen": now,
|
||||
})
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, result.Error)
|
||||
return status.Errorf(status.Internal, "failed to disconnect proxy")
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
log.WithContext(ctx).Debugf("proxy %s session %s: no row updated (superseded by newer session)", proxyID, sessionID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateProxyHeartbeat updates the last_seen timestamp for the proxy's current session.
|
||||
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
now := time.Now()
|
||||
|
||||
result := s.db.
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("id = ? AND status = ?", proxyID, "connected").
|
||||
Where("id = ? AND session_id = ?", p.ID, p.SessionID).
|
||||
Update("last_seen", now)
|
||||
|
||||
if result.Error != nil {
|
||||
@@ -5452,17 +5474,11 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
p := &proxy.Proxy{
|
||||
ID: proxyID,
|
||||
ClusterAddress: clusterAddress,
|
||||
IPAddress: ipAddress,
|
||||
LastSeen: now,
|
||||
ConnectedAt: &now,
|
||||
Status: "connected",
|
||||
}
|
||||
if err := s.db.Save(p).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to create proxy on heartbeat: %v", err)
|
||||
return status.Errorf(status.Internal, "failed to create proxy on heartbeat")
|
||||
p.LastSeen = now
|
||||
p.ConnectedAt = &now
|
||||
p.Status = "connected"
|
||||
if err := s.db.Create(p).Error; err != nil {
|
||||
log.WithContext(ctx).Debugf("proxy %s session %s: heartbeat fallback insert skipped: %v", p.ID, p.SessionID, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -284,7 +284,8 @@ type Store interface {
|
||||
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
|
||||
|
||||
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
|
||||
@@ -178,6 +178,7 @@ func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr int
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockStore) Close(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2799,6 +2800,20 @@ func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy)
|
||||
}
|
||||
|
||||
// DisconnectProxy mocks base method.
|
||||
func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID, sessionID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DisconnectProxy indicates an expected call of DisconnectProxy.
|
||||
func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID, sessionID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID, sessionID)
|
||||
}
|
||||
|
||||
// SaveProxyAccessToken mocks base method.
|
||||
func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2995,17 +3010,17 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{}
|
||||
}
|
||||
|
||||
// UpdateProxyHeartbeat mocks base method.
|
||||
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID, clusterAddress, ipAddress)
|
||||
ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, p)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat.
|
||||
func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
|
||||
func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, p interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID, clusterAddress, ipAddress)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, p)
|
||||
}
|
||||
|
||||
// UpdateService mocks base method.
|
||||
|
||||
@@ -144,8 +144,11 @@ func TestValidateInviteToken_ModifiedToken(t *testing.T) {
|
||||
_, plainToken, err := GenerateInviteToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Modify one character in the secret part
|
||||
modifiedToken := plainToken[:5] + "X" + plainToken[6:]
|
||||
replacement := "X"
|
||||
if plainToken[5] == 'X' {
|
||||
replacement = "Y"
|
||||
}
|
||||
modifiedToken := plainToken[:5] + replacement + plainToken[6:]
|
||||
err = ValidateInviteToken(modifiedToken)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -201,15 +201,15 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string,
|
||||
// testProxyManager is a mock implementation of proxy.Manager for testing.
|
||||
type testProxyManager struct{}
|
||||
|
||||
func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *nbproxy.Capabilities) error {
|
||||
func (m *testProxyManager) Connect(_ context.Context, proxyID, sessionID, _, _ string, _ *nbproxy.Capabilities) (*nbproxy.Proxy, error) {
|
||||
return &nbproxy.Proxy{ID: proxyID, SessionID: sessionID, Status: "connected"}, nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) Disconnect(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) Disconnect(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) Heartbeat(_ context.Context, _, _, _ string) error {
|
||||
func (m *testProxyManager) Heartbeat(_ context.Context, _ *nbproxy.Proxy) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -364,14 +364,16 @@ func TestIntegration_ProxyConnection_HappyPath(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Receive all mappings from the snapshot - server sends each mapping individually
|
||||
mappingsByID := make(map[string]*proto.ProxyMapping)
|
||||
for i := 0; i < 2; i++ {
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
for _, m := range msg.GetMapping() {
|
||||
mappingsByID[m.GetId()] = m
|
||||
}
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Should receive 2 mappings total
|
||||
@@ -411,12 +413,14 @@ func TestIntegration_ProxyConnection_SendsClusterAddress(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Receive all mappings - server sends each mapping individually
|
||||
mappings := make([]*proto.ProxyMapping, 0)
|
||||
for i := 0; i < 2; i++ {
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
mappings = append(mappings, msg.GetMapping()...)
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Should receive the 2 mappings matching the cluster
|
||||
@@ -440,13 +444,15 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
|
||||
clusterAddress := "test.proxy.io"
|
||||
proxyID := "test-proxy-reconnect"
|
||||
|
||||
// Helper to receive all mappings from a stream
|
||||
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient, count int) []*proto.ProxyMapping {
|
||||
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping {
|
||||
var mappings []*proto.ProxyMapping
|
||||
for i := 0; i < count; i++ {
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
mappings = append(mappings, msg.GetMapping()...)
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
return mappings
|
||||
}
|
||||
@@ -460,7 +466,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
firstMappings := receiveMappings(stream1, 2)
|
||||
firstMappings := receiveMappings(stream1)
|
||||
cancel1()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
@@ -476,7 +482,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
secondMappings := receiveMappings(stream2, 2)
|
||||
secondMappings := receiveMappings(stream2)
|
||||
|
||||
// Should receive the same mappings
|
||||
assert.Equal(t, len(firstMappings), len(secondMappings),
|
||||
@@ -542,12 +548,14 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to receive and apply all mappings
|
||||
receiveAndApply := func(stream proto.ProxyService_GetMappingUpdateClient) {
|
||||
for i := 0; i < 2; i++ {
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
applyMappings(msg.GetMapping())
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -636,12 +644,14 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Receive all mappings - server sends each mapping individually
|
||||
count := 0
|
||||
for i := 0; i < 2; i++ {
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
count += len(msg.GetMapping())
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
@@ -656,3 +666,78 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T)
|
||||
assert.Equal(t, 2, count, "Proxy %s should receive 2 mappings", proxyID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState verifies that
|
||||
// when a proxy reconnects before the old stream's cleanup runs, the new
|
||||
// connection is NOT removed by the stale defer.
|
||||
func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T) {
|
||||
setup := setupIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
clusterAddress := "test.proxy.io"
|
||||
proxyID := "test-proxy-race"
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
ctx1, cancel1 := context.WithCancel(context.Background())
|
||||
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: proxyID,
|
||||
Version: "test-v1",
|
||||
Address: clusterAddress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
for {
|
||||
msg, err := stream1.Recv()
|
||||
require.NoError(t, err)
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
|
||||
"proxy should be registered after first connection")
|
||||
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel2()
|
||||
|
||||
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: proxyID,
|
||||
Version: "test-v1",
|
||||
Address: clusterAddress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
for {
|
||||
msg, err := stream2.Recv()
|
||||
require.NoError(t, err)
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
cancel1()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
assert.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
|
||||
"proxy should still be registered after old connection cleanup — old defer must not remove new connection")
|
||||
|
||||
setup.proxyService.SendServiceUpdate(&proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
|
||||
Id: "rp-1",
|
||||
AccountId: "test-account-1",
|
||||
Domain: "app1.test.proxy.io",
|
||||
}},
|
||||
})
|
||||
|
||||
msg, err := stream2.Recv()
|
||||
require.NoError(t, err, "new stream should still receive updates")
|
||||
require.NotEmpty(t, msg.GetMapping(), "update should contain the mapping")
|
||||
assert.Equal(t, "rp-1", msg.GetMapping()[0].GetId())
|
||||
}
|
||||
|
||||
@@ -943,6 +943,8 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
|
||||
operation := func() error {
|
||||
s.Logger.Debug("connecting to management mapping stream")
|
||||
|
||||
initialSyncDone = false
|
||||
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(false)
|
||||
}
|
||||
@@ -1000,6 +1002,11 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var snapshotIDs map[types.ServiceID]struct{}
|
||||
if !*initialSyncDone {
|
||||
snapshotIDs = make(map[types.ServiceID]struct{})
|
||||
}
|
||||
|
||||
for {
|
||||
// Check for context completion to gracefully shutdown.
|
||||
select {
|
||||
@@ -1020,17 +1027,45 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
||||
s.processMappings(ctx, msg.GetMapping())
|
||||
s.Logger.Debug("Processing mapping update completed")
|
||||
|
||||
if !*initialSyncDone && msg.GetInitialSyncComplete() {
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetInitialSyncComplete()
|
||||
if !*initialSyncDone {
|
||||
for _, m := range msg.GetMapping() {
|
||||
snapshotIDs[types.ServiceID(m.GetId())] = struct{}{}
|
||||
}
|
||||
if msg.GetInitialSyncComplete() {
|
||||
s.reconcileSnapshot(ctx, snapshotIDs)
|
||||
snapshotIDs = nil
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetInitialSyncComplete()
|
||||
}
|
||||
*initialSyncDone = true
|
||||
s.Logger.Info("Initial mapping sync complete")
|
||||
}
|
||||
*initialSyncDone = true
|
||||
s.Logger.Info("Initial mapping sync complete")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reconcileSnapshot removes local mappings that are absent from the snapshot.
|
||||
// This ensures services deleted while the proxy was disconnected get cleaned up.
|
||||
func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.ServiceID]struct{}) {
|
||||
s.portMu.RLock()
|
||||
var stale []*proto.ProxyMapping
|
||||
for svcID, mapping := range s.lastMappings {
|
||||
if _, ok := snapshotIDs[svcID]; !ok {
|
||||
stale = append(stale, mapping)
|
||||
}
|
||||
}
|
||||
s.portMu.RUnlock()
|
||||
|
||||
for _, mapping := range stale {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"service_id": mapping.GetId(),
|
||||
"domain": mapping.GetDomain(),
|
||||
}).Info("Removing stale mapping absent from snapshot")
|
||||
s.removeMapping(ctx, mapping)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
|
||||
for _, mapping := range mappings {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
|
||||
227
proxy/snapshot_reconcile_test.go
Normal file
227
proxy/snapshot_reconcile_test.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/health"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// collectStaleIDs mirrors the stale-detection logic in reconcileSnapshot
|
||||
// so we can verify it without triggering removeMapping (which requires full
|
||||
// server wiring). This keeps the test focused on the detection algorithm.
|
||||
func collectStaleIDs(lastMappings map[types.ServiceID]*proto.ProxyMapping, snapshotIDs map[types.ServiceID]struct{}) []types.ServiceID {
|
||||
var stale []types.ServiceID
|
||||
for svcID := range lastMappings {
|
||||
if _, ok := snapshotIDs[svcID]; !ok {
|
||||
stale = append(stale, svcID)
|
||||
}
|
||||
}
|
||||
return stale
|
||||
}
|
||||
|
||||
// TestStaleDetection_PartialOverlap verifies that only services absent from
|
||||
// the snapshot are flagged as stale.
|
||||
func TestStaleDetection_PartialOverlap(t *testing.T) {
|
||||
local := map[types.ServiceID]*proto.ProxyMapping{
|
||||
"svc-1": {Id: "svc-1"},
|
||||
"svc-2": {Id: "svc-2"},
|
||||
"svc-stale-a": {Id: "svc-stale-a"},
|
||||
"svc-stale-b": {Id: "svc-stale-b"},
|
||||
}
|
||||
snapshot := map[types.ServiceID]struct{}{
|
||||
"svc-1": {},
|
||||
"svc-2": {},
|
||||
"svc-3": {}, // new service, not in local
|
||||
}
|
||||
|
||||
stale := collectStaleIDs(local, snapshot)
|
||||
assert.Len(t, stale, 2)
|
||||
staleSet := make(map[types.ServiceID]struct{})
|
||||
for _, id := range stale {
|
||||
staleSet[id] = struct{}{}
|
||||
}
|
||||
assert.Contains(t, staleSet, types.ServiceID("svc-stale-a"))
|
||||
assert.Contains(t, staleSet, types.ServiceID("svc-stale-b"))
|
||||
}
|
||||
|
||||
// TestStaleDetection_AllStale verifies an empty snapshot flags everything.
|
||||
func TestStaleDetection_AllStale(t *testing.T) {
|
||||
local := map[types.ServiceID]*proto.ProxyMapping{
|
||||
"svc-1": {Id: "svc-1"},
|
||||
"svc-2": {Id: "svc-2"},
|
||||
}
|
||||
stale := collectStaleIDs(local, map[types.ServiceID]struct{}{})
|
||||
assert.Len(t, stale, 2)
|
||||
}
|
||||
|
||||
// TestStaleDetection_NoneStale verifies full overlap produces no stale entries.
|
||||
func TestStaleDetection_NoneStale(t *testing.T) {
|
||||
local := map[types.ServiceID]*proto.ProxyMapping{
|
||||
"svc-1": {Id: "svc-1"},
|
||||
"svc-2": {Id: "svc-2"},
|
||||
}
|
||||
snapshot := map[types.ServiceID]struct{}{
|
||||
"svc-1": {},
|
||||
"svc-2": {},
|
||||
}
|
||||
stale := collectStaleIDs(local, snapshot)
|
||||
assert.Empty(t, stale)
|
||||
}
|
||||
|
||||
// TestStaleDetection_EmptyLocal verifies no stale entries when local is empty.
|
||||
func TestStaleDetection_EmptyLocal(t *testing.T) {
|
||||
stale := collectStaleIDs(
|
||||
map[types.ServiceID]*proto.ProxyMapping{},
|
||||
map[types.ServiceID]struct{}{"svc-1": {}},
|
||||
)
|
||||
assert.Empty(t, stale)
|
||||
}
|
||||
|
||||
// TestReconcileSnapshot_NoStale verifies reconciliation is a no-op when all
|
||||
// local mappings are present in the snapshot (removeMapping is never called).
|
||||
func TestReconcileSnapshot_NoStale(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1"}
|
||||
s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2"}
|
||||
|
||||
snapshotIDs := map[types.ServiceID]struct{}{
|
||||
"svc-1": {},
|
||||
"svc-2": {},
|
||||
}
|
||||
// This should not panic — no stale entries means removeMapping is never called.
|
||||
s.reconcileSnapshot(context.Background(), snapshotIDs)
|
||||
|
||||
assert.Len(t, s.lastMappings, 2, "no mappings should be removed when all are in snapshot")
|
||||
}
|
||||
|
||||
// TestReconcileSnapshot_EmptyLocal verifies reconciliation is a no-op with
|
||||
// no local mappings.
|
||||
func TestReconcileSnapshot_EmptyLocal(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
s.reconcileSnapshot(context.Background(), map[types.ServiceID]struct{}{"svc-1": {}})
|
||||
assert.Empty(t, s.lastMappings)
|
||||
}
|
||||
|
||||
// --- handleMappingStream tests for batched snapshot ID accumulation ---
|
||||
|
||||
// TestHandleMappingStream_BatchedSnapshotSyncComplete verifies that sync is
|
||||
// marked done only after the final InitialSyncComplete message, even when
|
||||
// the snapshot arrives in multiple batches.
|
||||
func TestHandleMappingStream_BatchedSnapshotSyncComplete(t *testing.T) {
|
||||
checker := health.NewChecker(nil, nil)
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
healthChecker: checker,
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
stream := &mockMappingStream{
|
||||
messages: []*proto.GetMappingUpdateResponse{
|
||||
{}, // batch 1: no sync-complete
|
||||
{}, // batch 2: no sync-complete
|
||||
{InitialSyncComplete: true}, // batch 3: sync done
|
||||
},
|
||||
}
|
||||
|
||||
syncDone := false
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, syncDone, "sync should be marked done after final batch")
|
||||
}
|
||||
|
||||
// TestHandleMappingStream_PostSyncDoesNotReconcile verifies that messages
|
||||
// arriving after InitialSyncComplete do not trigger a second reconciliation.
|
||||
func TestHandleMappingStream_PostSyncDoesNotReconcile(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
// Simulate state left over from a previous sync.
|
||||
s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1", AccountId: "acct-1"}
|
||||
s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2", AccountId: "acct-1"}
|
||||
|
||||
stream := &mockMappingStream{
|
||||
messages: []*proto.GetMappingUpdateResponse{
|
||||
{}, // post-sync empty message — must not reconcile
|
||||
},
|
||||
}
|
||||
|
||||
syncDone := true // sync already completed in a previous stream
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, s.lastMappings, 2,
|
||||
"post-sync messages must not trigger reconciliation — all entries should survive")
|
||||
}
|
||||
|
||||
// TestHandleMappingStream_ImmediateEOF_NoReconciliation verifies that if the
|
||||
// stream closes before sync completes, no reconciliation occurs.
|
||||
func TestHandleMappingStream_ImmediateEOF_NoReconciliation(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"}
|
||||
|
||||
stream := &mockMappingStream{} // no messages → immediate EOF
|
||||
|
||||
syncDone := false
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, syncDone, "sync should not be marked done on immediate EOF")
|
||||
|
||||
_, hasStale := s.lastMappings["svc-stale"]
|
||||
assert.True(t, hasStale, "stale mapping should remain when sync never completed")
|
||||
}
|
||||
|
||||
// mockErrRecvStream returns an error on the second Recv to verify
|
||||
// handleMappingStream returns without completing sync.
|
||||
type mockErrRecvStream struct {
|
||||
mockMappingStream
|
||||
calls int
|
||||
}
|
||||
|
||||
func (m *mockErrRecvStream) Recv() (*proto.GetMappingUpdateResponse, error) {
|
||||
m.calls++
|
||||
if m.calls == 1 {
|
||||
return &proto.GetMappingUpdateResponse{}, nil
|
||||
}
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
func TestHandleMappingStream_ErrorMidSync_NoReconciliation(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"}
|
||||
|
||||
syncDone := false
|
||||
err := s.handleMappingStream(context.Background(), &mockErrRecvStream{}, &syncDone)
|
||||
assert.Error(t, err)
|
||||
assert.False(t, syncDone)
|
||||
|
||||
_, hasStale := s.lastMappings["svc-stale"]
|
||||
assert.True(t, hasStale, "stale mapping should remain when sync was interrupted by error")
|
||||
}
|
||||
@@ -246,27 +246,23 @@ func (c *GrpcClient) handleJobStream(
|
||||
for {
|
||||
jobReq, err := c.receiveJobRequest(ctx, stream, serverPubKey)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
log.Debugf("job stream context has been canceled, this usually indicates shutdown")
|
||||
return nil
|
||||
}
|
||||
if s, ok := gstatus.FromError(err); ok {
|
||||
switch s.Code() {
|
||||
case codes.PermissionDenied:
|
||||
c.notifyDisconnected(err)
|
||||
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
|
||||
case codes.Canceled:
|
||||
log.Debugf("job stream context has been canceled, this usually indicates shutdown")
|
||||
return err
|
||||
case codes.Unimplemented:
|
||||
log.Warn("Job feature is not supported by the current management server version. " +
|
||||
"Please update the management service to use this feature.")
|
||||
return nil
|
||||
default:
|
||||
log.Warnf("job stream disconnected, will retry silently. Reason: %v", err)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// non-gRPC error
|
||||
log.Warnf("job stream disconnected, will retry silently. Reason: %v", err)
|
||||
return err
|
||||
}
|
||||
log.Warnf("job stream disconnected, will retry silently. Reason: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if jobReq == nil || len(jobReq.ID) == 0 {
|
||||
@@ -381,22 +377,15 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes.
|
||||
err = c.receiveUpdatesEvents(stream, serverPubKey, msgHandler)
|
||||
if err != nil {
|
||||
c.notifyDisconnected(err)
|
||||
if s, ok := gstatus.FromError(err); ok {
|
||||
switch s.Code() {
|
||||
case codes.PermissionDenied:
|
||||
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
|
||||
case codes.Canceled:
|
||||
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
|
||||
return nil
|
||||
default:
|
||||
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// non-gRPC error
|
||||
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
|
||||
return err
|
||||
if ctx.Err() != nil {
|
||||
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
|
||||
return nil
|
||||
}
|
||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
|
||||
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
|
||||
}
|
||||
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -2,8 +2,12 @@ package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -146,6 +150,7 @@ func (cc *connContainer) close() {
|
||||
type Client struct {
|
||||
log *log.Entry
|
||||
connectionURL string
|
||||
serverIP netip.Addr
|
||||
authTokenStore *auth.TokenStore
|
||||
hashedID messages.PeerID
|
||||
|
||||
@@ -170,13 +175,22 @@ type Client struct {
|
||||
}
|
||||
|
||||
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
||||
// is called.
|
||||
func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string, mtu uint16) *Client {
|
||||
return NewClientWithServerIP(serverURL, netip.Addr{}, authTokenStore, peerID, mtu)
|
||||
}
|
||||
|
||||
// NewClientWithServerIP creates a new client for the relay server with a known server IP. serverIP, when valid, is
|
||||
// dialed directly first; the FQDN is only attempted if the IP-based dial fails. TLS verification still uses the
|
||||
// FQDN from serverURL via SNI.
|
||||
func NewClientWithServerIP(serverURL string, serverIP netip.Addr, authTokenStore *auth.TokenStore, peerID string, mtu uint16) *Client {
|
||||
hashedID := messages.HashID(peerID)
|
||||
relayLog := log.WithFields(log.Fields{"relay": serverURL})
|
||||
|
||||
c := &Client{
|
||||
log: relayLog,
|
||||
connectionURL: serverURL,
|
||||
serverIP: serverIP,
|
||||
authTokenStore: authTokenStore,
|
||||
hashedID: hashedID,
|
||||
mtu: mtu,
|
||||
@@ -304,6 +318,23 @@ func (c *Client) ServerInstanceURL() (string, error) {
|
||||
return c.instanceURL.String(), nil
|
||||
}
|
||||
|
||||
// ConnectedIP returns the IP address of the live relay-server connection,
|
||||
// extracted from the underlying socket's RemoteAddr. Zero value if not
|
||||
// connected or if the address is not an IP literal.
|
||||
func (c *Client) ConnectedIP() netip.Addr {
|
||||
c.mu.Lock()
|
||||
conn := c.relayConn
|
||||
c.mu.Unlock()
|
||||
if conn == nil {
|
||||
return netip.Addr{}
|
||||
}
|
||||
addr := conn.RemoteAddr()
|
||||
if addr == nil {
|
||||
return netip.Addr{}
|
||||
}
|
||||
return extractIPLiteral(addr.String())
|
||||
}
|
||||
|
||||
// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
|
||||
func (c *Client) SetOnDisconnectListener(fn func(string)) {
|
||||
c.listenerMutex.Lock()
|
||||
@@ -332,10 +363,23 @@ func (c *Client) Close() error {
|
||||
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
dialers := c.getDialers()
|
||||
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
|
||||
conn, err := rd.Dial(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
var conn net.Conn
|
||||
if c.serverIP.IsValid() {
|
||||
var err error
|
||||
conn, err = c.dialRaceDirect(ctx, dialers)
|
||||
if err != nil {
|
||||
c.log.Infof("dial via server IP %s failed, falling back to FQDN: %v", c.serverIP, err)
|
||||
conn = nil
|
||||
}
|
||||
}
|
||||
|
||||
if conn == nil {
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
|
||||
var err error
|
||||
conn, err = rd.Dial(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial via FQDN: %w", err)
|
||||
}
|
||||
}
|
||||
c.relayConn = conn
|
||||
|
||||
@@ -351,6 +395,52 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
return instanceURL, nil
|
||||
}
|
||||
|
||||
// dialRaceDirect dials c.serverIP, preserving the original FQDN as the TLS ServerName for SNI.
|
||||
func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (net.Conn, error) {
|
||||
directURL, serverName, err := substituteHost(c.connectionURL, c.serverIP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("substitute host: %w", err)
|
||||
}
|
||||
|
||||
c.log.Debugf("dialing via server IP %s (SNI=%s)", c.serverIP, serverName)
|
||||
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, directURL, dialers...).
|
||||
WithServerName(serverName)
|
||||
return rd.Dial(ctx)
|
||||
}
|
||||
|
||||
// substituteHost replaces the host portion of a rel/rels URL with ip,
|
||||
// preserving the scheme and port. Returns the rewritten URL and the
|
||||
// original host to use as the TLS ServerName, or empty if the original
|
||||
// host is itself an IP literal (SNI requires a DNS name).
|
||||
func substituteHost(serverURL string, ip netip.Addr) (string, string, error) {
|
||||
u, err := url.Parse(serverURL)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("parse %q: %w", serverURL, err)
|
||||
}
|
||||
if u.Scheme == "" || u.Host == "" {
|
||||
return "", "", fmt.Errorf("invalid relay URL %q", serverURL)
|
||||
}
|
||||
if !ip.IsValid() {
|
||||
return "", "", errors.New("invalid server IP")
|
||||
}
|
||||
origHost := u.Hostname()
|
||||
if _, err := netip.ParseAddr(origHost); err == nil {
|
||||
origHost = ""
|
||||
}
|
||||
ip = ip.Unmap()
|
||||
newHost := ip.String()
|
||||
if ip.Is6() {
|
||||
newHost = "[" + newHost + "]"
|
||||
}
|
||||
if port := u.Port(); port != "" {
|
||||
u.Host = newHost + ":" + port
|
||||
} else {
|
||||
u.Host = newHost
|
||||
}
|
||||
return u.String(), origHost, nil
|
||||
}
|
||||
|
||||
func (c *Client) handShake(ctx context.Context) (*RelayAddr, error) {
|
||||
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
|
||||
if err != nil {
|
||||
@@ -716,3 +806,21 @@ func (c *Client) handlePeersWentOfflineMsg(buf []byte) {
|
||||
}
|
||||
c.stateSubscription.OnPeersWentOffline(peersID)
|
||||
}
|
||||
|
||||
// extractIPLiteral returns the IP from address forms produced by the relay
|
||||
// dialers (URL or host:port). Zero value if the host is not an IP.
|
||||
func extractIPLiteral(s string) netip.Addr {
|
||||
if u, err := url.Parse(s); err == nil && u.Host != "" {
|
||||
s = u.Host
|
||||
}
|
||||
host, _, err := net.SplitHostPort(s)
|
||||
if err != nil {
|
||||
host = s
|
||||
}
|
||||
host = strings.Trim(host, "[]")
|
||||
ip, err := netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
return netip.Addr{}
|
||||
}
|
||||
return ip.Unmap()
|
||||
}
|
||||
|
||||
280
shared/relay/client/client_serverip_test.go
Normal file
280
shared/relay/client/client_serverip_test.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/relay/server"
|
||||
"github.com/netbirdio/netbird/shared/relay/auth/allow"
|
||||
)
|
||||
|
||||
// TestClient_ServerIPRecoversFromUnresolvableFQDN verifies that when the
|
||||
// primary FQDN-based dial fails (unresolvable .invalid host), Connect
|
||||
// recovers via the server IP and SNI still uses the FQDN.
|
||||
func TestClient_ServerIPRecoversFromUnresolvableFQDN(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
listenAddr, port := freeAddr(t)
|
||||
srvCfg := server.Config{
|
||||
Meter: otel.Meter(""),
|
||||
ExposedAddress: fmt.Sprintf("rel://test-unresolvable-host.invalid:%d", port),
|
||||
TLSSupport: false,
|
||||
AuthValidator: &allow.Auth{},
|
||||
}
|
||||
srv, err := server.NewServer(srvCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("create server: %s", err)
|
||||
}
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
if err := srv.Listen(server.ListenerConfig{Address: listenAddr}); err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
if err := srv.Shutdown(context.Background()); err != nil {
|
||||
t.Errorf("shutdown server: %s", err)
|
||||
}
|
||||
})
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("server failed to start: %s", err)
|
||||
}
|
||||
|
||||
t.Run("no server IP, primary fails", func(t *testing.T) {
|
||||
c := NewClient(srvCfg.ExposedAddress, hmacTokenStore, "alice-noip", iface.DefaultMTU)
|
||||
err := c.Connect(ctx)
|
||||
if err == nil {
|
||||
_ = c.Close()
|
||||
t.Fatalf("expected connect to fail without server IP, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("server IP recovers", func(t *testing.T) {
|
||||
c := NewClientWithServerIP(srvCfg.ExposedAddress, netip.MustParseAddr("127.0.0.1"), hmacTokenStore, "alice-with-ip", iface.DefaultMTU)
|
||||
if err := c.Connect(ctx); err != nil {
|
||||
t.Fatalf("connect with server IP: %s", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = c.Close() })
|
||||
|
||||
if !c.Ready() {
|
||||
t.Fatalf("client not ready after connect")
|
||||
}
|
||||
if got := c.ConnectedIP(); got.String() != "127.0.0.1" {
|
||||
t.Fatalf("ConnectedIP = %q, want 127.0.0.1", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestClient_ConnectedIPAfterFQDNDial verifies ConnectedIP returns the
|
||||
// resolved IP after a successful FQDN-based dial. The underlying socket's
|
||||
// RemoteAddr must be exposed through the dialer wrappers; if it returns
|
||||
// the dial-time URL instead, ConnectedIP returns empty and the dial
|
||||
// IP we advertise to peers is empty too.
|
||||
func TestClient_ConnectedIPAfterFQDNDial(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
listenAddr, port := freeAddr(t)
|
||||
srvCfg := server.Config{
|
||||
Meter: otel.Meter(""),
|
||||
ExposedAddress: fmt.Sprintf("rel://localhost:%d", port),
|
||||
TLSSupport: false,
|
||||
AuthValidator: &allow.Auth{},
|
||||
}
|
||||
srv, err := server.NewServer(srvCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
if err := srv.Listen(server.ListenerConfig{Address: listenAddr}); err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
t.Cleanup(func() { _ = srv.Shutdown(context.Background()) })
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("server failed to start: %s", err)
|
||||
}
|
||||
|
||||
c := NewClient(srvCfg.ExposedAddress, hmacTokenStore, "alice-fqdn", iface.DefaultMTU)
|
||||
if err := c.Connect(ctx); err != nil {
|
||||
t.Fatalf("connect: %s", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = c.Close() })
|
||||
|
||||
got := c.ConnectedIP().String()
|
||||
if got != "127.0.0.1" && got != "::1" {
|
||||
t.Fatalf("ConnectedIP after FQDN dial = %q, want 127.0.0.1 or ::1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubstituteHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverURL string
|
||||
ip string
|
||||
wantURL string
|
||||
wantServerName string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "rels with port",
|
||||
serverURL: "rels://relay.netbird.io:443",
|
||||
ip: "10.0.0.5",
|
||||
wantURL: "rels://10.0.0.5:443",
|
||||
wantServerName: "relay.netbird.io",
|
||||
},
|
||||
{
|
||||
name: "rel with port",
|
||||
serverURL: "rel://relay.example.com:80",
|
||||
ip: "192.0.2.1",
|
||||
wantURL: "rel://192.0.2.1:80",
|
||||
wantServerName: "relay.example.com",
|
||||
},
|
||||
{
|
||||
name: "ipv6 server IP bracketed",
|
||||
serverURL: "rels://relay.example.com:443",
|
||||
ip: "2001:db8::1",
|
||||
wantURL: "rels://[2001:db8::1]:443",
|
||||
wantServerName: "relay.example.com",
|
||||
},
|
||||
{
|
||||
name: "no port",
|
||||
serverURL: "rels://relay.example.com",
|
||||
ip: "10.0.0.5",
|
||||
wantURL: "rels://10.0.0.5",
|
||||
wantServerName: "relay.example.com",
|
||||
},
|
||||
{
|
||||
name: "ipv6 server with port returns empty SNI",
|
||||
serverURL: "rels://[2001:db8::5]:443",
|
||||
ip: "10.0.0.5",
|
||||
wantURL: "rels://10.0.0.5:443",
|
||||
wantServerName: "",
|
||||
},
|
||||
{
|
||||
name: "ipv4 server with port returns empty SNI",
|
||||
serverURL: "rels://10.0.0.5:443",
|
||||
ip: "10.0.0.6",
|
||||
wantURL: "rels://10.0.0.6:443",
|
||||
wantServerName: "",
|
||||
},
|
||||
{
|
||||
name: "ipv6 server IP no port",
|
||||
serverURL: "rels://relay.example.com",
|
||||
ip: "2001:db8::1",
|
||||
wantURL: "rels://[2001:db8::1]",
|
||||
wantServerName: "relay.example.com",
|
||||
},
|
||||
{
|
||||
name: "missing scheme",
|
||||
serverURL: "relay.example.com:443",
|
||||
ip: "10.0.0.5",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
serverURL: "",
|
||||
ip: "10.0.0.5",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var ip netip.Addr
|
||||
if tt.ip != "" {
|
||||
ip = netip.MustParseAddr(tt.ip)
|
||||
}
|
||||
gotURL, gotName, err := substituteHost(tt.serverURL, ip)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
if gotURL != tt.wantURL {
|
||||
t.Errorf("URL = %q, want %q", gotURL, tt.wantURL)
|
||||
}
|
||||
if gotName != tt.wantServerName {
|
||||
t.Errorf("ServerName = %q, want %q", gotName, tt.wantServerName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_ConnectedIPEmptyWhenNotConnected(t *testing.T) {
|
||||
c := NewClient("rel://example.invalid:80", hmacTokenStore, "x", iface.DefaultMTU)
|
||||
if got := c.ConnectedIP(); got.IsValid() {
|
||||
t.Fatalf("ConnectedIP on disconnected client = %q, want zero", got)
|
||||
}
|
||||
}
|
||||
|
||||
// staticAddr is a net.Addr that returns a fixed string. Used to verify
|
||||
// ConnectedIP parses RemoteAddr correctly.
|
||||
type staticAddr struct{ s string }
|
||||
|
||||
func (a staticAddr) Network() string { return "tcp" }
|
||||
func (a staticAddr) String() string { return a.s }
|
||||
|
||||
type stubConn struct {
|
||||
net.Conn
|
||||
remote net.Addr
|
||||
}
|
||||
|
||||
func (s stubConn) RemoteAddr() net.Addr { return s.remote }
|
||||
|
||||
func TestClient_ConnectedIPParsesRemoteAddr(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s string
|
||||
want string
|
||||
}{
|
||||
{"hostport ipv4", "127.0.0.1:50301", "127.0.0.1"},
|
||||
{"hostport ipv6 bracketed", "[::1]:50301", "::1"},
|
||||
{"url with ipv4", "rel://127.0.0.1:50301", "127.0.0.1"},
|
||||
{"url with ipv6", "rels://[2001:db8::1]:443", "2001:db8::1"},
|
||||
{"fqdn url returns empty", "rel://relay.example.com:50301", ""},
|
||||
{"fqdn hostport returns empty", "relay.example.com:50301", ""},
|
||||
{"plain ipv4 no port", "10.0.0.1", "10.0.0.1"},
|
||||
{"empty", "", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Client{relayConn: stubConn{remote: staticAddr{s: tt.s}}}
|
||||
got := c.ConnectedIP()
|
||||
var gotStr string
|
||||
if got.IsValid() {
|
||||
gotStr = got.String()
|
||||
}
|
||||
if gotStr != tt.want {
|
||||
t.Errorf("ConnectedIP(%q) = %q, want %q", tt.s, gotStr, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// freeAddr returns a 127.0.0.1 address with an OS-assigned port. The
|
||||
// listener is closed before returning, so the port is briefly free for
|
||||
// the caller to bind. Avoids hardcoded ports that can collide.
|
||||
func freeAddr(t *testing.T) (string, int) {
|
||||
t.Helper()
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("get free port: %s", err)
|
||||
}
|
||||
addr := l.Addr().(*net.TCPAddr)
|
||||
_ = l.Close()
|
||||
return addr.String(), addr.Port
|
||||
}
|
||||
@@ -23,7 +23,7 @@ func (d Dialer) Protocol() string {
|
||||
return Network
|
||||
}
|
||||
|
||||
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) {
|
||||
quicURL, err := prepareURL(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -32,11 +32,14 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
// Get the base TLS config
|
||||
tlsClientConfig := quictls.ClientQUICTLSConfig()
|
||||
|
||||
// Set ServerName to hostname if not an IP address
|
||||
host, _, splitErr := net.SplitHostPort(quicURL)
|
||||
if splitErr == nil && net.ParseIP(host) == nil {
|
||||
// It's a hostname, not an IP - modify directly
|
||||
tlsClientConfig.ServerName = host
|
||||
switch {
|
||||
case serverName != "" && net.ParseIP(serverName) == nil:
|
||||
tlsClientConfig.ServerName = serverName
|
||||
default:
|
||||
host, _, splitErr := net.SplitHostPort(quicURL)
|
||||
if splitErr == nil && net.ParseIP(host) == nil {
|
||||
tlsClientConfig.ServerName = host
|
||||
}
|
||||
}
|
||||
|
||||
quicConfig := &quic.Config{
|
||||
|
||||
@@ -14,7 +14,9 @@ const (
|
||||
)
|
||||
|
||||
type DialeFn interface {
|
||||
Dial(ctx context.Context, address string) (net.Conn, error)
|
||||
// Dial connects to address. serverName, when non-empty, overrides the TLS
|
||||
// ServerName used for SNI/cert validation. Empty means derive from address.
|
||||
Dial(ctx context.Context, address, serverName string) (net.Conn, error)
|
||||
Protocol() string
|
||||
}
|
||||
|
||||
@@ -27,6 +29,7 @@ type dialResult struct {
|
||||
type RaceDial struct {
|
||||
log *log.Entry
|
||||
serverURL string
|
||||
serverName string
|
||||
dialerFns []DialeFn
|
||||
connectionTimeout time.Duration
|
||||
}
|
||||
@@ -40,6 +43,16 @@ func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL stri
|
||||
}
|
||||
}
|
||||
|
||||
// WithServerName sets a TLS SNI/cert validation override. Used when serverURL
|
||||
// contains an IP literal but the cert is issued for a different hostname.
|
||||
//
|
||||
// Mutates the receiver and is not safe for concurrent reconfiguration; a
|
||||
// RaceDial is intended to be constructed per dial and discarded.
|
||||
func (r *RaceDial) WithServerName(serverName string) *RaceDial {
|
||||
r.serverName = serverName
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
|
||||
connChan := make(chan dialResult, len(r.dialerFns))
|
||||
winnerConn := make(chan net.Conn, 1)
|
||||
@@ -64,7 +77,7 @@ func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dia
|
||||
defer cancel()
|
||||
|
||||
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
|
||||
conn, err := dfn.Dial(ctx, r.serverURL)
|
||||
conn, err := dfn.Dial(ctx, r.serverURL, r.serverName)
|
||||
connChan <- dialResult{Conn: conn, Protocol: dfn.Protocol(), Err: err}
|
||||
}
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ type MockDialer struct {
|
||||
protocolStr string
|
||||
}
|
||||
|
||||
func (m *MockDialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
func (m *MockDialer) Dial(ctx context.Context, address, _ string) (net.Conn, error) {
|
||||
return m.dialFunc(ctx, address)
|
||||
}
|
||||
|
||||
|
||||
@@ -12,14 +12,24 @@ import (
|
||||
type Conn struct {
|
||||
ctx context.Context
|
||||
*websocket.Conn
|
||||
remoteAddr WebsocketAddr
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn {
|
||||
// NewConn builds a relay ws.Conn. underlying is the raw TCP/TLS conn captured
|
||||
// from the http transport's DialContext; when set, RemoteAddr returns its
|
||||
// peer address (an IP literal). When nil (e.g. wasm), RemoteAddr falls back
|
||||
// to the dial-time URL.
|
||||
func NewConn(wsConn *websocket.Conn, serverAddress string, underlying net.Conn) net.Conn {
|
||||
var addr net.Addr = WebsocketAddr{serverAddress}
|
||||
if underlying != nil {
|
||||
if ra := underlying.RemoteAddr(); ra != nil {
|
||||
addr = ra
|
||||
}
|
||||
}
|
||||
return &Conn{
|
||||
ctx: context.Background(),
|
||||
Conn: wsConn,
|
||||
remoteAddr: WebsocketAddr{serverAddress},
|
||||
remoteAddr: addr,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,10 +2,14 @@
|
||||
|
||||
package ws
|
||||
|
||||
import "github.com/coder/websocket"
|
||||
import (
|
||||
"net"
|
||||
|
||||
func createDialOptions() *websocket.DialOptions {
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
func createDialOptions(serverName string, underlyingOut *net.Conn) *websocket.DialOptions {
|
||||
return &websocket.DialOptions{
|
||||
HTTPClient: httpClientNbDialer(),
|
||||
HTTPClient: httpClientNbDialer(serverName, underlyingOut),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,9 +2,13 @@
|
||||
|
||||
package ws
|
||||
|
||||
import "github.com/coder/websocket"
|
||||
import (
|
||||
"net"
|
||||
|
||||
func createDialOptions() *websocket.DialOptions {
|
||||
// WASM version doesn't support HTTPClient
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
func createDialOptions(_ string, _ *net.Conn) *websocket.DialOptions {
|
||||
// WASM version doesn't support HTTPClient or custom TLS config.
|
||||
return &websocket.DialOptions{}
|
||||
}
|
||||
|
||||
@@ -26,13 +26,14 @@ func (d Dialer) Protocol() string {
|
||||
return "WS"
|
||||
}
|
||||
|
||||
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) {
|
||||
wsURL, err := prepareURL(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts := createDialOptions()
|
||||
var underlying net.Conn
|
||||
opts := createDialOptions(serverName, &underlying)
|
||||
|
||||
parsedURL, err := url.Parse(wsURL)
|
||||
if err != nil {
|
||||
@@ -52,7 +53,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
conn := NewConn(wsConn, address)
|
||||
conn := NewConn(wsConn, address, underlying)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
@@ -64,7 +65,10 @@ func prepareURL(address string) (string, error) {
|
||||
return strings.Replace(address, "rel", "ws", 1), nil
|
||||
}
|
||||
|
||||
func httpClientNbDialer() *http.Client {
|
||||
// httpClientNbDialer builds the http client used by the websocket library.
|
||||
// underlyingOut, when non-nil, is populated with the raw conn from the
|
||||
// transport's DialContext so the caller can read its RemoteAddr.
|
||||
func httpClientNbDialer(serverName string, underlyingOut *net.Conn) *http.Client {
|
||||
customDialer := nbnet.NewDialer()
|
||||
|
||||
certPool, err := x509.SystemCertPool()
|
||||
@@ -75,10 +79,15 @@ func httpClientNbDialer() *http.Client {
|
||||
|
||||
customTransport := &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return customDialer.DialContext(ctx, network, addr)
|
||||
c, err := customDialer.DialContext(ctx, network, addr)
|
||||
if err == nil && underlyingOut != nil {
|
||||
*underlyingOut = c
|
||||
}
|
||||
return c, err
|
||||
},
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: certPool,
|
||||
RootCAs: certPool,
|
||||
ServerName: serverName,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -75,6 +76,9 @@ type Manager struct {
|
||||
|
||||
mtu uint16
|
||||
maxBackoffInterval time.Duration
|
||||
|
||||
cleanupInterval time.Duration
|
||||
keepUnusedServerTime time.Duration
|
||||
}
|
||||
|
||||
// NewManager creates a new manager instance.
|
||||
@@ -95,6 +99,8 @@ func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uin
|
||||
},
|
||||
relayClients: make(map[string]*RelayTrack),
|
||||
onDisconnectedListeners: make(map[string]*list.List),
|
||||
cleanupInterval: relayCleanupInterval,
|
||||
keepUnusedServerTime: keepUnusedServerTime,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(m)
|
||||
@@ -130,7 +136,10 @@ func (m *Manager) Serve() error {
|
||||
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
|
||||
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
|
||||
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
|
||||
func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
|
||||
//
|
||||
// serverIP, when valid and serverAddress is foreign, is used as a dial target if the FQDN-based dial fails.
|
||||
// Ignored for the local home-server path. TLS verification still uses the FQDN via SNI.
|
||||
func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string, serverIP netip.Addr) (net.Conn, error) {
|
||||
m.relayClientMu.RLock()
|
||||
defer m.relayClientMu.RUnlock()
|
||||
|
||||
@@ -151,7 +160,7 @@ func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (
|
||||
netConn, err = m.relayClient.OpenConn(ctx, peerKey)
|
||||
} else {
|
||||
log.Debugf("open peer connection via foreign server: %s", serverAddress)
|
||||
netConn, err = m.openConnVia(ctx, serverAddress, peerKey)
|
||||
netConn, err = m.openConnVia(ctx, serverAddress, peerKey, serverIP)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -203,16 +212,22 @@ func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServ
|
||||
return nil
|
||||
}
|
||||
|
||||
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is
|
||||
// lost. This address will be sent to the target peer to choose the common relay server for the communication.
|
||||
func (m *Manager) RelayInstanceAddress() (string, error) {
|
||||
// RelayInstanceAddress returns the address and resolved IP of the permanent relay server. It could change if the
|
||||
// network connection is lost. The address is sent to the target peer to choose the common relay server for the
|
||||
// communication; the IP is sent alongside so remote peers can dial directly without their own DNS lookup. Both
|
||||
// values are read under the same lock so they cannot diverge across a reconnection.
|
||||
func (m *Manager) RelayInstanceAddress() (string, netip.Addr, error) {
|
||||
m.relayClientMu.RLock()
|
||||
defer m.relayClientMu.RUnlock()
|
||||
|
||||
if m.relayClient == nil {
|
||||
return "", ErrRelayClientNotConnected
|
||||
return "", netip.Addr{}, ErrRelayClientNotConnected
|
||||
}
|
||||
return m.relayClient.ServerInstanceURL()
|
||||
addr, err := m.relayClient.ServerInstanceURL()
|
||||
if err != nil {
|
||||
return "", netip.Addr{}, err
|
||||
}
|
||||
return addr, m.relayClient.ConnectedIP(), nil
|
||||
}
|
||||
|
||||
// ServerURLs returns the addresses of the relay servers.
|
||||
@@ -236,7 +251,7 @@ func (m *Manager) UpdateToken(token *relayAuth.Token) error {
|
||||
return m.tokenStore.UpdateToken(token)
|
||||
}
|
||||
|
||||
func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
|
||||
func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string, serverIP netip.Addr) (net.Conn, error) {
|
||||
// check if already has a connection to the desired relay server
|
||||
m.relayClientsMutex.RLock()
|
||||
rt, ok := m.relayClients[serverAddress]
|
||||
@@ -271,7 +286,7 @@ func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string
|
||||
m.relayClients[serverAddress] = rt
|
||||
m.relayClientsMutex.Unlock()
|
||||
|
||||
relayClient := NewClient(serverAddress, m.tokenStore, m.peerID, m.mtu)
|
||||
relayClient := NewClientWithServerIP(serverAddress, serverIP, m.tokenStore, m.peerID, m.mtu)
|
||||
err := relayClient.Connect(m.ctx)
|
||||
if err != nil {
|
||||
rt.err = err
|
||||
@@ -364,7 +379,7 @@ func (m *Manager) isForeignServer(address string) (bool, error) {
|
||||
}
|
||||
|
||||
func (m *Manager) startCleanupLoop() {
|
||||
ticker := time.NewTicker(relayCleanupInterval)
|
||||
ticker := time.NewTicker(m.cleanupInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
@@ -389,7 +404,7 @@ func (m *Manager) cleanUpUnusedRelays() {
|
||||
continue
|
||||
}
|
||||
|
||||
if time.Since(rt.created) <= keepUnusedServerTime {
|
||||
if time.Since(rt.created) <= m.keepUnusedServerTime {
|
||||
rt.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
144
shared/relay/client/manager_serverip_test.go
Normal file
144
shared/relay/client/manager_serverip_test.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/relay/server"
|
||||
)
|
||||
|
||||
// TestManager_ForeignRelayServerIP exercises the foreign-relay path
|
||||
// end-to-end through Manager.OpenConn. Alice and Bob register on different
|
||||
// relay servers; Alice dials Bob's foreign relay using an unresolvable
|
||||
// FQDN. Without a server IP the dial fails; with Bob's advertised IP it
|
||||
// recovers and a payload round-trips between the peers.
|
||||
func TestManager_ForeignRelayServerIP(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Alice's home relay
|
||||
homeCfg := server.ListenerConfig{Address: "127.0.0.1:52401"}
|
||||
homeSrv, err := server.NewServer(newManagerTestServerConfig(homeCfg.Address))
|
||||
if err != nil {
|
||||
t.Fatalf("create home server: %s", err)
|
||||
}
|
||||
homeErr := make(chan error, 1)
|
||||
go func() {
|
||||
if err := homeSrv.Listen(homeCfg); err != nil {
|
||||
homeErr <- err
|
||||
}
|
||||
}()
|
||||
t.Cleanup(func() { _ = homeSrv.Shutdown(context.Background()) })
|
||||
if err := waitForServerToStart(homeErr); err != nil {
|
||||
t.Fatalf("home server: %s", err)
|
||||
}
|
||||
|
||||
// Bob's foreign relay
|
||||
foreignCfg := server.ListenerConfig{Address: "127.0.0.1:52402"}
|
||||
foreignSrv, err := server.NewServer(newManagerTestServerConfig(foreignCfg.Address))
|
||||
if err != nil {
|
||||
t.Fatalf("create foreign server: %s", err)
|
||||
}
|
||||
foreignErr := make(chan error, 1)
|
||||
go func() {
|
||||
if err := foreignSrv.Listen(foreignCfg); err != nil {
|
||||
foreignErr <- err
|
||||
}
|
||||
}()
|
||||
t.Cleanup(func() { _ = foreignSrv.Shutdown(context.Background()) })
|
||||
if err := waitForServerToStart(foreignErr); err != nil {
|
||||
t.Fatalf("foreign server: %s", err)
|
||||
}
|
||||
|
||||
mCtx, mCancel := context.WithCancel(ctx)
|
||||
t.Cleanup(mCancel)
|
||||
|
||||
mgrAlice := NewManager(mCtx, toURL(homeCfg), "alice", iface.DefaultMTU)
|
||||
if err := mgrAlice.Serve(); err != nil {
|
||||
t.Fatalf("alice manager serve: %s", err)
|
||||
}
|
||||
|
||||
mgrBob := NewManager(mCtx, toURL(foreignCfg), "bob", iface.DefaultMTU)
|
||||
if err := mgrBob.Serve(); err != nil {
|
||||
t.Fatalf("bob manager serve: %s", err)
|
||||
}
|
||||
|
||||
// Bob's real relay URL and the IP that would ride along in signal as relayServerIP.
|
||||
bobRealAddr, bobAdvertisedIP, err := mgrBob.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
t.Fatalf("bob relay address: %s", err)
|
||||
}
|
||||
if !bobAdvertisedIP.IsValid() {
|
||||
t.Fatalf("expected valid RelayInstanceIP for bob, got zero")
|
||||
}
|
||||
|
||||
// .invalid is reserved (RFC 2606), so DNS resolution always fails.
|
||||
const brokenFQDN = "rel://relay-bob-instance.invalid:52402"
|
||||
if brokenFQDN == bobRealAddr {
|
||||
t.Fatalf("broken FQDN must differ from bob's real address (%s)", bobRealAddr)
|
||||
}
|
||||
|
||||
t.Run("no server IP, dial fails", func(t *testing.T) {
|
||||
dialCtx, dialCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer dialCancel()
|
||||
_, err := mgrAlice.OpenConn(dialCtx, brokenFQDN, "bob", netip.Addr{})
|
||||
if err == nil {
|
||||
t.Fatalf("expected OpenConn to fail without server IP, got success")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("server IP recovers", func(t *testing.T) {
|
||||
// Bob waits for Alice's incoming peer connection on his side.
|
||||
bobSideCh := make(chan error, 1)
|
||||
go func() {
|
||||
conn, err := mgrBob.OpenConn(ctx, bobRealAddr, "alice", netip.Addr{})
|
||||
if err != nil {
|
||||
bobSideCh <- err
|
||||
return
|
||||
}
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
bobSideCh <- err
|
||||
return
|
||||
}
|
||||
if _, err := conn.Write(buf[:n]); err != nil {
|
||||
bobSideCh <- err
|
||||
return
|
||||
}
|
||||
bobSideCh <- nil
|
||||
}()
|
||||
|
||||
aliceConn, err := mgrAlice.OpenConn(ctx, brokenFQDN, "bob", bobAdvertisedIP)
|
||||
if err != nil {
|
||||
t.Fatalf("alice OpenConn with server IP: %s", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = aliceConn.Close() })
|
||||
|
||||
payload := []byte("alice-to-bob")
|
||||
if _, err := aliceConn.Write(payload); err != nil {
|
||||
t.Fatalf("alice write: %s", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, len(payload))
|
||||
if _, err := io.ReadFull(aliceConn, buf); err != nil {
|
||||
t.Fatalf("alice read echo: %s", err)
|
||||
}
|
||||
if string(buf) != string(payload) {
|
||||
t.Fatalf("echo mismatch: got %q want %q", buf, payload)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-bobSideCh:
|
||||
if err != nil {
|
||||
t.Fatalf("bob side: %s", err)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatalf("timed out waiting for bob side")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package client
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -101,15 +102,15 @@ func TestForeignConn(t *testing.T) {
|
||||
if err := clientBob.Serve(); err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
bobsSrvAddr, err := clientBob.RelayInstanceAddress()
|
||||
bobsSrvAddr, _, err := clientBob.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get relay address: %s", err)
|
||||
}
|
||||
connAliceToBob, err := clientAlice.OpenConn(ctx, bobsSrvAddr, "bob")
|
||||
connAliceToBob, err := clientAlice.OpenConn(ctx, bobsSrvAddr, "bob", netip.Addr{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
connBobToAlice, err := clientBob.OpenConn(ctx, bobsSrvAddr, "alice")
|
||||
connBobToAlice, err := clientBob.OpenConn(ctx, bobsSrvAddr, "alice", netip.Addr{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
@@ -209,7 +210,7 @@ func TestForeginConnClose(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "bob")
|
||||
conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "bob", netip.Addr{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
@@ -301,7 +302,7 @@ func TestForeignAutoClose(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Log("open connection to another peer")
|
||||
if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer"); err == nil {
|
||||
if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer", netip.Addr{}); err == nil {
|
||||
t.Fatalf("should have failed to open connection to another peer")
|
||||
}
|
||||
|
||||
@@ -367,11 +368,11 @@ func TestAutoReconnect(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
ra, err := clientAlice.RelayInstanceAddress()
|
||||
ra, _, err := clientAlice.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
t.Errorf("failed to get relay address: %s", err)
|
||||
}
|
||||
conn, err := clientAlice.OpenConn(ctx, ra, "bob")
|
||||
conn, err := clientAlice.OpenConn(ctx, ra, "bob", netip.Addr{})
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
@@ -391,7 +392,7 @@ func TestAutoReconnect(t *testing.T) {
|
||||
}
|
||||
|
||||
log.Infof("reopent the connection")
|
||||
_, err = clientAlice.OpenConn(ctx, ra, "bob")
|
||||
_, err = clientAlice.OpenConn(ctx, ra, "bob", netip.Addr{})
|
||||
if err != nil {
|
||||
t.Errorf("failed to open channel: %s", err)
|
||||
}
|
||||
@@ -453,7 +454,7 @@ func TestNotifierDoubleAdd(t *testing.T) {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
|
||||
conn1, err := clientAlice.OpenConn(ctx, clientAlice.ServerURLs()[0], "bob")
|
||||
conn1, err := clientAlice.OpenConn(ctx, clientAlice.ServerURLs()[0], "bob", netip.Addr{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
@@ -14,17 +15,17 @@ import (
|
||||
|
||||
// A set of tools to exchange connection details (Wireguard endpoints) with the remote peer.
|
||||
|
||||
// Status is the status of the client
|
||||
type Status string
|
||||
|
||||
const StreamConnected Status = "Connected"
|
||||
const StreamDisconnected Status = "Disconnected"
|
||||
|
||||
const (
|
||||
StreamConnected Status = "Connected"
|
||||
StreamDisconnected Status = "Disconnected"
|
||||
|
||||
// DirectCheck indicates support to direct mode checks
|
||||
DirectCheck uint32 = 1
|
||||
)
|
||||
|
||||
// Status is the status of the client
|
||||
type Status string
|
||||
|
||||
type Client interface {
|
||||
io.Closer
|
||||
StreamConnected() bool
|
||||
@@ -38,6 +39,24 @@ type Client interface {
|
||||
SetOnReconnectedListener(func())
|
||||
}
|
||||
|
||||
// Credential is an instance of a GrpcClient's Credential
|
||||
type Credential struct {
|
||||
UFrag string
|
||||
Pwd string
|
||||
}
|
||||
|
||||
// CredentialPayload bundles the fields of a signal Body for MarshalCredential.
|
||||
type CredentialPayload struct {
|
||||
Type proto.Body_Type
|
||||
WgListenPort int
|
||||
Credential *Credential
|
||||
RosenpassPubKey []byte
|
||||
RosenpassAddr string
|
||||
RelaySrvAddress string
|
||||
RelaySrvIP netip.Addr
|
||||
SessionID []byte
|
||||
}
|
||||
|
||||
// UnMarshalCredential parses the credentials from the message and returns a Credential instance
|
||||
func UnMarshalCredential(msg *proto.Message) (*Credential, error) {
|
||||
|
||||
@@ -52,27 +71,27 @@ func UnMarshalCredential(msg *proto.Message) (*Credential, error) {
|
||||
}
|
||||
|
||||
// MarshalCredential marshal a Credential instance and returns a Message object
|
||||
func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey string, credential *Credential, t proto.Body_Type, rosenpassPubKey []byte, rosenpassAddr string, relaySrvAddress string, sessionID []byte) (*proto.Message, error) {
|
||||
func MarshalCredential(myKey wgtypes.Key, remoteKey string, p CredentialPayload) (*proto.Message, error) {
|
||||
body := &proto.Body{
|
||||
Type: p.Type,
|
||||
Payload: fmt.Sprintf("%s:%s", p.Credential.UFrag, p.Credential.Pwd),
|
||||
WgListenPort: uint32(p.WgListenPort),
|
||||
NetBirdVersion: version.NetbirdVersion(),
|
||||
RosenpassConfig: &proto.RosenpassConfig{
|
||||
RosenpassPubKey: p.RosenpassPubKey,
|
||||
RosenpassServerAddr: p.RosenpassAddr,
|
||||
},
|
||||
SessionId: p.SessionID,
|
||||
}
|
||||
if p.RelaySrvAddress != "" {
|
||||
body.RelayServerAddress = &p.RelaySrvAddress
|
||||
}
|
||||
if p.RelaySrvIP.IsValid() {
|
||||
body.RelayServerIP = p.RelaySrvIP.Unmap().AsSlice()
|
||||
}
|
||||
return &proto.Message{
|
||||
Key: myKey.PublicKey().String(),
|
||||
RemoteKey: remoteKey,
|
||||
Body: &proto.Body{
|
||||
Type: t,
|
||||
Payload: fmt.Sprintf("%s:%s", credential.UFrag, credential.Pwd),
|
||||
WgListenPort: uint32(myPort),
|
||||
NetBirdVersion: version.NetbirdVersion(),
|
||||
RosenpassConfig: &proto.RosenpassConfig{
|
||||
RosenpassPubKey: rosenpassPubKey,
|
||||
RosenpassServerAddr: rosenpassAddr,
|
||||
},
|
||||
RelayServerAddress: relaySrvAddress,
|
||||
SessionId: sessionID,
|
||||
},
|
||||
Body: body,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Credential is an instance of a GrpcClient's Credential
|
||||
type Credential struct {
|
||||
UFrag string
|
||||
Pwd string
|
||||
}
|
||||
|
||||
@@ -167,7 +167,7 @@ func (c *GrpcClient) Receive(ctx context.Context, msgHandler func(msg *proto.Mes
|
||||
// start receiving messages from the Signal stream (from other peers through signal)
|
||||
err = c.receive(stream)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
|
||||
if ctx.Err() != nil {
|
||||
log.Debugf("signal connection context has been canceled, this usually indicates shutdown")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -229,8 +229,13 @@ type Body struct {
|
||||
// RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to
|
||||
RosenpassConfig *RosenpassConfig `protobuf:"bytes,7,opt,name=rosenpassConfig,proto3" json:"rosenpassConfig,omitempty"`
|
||||
// relayServerAddress is url of the relay server
|
||||
RelayServerAddress string `protobuf:"bytes,8,opt,name=relayServerAddress,proto3" json:"relayServerAddress,omitempty"`
|
||||
SessionId []byte `protobuf:"bytes,10,opt,name=sessionId,proto3,oneof" json:"sessionId,omitempty"`
|
||||
RelayServerAddress *string `protobuf:"bytes,8,opt,name=relayServerAddress,proto3,oneof" json:"relayServerAddress,omitempty"`
|
||||
SessionId []byte `protobuf:"bytes,10,opt,name=sessionId,proto3,oneof" json:"sessionId,omitempty"`
|
||||
// relayServerIP is the IP the sender is connected to on its relay server,
|
||||
// encoded as 4 bytes (IPv4) or 16 bytes (IPv6). Receivers may use it as a
|
||||
// fallback dial target when DNS resolution of relayServerAddress fails.
|
||||
// SNI/TLS verification still uses relayServerAddress.
|
||||
RelayServerIP []byte `protobuf:"bytes,11,opt,name=relayServerIP,proto3,oneof" json:"relayServerIP,omitempty"`
|
||||
}
|
||||
|
||||
func (x *Body) Reset() {
|
||||
@@ -315,8 +320,8 @@ func (x *Body) GetRosenpassConfig() *RosenpassConfig {
|
||||
}
|
||||
|
||||
func (x *Body) GetRelayServerAddress() string {
|
||||
if x != nil {
|
||||
return x.RelayServerAddress
|
||||
if x != nil && x.RelayServerAddress != nil {
|
||||
return *x.RelayServerAddress
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -328,6 +333,13 @@ func (x *Body) GetSessionId() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *Body) GetRelayServerIP() []byte {
|
||||
if x != nil {
|
||||
return x.RelayServerIP
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Mode indicates a connection mode
|
||||
type Mode struct {
|
||||
state protoimpl.MessageState
|
||||
@@ -451,7 +463,7 @@ var file_signalexchange_proto_rawDesc = []byte{
|
||||
0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62,
|
||||
0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e,
|
||||
0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52,
|
||||
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xe4, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
|
||||
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xc3, 0x04, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
|
||||
0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73,
|
||||
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f,
|
||||
0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a,
|
||||
@@ -471,40 +483,46 @@ var file_signalexchange_proto_rawDesc = []byte{
|
||||
0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63,
|
||||
0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43,
|
||||
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73,
|
||||
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53,
|
||||
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x33, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53,
|
||||
0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x08, 0x20, 0x01,
|
||||
0x28, 0x09, 0x52, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41,
|
||||
0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x21, 0x0a, 0x09, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f,
|
||||
0x6e, 0x49, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x00, 0x52, 0x09, 0x73, 0x65, 0x73,
|
||||
0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x22, 0x43, 0x0a, 0x04, 0x54, 0x79, 0x70,
|
||||
0x28, 0x09, 0x48, 0x00, 0x52, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65,
|
||||
0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x88, 0x01, 0x01, 0x12, 0x21, 0x0a, 0x09, 0x73,
|
||||
0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x01,
|
||||
0x52, 0x09, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x12, 0x29,
|
||||
0x0a, 0x0d, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x49, 0x50, 0x18,
|
||||
0x0b, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x02, 0x52, 0x0d, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65,
|
||||
0x72, 0x76, 0x65, 0x72, 0x49, 0x50, 0x88, 0x01, 0x01, 0x22, 0x43, 0x0a, 0x04, 0x54, 0x79, 0x70,
|
||||
0x65, 0x12, 0x09, 0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06,
|
||||
0x41, 0x4e, 0x53, 0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44,
|
||||
0x49, 0x44, 0x41, 0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10,
|
||||
0x04, 0x12, 0x0b, 0x0a, 0x07, 0x47, 0x4f, 0x5f, 0x49, 0x44, 0x4c, 0x45, 0x10, 0x05, 0x42, 0x0c,
|
||||
0x0a, 0x0a, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x22, 0x2e, 0x0a, 0x04,
|
||||
0x4d, 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x01,
|
||||
0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x88, 0x01,
|
||||
0x01, 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d, 0x0a, 0x0f,
|
||||
0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12,
|
||||
0x28, 0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b,
|
||||
0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70,
|
||||
0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73,
|
||||
0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72,
|
||||
0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73,
|
||||
0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, 0x0a, 0x0e,
|
||||
0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c,
|
||||
0x0a, 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65,
|
||||
0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65,
|
||||
0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61,
|
||||
0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70,
|
||||
0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d,
|
||||
0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e,
|
||||
0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45,
|
||||
0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a,
|
||||
0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65,
|
||||
0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67,
|
||||
0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
0x04, 0x12, 0x0b, 0x0a, 0x07, 0x47, 0x4f, 0x5f, 0x49, 0x44, 0x4c, 0x45, 0x10, 0x05, 0x42, 0x15,
|
||||
0x0a, 0x13, 0x5f, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64,
|
||||
0x64, 0x72, 0x65, 0x73, 0x73, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f,
|
||||
0x6e, 0x49, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72,
|
||||
0x76, 0x65, 0x72, 0x49, 0x50, 0x4a, 0x04, 0x08, 0x09, 0x10, 0x0a, 0x22, 0x2e, 0x0a, 0x04, 0x4d,
|
||||
0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x88, 0x01, 0x01,
|
||||
0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d, 0x0a, 0x0f, 0x52,
|
||||
0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x28,
|
||||
0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65,
|
||||
0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61,
|
||||
0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65,
|
||||
0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x18,
|
||||
0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73,
|
||||
0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, 0x0a, 0x0e, 0x53,
|
||||
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a,
|
||||
0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78,
|
||||
0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64,
|
||||
0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c,
|
||||
0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74,
|
||||
0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, 0x43,
|
||||
0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, 0x73,
|
||||
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e,
|
||||
0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20,
|
||||
0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e,
|
||||
0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
|
||||
0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f,
|
||||
0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
@@ -63,9 +63,17 @@ message Body {
|
||||
RosenpassConfig rosenpassConfig = 7;
|
||||
|
||||
// relayServerAddress is url of the relay server
|
||||
string relayServerAddress = 8;
|
||||
optional string relayServerAddress = 8;
|
||||
|
||||
reserved 9;
|
||||
|
||||
optional bytes sessionId = 10;
|
||||
|
||||
// relayServerIP is the IP the sender is connected to on its relay server,
|
||||
// encoded as 4 bytes (IPv4) or 16 bytes (IPv6). Receivers may use it as a
|
||||
// fallback dial target when DNS resolution of relayServerAddress fails.
|
||||
// SNI/TLS verification still uses relayServerAddress.
|
||||
optional bytes relayServerIP = 11;
|
||||
}
|
||||
|
||||
// Mode indicates a connection mode
|
||||
|
||||
@@ -10,15 +10,13 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/google/gopacket/routing"
|
||||
"github.com/libp2p/go-netroute"
|
||||
"github.com/mdlayher/socket"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
@@ -37,8 +35,6 @@ type SharedSocket struct {
|
||||
conn6 *socket.Conn
|
||||
port int
|
||||
mtu uint16
|
||||
routerMux sync.RWMutex
|
||||
router routing.Router
|
||||
packetDemux chan rcvdPacket
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
@@ -82,11 +78,6 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error
|
||||
}
|
||||
}()
|
||||
|
||||
rawSock.router, err = netroute.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create raw socket router: %w", err)
|
||||
}
|
||||
|
||||
rawSock.conn4, err = socket.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp4", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err)
|
||||
@@ -127,31 +118,26 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error
|
||||
go rawSock.read(rawSock.conn6.Recvfrom)
|
||||
}
|
||||
|
||||
go rawSock.updateRouter()
|
||||
|
||||
return rawSock, nil
|
||||
}
|
||||
|
||||
// updateRouter updates the listener routing table client
|
||||
// this is needed to avoid outdated information across different client networks
|
||||
func (s *SharedSocket) updateRouter() {
|
||||
ticker := time.NewTicker(15 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
router, err := netroute.New()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to create and update packet router for stunListener: %s", err)
|
||||
continue
|
||||
}
|
||||
s.routerMux.Lock()
|
||||
s.router = router
|
||||
s.routerMux.Unlock()
|
||||
// resolveSrc returns the source IP the kernel will pick for a packet sent to
|
||||
// dst by these raw sockets, mirroring the fwmark the kernel will see on send.
|
||||
func (s *SharedSocket) resolveSrc(dst net.IP) (net.IP, error) {
|
||||
opts := &netlink.RouteGetOptions{}
|
||||
if nbnet.AdvancedRouting() {
|
||||
opts.Mark = nbnet.ControlPlaneMark
|
||||
}
|
||||
routes, err := netlink.RouteGetWithOptions(dst, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("route get %s: %w", dst, err)
|
||||
}
|
||||
for _, r := range routes {
|
||||
if r.Src != nil {
|
||||
return r.Src, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("no source IP for %s", dst)
|
||||
}
|
||||
|
||||
// LocalAddr returns the local address, preferring IPv4 for backward compatibility.
|
||||
@@ -310,15 +296,15 @@ func (s *SharedSocket) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
|
||||
DstPort: layers.UDPPort(rUDPAddr.Port),
|
||||
}
|
||||
|
||||
s.routerMux.RLock()
|
||||
defer s.routerMux.RUnlock()
|
||||
|
||||
_, _, src, err := s.router.Route(rUDPAddr.IP)
|
||||
src, err := s.resolveSrc(rUDPAddr.IP)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("got an error while checking route, err: %w", err)
|
||||
return 0, fmt.Errorf("resolve source for %s: %w", rUDPAddr.IP, err)
|
||||
}
|
||||
|
||||
rSockAddr, conn, nwLayer := s.getWriterObjects(src, rUDPAddr.IP)
|
||||
if conn == nil {
|
||||
return 0, fmt.Errorf("no raw socket for %s", rUDPAddr.IP)
|
||||
}
|
||||
|
||||
if err := udp.SetNetworkLayerForChecksum(nwLayer); err != nil {
|
||||
return -1, fmt.Errorf("failed to set network layer for checksum: %w", err)
|
||||
|
||||
Reference in New Issue
Block a user