mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-31 13:09:55 +00:00
Compare commits
50 Commits
fix/wiregu
...
nmap/compo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eab0826b4e | ||
|
|
7048b87931 | ||
|
|
174dc24867 | ||
|
|
596952265d | ||
|
|
7ea5e37dd4 | ||
|
|
9d7ef9b255 | ||
|
|
944a258459 | ||
|
|
21cfec93d4 | ||
|
|
1f9a829f2c | ||
|
|
98818e3095 | ||
|
|
14af179556 | ||
|
|
1fbb5e6d5d | ||
|
|
6771e35d57 | ||
|
|
e89b1e0596 | ||
|
|
5d5c2d9f95 | ||
|
|
d542c60e21 | ||
|
|
4983b5cf17 | ||
|
|
b3b0feb3b8 | ||
|
|
7aebdd69dd | ||
|
|
0358be2313 | ||
|
|
13e41e432c | ||
|
|
37052fd5bc | ||
|
|
454ff66518 | ||
|
|
6137a1fcc5 | ||
|
|
4955c345d5 | ||
|
|
9192b4f029 | ||
|
|
efa6a3f502 | ||
|
|
c784b02550 | ||
|
|
d250f92c43 | ||
|
|
80966ab1b0 | ||
|
|
5fbcdeceac | ||
|
|
3a1bbeba90 | ||
|
|
728057ef15 | ||
|
|
582cd70086 | ||
|
|
9bbbafaf69 | ||
|
|
672b057aa0 | ||
|
|
b9a0186200 | ||
|
|
9083bdb977 | ||
|
|
b194af48b8 | ||
|
|
4543780ef0 | ||
|
|
2de0283971 | ||
|
|
af24fd7796 | ||
|
|
13d32d274f | ||
|
|
705f87fc20 | ||
|
|
3f91f49277 | ||
|
|
347c5bf317 | ||
|
|
22e2519d71 | ||
|
|
e916f12cca | ||
|
|
9ed2e2a5b4 | ||
|
|
2ccae7ec47 |
1
.github/pull_request_template.md
vendored
1
.github/pull_request_template.md
vendored
@@ -12,6 +12,7 @@
|
|||||||
- [ ] Is a feature enhancement
|
- [ ] Is a feature enhancement
|
||||||
- [ ] It is a refactor
|
- [ ] It is a refactor
|
||||||
- [ ] Created tests that fail without the change (if possible)
|
- [ ] Created tests that fail without the change (if possible)
|
||||||
|
- [ ] This change does **not** modify the public API, gRPC protocols, functionality behavior, CLI / service flags, or introduce a new feature — **OR** I have discussed it with the NetBird team beforehand (link the issue / Slack thread in the description). See [CONTRIBUTING.md](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTING.md#discuss-changes-with-the-netbird-team-first).
|
||||||
|
|
||||||
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).
|
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).
|
||||||
|
|
||||||
|
|||||||
4
.github/workflows/golangci-lint.yml
vendored
4
.github/workflows/golangci-lint.yml
vendored
@@ -35,7 +35,7 @@ jobs:
|
|||||||
display_name: Linux
|
display_name: Linux
|
||||||
name: ${{ matrix.display_name }}
|
name: ${{ matrix.display_name }}
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
timeout-minutes: 15
|
timeout-minutes: 25
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -58,4 +58,4 @@ jobs:
|
|||||||
skip-cache: true
|
skip-cache: true
|
||||||
skip-save-cache: true
|
skip-save-cache: true
|
||||||
cache-invalidation-interval: 0
|
cache-invalidation-interval: 0
|
||||||
args: --timeout=12m
|
args: --timeout=20m
|
||||||
|
|||||||
66
.github/workflows/proto-version-check.yml
vendored
66
.github/workflows/proto-version-check.yml
vendored
@@ -20,34 +20,66 @@ jobs:
|
|||||||
per_page: 100,
|
per_page: 100,
|
||||||
});
|
});
|
||||||
|
|
||||||
const pbFiles = files.filter(f => f.filename.endsWith('.pb.go'));
|
const modifiedPbFiles = files.filter(
|
||||||
const missingPatch = pbFiles.filter(f => !f.patch).map(f => f.filename);
|
f => f.filename.endsWith('.pb.go') && f.status === 'modified'
|
||||||
if (missingPatch.length > 0) {
|
);
|
||||||
core.setFailed(
|
if (modifiedPbFiles.length === 0) {
|
||||||
`Cannot inspect patch data for:\n` +
|
console.log('No modified .pb.go files to check');
|
||||||
missingPatch.map(f => `- ${f}`).join('\n') +
|
|
||||||
`\nThis can happen with very large PRs. Verify proto versions manually.`
|
|
||||||
);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const versionPattern = /^[+-]\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
|
|
||||||
const violations = [];
|
|
||||||
|
|
||||||
for (const file of pbFiles) {
|
const versionPattern = /^\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
|
||||||
const changed = file.patch
|
const baseSha = context.payload.pull_request.base.sha;
|
||||||
.split('\n')
|
const headSha = context.payload.pull_request.head.sha;
|
||||||
.filter(line => versionPattern.test(line));
|
|
||||||
if (changed.length > 0) {
|
async function getVersionHeader(path, ref) {
|
||||||
|
try {
|
||||||
|
const res = await github.rest.repos.getContent({
|
||||||
|
owner: context.repo.owner,
|
||||||
|
repo: context.repo.repo,
|
||||||
|
path,
|
||||||
|
ref,
|
||||||
|
});
|
||||||
|
if (!res.data.content) {
|
||||||
|
return { ok: false, reason: 'no inline content (file too large)' };
|
||||||
|
}
|
||||||
|
const content = Buffer.from(res.data.content, 'base64').toString('utf8');
|
||||||
|
const lines = content
|
||||||
|
.split('\n')
|
||||||
|
.slice(0, 20)
|
||||||
|
.filter(line => versionPattern.test(line));
|
||||||
|
return { ok: true, lines };
|
||||||
|
} catch (e) {
|
||||||
|
return { ok: false, reason: e.message };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const violations = [];
|
||||||
|
for (const file of modifiedPbFiles) {
|
||||||
|
const [base, head] = await Promise.all([
|
||||||
|
getVersionHeader(file.filename, baseSha),
|
||||||
|
getVersionHeader(file.filename, headSha),
|
||||||
|
]);
|
||||||
|
if (!base.ok || !head.ok) {
|
||||||
|
core.warning(
|
||||||
|
`Skipping ${file.filename}: base=${base.ok ? 'ok' : base.reason}, head=${head.ok ? 'ok' : head.reason}`
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (base.lines.join('\n') !== head.lines.join('\n')) {
|
||||||
violations.push({
|
violations.push({
|
||||||
file: file.filename,
|
file: file.filename,
|
||||||
lines: changed,
|
base: base.lines,
|
||||||
|
head: head.lines,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (violations.length > 0) {
|
if (violations.length > 0) {
|
||||||
const details = violations.map(v =>
|
const details = violations.map(v =>
|
||||||
`${v.file}:\n${v.lines.map(l => ' ' + l).join('\n')}`
|
`${v.file}:\n` +
|
||||||
|
` base:\n${v.base.map(l => ' ' + l).join('\n') || ' (none)'}\n` +
|
||||||
|
` head:\n${v.head.map(l => ' ' + l).join('\n') || ' (none)'}`
|
||||||
).join('\n\n');
|
).join('\n\n');
|
||||||
|
|
||||||
core.setFailed(
|
core.setFailed(
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ If you haven't already, join our slack workspace [here](https://docs.netbird.io/
|
|||||||
- [Contributing to NetBird](#contributing-to-netbird)
|
- [Contributing to NetBird](#contributing-to-netbird)
|
||||||
- [Contents](#contents)
|
- [Contents](#contents)
|
||||||
- [Code of conduct](#code-of-conduct)
|
- [Code of conduct](#code-of-conduct)
|
||||||
|
- [Discuss changes with the NetBird team first](#discuss-changes-with-the-netbird-team-first)
|
||||||
- [Directory structure](#directory-structure)
|
- [Directory structure](#directory-structure)
|
||||||
- [Development setup](#development-setup)
|
- [Development setup](#development-setup)
|
||||||
- [Requirements](#requirements)
|
- [Requirements](#requirements)
|
||||||
@@ -33,6 +34,14 @@ Conduct which can be found in the file [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md).
|
|||||||
By participating, you are expected to uphold this code. Please report
|
By participating, you are expected to uphold this code. Please report
|
||||||
unacceptable behavior to community@netbird.io.
|
unacceptable behavior to community@netbird.io.
|
||||||
|
|
||||||
|
## Discuss changes with the NetBird team first
|
||||||
|
|
||||||
|
Changes to the **public API**, **gRPC protocols**, **functionality behavior**, **CLI / service flags**, or **new features** should be discussed with the NetBird team before you start the work. These surfaces are part of NetBird's contract with operators, self-hosters, and downstream integrators, and changes to them have compatibility, security, and release-planning implications that benefit from an early conversation.
|
||||||
|
|
||||||
|
Open an issue or reach out on [Slack](https://docs.netbird.io/slack-url) to talk through what you have in mind. We'll help shape the change, flag any constraints we know about, and confirm the direction so the PR review can focus on implementation rather than design.
|
||||||
|
|
||||||
|
Typical bug fixes, internal refactors, documentation updates, and tests do not need pre-discussion — open the PR directly.
|
||||||
|
|
||||||
## Directory structure
|
## Directory structure
|
||||||
|
|
||||||
The NetBird project monorepo is organized to maintain most of its individual dependencies code within their directories, except for a few auxiliary or shared packages.
|
The NetBird project monorepo is organized to maintain most of its individual dependencies code within their directories, except for a few auxiliary or shared packages.
|
||||||
|
|||||||
153
README.md
153
README.md
@@ -1,147 +1,134 @@
|
|||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<br/>
|
<p align="center">
|
||||||
<br/>
|
<img width="234" src="docs/media/logo-full.png" alt="NetBird logo"/>
|
||||||
<p align="center">
|
</p>
|
||||||
<img width="234" src="docs/media/logo-full.png"/>
|
<p align="center">
|
||||||
</p>
|
<a href="https://sonarcloud.io/dashboard?id=netbirdio_netbird">
|
||||||
<p>
|
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" alt="SonarCloud alert status"/>
|
||||||
<a href="https://img.shields.io/badge/license-BSD--3-blue)">
|
</a>
|
||||||
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" />
|
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
|
||||||
</a>
|
<img src="https://img.shields.io/badge/license-BSD--3-blue" alt="BSD-3 License"/>
|
||||||
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
|
</a>
|
||||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
|
||||||
</a>
|
|
||||||
<br>
|
|
||||||
<a href="https://docs.netbird.io/slack-url">
|
<a href="https://docs.netbird.io/slack-url">
|
||||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack" alt="NetBird Slack"/>
|
||||||
</a>
|
</a>
|
||||||
<a href="https://forum.netbird.io">
|
<a href="https://forum.netbird.io">
|
||||||
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
|
<img src="https://img.shields.io/badge/community%20forum-@netbird-red.svg?logo=discourse" alt="Community forum"/>
|
||||||
</a>
|
</a>
|
||||||
<br>
|
|
||||||
<a href="https://gurubase.io/g/netbird">
|
<a href="https://gurubase.io/g/netbird">
|
||||||
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
|
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF" alt="Gurubase: Ask NetBird Guru"/>
|
||||||
</a>
|
</a>
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<strong>
|
<strong>
|
||||||
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
|
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
|
||||||
|
<br/>
|
||||||
|
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||||
|
<br/>
|
||||||
|
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
|
||||||
|
</strong>
|
||||||
<br/>
|
<br/>
|
||||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
|
||||||
<br/>
|
<br/>
|
||||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
|
<strong>
|
||||||
<br/>
|
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
|
||||||
|
</strong>
|
||||||
</strong>
|
|
||||||
<br>
|
|
||||||
<strong>
|
|
||||||
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
|
|
||||||
</strong>
|
|
||||||
<br>
|
|
||||||
<br>
|
|
||||||
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
|
|
||||||
New: NetBird terraform provider
|
|
||||||
</a>
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<br>
|
|
||||||
|
|
||||||
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
|
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
|
||||||
|
|
||||||
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
||||||
|
|
||||||
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
||||||
|
|
||||||
### Open Source Network Security in a Single Platform
|
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||||
|
|
||||||
### Self-Host NetBird (Video)
|
### Self-host NetBird (video)
|
||||||
|
|
||||||
[](https://youtu.be/bZAgpT6nzaQ)
|
[](https://youtu.be/bZAgpT6nzaQ)
|
||||||
|
|
||||||
### Key features
|
### Key features
|
||||||
|
|
||||||
| Connectivity | Management | Security | Automation| Platforms |
|
| Connectivity | Management | Security | Automation | Platforms |
|
||||||
|----|----|----|----|----|
|
|---|---|---|---|---|
|
||||||
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
|
| ✓ [Kernel WireGuard](https://docs.netbird.io/about-netbird/why-wireguard-with-netbird) | ✓ [Admin Web UI](https://github.com/netbirdio/dashboard) | ✓ [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) | ✓ [Public API](https://docs.netbird.io/api) | ✓ [Linux](https://docs.netbird.io/get-started/install/linux) |
|
||||||
| <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
|
| ✓ [Peer-to-peer connections](https://docs.netbird.io/about-netbird/how-netbird-works) | ✓ Auto peer discovery and configuration | ✓ [Access control: groups & rules](https://docs.netbird.io/how-to/manage-network-access) | ✓ [Setup keys for bulk provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) | ✓ [macOS](https://docs.netbird.io/get-started/install/macos) |
|
||||||
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
|
| ✓ Connection relay fallback | ✓ [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) | ✓ [Activity logging](https://docs.netbird.io/how-to/audit-events-logging) | ✓ [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) | ✓ [Windows](https://docs.netbird.io/get-started/install/windows) |
|
||||||
| <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
|
| ✓ [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) | ✓ [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) | ✓ [Traffic events](https://docs.netbird.io/manage/activity/traffic-events-logging) | ✓ [IdP groups sync with JWT](https://docs.netbird.io/manage/team/idp-sync) | ✓ [Android](https://docs.netbird.io/get-started/install/android) |
|
||||||
| <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
|
| ✓ [Domain-based DNS routes](https://docs.netbird.io/manage/dns/dns-aliases-for-routed-networks) | ✓ [Custom DNS zones](https://docs.netbird.io/manage/dns/custom-zones) | ✓ [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) | ✓ [Terraform provider](https://registry.terraform.io/providers/netbirdio/netbird/latest) | ✓ [Android TV](https://docs.netbird.io/get-started/install/android-tv) |
|
||||||
||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
|
| ✓ [Exit nodes](https://docs.netbird.io/manage/network-routes/use-cases/exit-nodes) | ✓ [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) | ✓ Peer-to-peer encryption | ✓ [Ansible collection](https://github.com/netbirdio/ansible-netbird) | ✓ [iOS](https://docs.netbird.io/get-started/install/ios) |
|
||||||
||| <ul><li>- \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ui></li> || <ul><li>- \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)</ui></li> |
|
| ✓ [IPv6 dual-stack overlay](https://docs.netbird.io/manage/settings/ipv6) | ✓ [Multi-account profile switching](https://docs.netbird.io/client/profiles) | ✓ [SSH with central access policies](https://docs.netbird.io/manage/peers/ssh) | | ✓ [Apple TV](https://docs.netbird.io/get-started/install/tvos) |
|
||||||
||||| <ul><li>- \[x] Docker</ui></li> |
|
| ✓ [Browser SSH & RDP](https://docs.netbird.io/manage/peers/browser-client) | | ✓ [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) | | ✓ FreeBSD |
|
||||||
|
| ✓ [Reverse proxy with auto-TLS](https://docs.netbird.io/manage/reverse-proxy) | | ✓ [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication) | | ✓ [pfSense](https://docs.netbird.io/get-started/install/pfsense) |
|
||||||
|
| | | | | ✓ [OPNsense](https://docs.netbird.io/get-started/install/opnsense) |
|
||||||
|
| | | | | ✓ [MikroTik RouterOS](https://docs.netbird.io/use-cases/homelab/client-on-mikrotik-router) |
|
||||||
|
| | | | | ✓ OpenWRT |
|
||||||
|
| | | | | ✓ [Synology](https://docs.netbird.io/get-started/install/synology) |
|
||||||
|
| | | | | ✓ [TrueNAS](https://docs.netbird.io/get-started/install/truenas) |
|
||||||
|
| | | | | ✓ [Proxmox](https://docs.netbird.io/get-started/install/proxmox-ve) |
|
||||||
|
| | | | | ✓ [Raspberry Pi](https://docs.netbird.io/get-started/install/raspberrypi) |
|
||||||
|
| | | | | ✓ [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) |
|
||||||
|
| | | | | ✓ [Container](https://docs.netbird.io/get-started/install/docker) |
|
||||||
|
|
||||||
### Quickstart with NetBird Cloud
|
### Quickstart with NetBird Cloud
|
||||||
|
|
||||||
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
|
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install).
|
||||||
- Follow the steps to sign-up with Google, Microsoft, GitHub or your email address.
|
- Follow the steps to sign up with Google, Microsoft, GitHub or your email address.
|
||||||
- Check NetBird [admin UI](https://app.netbird.io/).
|
- Check the NetBird [admin UI](https://app.netbird.io/).
|
||||||
- Add more machines.
|
|
||||||
|
|
||||||
### Quickstart with self-hosted NetBird
|
### Quickstart with self-hosted NetBird
|
||||||
|
|
||||||
> This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM.
|
This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM. Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IdPs.
|
||||||
Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IDPs.
|
|
||||||
|
|
||||||
**Infrastructure requirements:**
|
**Infrastructure requirements:**
|
||||||
- A Linux VM with at least **1CPU** and **2GB** of memory.
|
- A Linux VM with at least **1 CPU** and **2 GB** of memory.
|
||||||
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port: **3478**.
|
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port **3478**.
|
||||||
- **Public domain** name pointing to the VM.
|
- A **public domain** name pointing to the VM.
|
||||||
|
|
||||||
**Software requirements:**
|
**Software requirements:**
|
||||||
- Docker installed on the VM with the docker-compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
|
- Docker with the Compose plugin (Compose v2 or higher). See the [Docker installation guide](https://docs.docker.com/engine/install/).
|
||||||
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
|
|
||||||
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
|
|
||||||
- [curl](https://curl.se/) installed.
|
|
||||||
Usually available in the official repositories and can be installed with `sudo apt install curl` or `sudo yum install curl`
|
|
||||||
|
|
||||||
**Steps**
|
**Steps**
|
||||||
- Download and run the installation script:
|
- Download and run the installation script:
|
||||||
```bash
|
```bash
|
||||||
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started.sh | bash
|
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started.sh | bash
|
||||||
```
|
```
|
||||||
- Once finished, you can manage the resources via `docker-compose`
|
|
||||||
|
|
||||||
### A bit on NetBird internals
|
### A bit on NetBird internals
|
||||||
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
|
- Every machine in the network runs the [NetBird agent](client/), which manages WireGuard.
|
||||||
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
|
- Every agent connects to the [Management Service](management/), which holds network state, manages peer IPs, and distributes updates to agents.
|
||||||
- NetBird agent uses WebRTC ICE implemented in [pion/ice library](https://github.com/pion/ice) to discover connection candidates when establishing a peer-to-peer connection between machines.
|
- Agents use ICE (via [pion/ice](https://github.com/pion/ice)) to discover connection candidates for peer-to-peer connections.
|
||||||
- Connection candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
|
- Candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
|
||||||
- Agents negotiate a connection through [Signal Service](signal/) passing p2p encrypted messages with candidates.
|
- Agents negotiate a connection through the [Signal Service](signal/), exchanging end-to-end encrypted messages with candidates.
|
||||||
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and a p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
|
- When NAT traversal fails (e.g. mobile carrier-grade NAT) and a direct p2p connection isn't possible, the system falls back to a [Relay Service](relay/) and a secure WireGuard tunnel is established through it.
|
||||||
|
|
||||||
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
|
|
||||||
|
|
||||||
<p float="left" align="middle">
|
<p float="left" align="middle">
|
||||||
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700"/>
|
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700" alt="NetBird high-level architecture diagram"/>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
||||||
|
|
||||||
### Community projects
|
### Community projects
|
||||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||||
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
|
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) - terminal UI for managing NetBird peers, routes, and settings
|
||||||
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) — terminal UI for managing NetBird peers, routes, and settings
|
- [caddy-netbird](https://github.com/lixmal/caddy-netbird) - Caddy plugin that embeds a NetBird client for proxying HTTP and TCP/UDP traffic through NetBird networks
|
||||||
|
|
||||||
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
||||||
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
||||||
|
|
||||||
### Support acknowledgement
|
### Support acknowledgement
|
||||||
|
|
||||||
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by The Federal Ministry of Education and Research of The Federal Republic of Germany. Together with [CISPA Helmholtz Center for Information Security](https://cispa.de/en) NetBird brings the security best practices and simplicity to private networking.
|
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by the Federal Ministry of Education and Research of the Federal Republic of Germany. Together with the [CISPA Helmholtz Center for Information Security](https://cispa.de/en), NetBird brings security best practices and simplicity to private networking.
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
### Testimonials
|
### Acknowledgements
|
||||||
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g., by giving a star or a contribution).
|
We build on open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE](https://github.com/pion/ice), and [Rosenpass](https://rosenpass.eu). We greatly appreciate the work these projects are doing, and we'd love it if you could support them too (e.g., by starring or contributing).
|
||||||
|
|
||||||
### Legal
|
### Legal
|
||||||
This repository is licensed under BSD-3-Clause license that applies to all parts of the repository except for the directories management/, signal/ and relay/.
|
This repository is licensed under the BSD-3-Clause license, which applies to all parts of the repository except for the directories management/, signal/ and relay/.
|
||||||
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
|
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
|
||||||
|
|
||||||
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
|
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||||
|
|
||||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
|
|
||||||
@@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
|
iv, _ := validator.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
|
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
@@ -84,6 +85,12 @@ type Options struct {
|
|||||||
DisableIPv6 bool
|
DisableIPv6 bool
|
||||||
// BlockInbound blocks all inbound connections from peers
|
// BlockInbound blocks all inbound connections from peers
|
||||||
BlockInbound bool
|
BlockInbound bool
|
||||||
|
// BlockLANAccess blocks the embedded peer from reaching the host's
|
||||||
|
// LAN (RFC 1918, link-local, loopback) when it's used as a routing
|
||||||
|
// peer. Mirrors profilemanager.ConfigInput.BlockLANAccess. Useful
|
||||||
|
// when the embedded client must never act as a stepping stone into
|
||||||
|
// the host's local network (e.g. the proxy's overlay peer).
|
||||||
|
BlockLANAccess bool
|
||||||
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
|
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
|
||||||
WireguardPort *int
|
WireguardPort *int
|
||||||
// MTU is the MTU for the tunnel interface.
|
// MTU is the MTU for the tunnel interface.
|
||||||
@@ -94,6 +101,26 @@ type Options struct {
|
|||||||
MTU *uint16
|
MTU *uint16
|
||||||
// DNSLabels defines additional DNS labels configured in the peer.
|
// DNSLabels defines additional DNS labels configured in the peer.
|
||||||
DNSLabels []string
|
DNSLabels []string
|
||||||
|
// Performance configures the tunnel's buffer pool cap and batch size.
|
||||||
|
Performance Performance
|
||||||
|
}
|
||||||
|
|
||||||
|
// Performance configures the embedded client's tunnel memory/throughput knobs.
|
||||||
|
//
|
||||||
|
// These settings are process-global: any non-nil field also becomes the
|
||||||
|
// default for Clients constructed by later embed.New calls in the same
|
||||||
|
// process. Nil fields are ignored.
|
||||||
|
type Performance struct {
|
||||||
|
// PreallocatedBuffersPerPool caps the per-tunnel buffer pool. Zero
|
||||||
|
// leaves the pool unbounded. Lower values trade throughput for a
|
||||||
|
// tighter memory ceiling. May also be changed on a running Client via
|
||||||
|
// Client.SetPerformance, provided this field was nonzero at construction.
|
||||||
|
PreallocatedBuffersPerPool *uint32
|
||||||
|
// MaxBatchSize overrides the number of packets the tunnel reads or
|
||||||
|
// writes per syscall, which also bounds eager buffer allocation per
|
||||||
|
// worker. Zero uses the platform default. Applied at construction
|
||||||
|
// only; ignored by Client.SetPerformance.
|
||||||
|
MaxBatchSize *uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateCredentials checks that exactly one credential type is provided
|
// validateCredentials checks that exactly one credential type is provided
|
||||||
@@ -175,6 +202,7 @@ func New(opts Options) (*Client, error) {
|
|||||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||||
DisableIPv6: &opts.DisableIPv6,
|
DisableIPv6: &opts.DisableIPv6,
|
||||||
BlockInbound: &opts.BlockInbound,
|
BlockInbound: &opts.BlockInbound,
|
||||||
|
BlockLANAccess: &opts.BlockLANAccess,
|
||||||
WireguardPort: opts.WireguardPort,
|
WireguardPort: opts.WireguardPort,
|
||||||
MTU: opts.MTU,
|
MTU: opts.MTU,
|
||||||
DNSLabels: parsedLabels,
|
DNSLabels: parsedLabels,
|
||||||
@@ -192,6 +220,13 @@ func New(opts Options) (*Client, error) {
|
|||||||
config.PrivateKey = opts.PrivateKey
|
config.PrivateKey = opts.PrivateKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if opts.Performance.PreallocatedBuffersPerPool != nil {
|
||||||
|
wgdevice.SetPreallocatedBuffersPerPool(*opts.Performance.PreallocatedBuffersPerPool)
|
||||||
|
}
|
||||||
|
if opts.Performance.MaxBatchSize != nil {
|
||||||
|
wgdevice.SetMaxBatchSizeOverride(*opts.Performance.MaxBatchSize)
|
||||||
|
}
|
||||||
|
|
||||||
return &Client{
|
return &Client{
|
||||||
deviceName: opts.DeviceName,
|
deviceName: opts.DeviceName,
|
||||||
setupKey: opts.SetupKey,
|
setupKey: opts.SetupKey,
|
||||||
@@ -405,6 +440,21 @@ func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession,
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IdentityForIP looks up a remote peer by its tunnel IP using the
|
||||||
|
// embedded client's status recorder. Returns the peer's WireGuard public
|
||||||
|
// key and FQDN. ok=false means the IP isn't in this client's peer
|
||||||
|
// roster — callers should treat that as "unknown peer".
|
||||||
|
func (c *Client) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
|
||||||
|
if !ip.IsValid() || c.recorder == nil {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
state, found := c.recorder.PeerStateByIP(ip.String())
|
||||||
|
if !found {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
return state.PubKey, state.FQDN, true
|
||||||
|
}
|
||||||
|
|
||||||
// Status returns the current status of the client.
|
// Status returns the current status of the client.
|
||||||
func (c *Client) Status() (peer.FullStatus, error) {
|
func (c *Client) Status() (peer.FullStatus, error) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
@@ -473,6 +523,25 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
|
|||||||
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetPerformance retunes a running Client. Only PreallocatedBuffersPerPool
|
||||||
|
// takes effect, and only when it was nonzero at construction;
|
||||||
|
// MaxBatchSize is construction-only and returns an error if set here.
|
||||||
|
//
|
||||||
|
// Returns ErrClientNotStarted / ErrEngineNotStarted if the Client is not
|
||||||
|
// running yet.
|
||||||
|
func (c *Client) SetPerformance(t Performance) error {
|
||||||
|
if t.MaxBatchSize != nil {
|
||||||
|
return errors.New("MaxBatchSize is construction-only and cannot be changed at runtime")
|
||||||
|
}
|
||||||
|
engine, err := c.getEngine()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return engine.SetPerformance(internal.Performance{
|
||||||
|
PreallocatedBuffersPerPool: t.PreallocatedBuffersPerPool,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// StartCapture begins capturing packets on this client's tunnel device.
|
// StartCapture begins capturing packets on this client's tunnel device.
|
||||||
// Only one capture can be active at a time; starting a new one stops the previous.
|
// Only one capture can be active at a time; starting a new one stops the previous.
|
||||||
// Call StopCapture (or CaptureSession.Stop) to end it.
|
// Call StopCapture (or CaptureSession.Stop) to end it.
|
||||||
|
|||||||
@@ -52,9 +52,10 @@ func (m *externalChainMonitor) start() {
|
|||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
m.cancel = cancel
|
m.cancel = cancel
|
||||||
m.done = make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
m.done = done
|
||||||
|
|
||||||
go m.run(ctx)
|
go m.run(ctx, done)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *externalChainMonitor) stop() {
|
func (m *externalChainMonitor) stop() {
|
||||||
@@ -72,8 +73,8 @@ func (m *externalChainMonitor) stop() {
|
|||||||
<-done
|
<-done
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *externalChainMonitor) run(ctx context.Context) {
|
func (m *externalChainMonitor) run(ctx context.Context, done chan struct{}) {
|
||||||
defer close(m.done)
|
defer close(done)
|
||||||
|
|
||||||
bo := &backoff.ExponentialBackOff{
|
bo := &backoff.ExponentialBackOff{
|
||||||
InitialInterval: externalMonitorInitInterval,
|
InitialInterval: externalMonitorInitInterval,
|
||||||
|
|||||||
@@ -116,7 +116,6 @@ func (c *ConnectClient) RunOniOS(
|
|||||||
fileDescriptor int32,
|
fileDescriptor int32,
|
||||||
networkChangeListener listener.NetworkChangeListener,
|
networkChangeListener listener.NetworkChangeListener,
|
||||||
dnsManager dns.IosDnsManager,
|
dnsManager dns.IosDnsManager,
|
||||||
dnsAddresses []netip.AddrPort,
|
|
||||||
stateFilePath string,
|
stateFilePath string,
|
||||||
) error {
|
) error {
|
||||||
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
||||||
@@ -126,7 +125,6 @@ func (c *ConnectClient) RunOniOS(
|
|||||||
FileDescriptor: fileDescriptor,
|
FileDescriptor: fileDescriptor,
|
||||||
NetworkChangeListener: networkChangeListener,
|
NetworkChangeListener: networkChangeListener,
|
||||||
DnsManager: dnsManager,
|
DnsManager: dnsManager,
|
||||||
HostDNSAddresses: dnsAddresses,
|
|
||||||
StateFilePath: stateFilePath,
|
StateFilePath: stateFilePath,
|
||||||
}
|
}
|
||||||
return c.run(mobileDependency, nil, "")
|
return c.run(mobileDependency, nil, "")
|
||||||
|
|||||||
@@ -339,8 +339,7 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
|||||||
case entry.Pattern == ".":
|
case entry.Pattern == ".":
|
||||||
return true
|
return true
|
||||||
case entry.IsWildcard:
|
case entry.IsWildcard:
|
||||||
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
|
return strings.HasSuffix(qname, "."+entry.Pattern)
|
||||||
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
|
|
||||||
default:
|
default:
|
||||||
// For non-wildcard patterns:
|
// For non-wildcard patterns:
|
||||||
// If handler wants subdomain matching, allow suffix match
|
// If handler wants subdomain matching, allow suffix match
|
||||||
|
|||||||
@@ -164,6 +164,54 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
|||||||
matchSubdomains: true,
|
matchSubdomains: true,
|
||||||
shouldMatch: true,
|
shouldMatch: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard label-boundary mismatch (suffix overlap)",
|
||||||
|
handlerDomain: "*.b.test.",
|
||||||
|
queryDomain: "x.ab.test.",
|
||||||
|
isWildcard: true,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard label-boundary match",
|
||||||
|
handlerDomain: "*.b.test.",
|
||||||
|
queryDomain: "x.b.test.",
|
||||||
|
isWildcard: true,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard multi-label match",
|
||||||
|
handlerDomain: "*.b.test.",
|
||||||
|
queryDomain: "x.y.b.test.",
|
||||||
|
isWildcard: true,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard no match on multi-label apex",
|
||||||
|
handlerDomain: "*.b.test.",
|
||||||
|
queryDomain: "b.test.",
|
||||||
|
isWildcard: true,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard no match on unrelated suffix containment",
|
||||||
|
handlerDomain: "*.example.com.",
|
||||||
|
queryDomain: "notexample.com.",
|
||||||
|
isWildcard: true,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard accepts pattern registered without trailing dot",
|
||||||
|
handlerDomain: "*.b.test",
|
||||||
|
queryDomain: "x.b.test.",
|
||||||
|
isWildcard: true,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -273,6 +321,19 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
expectedCalls: 1,
|
expectedCalls: 1,
|
||||||
expectedHandler: 2, // highest priority matching handler should be called
|
expectedHandler: 2, // highest priority matching handler should be called
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "overlapping wildcard suffixes route to correct handler",
|
||||||
|
handlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{pattern: "*.b.test.", priority: nbdns.PriorityDNSRoute},
|
||||||
|
{pattern: "*.ab.test.", priority: nbdns.PriorityDNSRoute},
|
||||||
|
},
|
||||||
|
queryDomain: "app.ab.test.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
expectedHandler: 1,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "root zone with specific domain",
|
name: "root zone with specific domain",
|
||||||
handlers: []struct {
|
handlers: []struct {
|
||||||
|
|||||||
@@ -16,6 +16,10 @@ type hostManager interface {
|
|||||||
restoreHostDNS() error
|
restoreHostDNS() error
|
||||||
supportCustomPort() bool
|
supportCustomPort() bool
|
||||||
string() string
|
string() string
|
||||||
|
// getOriginalNameservers returns the OS-side resolvers used as PriorityFallback
|
||||||
|
// upstreams: pre-takeover snapshots on desktop, the OS-pushed list on Android,
|
||||||
|
// hardcoded Quad9 on iOS, nil for noop / mock.
|
||||||
|
getOriginalNameservers() []netip.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
type SystemDNSSettings struct {
|
type SystemDNSSettings struct {
|
||||||
@@ -131,3 +135,11 @@ func (n noopHostConfigurator) supportCustomPort() bool {
|
|||||||
func (n noopHostConfigurator) string() string {
|
func (n noopHostConfigurator) string() string {
|
||||||
return "noop"
|
return "noop"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n noopHostConfigurator) getOriginalNameservers() []netip.Addr {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockHostConfigurator) getOriginalNameservers() []netip.Addr {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,14 +1,20 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// androidHostManager is a noop on the OS side (Android's VPN service handles
|
||||||
|
// DNS for us) but tracks the OS-reported resolver list pushed via
|
||||||
|
// OnUpdatedHostDNSServer so it can serve as the fallback nameserver source.
|
||||||
type androidHostManager struct {
|
type androidHostManager struct {
|
||||||
|
holder *hostsDNSHolder
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager() (*androidHostManager, error) {
|
func newHostManager(holder *hostsDNSHolder) (*androidHostManager, error) {
|
||||||
return &androidHostManager{}, nil
|
return &androidHostManager{holder: holder}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
|
func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
|
||||||
@@ -26,3 +32,12 @@ func (a androidHostManager) supportCustomPort() bool {
|
|||||||
func (a androidHostManager) string() string {
|
func (a androidHostManager) string() string {
|
||||||
return "none"
|
return "none"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a androidHostManager) getOriginalNameservers() []netip.Addr {
|
||||||
|
hosts := a.holder.get()
|
||||||
|
out := make([]netip.Addr, 0, len(hosts))
|
||||||
|
for ap := range hosts {
|
||||||
|
out = append(out, ap.Addr())
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -20,6 +21,14 @@ func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a iosHostManager) getOriginalNameservers() []netip.Addr {
|
||||||
|
// Quad9 v4+v6: 9.9.9.9, 2620:fe::fe.
|
||||||
|
return []netip.Addr{
|
||||||
|
netip.AddrFrom4([4]byte{9, 9, 9, 9}),
|
||||||
|
netip.AddrFrom16([16]byte{0x26, 0x20, 0x00, 0xfe, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xfe}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error {
|
func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||||
jsonData, err := json.Marshal(config)
|
jsonData, err := json.Marshal(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
@@ -44,9 +45,11 @@ const (
|
|||||||
|
|
||||||
nrptMaxDomainsPerRule = 50
|
nrptMaxDomainsPerRule = 50
|
||||||
|
|
||||||
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||||
interfaceConfigNameServerKey = "NameServer"
|
interfaceConfigPathV6 = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces`
|
||||||
interfaceConfigSearchListKey = "SearchList"
|
interfaceConfigNameServerKey = "NameServer"
|
||||||
|
interfaceConfigDhcpNameSrvKey = "DhcpNameServer"
|
||||||
|
interfaceConfigSearchListKey = "SearchList"
|
||||||
|
|
||||||
// Network interface DNS registration settings
|
// Network interface DNS registration settings
|
||||||
disableDynamicUpdateKey = "DisableDynamicUpdate"
|
disableDynamicUpdateKey = "DisableDynamicUpdate"
|
||||||
@@ -67,10 +70,11 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type registryConfigurator struct {
|
type registryConfigurator struct {
|
||||||
guid string
|
guid string
|
||||||
routingAll bool
|
routingAll bool
|
||||||
gpo bool
|
gpo bool
|
||||||
nrptEntryCount int
|
nrptEntryCount int
|
||||||
|
origNameservers []netip.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||||
@@ -94,6 +98,17 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
|||||||
gpo: useGPO,
|
gpo: useGPO,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
origNameservers, err := configurator.captureOriginalNameservers()
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
log.Warnf("capture original nameservers from non-WG adapters: %v", err)
|
||||||
|
case len(origNameservers) == 0:
|
||||||
|
log.Warnf("no original nameservers captured from non-WG adapters; DNS fallback will be empty")
|
||||||
|
default:
|
||||||
|
log.Debugf("captured %d original nameservers from non-WG adapters: %v", len(origNameservers), origNameservers)
|
||||||
|
}
|
||||||
|
configurator.origNameservers = origNameservers
|
||||||
|
|
||||||
if err := configurator.configureInterface(); err != nil {
|
if err := configurator.configureInterface(); err != nil {
|
||||||
log.Errorf("failed to configure interface settings: %v", err)
|
log.Errorf("failed to configure interface settings: %v", err)
|
||||||
}
|
}
|
||||||
@@ -101,6 +116,98 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
|||||||
return configurator, nil
|
return configurator, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// captureOriginalNameservers reads DNS addresses from every Tcpip(6) interface
|
||||||
|
// registry key except the WG adapter. v4 and v6 servers live in separate
|
||||||
|
// hives (Tcpip vs Tcpip6) keyed by the same interface GUID.
|
||||||
|
func (r *registryConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
|
||||||
|
seen := make(map[netip.Addr]struct{})
|
||||||
|
var out []netip.Addr
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, root := range []string{interfaceConfigPath, interfaceConfigPathV6} {
|
||||||
|
addrs, err := r.captureFromTcpipRoot(root)
|
||||||
|
if err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("%s: %w", root, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, addr := range addrs {
|
||||||
|
if _, dup := seen[addr]; dup {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[addr] = struct{}{}
|
||||||
|
out = append(out, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *registryConfigurator) captureFromTcpipRoot(rootPath string) ([]netip.Addr, error) {
|
||||||
|
root, err := registry.OpenKey(registry.LOCAL_MACHINE, rootPath, registry.READ)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("open key: %w", err)
|
||||||
|
}
|
||||||
|
defer closer(root)
|
||||||
|
|
||||||
|
guids, err := root.ReadSubKeyNames(-1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read subkeys: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var out []netip.Addr
|
||||||
|
for _, guid := range guids {
|
||||||
|
if strings.EqualFold(guid, r.guid) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, readInterfaceNameservers(rootPath, guid)...)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readInterfaceNameservers(rootPath, guid string) []netip.Addr {
|
||||||
|
keyPath := rootPath + "\\" + guid
|
||||||
|
k, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer closer(k)
|
||||||
|
|
||||||
|
// Static NameServer wins over DhcpNameServer for actual resolution.
|
||||||
|
for _, name := range []string{interfaceConfigNameServerKey, interfaceConfigDhcpNameSrvKey} {
|
||||||
|
raw, _, err := k.GetStringValue(name)
|
||||||
|
if err != nil || raw == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if out := parseRegistryNameservers(raw); len(out) > 0 {
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseRegistryNameservers(raw string) []netip.Addr {
|
||||||
|
var out []netip.Addr
|
||||||
|
for _, field := range strings.FieldsFunc(raw, func(r rune) bool { return r == ',' || r == ' ' || r == '\t' }) {
|
||||||
|
addr, err := netip.ParseAddr(strings.TrimSpace(field))
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addr = addr.Unmap()
|
||||||
|
if !addr.IsValid() || addr.IsUnspecified() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Drop unzoned link-local: not routable without a scope id. If
|
||||||
|
// the user wrote "fe80::1%eth0" ParseAddr preserves the zone.
|
||||||
|
if addr.IsLinkLocalUnicast() && addr.Zone() == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, addr)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *registryConfigurator) getOriginalNameservers() []netip.Addr {
|
||||||
|
return slices.Clone(r.origNameservers)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) supportCustomPort() bool {
|
func (r *registryConfigurator) supportCustomPort() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ func (h *hostsDNSHolder) set(list []netip.AddrPort) {
|
|||||||
h.mutex.Unlock()
|
h.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:unused
|
||||||
func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} {
|
func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} {
|
||||||
h.mutex.RLock()
|
h.mutex.RLock()
|
||||||
l := h.unprotectedDNSList
|
l := h.unprotectedDNSList
|
||||||
|
|||||||
@@ -26,6 +26,19 @@ type resolver interface {
|
|||||||
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PeerConnectivity reports whether a tunnel IP belongs to a peer the
|
||||||
|
// client knows about and whether that peer is currently connected. The
|
||||||
|
// local resolver uses this to suppress A/AAAA answers whose RDATA points
|
||||||
|
// at a disconnected peer (typical case: a synthesized private-service
|
||||||
|
// record pointing at an embedded proxy peer that just went offline).
|
||||||
|
//
|
||||||
|
// known=false means the IP isn't in the local peerstore at all — the
|
||||||
|
// record is left alone (it points at something outside our mesh, e.g.
|
||||||
|
// a non-peer upstream).
|
||||||
|
type PeerConnectivity interface {
|
||||||
|
IsConnectedByIP(ip string) (known, connected bool)
|
||||||
|
}
|
||||||
|
|
||||||
type Resolver struct {
|
type Resolver struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
records map[dns.Question][]dns.RR
|
records map[dns.Question][]dns.RR
|
||||||
@@ -33,6 +46,11 @@ type Resolver struct {
|
|||||||
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
|
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
|
||||||
zones map[domain.Domain]bool
|
zones map[domain.Domain]bool
|
||||||
resolver resolver
|
resolver resolver
|
||||||
|
// peerConn, when non-nil, is consulted on every A/AAAA answer to
|
||||||
|
// drop records pointing at disconnected peers. nil disables the
|
||||||
|
// filter and preserves the legacy "return whatever is registered"
|
||||||
|
// behaviour for callers that never wire a status source.
|
||||||
|
peerConn PeerConnectivity
|
||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
@@ -49,6 +67,15 @@ func NewResolver() *Resolver {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetPeerConnectivity wires the per-IP connectivity check used to filter
|
||||||
|
// out A/AAAA answers pointing at disconnected peers. Pass nil to disable.
|
||||||
|
// Safe to call multiple times; the latest value wins.
|
||||||
|
func (d *Resolver) SetPeerConnectivity(p PeerConnectivity) {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
d.peerConn = p
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Resolver) MatchSubdomains() bool {
|
func (d *Resolver) MatchSubdomains() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -76,8 +103,6 @@ func (d *Resolver) ID() types.HandlerID {
|
|||||||
return "local-resolver"
|
return "local-resolver"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Resolver) ProbeAvailability(context.Context) {}
|
|
||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
logger := log.WithFields(log.Fields{
|
logger := log.WithFields(log.Fields{
|
||||||
@@ -97,6 +122,7 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
replyMessage.RecursionAvailable = true
|
replyMessage.RecursionAvailable = true
|
||||||
|
|
||||||
result := d.lookupRecords(logger, question)
|
result := d.lookupRecords(logger, question)
|
||||||
|
result.records = d.filterDisconnectedPeerAnswers(logger, question, result.records)
|
||||||
replyMessage.Authoritative = !result.hasExternalData
|
replyMessage.Authoritative = !result.hasExternalData
|
||||||
replyMessage.Answer = result.records
|
replyMessage.Answer = result.records
|
||||||
replyMessage.Rcode = d.determineRcode(question, result)
|
replyMessage.Rcode = d.determineRcode(question, result)
|
||||||
@@ -438,6 +464,78 @@ func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// filterDisconnectedPeerAnswers drops A/AAAA records whose RDATA matches
|
||||||
|
// a known but disconnected peer. The synthesized private-service zones
|
||||||
|
// emit one A record per connected proxy peer in a cluster; when a peer
|
||||||
|
// goes offline, the server-side refresh removes the record from the
|
||||||
|
// next netmap, but the client may still hold the previous netmap for a
|
||||||
|
// short window. This filter is the local belt to that braces — even on
|
||||||
|
// the stale netmap, the resolver hides the offline target.
|
||||||
|
//
|
||||||
|
// Records pointing at unknown IPs (outside the local peerstore, e.g.
|
||||||
|
// non-mesh upstreams) are never dropped. Non-A/AAAA records pass
|
||||||
|
// through untouched.
|
||||||
|
//
|
||||||
|
// Escape hatch: if filtering would leave the answer empty AND at least
|
||||||
|
// one record was filtered, the original list is returned. Better to
|
||||||
|
// hand the client a record that may not respond than NXDOMAIN it
|
||||||
|
// completely when every proxy peer is offline (the upstream may still
|
||||||
|
// be reachable some other way, or the peerstore may be stale).
|
||||||
|
func (d *Resolver) filterDisconnectedPeerAnswers(logger *log.Entry, question dns.Question, records []dns.RR) []dns.RR {
|
||||||
|
if len(records) == 0 {
|
||||||
|
return records
|
||||||
|
}
|
||||||
|
d.mu.RLock()
|
||||||
|
checker := d.peerConn
|
||||||
|
d.mu.RUnlock()
|
||||||
|
if checker == nil {
|
||||||
|
return records
|
||||||
|
}
|
||||||
|
|
||||||
|
kept := make([]dns.RR, 0, len(records))
|
||||||
|
var dropped int
|
||||||
|
for _, rr := range records {
|
||||||
|
ip := extractRecordIP(rr)
|
||||||
|
if ip == "" {
|
||||||
|
kept = append(kept, rr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
known, connected := checker.IsConnectedByIP(ip)
|
||||||
|
if known && !connected {
|
||||||
|
dropped++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
kept = append(kept, rr)
|
||||||
|
}
|
||||||
|
if dropped == 0 {
|
||||||
|
return records
|
||||||
|
}
|
||||||
|
if len(kept) == 0 {
|
||||||
|
logger.Debugf("all %d answers for %s point at disconnected peers; returning the original list", dropped, question.Name)
|
||||||
|
return records
|
||||||
|
}
|
||||||
|
logger.Tracef("dropped %d disconnected-peer answer(s) for %s, returning %d", dropped, question.Name, len(kept))
|
||||||
|
return kept
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractRecordIP returns the dotted-decimal / colon-hex IP carried by
|
||||||
|
// an A or AAAA record, or "" for any other record type.
|
||||||
|
func extractRecordIP(rr dns.RR) string {
|
||||||
|
switch r := rr.(type) {
|
||||||
|
case *dns.A:
|
||||||
|
if r.A == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return r.A.String()
|
||||||
|
case *dns.AAAA:
|
||||||
|
if r.AAAA == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return r.AAAA.String()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// Update replaces all zones and their records
|
// Update replaces all zones and their records
|
||||||
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
|
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
|
||||||
d.mu.Lock()
|
d.mu.Lock()
|
||||||
|
|||||||
@@ -30,6 +30,21 @@ func (m *mockResolver) LookupNetIP(ctx context.Context, network, host string) ([
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mockPeerConnectivity returns canned (known, connected) results per IP.
|
||||||
|
// Used by the disconnected-peer filter tests below. IPs not in the map
|
||||||
|
// are reported as unknown so the filter leaves them alone.
|
||||||
|
type mockPeerConnectivity struct {
|
||||||
|
byIP map[string]struct{ known, connected bool }
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) {
|
||||||
|
v, ok := m.byIP[ip]
|
||||||
|
if !ok {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
return v.known, v.connected
|
||||||
|
}
|
||||||
|
|
||||||
func TestLocalResolver_ServeDNS(t *testing.T) {
|
func TestLocalResolver_ServeDNS(t *testing.T) {
|
||||||
recordA := nbdns.SimpleRecord{
|
recordA := nbdns.SimpleRecord{
|
||||||
Name: "peera.netbird.cloud.",
|
Name: "peera.netbird.cloud.",
|
||||||
@@ -2652,3 +2667,114 @@ func BenchmarkIsInManagedZone_ManyZones(b *testing.B) {
|
|||||||
resolver.isInManagedZone(qname)
|
resolver.isInManagedZone(qname)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestLocalResolver_FilterDisconnectedPeerAnswers verifies the
|
||||||
|
// connectivity-aware filtering layered on top of lookupRecords:
|
||||||
|
// when an A record's IP belongs to a known peer that's disconnected,
|
||||||
|
// the record is dropped from the answer. Records for unknown IPs pass
|
||||||
|
// through. If filtering would empty the answer entirely and at least
|
||||||
|
// one record was dropped, the original list is restored (escape hatch
|
||||||
|
// for the "all proxies offline" case).
|
||||||
|
func TestLocalResolver_FilterDisconnectedPeerAnswers(t *testing.T) {
|
||||||
|
zone := "svc.cluster.netbird."
|
||||||
|
connectedRec := nbdns.SimpleRecord{
|
||||||
|
Name: zone,
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 5,
|
||||||
|
RData: "100.64.0.10",
|
||||||
|
}
|
||||||
|
disconnectedRec := nbdns.SimpleRecord{
|
||||||
|
Name: zone,
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 5,
|
||||||
|
RData: "100.64.0.11",
|
||||||
|
}
|
||||||
|
unknownRec := nbdns.SimpleRecord{
|
||||||
|
Name: zone,
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 5,
|
||||||
|
RData: "203.0.113.5",
|
||||||
|
}
|
||||||
|
|
||||||
|
type ipState struct{ known, connected bool }
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
records []nbdns.SimpleRecord
|
||||||
|
connByIP map[string]ipState
|
||||||
|
wantInOrder []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "drops disconnected peer, keeps connected",
|
||||||
|
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
|
||||||
|
connByIP: map[string]ipState{
|
||||||
|
"100.64.0.10": {known: true, connected: true},
|
||||||
|
"100.64.0.11": {known: true, connected: false},
|
||||||
|
},
|
||||||
|
wantInOrder: []string{"100.64.0.10"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown IPs pass through untouched",
|
||||||
|
records: []nbdns.SimpleRecord{unknownRec, disconnectedRec},
|
||||||
|
connByIP: map[string]ipState{
|
||||||
|
"100.64.0.11": {known: true, connected: false},
|
||||||
|
},
|
||||||
|
wantInOrder: []string{"203.0.113.5"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all disconnected falls back to original list",
|
||||||
|
records: []nbdns.SimpleRecord{disconnectedRec, connectedRec},
|
||||||
|
connByIP: map[string]ipState{
|
||||||
|
"100.64.0.10": {known: true, connected: false},
|
||||||
|
"100.64.0.11": {known: true, connected: false},
|
||||||
|
},
|
||||||
|
wantInOrder: []string{"100.64.0.11", "100.64.0.10"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no checker wired returns all records",
|
||||||
|
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
|
||||||
|
connByIP: nil,
|
||||||
|
wantInOrder: []string{"100.64.0.10", "100.64.0.11"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
if tc.connByIP != nil {
|
||||||
|
cm := mockPeerConnectivity{byIP: make(map[string]struct{ known, connected bool }, len(tc.connByIP))}
|
||||||
|
for ip, st := range tc.connByIP {
|
||||||
|
cm.byIP[ip] = struct{ known, connected bool }{st.known, st.connected}
|
||||||
|
}
|
||||||
|
resolver.SetPeerConnectivity(cm)
|
||||||
|
}
|
||||||
|
resolver.Update([]nbdns.CustomZone{{
|
||||||
|
Domain: strings.TrimSuffix(zone, "."),
|
||||||
|
Records: tc.records,
|
||||||
|
NonAuthoritative: true,
|
||||||
|
}})
|
||||||
|
|
||||||
|
var got *dns.Msg
|
||||||
|
writer := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
got = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
req := new(dns.Msg).SetQuestion(zone, dns.TypeA)
|
||||||
|
resolver.ServeDNS(writer, req)
|
||||||
|
|
||||||
|
require.NotNil(t, got, "resolver must produce a response")
|
||||||
|
require.Len(t, got.Answer, len(tc.wantInOrder),
|
||||||
|
"answer count must match expected: %v", tc.wantInOrder)
|
||||||
|
for i, want := range tc.wantInOrder {
|
||||||
|
a, ok := got.Answer[i].(*dns.A)
|
||||||
|
require.True(t, ok, "answer[%d] must be an A record", i)
|
||||||
|
assert.Equal(t, want, a.A.String(),
|
||||||
|
"answer[%d] expected %s got %s", i, want, a.A.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -70,10 +71,6 @@ func (m *MockServer) SearchDomains() []string {
|
|||||||
return make([]string, 0)
|
return make([]string, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
|
|
||||||
func (m *MockServer) ProbeAvailability() {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
||||||
if m.UpdateServerConfigFunc != nil {
|
if m.UpdateServerConfigFunc != nil {
|
||||||
return m.UpdateServerConfigFunc(domains)
|
return m.UpdateServerConfigFunc(domains)
|
||||||
@@ -85,8 +82,8 @@ func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetRouteChecker mock implementation of SetRouteChecker from Server interface
|
// SetRouteSources mock implementation of SetRouteSources from Server interface
|
||||||
func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
|
func (m *MockServer) SetRouteSources(selected, active func() route.HAMap) {
|
||||||
// Mock implementation - no-op
|
// Mock implementation - no-op
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -32,6 +33,15 @@ const (
|
|||||||
networkManagerDbusDeviceGetAppliedConnectionMethod = networkManagerDbusDeviceInterface + ".GetAppliedConnection"
|
networkManagerDbusDeviceGetAppliedConnectionMethod = networkManagerDbusDeviceInterface + ".GetAppliedConnection"
|
||||||
networkManagerDbusDeviceReapplyMethod = networkManagerDbusDeviceInterface + ".Reapply"
|
networkManagerDbusDeviceReapplyMethod = networkManagerDbusDeviceInterface + ".Reapply"
|
||||||
networkManagerDbusDeviceDeleteMethod = networkManagerDbusDeviceInterface + ".Delete"
|
networkManagerDbusDeviceDeleteMethod = networkManagerDbusDeviceInterface + ".Delete"
|
||||||
|
networkManagerDbusDeviceIp4ConfigProperty = networkManagerDbusDeviceInterface + ".Ip4Config"
|
||||||
|
networkManagerDbusDeviceIp6ConfigProperty = networkManagerDbusDeviceInterface + ".Ip6Config"
|
||||||
|
networkManagerDbusDeviceIfaceProperty = networkManagerDbusDeviceInterface + ".Interface"
|
||||||
|
networkManagerDbusGetDevicesMethod = networkManagerDest + ".GetDevices"
|
||||||
|
networkManagerDbusIp4ConfigInterface = "org.freedesktop.NetworkManager.IP4Config"
|
||||||
|
networkManagerDbusIp6ConfigInterface = "org.freedesktop.NetworkManager.IP6Config"
|
||||||
|
networkManagerDbusIp4ConfigNameserverDataProperty = networkManagerDbusIp4ConfigInterface + ".NameserverData"
|
||||||
|
networkManagerDbusIp4ConfigNameserversProperty = networkManagerDbusIp4ConfigInterface + ".Nameservers"
|
||||||
|
networkManagerDbusIp6ConfigNameserversProperty = networkManagerDbusIp6ConfigInterface + ".Nameservers"
|
||||||
networkManagerDbusDefaultBehaviorFlag networkManagerConfigBehavior = 0
|
networkManagerDbusDefaultBehaviorFlag networkManagerConfigBehavior = 0
|
||||||
networkManagerDbusIPv4Key = "ipv4"
|
networkManagerDbusIPv4Key = "ipv4"
|
||||||
networkManagerDbusIPv6Key = "ipv6"
|
networkManagerDbusIPv6Key = "ipv6"
|
||||||
@@ -51,9 +61,10 @@ var supportedNetworkManagerVersionConstraints = []string{
|
|||||||
}
|
}
|
||||||
|
|
||||||
type networkManagerDbusConfigurator struct {
|
type networkManagerDbusConfigurator struct {
|
||||||
dbusLinkObject dbus.ObjectPath
|
dbusLinkObject dbus.ObjectPath
|
||||||
routingAll bool
|
routingAll bool
|
||||||
ifaceName string
|
ifaceName string
|
||||||
|
origNameservers []netip.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
// the types below are based on dbus specification, each field is mapped to a dbus type
|
// the types below are based on dbus specification, each field is mapped to a dbus type
|
||||||
@@ -92,10 +103,200 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusC
|
|||||||
|
|
||||||
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface)
|
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface)
|
||||||
|
|
||||||
return &networkManagerDbusConfigurator{
|
c := &networkManagerDbusConfigurator{
|
||||||
dbusLinkObject: dbus.ObjectPath(s),
|
dbusLinkObject: dbus.ObjectPath(s),
|
||||||
ifaceName: wgInterface,
|
ifaceName: wgInterface,
|
||||||
}, nil
|
}
|
||||||
|
|
||||||
|
origNameservers, err := c.captureOriginalNameservers()
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
log.Warnf("capture original nameservers from NetworkManager: %v", err)
|
||||||
|
case len(origNameservers) == 0:
|
||||||
|
log.Warnf("no original nameservers captured from non-WG NetworkManager devices; DNS fallback will be empty")
|
||||||
|
default:
|
||||||
|
log.Debugf("captured %d original nameservers from non-WG NetworkManager devices: %v", len(origNameservers), origNameservers)
|
||||||
|
}
|
||||||
|
c.origNameservers = origNameservers
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// captureOriginalNameservers reads DNS servers from every NM device's
|
||||||
|
// IP4Config / IP6Config except our WG device.
|
||||||
|
func (n *networkManagerDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
|
||||||
|
devices, err := networkManagerListDevices()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list devices: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[netip.Addr]struct{})
|
||||||
|
var out []netip.Addr
|
||||||
|
for _, dev := range devices {
|
||||||
|
if dev == n.dbusLinkObject {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ifaceName := readNetworkManagerDeviceInterface(dev)
|
||||||
|
for _, addr := range readNetworkManagerDeviceDNS(dev) {
|
||||||
|
addr = addr.Unmap()
|
||||||
|
if !addr.IsValid() || addr.IsUnspecified() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// IP6Config.Nameservers is a byte slice without zone info;
|
||||||
|
// reattach the device's interface name so a captured fe80::…
|
||||||
|
// stays routable.
|
||||||
|
if addr.IsLinkLocalUnicast() && ifaceName != "" {
|
||||||
|
addr = addr.WithZone(ifaceName)
|
||||||
|
}
|
||||||
|
if _, dup := seen[addr]; dup {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[addr] = struct{}{}
|
||||||
|
out = append(out, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readNetworkManagerDeviceInterface(devicePath dbus.ObjectPath) string {
|
||||||
|
obj, closeConn, err := getDbusObject(networkManagerDest, devicePath)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer closeConn()
|
||||||
|
v, err := obj.GetProperty(networkManagerDbusDeviceIfaceProperty)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
s, _ := v.Value().(string)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func networkManagerListDevices() ([]dbus.ObjectPath, error) {
|
||||||
|
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("dbus NetworkManager: %w", err)
|
||||||
|
}
|
||||||
|
defer closeConn()
|
||||||
|
var devs []dbus.ObjectPath
|
||||||
|
if err := obj.Call(networkManagerDbusGetDevicesMethod, dbusDefaultFlag).Store(&devs); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return devs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readNetworkManagerDeviceDNS(devicePath dbus.ObjectPath) []netip.Addr {
|
||||||
|
obj, closeConn, err := getDbusObject(networkManagerDest, devicePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer closeConn()
|
||||||
|
|
||||||
|
var out []netip.Addr
|
||||||
|
if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp4ConfigProperty); path != "" {
|
||||||
|
out = append(out, readIPv4ConfigDNS(path)...)
|
||||||
|
}
|
||||||
|
if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp6ConfigProperty); path != "" {
|
||||||
|
out = append(out, readIPv6ConfigDNS(path)...)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func readNetworkManagerConfigPath(obj dbus.BusObject, property string) dbus.ObjectPath {
|
||||||
|
v, err := obj.GetProperty(property)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
path, ok := v.Value().(dbus.ObjectPath)
|
||||||
|
if !ok || path == "/" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
func readIPv4ConfigDNS(path dbus.ObjectPath) []netip.Addr {
|
||||||
|
obj, closeConn, err := getDbusObject(networkManagerDest, path)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer closeConn()
|
||||||
|
|
||||||
|
// NameserverData (NM 1.13+) carries strings; older NMs only expose the
|
||||||
|
// legacy uint32 Nameservers property.
|
||||||
|
if out := readIPv4NameserverData(obj); len(out) > 0 {
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
return readIPv4LegacyNameservers(obj)
|
||||||
|
}
|
||||||
|
|
||||||
|
func readIPv4NameserverData(obj dbus.BusObject) []netip.Addr {
|
||||||
|
v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserverDataProperty)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
entries, ok := v.Value().([]map[string]dbus.Variant)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var out []netip.Addr
|
||||||
|
for _, entry := range entries {
|
||||||
|
addrVar, ok := entry["address"]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s, ok := addrVar.Value().(string)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if a, err := netip.ParseAddr(s); err == nil {
|
||||||
|
out = append(out, a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func readIPv4LegacyNameservers(obj dbus.BusObject) []netip.Addr {
|
||||||
|
v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserversProperty)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
raw, ok := v.Value().([]uint32)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]netip.Addr, 0, len(raw))
|
||||||
|
for _, n := range raw {
|
||||||
|
var b [4]byte
|
||||||
|
binary.LittleEndian.PutUint32(b[:], n)
|
||||||
|
out = append(out, netip.AddrFrom4(b))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func readIPv6ConfigDNS(path dbus.ObjectPath) []netip.Addr {
|
||||||
|
obj, closeConn, err := getDbusObject(networkManagerDest, path)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer closeConn()
|
||||||
|
v, err := obj.GetProperty(networkManagerDbusIp6ConfigNameserversProperty)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
raw, ok := v.Value().([][]byte)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]netip.Addr, 0, len(raw))
|
||||||
|
for _, b := range raw {
|
||||||
|
if a, ok := netip.AddrFromSlice(b); ok {
|
||||||
|
out = append(out, a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *networkManagerDbusConfigurator) getOriginalNameservers() []netip.Addr {
|
||||||
|
return slices.Clone(n.origNameservers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *networkManagerDbusConfigurator) supportCustomPort() bool {
|
func (n *networkManagerDbusConfigurator) supportCustomPort() bool {
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
||||||
return newHostManager()
|
return newHostManager(s.hostsDNSHolder)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
@@ -31,8 +32,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -101,16 +104,17 @@ func init() {
|
|||||||
formatter.SetTextFormatter(log.StandardLogger())
|
formatter.SetTextFormatter(log.StandardLogger())
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
|
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||||
var srvs []netip.AddrPort
|
var srvs []netip.AddrPort
|
||||||
for _, srv := range servers {
|
for _, srv := range servers {
|
||||||
srvs = append(srvs, srv.AddrPort())
|
srvs = append(srvs, srv.AddrPort())
|
||||||
}
|
}
|
||||||
return &upstreamResolverBase{
|
u := &upstreamResolverBase{
|
||||||
domain: domain,
|
domain: domain.Domain(d),
|
||||||
upstreamServers: srvs,
|
cancel: func() {},
|
||||||
cancel: func() {},
|
|
||||||
}
|
}
|
||||||
|
u.addRace(srvs)
|
||||||
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateDNSServer(t *testing.T) {
|
func TestUpdateDNSServer(t *testing.T) {
|
||||||
@@ -653,74 +657,8 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
|
||||||
hostManager := &mockHostConfigurator{}
|
|
||||||
server := DefaultServer{
|
|
||||||
ctx: context.Background(),
|
|
||||||
service: NewServiceViaMemory(&mocWGIface{}),
|
|
||||||
localResolver: local.NewResolver(),
|
|
||||||
handlerChain: NewHandlerChain(),
|
|
||||||
hostManager: hostManager,
|
|
||||||
currentConfig: HostDNSConfig{
|
|
||||||
Domains: []DomainConfig{
|
|
||||||
{false, "domain0", false},
|
|
||||||
{false, "domain1", false},
|
|
||||||
{false, "domain2", false},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
statusRecorder: peer.NewRecorder("mgm"),
|
|
||||||
}
|
|
||||||
|
|
||||||
var domainsUpdate string
|
|
||||||
hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
|
|
||||||
domains := []string{}
|
|
||||||
for _, item := range config.Domains {
|
|
||||||
if item.Disabled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
domains = append(domains, item.Domain)
|
|
||||||
}
|
|
||||||
domainsUpdate = strings.Join(domains, ",")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
|
|
||||||
Domains: []string{"domain1"},
|
|
||||||
NameServers: []nbdns.NameServer{
|
|
||||||
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
|
|
||||||
},
|
|
||||||
}, nil, 0)
|
|
||||||
|
|
||||||
deactivate(nil)
|
|
||||||
expected := "domain0,domain2"
|
|
||||||
domains := []string{}
|
|
||||||
for _, item := range server.currentConfig.Domains {
|
|
||||||
if item.Disabled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
domains = append(domains, item.Domain)
|
|
||||||
}
|
|
||||||
got := strings.Join(domains, ",")
|
|
||||||
if expected != got {
|
|
||||||
t.Errorf("expected domains list: %q, got %q", expected, got)
|
|
||||||
}
|
|
||||||
|
|
||||||
reactivate()
|
|
||||||
expected = "domain0,domain1,domain2"
|
|
||||||
domains = []string{}
|
|
||||||
for _, item := range server.currentConfig.Domains {
|
|
||||||
if item.Disabled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
domains = append(domains, item.Domain)
|
|
||||||
}
|
|
||||||
got = strings.Join(domains, ",")
|
|
||||||
if expected != got {
|
|
||||||
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||||
|
skipUnlessAndroid(t)
|
||||||
wgIFace, err := createWgInterfaceWithBind(t)
|
wgIFace, err := createWgInterfaceWithBind(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("failed to initialize wg interface")
|
t.Fatal("failed to initialize wg interface")
|
||||||
@@ -748,6 +686,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSPermanent_updateUpstream(t *testing.T) {
|
func TestDNSPermanent_updateUpstream(t *testing.T) {
|
||||||
|
skipUnlessAndroid(t)
|
||||||
wgIFace, err := createWgInterfaceWithBind(t)
|
wgIFace, err := createWgInterfaceWithBind(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("failed to initialize wg interface")
|
t.Fatal("failed to initialize wg interface")
|
||||||
@@ -841,6 +780,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSPermanent_matchOnly(t *testing.T) {
|
func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||||
|
skipUnlessAndroid(t)
|
||||||
wgIFace, err := createWgInterfaceWithBind(t)
|
wgIFace, err := createWgInterfaceWithBind(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("failed to initialize wg interface")
|
t.Fatal("failed to initialize wg interface")
|
||||||
@@ -913,6 +853,18 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// skipUnlessAndroid marks tests that exercise the mobile-permanent DNS path,
|
||||||
|
// which only matches a real production setup on android (NewDefaultServerPermanentUpstream
|
||||||
|
// + androidHostManager). On non-android the desktop host manager replaces it
|
||||||
|
// during Initialize and the assertion stops making sense. Skipped here until we
|
||||||
|
// have an android CI runner.
|
||||||
|
func skipUnlessAndroid(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
if runtime.GOOS != "android" {
|
||||||
|
t.Skip("requires android runner; mobile-permanent path doesn't match production on this OS")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||||
@@ -1065,7 +1017,6 @@ type mockHandler struct {
|
|||||||
|
|
||||||
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||||
func (m *mockHandler) Stop() {}
|
func (m *mockHandler) Stop() {}
|
||||||
func (m *mockHandler) ProbeAvailability(context.Context) {}
|
|
||||||
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
|
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
|
||||||
|
|
||||||
type mockService struct{}
|
type mockService struct{}
|
||||||
@@ -2085,6 +2036,598 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
|
|||||||
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
|
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestBuildUpstreamHandler_MergesGroupsPerDomain verifies that multiple
|
||||||
|
// admin-defined nameserver groups targeting the same domain collapse into a
|
||||||
|
// single handler with each group preserved as a sequential inner list.
|
||||||
|
func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) {
|
||||||
|
wgInterface := &mocWGIface{}
|
||||||
|
service := NewServiceViaMemory(wgInterface)
|
||||||
|
server := &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
wgInterface: wgInterface,
|
||||||
|
service: service,
|
||||||
|
localResolver: local.NewResolver(),
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
hostManager: &noopHostConfigurator{},
|
||||||
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
|
}
|
||||||
|
|
||||||
|
groups := []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{IP: netip.MustParseAddr("192.0.2.1"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||||
|
},
|
||||||
|
Domains: []string{"example.com"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{IP: netip.MustParseAddr("192.0.2.2"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||||
|
{IP: netip.MustParseAddr("192.0.2.3"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||||
|
},
|
||||||
|
Domains: []string{"example.com"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
muxUpdates, err := server.buildUpstreamHandlerUpdate(groups)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, muxUpdates, 1, "same-domain groups should merge into one handler")
|
||||||
|
assert.Equal(t, "example.com", muxUpdates[0].domain)
|
||||||
|
assert.Equal(t, PriorityUpstream, muxUpdates[0].priority)
|
||||||
|
|
||||||
|
handler := muxUpdates[0].handler.(*upstreamResolver)
|
||||||
|
require.Len(t, handler.upstreamServers, 2, "handler should have two groups")
|
||||||
|
assert.Equal(t, upstreamRace{netip.MustParseAddrPort("192.0.2.1:53")}, handler.upstreamServers[0])
|
||||||
|
assert.Equal(t, upstreamRace{
|
||||||
|
netip.MustParseAddrPort("192.0.2.2:53"),
|
||||||
|
netip.MustParseAddrPort("192.0.2.3:53"),
|
||||||
|
}, handler.upstreamServers[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEvaluateNSGroupHealth covers the records-only verdict. The gate
|
||||||
|
// (overlay route selected-but-no-active-peer) is intentionally NOT an
|
||||||
|
// input to the evaluator anymore: the verdict drives the Enabled flag,
|
||||||
|
// which must always reflect what we actually observed. Gate-aware event
|
||||||
|
// suppression is tested separately in the projection test.
|
||||||
|
//
|
||||||
|
// Matrix per upstream: {no record, fresh Ok, fresh Fail, stale Fail,
|
||||||
|
// stale Ok, Ok newer than Fail, Fail newer than Ok}.
|
||||||
|
// Group verdict: any fresh-working → Healthy; any fresh-broken with no
|
||||||
|
// fresh-working → Unhealthy; otherwise Undecided.
|
||||||
|
func TestEvaluateNSGroupHealth(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
a := netip.MustParseAddrPort("192.0.2.1:53")
|
||||||
|
b := netip.MustParseAddrPort("192.0.2.2:53")
|
||||||
|
|
||||||
|
recentOk := UpstreamHealth{LastOk: now.Add(-2 * time.Second)}
|
||||||
|
recentFail := UpstreamHealth{LastFail: now.Add(-1 * time.Second), LastErr: "timeout"}
|
||||||
|
staleOk := UpstreamHealth{LastOk: now.Add(-10 * time.Minute)}
|
||||||
|
staleFail := UpstreamHealth{LastFail: now.Add(-10 * time.Minute), LastErr: "timeout"}
|
||||||
|
okThenFail := UpstreamHealth{
|
||||||
|
LastOk: now.Add(-10 * time.Second),
|
||||||
|
LastFail: now.Add(-1 * time.Second),
|
||||||
|
LastErr: "timeout",
|
||||||
|
}
|
||||||
|
failThenOk := UpstreamHealth{
|
||||||
|
LastOk: now.Add(-1 * time.Second),
|
||||||
|
LastFail: now.Add(-10 * time.Second),
|
||||||
|
LastErr: "timeout",
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
health map[netip.AddrPort]UpstreamHealth
|
||||||
|
servers []netip.AddrPort
|
||||||
|
wantVerdict nsGroupVerdict
|
||||||
|
wantErrSubst string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no record, undecided",
|
||||||
|
servers: []netip.AddrPort{a},
|
||||||
|
wantVerdict: nsVerdictUndecided,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fresh success, healthy",
|
||||||
|
health: map[netip.AddrPort]UpstreamHealth{a: recentOk},
|
||||||
|
servers: []netip.AddrPort{a},
|
||||||
|
wantVerdict: nsVerdictHealthy,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fresh failure, unhealthy",
|
||||||
|
health: map[netip.AddrPort]UpstreamHealth{a: recentFail},
|
||||||
|
servers: []netip.AddrPort{a},
|
||||||
|
wantVerdict: nsVerdictUnhealthy,
|
||||||
|
wantErrSubst: "timeout",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only stale success, undecided",
|
||||||
|
health: map[netip.AddrPort]UpstreamHealth{a: staleOk},
|
||||||
|
servers: []netip.AddrPort{a},
|
||||||
|
wantVerdict: nsVerdictUndecided,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only stale failure, undecided",
|
||||||
|
health: map[netip.AddrPort]UpstreamHealth{a: staleFail},
|
||||||
|
servers: []netip.AddrPort{a},
|
||||||
|
wantVerdict: nsVerdictUndecided,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "both fresh, fail newer, unhealthy",
|
||||||
|
health: map[netip.AddrPort]UpstreamHealth{a: okThenFail},
|
||||||
|
servers: []netip.AddrPort{a},
|
||||||
|
wantVerdict: nsVerdictUnhealthy,
|
||||||
|
wantErrSubst: "timeout",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "both fresh, ok newer, healthy",
|
||||||
|
health: map[netip.AddrPort]UpstreamHealth{a: failThenOk},
|
||||||
|
servers: []netip.AddrPort{a},
|
||||||
|
wantVerdict: nsVerdictHealthy,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two upstreams, one success wins",
|
||||||
|
health: map[netip.AddrPort]UpstreamHealth{
|
||||||
|
a: recentFail,
|
||||||
|
b: recentOk,
|
||||||
|
},
|
||||||
|
servers: []netip.AddrPort{a, b},
|
||||||
|
wantVerdict: nsVerdictHealthy,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two upstreams, one fail one unseen, unhealthy",
|
||||||
|
health: map[netip.AddrPort]UpstreamHealth{
|
||||||
|
a: recentFail,
|
||||||
|
},
|
||||||
|
servers: []netip.AddrPort{a, b},
|
||||||
|
wantVerdict: nsVerdictUnhealthy,
|
||||||
|
wantErrSubst: "timeout",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two upstreams, all recent failures, unhealthy",
|
||||||
|
health: map[netip.AddrPort]UpstreamHealth{
|
||||||
|
a: {LastFail: now.Add(-5 * time.Second), LastErr: "timeout"},
|
||||||
|
b: {LastFail: now.Add(-1 * time.Second), LastErr: "SERVFAIL"},
|
||||||
|
},
|
||||||
|
servers: []netip.AddrPort{a, b},
|
||||||
|
wantVerdict: nsVerdictUnhealthy,
|
||||||
|
wantErrSubst: "SERVFAIL",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
verdict, err := evaluateNSGroupHealth(tc.health, tc.servers, now)
|
||||||
|
assert.Equal(t, tc.wantVerdict, verdict, "verdict mismatch")
|
||||||
|
if tc.wantErrSubst != "" {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tc.wantErrSubst)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed
|
||||||
|
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
|
||||||
|
// without spinning up real handlers.
|
||||||
|
type healthStubHandler struct {
|
||||||
|
health map[netip.AddrPort]UpstreamHealth
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *healthStubHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||||
|
func (h *healthStubHandler) Stop() {}
|
||||||
|
func (h *healthStubHandler) ID() types.HandlerID { return "health-stub" }
|
||||||
|
func (h *healthStubHandler) UpstreamHealth() map[netip.AddrPort]UpstreamHealth {
|
||||||
|
return h.health
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProjection_SteadyStateIsSilent guards against duplicate events:
|
||||||
|
// while a group stays Unhealthy tick after tick, only the first
|
||||||
|
// Unhealthy transition may emit. Same for staying Healthy.
|
||||||
|
func TestProjection_SteadyStateIsSilent(t *testing.T) {
|
||||||
|
fx := newProjTestFixture(t)
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectEvent("unreachable", "first fail emits warning")
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||||
|
fx.tick()
|
||||||
|
fx.tick()
|
||||||
|
fx.expectNoEvent("staying unhealthy must not re-emit")
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectEvent("recovered", "recovery on transition")
|
||||||
|
|
||||||
|
fx.tick()
|
||||||
|
fx.tick()
|
||||||
|
fx.expectNoEvent("staying healthy must not re-emit")
|
||||||
|
}
|
||||||
|
|
||||||
|
// projTestFixture is the common setup for the projection tests: a
|
||||||
|
// single-upstream group whose route classification the test can flip by
|
||||||
|
// assigning to selected/active. Callers drive failures/successes by
|
||||||
|
// mutating stub.health and calling refreshHealth.
|
||||||
|
type projTestFixture struct {
|
||||||
|
t *testing.T
|
||||||
|
recorder *peer.Status
|
||||||
|
events <-chan *proto.SystemEvent
|
||||||
|
server *DefaultServer
|
||||||
|
stub *healthStubHandler
|
||||||
|
group *nbdns.NameServerGroup
|
||||||
|
srv netip.AddrPort
|
||||||
|
selected route.HAMap
|
||||||
|
active route.HAMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func newProjTestFixture(t *testing.T) *projTestFixture {
|
||||||
|
t.Helper()
|
||||||
|
recorder := peer.NewRecorder("mgm")
|
||||||
|
sub := recorder.SubscribeToEvents()
|
||||||
|
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
|
||||||
|
|
||||||
|
srv := netip.MustParseAddrPort("100.64.0.1:53")
|
||||||
|
fx := &projTestFixture{
|
||||||
|
t: t,
|
||||||
|
recorder: recorder,
|
||||||
|
events: sub.Events(),
|
||||||
|
stub: &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{}},
|
||||||
|
srv: srv,
|
||||||
|
group: &nbdns.NameServerGroup{
|
||||||
|
Domains: []string{"example.com"},
|
||||||
|
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fx.server = &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
wgInterface: &mocWGIface{},
|
||||||
|
statusRecorder: recorder,
|
||||||
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
|
selectedRoutes: func() route.HAMap { return fx.selected },
|
||||||
|
activeRoutes: func() route.HAMap { return fx.active },
|
||||||
|
warningDelayBase: defaultWarningDelayBase,
|
||||||
|
}
|
||||||
|
fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}
|
||||||
|
|
||||||
|
fx.server.mux.Lock()
|
||||||
|
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
|
||||||
|
fx.server.mux.Unlock()
|
||||||
|
return fx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *projTestFixture) setHealth(h UpstreamHealth) {
|
||||||
|
f.stub.health = map[netip.AddrPort]UpstreamHealth{f.srv: h}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *projTestFixture) tick() []peer.NSGroupState {
|
||||||
|
f.server.refreshHealth()
|
||||||
|
return f.recorder.GetDNSStates()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *projTestFixture) expectNoEvent(why string) {
|
||||||
|
f.t.Helper()
|
||||||
|
select {
|
||||||
|
case evt := <-f.events:
|
||||||
|
f.t.Fatalf("unexpected event (%s): %+v", why, evt)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *projTestFixture) expectEvent(substr, why string) *proto.SystemEvent {
|
||||||
|
f.t.Helper()
|
||||||
|
select {
|
||||||
|
case evt := <-f.events:
|
||||||
|
assert.Contains(f.t, evt.Message, substr, why)
|
||||||
|
return evt
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
f.t.Fatalf("expected event (%s) with %q", why, substr)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var overlayNetForTest = netip.MustParsePrefix("100.64.0.0/16")
|
||||||
|
var overlayMapForTest = route.HAMap{"overlay": {{Network: overlayNetForTest}}}
|
||||||
|
|
||||||
|
// TestProjection_PublicFailEmitsImmediately covers rule 1: an upstream
|
||||||
|
// that is not inside any selected route (public DNS) fires the warning
|
||||||
|
// on the first Unhealthy tick, no grace period.
|
||||||
|
func TestProjection_PublicFailEmitsImmediately(t *testing.T) {
|
||||||
|
fx := newProjTestFixture(t)
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||||
|
states := fx.tick()
|
||||||
|
require.Len(t, states, 1)
|
||||||
|
assert.False(t, states[0].Enabled)
|
||||||
|
fx.expectEvent("unreachable", "public DNS failure")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProjection_OverlayConnectedFailEmitsImmediately covers rule 2:
|
||||||
|
// the upstream is inside a selected route AND the route has a Connected
|
||||||
|
// peer. Tunnel is up, failure is real, emit immediately.
|
||||||
|
func TestProjection_OverlayConnectedFailEmitsImmediately(t *testing.T) {
|
||||||
|
fx := newProjTestFixture(t)
|
||||||
|
fx.selected = overlayMapForTest
|
||||||
|
fx.active = overlayMapForTest
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||||
|
states := fx.tick()
|
||||||
|
require.Len(t, states, 1)
|
||||||
|
assert.False(t, states[0].Enabled)
|
||||||
|
fx.expectEvent("unreachable", "overlay + connected failure")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProjection_OverlayNotConnectedDelaysWarning covers rule 3: the
|
||||||
|
// upstream is routed but no peer is Connected (Connecting/Idle/missing).
|
||||||
|
// First tick: Unhealthy display, no warning. After the grace window
|
||||||
|
// elapses with no recovery, the warning fires.
|
||||||
|
func TestProjection_OverlayNotConnectedDelaysWarning(t *testing.T) {
|
||||||
|
grace := 50 * time.Millisecond
|
||||||
|
fx := newProjTestFixture(t)
|
||||||
|
fx.server.warningDelayBase = grace
|
||||||
|
fx.selected = overlayMapForTest
|
||||||
|
// active stays nil: routed but not connected.
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||||
|
states := fx.tick()
|
||||||
|
require.Len(t, states, 1)
|
||||||
|
assert.False(t, states[0].Enabled, "display must reflect failure even during grace window")
|
||||||
|
fx.expectNoEvent("first fail tick within grace window")
|
||||||
|
|
||||||
|
time.Sleep(grace + 10*time.Millisecond)
|
||||||
|
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectEvent("unreachable", "warning after grace window")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProjection_OverlayAddrNoRouteDelaysWarning covers an upstream
|
||||||
|
// whose address is inside the WireGuard overlay range but is not
|
||||||
|
// covered by any selected route (peer-to-peer DNS without an explicit
|
||||||
|
// route). Until a peer reports Connected for that address, startup
|
||||||
|
// failures must be held just like the routed case.
|
||||||
|
func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
|
||||||
|
recorder := peer.NewRecorder("mgm")
|
||||||
|
sub := recorder.SubscribeToEvents()
|
||||||
|
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
|
||||||
|
|
||||||
|
overlayPeer := netip.MustParseAddrPort("100.66.100.5:53")
|
||||||
|
server := &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
wgInterface: &mocWGIface{},
|
||||||
|
statusRecorder: recorder,
|
||||||
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
|
selectedRoutes: func() route.HAMap { return nil },
|
||||||
|
activeRoutes: func() route.HAMap { return nil },
|
||||||
|
warningDelayBase: 50 * time.Millisecond,
|
||||||
|
}
|
||||||
|
group := &nbdns.NameServerGroup{
|
||||||
|
Domains: []string{"example.com"},
|
||||||
|
NameServers: []nbdns.NameServer{{IP: overlayPeer.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlayPeer.Port())}},
|
||||||
|
}
|
||||||
|
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
|
||||||
|
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
|
||||||
|
}}
|
||||||
|
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||||
|
|
||||||
|
server.mux.Lock()
|
||||||
|
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||||
|
server.mux.Unlock()
|
||||||
|
server.refreshHealth()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case evt := <-sub.Events():
|
||||||
|
t.Fatalf("unexpected event during grace window: %+v", evt)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(60 * time.Millisecond)
|
||||||
|
stub.health = map[netip.AddrPort]UpstreamHealth{overlayPeer: {LastFail: time.Now(), LastErr: "timeout"}}
|
||||||
|
server.refreshHealth()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case evt := <-sub.Events():
|
||||||
|
assert.Contains(t, evt.Message, "unreachable")
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("expected warning after grace window")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProjection_StopClearsHealthState verifies that Stop wipes the
|
||||||
|
// per-group projection state so a subsequent Start doesn't inherit
|
||||||
|
// sticky flags (notably everHealthy) that would bypass the grace
|
||||||
|
// window during the next peer handshake.
|
||||||
|
func TestProjection_StopClearsHealthState(t *testing.T) {
|
||||||
|
wgIface := &mocWGIface{}
|
||||||
|
server := &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
wgInterface: wgIface,
|
||||||
|
service: NewServiceViaMemory(wgIface),
|
||||||
|
hostManager: &noopHostConfigurator{},
|
||||||
|
extraDomains: map[domain.Domain]int{},
|
||||||
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
|
statusRecorder: peer.NewRecorder("mgm"),
|
||||||
|
selectedRoutes: func() route.HAMap { return nil },
|
||||||
|
activeRoutes: func() route.HAMap { return nil },
|
||||||
|
warningDelayBase: defaultWarningDelayBase,
|
||||||
|
currentConfigHash: ^uint64(0),
|
||||||
|
}
|
||||||
|
server.ctx, server.ctxCancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
srv := netip.MustParseAddrPort("8.8.8.8:53")
|
||||||
|
group := &nbdns.NameServerGroup{
|
||||||
|
Domains: []string{"example.com"},
|
||||||
|
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
|
||||||
|
}
|
||||||
|
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
|
||||||
|
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||||
|
|
||||||
|
server.mux.Lock()
|
||||||
|
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||||
|
server.mux.Unlock()
|
||||||
|
server.refreshHealth()
|
||||||
|
|
||||||
|
server.healthProjectMu.Lock()
|
||||||
|
p, ok := server.nsGroupProj[generateGroupKey(group)]
|
||||||
|
server.healthProjectMu.Unlock()
|
||||||
|
require.True(t, ok, "projection state should exist after tick")
|
||||||
|
require.True(t, p.everHealthy, "tick with success must set everHealthy")
|
||||||
|
|
||||||
|
server.Stop()
|
||||||
|
|
||||||
|
server.healthProjectMu.Lock()
|
||||||
|
cleared := server.nsGroupProj == nil
|
||||||
|
server.healthProjectMu.Unlock()
|
||||||
|
assert.True(t, cleared, "Stop must clear nsGroupProj")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProjection_OverlayRecoversDuringGrace covers the happy path of
|
||||||
|
// rule 3: startup failures while the peer is handshaking, then the peer
|
||||||
|
// comes up and a query succeeds before the grace window elapses. No
|
||||||
|
// warning should ever have fired, and no recovery either.
|
||||||
|
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
|
||||||
|
fx := newProjTestFixture(t)
|
||||||
|
fx.server.warningDelayBase = 200 * time.Millisecond
|
||||||
|
fx.selected = overlayMapForTest
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectNoEvent("fail within grace, warning suppressed")
|
||||||
|
|
||||||
|
fx.active = overlayMapForTest
|
||||||
|
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||||
|
states := fx.tick()
|
||||||
|
require.Len(t, states, 1)
|
||||||
|
assert.True(t, states[0].Enabled)
|
||||||
|
fx.expectNoEvent("recovery without prior warning must not emit")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProjection_RecoveryOnlyAfterWarning enforces the invariant the
|
||||||
|
// whole design leans on: recovery events only appear when a warning
|
||||||
|
// event was actually emitted for the current streak. A Healthy verdict
|
||||||
|
// without a prior warning is silent, so the user never sees "recovered"
|
||||||
|
// out of thin air.
|
||||||
|
func TestProjection_RecoveryOnlyAfterWarning(t *testing.T) {
|
||||||
|
fx := newProjTestFixture(t)
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||||
|
states := fx.tick()
|
||||||
|
require.Len(t, states, 1)
|
||||||
|
assert.True(t, states[0].Enabled)
|
||||||
|
fx.expectNoEvent("first healthy tick should not recover anything")
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectEvent("unreachable", "public fail emits immediately")
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectEvent("recovered", "recovery follows real warning")
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectEvent("unreachable", "second cycle warning")
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectEvent("recovered", "second cycle recovery")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProjection_EverHealthyOverridesDelay covers rule 4: once a group
|
||||||
|
// has ever been Healthy, subsequent failures skip the grace window even
|
||||||
|
// if classification says "routed + not connected". The system has
|
||||||
|
// proved it can work, so any new failure is real.
|
||||||
|
func TestProjection_EverHealthyOverridesDelay(t *testing.T) {
|
||||||
|
fx := newProjTestFixture(t)
|
||||||
|
// Large base so any emission must come from the everHealthy bypass, not elapsed time.
|
||||||
|
fx.server.warningDelayBase = time.Hour
|
||||||
|
fx.selected = overlayMapForTest
|
||||||
|
fx.active = overlayMapForTest
|
||||||
|
|
||||||
|
// Establish "ever healthy".
|
||||||
|
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectNoEvent("first healthy tick")
|
||||||
|
|
||||||
|
// Peer drops. Query fails. Routed + not connected → normally grace,
|
||||||
|
// but everHealthy flag bypasses it.
|
||||||
|
fx.active = nil
|
||||||
|
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectEvent("unreachable", "failure after ever-healthy must be immediate")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProjection_ReconnectBlipEmitsPair covers the explicit tradeoff
|
||||||
|
// from the design discussion: once a group has been healthy, a brief
|
||||||
|
// reconnect that produces a failing tick will fire warning + recovery.
|
||||||
|
// This is by design: user-visible blips are accurate signal, not noise.
|
||||||
|
func TestProjection_ReconnectBlipEmitsPair(t *testing.T) {
|
||||||
|
fx := newProjTestFixture(t)
|
||||||
|
fx.selected = overlayMapForTest
|
||||||
|
fx.active = overlayMapForTest
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||||
|
fx.tick()
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectEvent("unreachable", "blip warning")
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectEvent("recovered", "blip recovery")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProjection_MixedGroupEmitsImmediately covers the multi-upstream
|
||||||
|
// rule: a group with at least one public upstream is in the "immediate"
|
||||||
|
// category regardless of the other upstreams' routing, because the
|
||||||
|
// public one has no peer-startup excuse. Prevents public-DNS failures
|
||||||
|
// from being hidden behind a routed sibling.
|
||||||
|
func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
|
||||||
|
recorder := peer.NewRecorder("mgm")
|
||||||
|
sub := recorder.SubscribeToEvents()
|
||||||
|
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
|
||||||
|
events := sub.Events()
|
||||||
|
|
||||||
|
public := netip.MustParseAddrPort("8.8.8.8:53")
|
||||||
|
overlay := netip.MustParseAddrPort("100.64.0.1:53")
|
||||||
|
overlayMap := route.HAMap{"overlay": {{Network: netip.MustParsePrefix("100.64.0.0/16")}}}
|
||||||
|
|
||||||
|
server := &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
statusRecorder: recorder,
|
||||||
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
|
selectedRoutes: func() route.HAMap { return overlayMap },
|
||||||
|
activeRoutes: func() route.HAMap { return nil },
|
||||||
|
warningDelayBase: time.Hour,
|
||||||
|
}
|
||||||
|
group := &nbdns.NameServerGroup{
|
||||||
|
Domains: []string{"example.com"},
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{IP: public.Addr(), NSType: nbdns.UDPNameServerType, Port: int(public.Port())},
|
||||||
|
{IP: overlay.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlay.Port())},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
stub := &healthStubHandler{
|
||||||
|
health: map[netip.AddrPort]UpstreamHealth{
|
||||||
|
public: {LastFail: time.Now(), LastErr: "servfail"},
|
||||||
|
overlay: {LastFail: time.Now(), LastErr: "timeout"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||||
|
|
||||||
|
server.mux.Lock()
|
||||||
|
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||||
|
server.mux.Unlock()
|
||||||
|
server.refreshHealth()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case evt := <-events:
|
||||||
|
assert.Contains(t, evt.Message, "unreachable")
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("expected immediate warning because group contains a public upstream")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDNSLoopPrevention(t *testing.T) {
|
func TestDNSLoopPrevention(t *testing.T) {
|
||||||
wgInterface := &mocWGIface{}
|
wgInterface := &mocWGIface{}
|
||||||
service := NewServiceViaMemory(wgInterface)
|
service := NewServiceViaMemory(wgInterface)
|
||||||
@@ -2183,17 +2726,18 @@ func TestDNSLoopPrevention(t *testing.T) {
|
|||||||
|
|
||||||
if tt.expectedHandlers > 0 {
|
if tt.expectedHandlers > 0 {
|
||||||
handler := muxUpdates[0].handler.(*upstreamResolver)
|
handler := muxUpdates[0].handler.(*upstreamResolver)
|
||||||
assert.Len(t, handler.upstreamServers, len(tt.expectedServers))
|
flat := handler.flatUpstreams()
|
||||||
|
assert.Len(t, flat, len(tt.expectedServers))
|
||||||
|
|
||||||
if tt.shouldFilterOwnIP {
|
if tt.shouldFilterOwnIP {
|
||||||
for _, upstream := range handler.upstreamServers {
|
for _, upstream := range flat {
|
||||||
assert.NotEqual(t, dnsServerIP, upstream.Addr())
|
assert.NotEqual(t, dnsServerIP, upstream.Addr())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, expected := range tt.expectedServers {
|
for _, expected := range tt.expectedServers {
|
||||||
found := false
|
found := false
|
||||||
for _, upstream := range handler.upstreamServers {
|
for _, upstream := range flat {
|
||||||
if upstream.Addr() == expected {
|
if upstream.Addr() == expected {
|
||||||
found = true
|
found = true
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/godbus/dbus/v5"
|
"github.com/godbus/dbus/v5"
|
||||||
@@ -40,10 +41,17 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type systemdDbusConfigurator struct {
|
type systemdDbusConfigurator struct {
|
||||||
dbusLinkObject dbus.ObjectPath
|
dbusLinkObject dbus.ObjectPath
|
||||||
ifaceName string
|
ifaceName string
|
||||||
|
wgIndex int
|
||||||
|
origNameservers []netip.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
systemdDbusLinkDNSProperty = systemdDbusLinkInterface + ".DNS"
|
||||||
|
systemdDbusLinkDefaultRouteProperty = systemdDbusLinkInterface + ".DefaultRoute"
|
||||||
|
)
|
||||||
|
|
||||||
// the types below are based on dbus specification, each field is mapped to a dbus type
|
// the types below are based on dbus specification, each field is mapped to a dbus type
|
||||||
// see https://dbus.freedesktop.org/doc/dbus-specification.html#basic-types for more details on dbus types
|
// see https://dbus.freedesktop.org/doc/dbus-specification.html#basic-types for more details on dbus types
|
||||||
// see https://www.freedesktop.org/software/systemd/man/org.freedesktop.resolve1.html on resolve1 input types
|
// see https://www.freedesktop.org/software/systemd/man/org.freedesktop.resolve1.html on resolve1 input types
|
||||||
@@ -79,10 +87,145 @@ func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, e
|
|||||||
|
|
||||||
log.Debugf("got dbus Link interface: %s from net interface %s and index %d", s, iface.Name, iface.Index)
|
log.Debugf("got dbus Link interface: %s from net interface %s and index %d", s, iface.Name, iface.Index)
|
||||||
|
|
||||||
return &systemdDbusConfigurator{
|
c := &systemdDbusConfigurator{
|
||||||
dbusLinkObject: dbus.ObjectPath(s),
|
dbusLinkObject: dbus.ObjectPath(s),
|
||||||
ifaceName: wgInterface,
|
ifaceName: wgInterface,
|
||||||
}, nil
|
wgIndex: iface.Index,
|
||||||
|
}
|
||||||
|
|
||||||
|
origNameservers, err := c.captureOriginalNameservers()
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
log.Warnf("capture original nameservers from systemd-resolved: %v", err)
|
||||||
|
case len(origNameservers) == 0:
|
||||||
|
log.Warnf("no original nameservers captured from systemd-resolved default-route links; DNS fallback will be empty")
|
||||||
|
default:
|
||||||
|
log.Debugf("captured %d original nameservers from systemd-resolved default-route links: %v", len(origNameservers), origNameservers)
|
||||||
|
}
|
||||||
|
c.origNameservers = origNameservers
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// captureOriginalNameservers reads per-link DNS from systemd-resolved for
|
||||||
|
// every default-route link except our own WG link. Non-default-route links
|
||||||
|
// (VPNs, docker bridges) are skipped because their upstreams wouldn't
|
||||||
|
// actually serve host queries.
|
||||||
|
func (s *systemdDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
|
||||||
|
ifaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list interfaces: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[netip.Addr]struct{})
|
||||||
|
var out []netip.Addr
|
||||||
|
for _, iface := range ifaces {
|
||||||
|
if !s.isCandidateLink(iface) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
linkPath, err := getSystemdLinkPath(iface.Index)
|
||||||
|
if err != nil || !isSystemdLinkDefaultRoute(linkPath) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, addr := range readSystemdLinkDNS(linkPath) {
|
||||||
|
addr = normalizeSystemdAddr(addr, iface.Name)
|
||||||
|
if !addr.IsValid() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, dup := seen[addr]; dup {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[addr] = struct{}{}
|
||||||
|
out = append(out, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *systemdDbusConfigurator) isCandidateLink(iface net.Interface) bool {
|
||||||
|
if iface.Index == s.wgIndex {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeSystemdAddr unmaps v4-mapped-v6, drops unspecified, and reattaches
|
||||||
|
// the link's iface name as zone for link-local v6 (Link.DNS strips it).
|
||||||
|
// Returns the zero Addr to signal "skip this entry".
|
||||||
|
func normalizeSystemdAddr(addr netip.Addr, ifaceName string) netip.Addr {
|
||||||
|
addr = addr.Unmap()
|
||||||
|
if !addr.IsValid() || addr.IsUnspecified() {
|
||||||
|
return netip.Addr{}
|
||||||
|
}
|
||||||
|
if addr.IsLinkLocalUnicast() {
|
||||||
|
return addr.WithZone(ifaceName)
|
||||||
|
}
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSystemdLinkPath(ifIndex int) (dbus.ObjectPath, error) {
|
||||||
|
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("dbus resolve1: %w", err)
|
||||||
|
}
|
||||||
|
defer closeConn()
|
||||||
|
var p string
|
||||||
|
if err := obj.Call(systemdDbusGetLinkMethod, dbusDefaultFlag, int32(ifIndex)).Store(&p); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return dbus.ObjectPath(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSystemdLinkDefaultRoute(linkPath dbus.ObjectPath) bool {
|
||||||
|
obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
defer closeConn()
|
||||||
|
v, err := obj.GetProperty(systemdDbusLinkDefaultRouteProperty)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
b, ok := v.Value().(bool)
|
||||||
|
return ok && b
|
||||||
|
}
|
||||||
|
|
||||||
|
func readSystemdLinkDNS(linkPath dbus.ObjectPath) []netip.Addr {
|
||||||
|
obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer closeConn()
|
||||||
|
v, err := obj.GetProperty(systemdDbusLinkDNSProperty)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
entries, ok := v.Value().([][]any)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var out []netip.Addr
|
||||||
|
for _, entry := range entries {
|
||||||
|
if len(entry) < 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
raw, ok := entry[1].([]byte)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addr, ok := netip.AddrFromSlice(raw)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, addr)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *systemdDbusConfigurator) getOriginalNameservers() []netip.Addr {
|
||||||
|
return slices.Clone(s.origNameservers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemdDbusConfigurator) supportCustomPort() bool {
|
func (s *systemdDbusConfigurator) supportCustomPort() bool {
|
||||||
|
|||||||
@@ -1,3 +1,32 @@
|
|||||||
|
// Package dns implements the client-side DNS stack: listener/service on the
|
||||||
|
// peer's tunnel address, handler chain that routes questions by domain and
|
||||||
|
// priority, and upstream resolvers that forward what remains to configured
|
||||||
|
// nameservers.
|
||||||
|
//
|
||||||
|
// # Upstream resolution and the race model
|
||||||
|
//
|
||||||
|
// When two or more nameserver groups target the same domain, DefaultServer
|
||||||
|
// merges them into one upstream handler whose state is:
|
||||||
|
//
|
||||||
|
// upstreamResolverBase
|
||||||
|
// └── upstreamServers []upstreamRace // one entry per source NS group
|
||||||
|
// └── []netip.AddrPort // primary, fallback, ...
|
||||||
|
//
|
||||||
|
// Each source nameserver group contributes one upstreamRace. Within a race
|
||||||
|
// upstreams are tried in order: the next is used only on failure (timeout,
|
||||||
|
// SERVFAIL, REFUSED, no response). NXDOMAIN is a valid answer and stops
|
||||||
|
// the walk. When more than one race exists, ServeDNS fans out one
|
||||||
|
// goroutine per race and returns the first valid answer, cancelling the
|
||||||
|
// rest. A handler with a single race skips the fan-out.
|
||||||
|
//
|
||||||
|
// # Health projection
|
||||||
|
//
|
||||||
|
// Query outcomes are recorded per-upstream in UpstreamHealth. The server
|
||||||
|
// periodically merges these snapshots across handlers and projects them
|
||||||
|
// into peer.NSGroupState. There is no active probing: a group is marked
|
||||||
|
// unhealthy only when every seen upstream has a recent failure and none
|
||||||
|
// has a recent success. Healthy→unhealthy fires a single
|
||||||
|
// SystemEvent_WARNING; steady-state refreshes do not duplicate it.
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -11,11 +40,8 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
@@ -25,7 +51,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
var currentMTU uint16 = iface.DefaultMTU
|
var currentMTU uint16 = iface.DefaultMTU
|
||||||
@@ -67,15 +94,17 @@ const (
|
|||||||
// Set longer than UpstreamTimeout to ensure context timeout takes precedence
|
// Set longer than UpstreamTimeout to ensure context timeout takes precedence
|
||||||
ClientTimeout = 5 * time.Second
|
ClientTimeout = 5 * time.Second
|
||||||
|
|
||||||
reactivatePeriod = 30 * time.Second
|
|
||||||
probeTimeout = 2 * time.Second
|
|
||||||
|
|
||||||
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
|
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
|
||||||
// payload from the tunnel MTU.
|
// payload from the tunnel MTU.
|
||||||
ipUDPHeaderSize = 60 + 8
|
ipUDPHeaderSize = 60 + 8
|
||||||
)
|
|
||||||
|
|
||||||
const testRecord = "com."
|
// raceMaxTotalTimeout caps the combined time spent walking all upstreams
|
||||||
|
// within one race, so a slow primary can't eat the whole race budget.
|
||||||
|
raceMaxTotalTimeout = 5 * time.Second
|
||||||
|
// raceMinPerUpstreamTimeout is the floor applied when dividing
|
||||||
|
// raceMaxTotalTimeout across upstreams within a race.
|
||||||
|
raceMinPerUpstreamTimeout = 2 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
protoUDP = "udp"
|
protoUDP = "udp"
|
||||||
@@ -84,6 +113,69 @@ const (
|
|||||||
|
|
||||||
type dnsProtocolKey struct{}
|
type dnsProtocolKey struct{}
|
||||||
|
|
||||||
|
type upstreamProtocolKey struct{}
|
||||||
|
|
||||||
|
// upstreamProtocolResult holds the protocol used for the upstream exchange.
|
||||||
|
// Stored as a pointer in context so the exchange function can set it.
|
||||||
|
type upstreamProtocolResult struct {
|
||||||
|
protocol string
|
||||||
|
}
|
||||||
|
|
||||||
|
type upstreamClient interface {
|
||||||
|
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpstreamResolver interface {
|
||||||
|
serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||||
|
upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// upstreamRace is an ordered list of upstreams derived from one configured
|
||||||
|
// nameserver group. Order matters: the first upstream is tried first, the
|
||||||
|
// second only on failure, and so on. Multiple upstreamRace values coexist
|
||||||
|
// inside one resolver when overlapping nameserver groups target the same
|
||||||
|
// domain; those races run in parallel and the first valid answer wins.
|
||||||
|
type upstreamRace []netip.AddrPort
|
||||||
|
|
||||||
|
// UpstreamHealth is the last query-path outcome for a single upstream,
|
||||||
|
// consumed by nameserver-group status projection.
|
||||||
|
type UpstreamHealth struct {
|
||||||
|
LastOk time.Time
|
||||||
|
LastFail time.Time
|
||||||
|
LastErr string
|
||||||
|
}
|
||||||
|
|
||||||
|
type upstreamResolverBase struct {
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
upstreamClient upstreamClient
|
||||||
|
upstreamServers []upstreamRace
|
||||||
|
domain domain.Domain
|
||||||
|
upstreamTimeout time.Duration
|
||||||
|
|
||||||
|
healthMu sync.RWMutex
|
||||||
|
health map[netip.AddrPort]*UpstreamHealth
|
||||||
|
|
||||||
|
statusRecorder *peer.Status
|
||||||
|
// selectedRoutes returns the current set of client routes the admin
|
||||||
|
// has enabled. Called lazily from the query hot path when an upstream
|
||||||
|
// might need a tunnel-bound client (iOS) and from health projection.
|
||||||
|
selectedRoutes func() route.HAMap
|
||||||
|
}
|
||||||
|
|
||||||
|
type upstreamFailure struct {
|
||||||
|
upstream netip.AddrPort
|
||||||
|
reason string
|
||||||
|
}
|
||||||
|
|
||||||
|
type raceResult struct {
|
||||||
|
msg *dns.Msg
|
||||||
|
upstream netip.AddrPort
|
||||||
|
protocol string
|
||||||
|
ede string
|
||||||
|
failures []upstreamFailure
|
||||||
|
}
|
||||||
|
|
||||||
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
|
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
|
||||||
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
|
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
|
||||||
return context.WithValue(ctx, dnsProtocolKey{}, network)
|
return context.WithValue(ctx, dnsProtocolKey{}, network)
|
||||||
@@ -100,16 +192,8 @@ func dnsProtocolFromContext(ctx context.Context) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
type upstreamProtocolKey struct{}
|
// contextWithUpstreamProtocolResult stores a mutable result holder in the context.
|
||||||
|
func contextWithUpstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
|
||||||
// upstreamProtocolResult holds the protocol used for the upstream exchange.
|
|
||||||
// Stored as a pointer in context so the exchange function can set it.
|
|
||||||
type upstreamProtocolResult struct {
|
|
||||||
protocol string
|
|
||||||
}
|
|
||||||
|
|
||||||
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
|
|
||||||
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
|
|
||||||
r := &upstreamProtocolResult{}
|
r := &upstreamProtocolResult{}
|
||||||
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
|
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
|
||||||
}
|
}
|
||||||
@@ -124,67 +208,37 @@ func setUpstreamProtocol(ctx context.Context, protocol string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type upstreamClient interface {
|
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d domain.Domain) *upstreamResolverBase {
|
||||||
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type UpstreamResolver interface {
|
|
||||||
serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error)
|
|
||||||
upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type upstreamResolverBase struct {
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
upstreamClient upstreamClient
|
|
||||||
upstreamServers []netip.AddrPort
|
|
||||||
domain string
|
|
||||||
disabled bool
|
|
||||||
successCount atomic.Int32
|
|
||||||
mutex sync.Mutex
|
|
||||||
reactivatePeriod time.Duration
|
|
||||||
upstreamTimeout time.Duration
|
|
||||||
wg sync.WaitGroup
|
|
||||||
|
|
||||||
deactivate func(error)
|
|
||||||
reactivate func()
|
|
||||||
statusRecorder *peer.Status
|
|
||||||
routeMatch func(netip.Addr) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type upstreamFailure struct {
|
|
||||||
upstream netip.AddrPort
|
|
||||||
reason string
|
|
||||||
}
|
|
||||||
|
|
||||||
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
return &upstreamResolverBase{
|
return &upstreamResolverBase{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
domain: domain,
|
domain: d,
|
||||||
upstreamTimeout: UpstreamTimeout,
|
upstreamTimeout: UpstreamTimeout,
|
||||||
reactivatePeriod: reactivatePeriod,
|
statusRecorder: statusRecorder,
|
||||||
statusRecorder: statusRecorder,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// String returns a string representation of the upstream resolver
|
// String returns a string representation of the upstream resolver
|
||||||
func (u *upstreamResolverBase) String() string {
|
func (u *upstreamResolverBase) String() string {
|
||||||
return fmt.Sprintf("Upstream %s", u.upstreamServers)
|
return fmt.Sprintf("Upstream %s", u.flatUpstreams())
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID returns the unique handler ID
|
// ID returns the unique handler ID. Race groupings and within-race
|
||||||
|
// ordering are both part of the identity: [[A,B]] and [[A],[B]] query
|
||||||
|
// the same servers but with different semantics (serial fallback vs
|
||||||
|
// parallel race), so their handlers must not collide.
|
||||||
func (u *upstreamResolverBase) ID() types.HandlerID {
|
func (u *upstreamResolverBase) ID() types.HandlerID {
|
||||||
servers := slices.Clone(u.upstreamServers)
|
|
||||||
slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) })
|
|
||||||
|
|
||||||
hash := sha256.New()
|
hash := sha256.New()
|
||||||
hash.Write([]byte(u.domain + ":"))
|
hash.Write([]byte(u.domain.PunycodeString() + ":"))
|
||||||
for _, s := range servers {
|
for _, race := range u.upstreamServers {
|
||||||
hash.Write([]byte(s.String()))
|
hash.Write([]byte("["))
|
||||||
hash.Write([]byte("|"))
|
for _, s := range race {
|
||||||
|
hash.Write([]byte(s.String()))
|
||||||
|
hash.Write([]byte("|"))
|
||||||
|
}
|
||||||
|
hash.Write([]byte("]"))
|
||||||
}
|
}
|
||||||
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
||||||
}
|
}
|
||||||
@@ -194,13 +248,31 @@ func (u *upstreamResolverBase) MatchSubdomains() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) Stop() {
|
func (u *upstreamResolverBase) Stop() {
|
||||||
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
log.Debugf("stopping serving DNS for upstreams %s", u.flatUpstreams())
|
||||||
u.cancel()
|
u.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
u.mutex.Lock()
|
// flatUpstreams is for logging and ID hashing only, not for dispatch.
|
||||||
u.wg.Wait()
|
func (u *upstreamResolverBase) flatUpstreams() []netip.AddrPort {
|
||||||
u.mutex.Unlock()
|
var out []netip.AddrPort
|
||||||
|
for _, g := range u.upstreamServers {
|
||||||
|
out = append(out, g...)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// setSelectedRoutes swaps the accessor used to classify overlay-routed
|
||||||
|
// upstreams. Called when route sources are wired after the handler was
|
||||||
|
// built (permanent / iOS constructors).
|
||||||
|
func (u *upstreamResolverBase) setSelectedRoutes(selected func() route.HAMap) {
|
||||||
|
u.selectedRoutes = selected
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) addRace(servers []netip.AddrPort) {
|
||||||
|
if len(servers) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
u.upstreamServers = append(u.upstreamServers, slices.Clone(servers))
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
@@ -242,82 +314,201 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
||||||
timeout := u.upstreamTimeout
|
groups := u.upstreamServers
|
||||||
if len(u.upstreamServers) > 1 {
|
switch len(groups) {
|
||||||
maxTotal := 5 * time.Second
|
case 0:
|
||||||
minPerUpstream := 2 * time.Second
|
return false, nil
|
||||||
scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers))
|
case 1:
|
||||||
if scaledTimeout > minPerUpstream {
|
return u.tryOnlyRace(ctx, w, r, groups[0], logger)
|
||||||
timeout = scaledTimeout
|
default:
|
||||||
} else {
|
return u.raceAll(ctx, w, r, groups, logger)
|
||||||
timeout = minPerUpstream
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) tryOnlyRace(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, group upstreamRace, logger *log.Entry) (bool, []upstreamFailure) {
|
||||||
|
res := u.tryRace(ctx, r, group)
|
||||||
|
if res.msg == nil {
|
||||||
|
return false, res.failures
|
||||||
|
}
|
||||||
|
if res.ede != "" {
|
||||||
|
resutil.SetMeta(w, "ede", res.ede)
|
||||||
|
}
|
||||||
|
u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger)
|
||||||
|
return true, res.failures
|
||||||
|
}
|
||||||
|
|
||||||
|
// raceAll runs one worker per group in parallel, taking the first valid
|
||||||
|
// answer and cancelling the rest.
|
||||||
|
func (u *upstreamResolverBase) raceAll(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, groups []upstreamRace, logger *log.Entry) (bool, []upstreamFailure) {
|
||||||
|
raceCtx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Buffer sized to len(groups) so workers never block on send, even
|
||||||
|
// after the coordinator has returned.
|
||||||
|
results := make(chan raceResult, len(groups))
|
||||||
|
for _, g := range groups {
|
||||||
|
// tryRace clones the request per attempt, so workers never share
|
||||||
|
// a *dns.Msg and concurrent EDNS0 mutations can't race.
|
||||||
|
go func(g upstreamRace) {
|
||||||
|
results <- u.tryRace(raceCtx, r, g)
|
||||||
|
}(g)
|
||||||
}
|
}
|
||||||
|
|
||||||
var failures []upstreamFailure
|
var failures []upstreamFailure
|
||||||
for _, upstream := range u.upstreamServers {
|
for range groups {
|
||||||
if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil {
|
select {
|
||||||
failures = append(failures, *failure)
|
case res := <-results:
|
||||||
} else {
|
failures = append(failures, res.failures...)
|
||||||
return true, failures
|
if res.msg != nil {
|
||||||
|
if res.ede != "" {
|
||||||
|
resutil.SetMeta(w, "ede", res.ede)
|
||||||
|
}
|
||||||
|
u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger)
|
||||||
|
return true, failures
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false, failures
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false, failures
|
return false, failures
|
||||||
}
|
}
|
||||||
|
|
||||||
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
|
func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group upstreamRace) raceResult {
|
||||||
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
|
timeout := u.upstreamTimeout
|
||||||
var rm *dns.Msg
|
if len(group) > 1 {
|
||||||
var t time.Duration
|
// Cap the whole walk at raceMaxTotalTimeout: per-upstream timeouts
|
||||||
var err error
|
// still honor raceMinPerUpstreamTimeout as a floor for correctness
|
||||||
|
// on slow links, but the outer context ensures the combined walk
|
||||||
|
// cannot exceed the cap regardless of group size.
|
||||||
|
timeout = max(raceMaxTotalTimeout/time.Duration(len(group)), raceMinPerUpstreamTimeout)
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
ctx, cancel = context.WithTimeout(ctx, raceMaxTotalTimeout)
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
var failures []upstreamFailure
|
||||||
|
for _, upstream := range group {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return raceResult{failures: failures}
|
||||||
|
}
|
||||||
|
// Clone the request per attempt: the exchange path mutates EDNS0
|
||||||
|
// options in-place, so reusing the same *dns.Msg across sequential
|
||||||
|
// upstreams would carry those mutations (e.g. a reduced UDP size)
|
||||||
|
// into the next attempt.
|
||||||
|
res, failure := u.queryUpstream(ctx, r.Copy(), upstream, timeout)
|
||||||
|
if failure != nil {
|
||||||
|
failures = append(failures, *failure)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
res.failures = failures
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
return raceResult{failures: failures}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration) (raceResult, *upstreamFailure) {
|
||||||
|
ctx, cancel := context.WithTimeout(parentCtx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
ctx, upstreamProto := contextWithUpstreamProtocolResult(ctx)
|
||||||
|
|
||||||
// Advertise EDNS0 so the upstream may include Extended DNS Errors
|
// Advertise EDNS0 so the upstream may include Extended DNS Errors
|
||||||
// (RFC 8914) in failure responses; we use those to short-circuit
|
// (RFC 8914) in failure responses; we use those to short-circuit
|
||||||
// failover for definitive answers like DNSSEC validation failures.
|
// failover for definitive answers like DNSSEC validation failures.
|
||||||
// Operate on a copy so the inbound request is unchanged: a client that
|
// The caller already passed a per-attempt copy, so we can mutate r
|
||||||
// did not advertise EDNS0 must not see an OPT in the response.
|
// directly; hadEdns reflects the original client request's state and
|
||||||
|
// controls whether we strip the OPT from the response.
|
||||||
hadEdns := r.IsEdns0() != nil
|
hadEdns := r.IsEdns0() != nil
|
||||||
reqUp := r
|
|
||||||
if !hadEdns {
|
if !hadEdns {
|
||||||
reqUp = r.Copy()
|
r.SetEdns0(upstreamUDPSize(), false)
|
||||||
reqUp.SetEdns0(upstreamUDPSize(), false)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var startTime time.Time
|
startTime := time.Now()
|
||||||
var upstreamProto *upstreamProtocolResult
|
rm, _, err := u.upstreamClient.exchange(ctx, upstream.String(), r)
|
||||||
func() {
|
|
||||||
ctx, cancel := context.WithTimeout(parentCtx, timeout)
|
|
||||||
defer cancel()
|
|
||||||
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
|
|
||||||
startTime = time.Now()
|
|
||||||
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), reqUp)
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return u.handleUpstreamError(err, upstream, startTime)
|
// A parent cancellation (e.g., another race won and the coordinator
|
||||||
|
// cancelled the losers) is not an upstream failure. Check both the
|
||||||
|
// error chain and the parent context: a transport may surface the
|
||||||
|
// cancellation as a read/deadline error rather than context.Canceled.
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(parentCtx.Err(), context.Canceled) {
|
||||||
|
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "canceled"}
|
||||||
|
}
|
||||||
|
failure := u.handleUpstreamError(err, upstream, startTime)
|
||||||
|
u.markUpstreamFail(upstream, failure.reason)
|
||||||
|
return raceResult{}, failure
|
||||||
}
|
}
|
||||||
|
|
||||||
if rm == nil || !rm.Response {
|
if rm == nil || !rm.Response {
|
||||||
return &upstreamFailure{upstream: upstream, reason: "no response"}
|
u.markUpstreamFail(upstream, "no response")
|
||||||
|
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"}
|
||||||
|
}
|
||||||
|
|
||||||
|
proto := ""
|
||||||
|
if upstreamProto != nil {
|
||||||
|
proto = upstreamProto.protocol
|
||||||
}
|
}
|
||||||
|
|
||||||
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
|
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
|
||||||
if code, ok := nonRetryableEDE(rm); ok {
|
if code, ok := nonRetryableEDE(rm); ok {
|
||||||
resutil.SetMeta(w, "ede", edeName(code))
|
|
||||||
if !hadEdns {
|
if !hadEdns {
|
||||||
stripOPT(rm)
|
stripOPT(rm)
|
||||||
}
|
}
|
||||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
|
u.markUpstreamOk(upstream)
|
||||||
return nil
|
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
|
||||||
}
|
}
|
||||||
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
|
reason := dns.RcodeToString[rm.Rcode]
|
||||||
|
u.markUpstreamFail(upstream, reason)
|
||||||
|
return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hadEdns {
|
if !hadEdns {
|
||||||
stripOPT(rm)
|
stripOPT(rm)
|
||||||
}
|
}
|
||||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
|
|
||||||
return nil
|
u.markUpstreamOk(upstream)
|
||||||
|
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// healthEntry returns the mutable health record for addr, lazily creating
|
||||||
|
// the map and the entry. Caller must hold u.healthMu.
|
||||||
|
func (u *upstreamResolverBase) healthEntry(addr netip.AddrPort) *UpstreamHealth {
|
||||||
|
if u.health == nil {
|
||||||
|
u.health = make(map[netip.AddrPort]*UpstreamHealth)
|
||||||
|
}
|
||||||
|
h := u.health[addr]
|
||||||
|
if h == nil {
|
||||||
|
h = &UpstreamHealth{}
|
||||||
|
u.health[addr] = h
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) markUpstreamOk(addr netip.AddrPort) {
|
||||||
|
u.healthMu.Lock()
|
||||||
|
defer u.healthMu.Unlock()
|
||||||
|
h := u.healthEntry(addr)
|
||||||
|
h.LastOk = time.Now()
|
||||||
|
h.LastFail = time.Time{}
|
||||||
|
h.LastErr = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) markUpstreamFail(addr netip.AddrPort, reason string) {
|
||||||
|
u.healthMu.Lock()
|
||||||
|
defer u.healthMu.Unlock()
|
||||||
|
h := u.healthEntry(addr)
|
||||||
|
h.LastFail = time.Now()
|
||||||
|
h.LastErr = reason
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamHealth returns a snapshot of per-upstream query outcomes.
|
||||||
|
func (u *upstreamResolverBase) UpstreamHealth() map[netip.AddrPort]UpstreamHealth {
|
||||||
|
u.healthMu.RLock()
|
||||||
|
defer u.healthMu.RUnlock()
|
||||||
|
out := make(map[netip.AddrPort]UpstreamHealth, len(u.health))
|
||||||
|
for k, v := range u.health {
|
||||||
|
out[k] = *v
|
||||||
|
}
|
||||||
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
// upstreamUDPSize returns the EDNS0 UDP buffer size we advertise to upstreams,
|
// upstreamUDPSize returns the EDNS0 UDP buffer size we advertise to upstreams,
|
||||||
@@ -358,12 +549,23 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
|
|||||||
return &upstreamFailure{upstream: upstream, reason: reason}
|
return &upstreamFailure{upstream: upstream, reason: reason}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool {
|
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
|
||||||
u.successCount.Add(1)
|
if u.statusRecorder == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
|
||||||
|
if peerInfo == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, proto string, logger *log.Entry) {
|
||||||
resutil.SetMeta(w, "upstream", upstream.String())
|
resutil.SetMeta(w, "upstream", upstream.String())
|
||||||
if upstreamProto != nil && upstreamProto.protocol != "" {
|
if proto != "" {
|
||||||
resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol)
|
resutil.SetMeta(w, "upstream_protocol", proto)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear Zero bit from external responses to prevent upstream servers from
|
// Clear Zero bit from external responses to prevent upstream servers from
|
||||||
@@ -372,14 +574,11 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn
|
|||||||
|
|
||||||
if err := w.WriteMsg(rm); err != nil {
|
if err := w.WriteMsg(rm); err != nil {
|
||||||
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
|
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) {
|
func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) {
|
||||||
totalUpstreams := len(u.upstreamServers)
|
totalUpstreams := len(u.flatUpstreams())
|
||||||
failedCount := len(failures)
|
failedCount := len(failures)
|
||||||
failureSummary := formatFailures(failures)
|
failureSummary := formatFailures(failures)
|
||||||
|
|
||||||
@@ -434,119 +633,6 @@ func edeName(code uint16) string {
|
|||||||
return fmt.Sprintf("EDE %d", code)
|
return fmt.Sprintf("EDE %d", code)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProbeAvailability tests all upstream servers simultaneously and
|
|
||||||
// disables the resolver if none work
|
|
||||||
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
|
|
||||||
u.mutex.Lock()
|
|
||||||
defer u.mutex.Unlock()
|
|
||||||
|
|
||||||
// avoid probe if upstreams could resolve at least one query
|
|
||||||
if u.successCount.Load() > 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var success bool
|
|
||||||
var mu sync.Mutex
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
|
|
||||||
var errs *multierror.Error
|
|
||||||
for _, upstream := range u.upstreamServers {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(upstream netip.AddrPort) {
|
|
||||||
defer wg.Done()
|
|
||||||
err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond)
|
|
||||||
if err != nil {
|
|
||||||
mu.Lock()
|
|
||||||
errs = multierror.Append(errs, err)
|
|
||||||
mu.Unlock()
|
|
||||||
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
mu.Lock()
|
|
||||||
success = true
|
|
||||||
mu.Unlock()
|
|
||||||
}(upstream)
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-u.ctx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
// didn't find a working upstream server, let's disable and try later
|
|
||||||
if !success {
|
|
||||||
u.disable(errs.ErrorOrNil())
|
|
||||||
|
|
||||||
if u.statusRecorder == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
u.statusRecorder.PublishEvent(
|
|
||||||
proto.SystemEvent_WARNING,
|
|
||||||
proto.SystemEvent_DNS,
|
|
||||||
"All upstream servers failed (probe failed)",
|
|
||||||
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
|
|
||||||
map[string]string{"upstreams": u.upstreamServersString()},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// waitUntilResponse retries, in an exponential interval, querying the upstream servers until it gets a positive response
|
|
||||||
func (u *upstreamResolverBase) waitUntilResponse() {
|
|
||||||
exponentialBackOff := &backoff.ExponentialBackOff{
|
|
||||||
InitialInterval: 500 * time.Millisecond,
|
|
||||||
RandomizationFactor: 0.5,
|
|
||||||
Multiplier: 1.1,
|
|
||||||
MaxInterval: u.reactivatePeriod,
|
|
||||||
MaxElapsedTime: 0,
|
|
||||||
Stop: backoff.Stop,
|
|
||||||
Clock: backoff.SystemClock,
|
|
||||||
}
|
|
||||||
|
|
||||||
operation := func() error {
|
|
||||||
select {
|
|
||||||
case <-u.ctx.Done():
|
|
||||||
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServersString()))
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, upstream := range u.upstreamServers {
|
|
||||||
if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil {
|
|
||||||
log.Tracef("upstream check for %s: %s", upstream, err)
|
|
||||||
} else {
|
|
||||||
// at least one upstream server is available, stop probing
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServersString(), exponentialBackOff.NextBackOff())
|
|
||||||
return fmt.Errorf("upstream check call error")
|
|
||||||
}
|
|
||||||
|
|
||||||
err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx))
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, context.Canceled) {
|
|
||||||
log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString())
|
|
||||||
} else {
|
|
||||||
log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
|
|
||||||
u.successCount.Add(1)
|
|
||||||
u.reactivate()
|
|
||||||
u.mutex.Lock()
|
|
||||||
u.disabled = false
|
|
||||||
u.mutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// isTimeout returns true if the given error is a network timeout error.
|
// isTimeout returns true if the given error is a network timeout error.
|
||||||
//
|
//
|
||||||
// Copied from k8s.io/apimachinery/pkg/util/net.IsTimeout
|
// Copied from k8s.io/apimachinery/pkg/util/net.IsTimeout
|
||||||
@@ -558,45 +644,6 @@ func isTimeout(err error) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) disable(err error) {
|
|
||||||
if u.disabled {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
|
|
||||||
u.successCount.Store(0)
|
|
||||||
u.deactivate(err)
|
|
||||||
u.disabled = true
|
|
||||||
u.wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer u.wg.Done()
|
|
||||||
u.waitUntilResponse()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *upstreamResolverBase) upstreamServersString() string {
|
|
||||||
var servers []string
|
|
||||||
for _, server := range u.upstreamServers {
|
|
||||||
servers = append(servers, server.String())
|
|
||||||
}
|
|
||||||
return strings.Join(servers, ", ")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error {
|
|
||||||
mergedCtx, cancel := context.WithTimeout(baseCtx, timeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
if externalCtx != nil {
|
|
||||||
stop2 := context.AfterFunc(externalCtx, cancel)
|
|
||||||
defer stop2()
|
|
||||||
}
|
|
||||||
|
|
||||||
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
|
|
||||||
|
|
||||||
_, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
|
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
|
||||||
func clientUDPMaxSize(r *dns.Msg) int {
|
func clientUDPMaxSize(r *dns.Msg) int {
|
||||||
if opt := r.IsEdns0(); opt != nil {
|
if opt := r.IsEdns0(); opt != nil {
|
||||||
@@ -608,13 +655,10 @@ func clientUDPMaxSize(r *dns.Msg) int {
|
|||||||
// ExchangeWithFallback exchanges a DNS message with the upstream server.
|
// ExchangeWithFallback exchanges a DNS message with the upstream server.
|
||||||
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
|
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
|
||||||
// If the inbound request came over TCP (via context), it skips the UDP attempt.
|
// If the inbound request came over TCP (via context), it skips the UDP attempt.
|
||||||
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
|
|
||||||
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
|
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
|
||||||
// If the request came in over TCP, go straight to TCP upstream.
|
// If the request came in over TCP, go straight to TCP upstream.
|
||||||
if dnsProtocolFromContext(ctx) == protoTCP {
|
if dnsProtocolFromContext(ctx) == protoTCP {
|
||||||
tcpClient := *client
|
rm, t, err := toTCPClient(client).ExchangeContext(ctx, r, upstream)
|
||||||
tcpClient.Net = protoTCP
|
|
||||||
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||||
}
|
}
|
||||||
@@ -634,18 +678,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
|||||||
opt.SetUDPSize(maxUDPPayload)
|
opt.SetUDPSize(maxUDPPayload)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
rm, t, err := client.ExchangeContext(ctx, r, upstream)
|
||||||
rm *dns.Msg
|
|
||||||
t time.Duration
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
if ctx == nil {
|
|
||||||
rm, t, err = client.Exchange(r, upstream)
|
|
||||||
} else {
|
|
||||||
rm, t, err = client.ExchangeContext(ctx, r, upstream)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, t, fmt.Errorf("with udp: %w", err)
|
return nil, t, fmt.Errorf("with udp: %w", err)
|
||||||
}
|
}
|
||||||
@@ -659,15 +692,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
|||||||
// data than the client's buffer, we could truncate locally and skip
|
// data than the client's buffer, we could truncate locally and skip
|
||||||
// the TCP retry.
|
// the TCP retry.
|
||||||
|
|
||||||
tcpClient := *client
|
rm, t, err = toTCPClient(client).ExchangeContext(ctx, r, upstream)
|
||||||
tcpClient.Net = protoTCP
|
|
||||||
|
|
||||||
if ctx == nil {
|
|
||||||
rm, t, err = tcpClient.Exchange(r, upstream)
|
|
||||||
} else {
|
|
||||||
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||||
}
|
}
|
||||||
@@ -681,6 +706,25 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
|||||||
return rm, t, nil
|
return rm, t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// toTCPClient returns a copy of c configured for TCP. If c's Dialer has a
|
||||||
|
// *net.UDPAddr bound as LocalAddr (iOS does this to keep the source IP on
|
||||||
|
// the tunnel interface), it is converted to the equivalent *net.TCPAddr
|
||||||
|
// so net.Dialer doesn't reject the TCP dial with "mismatched local
|
||||||
|
// address type".
|
||||||
|
func toTCPClient(c *dns.Client) *dns.Client {
|
||||||
|
tcp := *c
|
||||||
|
tcp.Net = protoTCP
|
||||||
|
if tcp.Dialer == nil {
|
||||||
|
return &tcp
|
||||||
|
}
|
||||||
|
d := *tcp.Dialer
|
||||||
|
if ua, ok := d.LocalAddr.(*net.UDPAddr); ok {
|
||||||
|
d.LocalAddr = &net.TCPAddr{IP: ua.IP, Port: ua.Port, Zone: ua.Zone}
|
||||||
|
}
|
||||||
|
tcp.Dialer = &d
|
||||||
|
return &tcp
|
||||||
|
}
|
||||||
|
|
||||||
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
|
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
|
||||||
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
|
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
|
||||||
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
|
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
|
||||||
@@ -822,15 +866,36 @@ func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State {
|
|||||||
return bestMatch
|
return bestMatch
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
|
// haMapRouteCount returns the total number of routes across all HA
|
||||||
if u.statusRecorder == nil {
|
// groups in the map. route.HAMap is keyed by HAUniqueID with slices of
|
||||||
return ""
|
// routes per key, so len(hm) is the number of HA groups, not routes.
|
||||||
|
func haMapRouteCount(hm route.HAMap) int {
|
||||||
|
total := 0
|
||||||
|
for _, routes := range hm {
|
||||||
|
total += len(routes)
|
||||||
}
|
}
|
||||||
|
return total
|
||||||
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
|
}
|
||||||
if peerInfo == nil {
|
|
||||||
return ""
|
// haMapContains checks whether ip is covered by any concrete prefix in
|
||||||
}
|
// the HA map. haveDynamic is reported separately: dynamic (domain-based)
|
||||||
|
// routes carry a placeholder Network that can't be prefix-checked, so we
|
||||||
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
|
// can't know at this point whether ip is reached through one. Callers
|
||||||
|
// decide how to interpret the unknown: health projection treats it as
|
||||||
|
// "possibly routed" to avoid emitting false-positive warnings during
|
||||||
|
// startup, while iOS dial selection requires a concrete match before
|
||||||
|
// binding to the tunnel.
|
||||||
|
func haMapContains(hm route.HAMap, ip netip.Addr) (matched, haveDynamic bool) {
|
||||||
|
for _, routes := range hm {
|
||||||
|
for _, r := range routes {
|
||||||
|
if r.IsDynamic() {
|
||||||
|
haveDynamic = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if r.Network.Contains(ip) {
|
||||||
|
return true, haveDynamic
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, haveDynamic
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type upstreamResolver struct {
|
type upstreamResolver struct {
|
||||||
@@ -26,9 +27,9 @@ func newUpstreamResolver(
|
|||||||
_ WGIface,
|
_ WGIface,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
hostsDNSHolder *hostsDNSHolder,
|
hostsDNSHolder *hostsDNSHolder,
|
||||||
domain string,
|
d domain.Domain,
|
||||||
) (*upstreamResolver, error) {
|
) (*upstreamResolver, error) {
|
||||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
|
||||||
c := &upstreamResolver{
|
c := &upstreamResolver{
|
||||||
upstreamResolverBase: upstreamResolverBase,
|
upstreamResolverBase: upstreamResolverBase,
|
||||||
hostsDNSHolder: hostsDNSHolder,
|
hostsDNSHolder: hostsDNSHolder,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type upstreamResolver struct {
|
type upstreamResolver struct {
|
||||||
@@ -24,9 +25,9 @@ func newUpstreamResolver(
|
|||||||
wgIface WGIface,
|
wgIface WGIface,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
_ *hostsDNSHolder,
|
_ *hostsDNSHolder,
|
||||||
domain string,
|
d domain.Domain,
|
||||||
) (*upstreamResolver, error) {
|
) (*upstreamResolver, error) {
|
||||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
|
||||||
nonIOS := &upstreamResolver{
|
nonIOS := &upstreamResolver{
|
||||||
upstreamResolverBase: upstreamResolverBase,
|
upstreamResolverBase: upstreamResolverBase,
|
||||||
nsNet: wgIface.GetNet(),
|
nsNet: wgIface.GetNet(),
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type upstreamResolverIOS struct {
|
type upstreamResolverIOS struct {
|
||||||
@@ -27,9 +28,9 @@ func newUpstreamResolver(
|
|||||||
wgIface WGIface,
|
wgIface WGIface,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
_ *hostsDNSHolder,
|
_ *hostsDNSHolder,
|
||||||
domain string,
|
d domain.Domain,
|
||||||
) (*upstreamResolverIOS, error) {
|
) (*upstreamResolverIOS, error) {
|
||||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
|
||||||
|
|
||||||
ios := &upstreamResolverIOS{
|
ios := &upstreamResolverIOS{
|
||||||
upstreamResolverBase: upstreamResolverBase,
|
upstreamResolverBase: upstreamResolverBase,
|
||||||
@@ -62,9 +63,16 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
|||||||
upstreamIP = upstreamIP.Unmap()
|
upstreamIP = upstreamIP.Unmap()
|
||||||
}
|
}
|
||||||
addr := u.wgIface.Address()
|
addr := u.wgIface.Address()
|
||||||
|
var routed bool
|
||||||
|
if u.selectedRoutes != nil {
|
||||||
|
// Only a concrete prefix match binds to the tunnel: dialing
|
||||||
|
// through a private client for an upstream we can't prove is
|
||||||
|
// routed would break public resolvers.
|
||||||
|
routed, _ = haMapContains(u.selectedRoutes(), upstreamIP)
|
||||||
|
}
|
||||||
needsPrivate := addr.Network.Contains(upstreamIP) ||
|
needsPrivate := addr.Network.Contains(upstreamIP) ||
|
||||||
addr.IPv6Net.Contains(upstreamIP) ||
|
addr.IPv6Net.Contains(upstreamIP) ||
|
||||||
(u.routeMatch != nil && u.routeMatch(upstreamIP))
|
routed
|
||||||
if needsPrivate {
|
if needsPrivate {
|
||||||
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
|
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
|
||||||
client, err = GetClientPrivate(u.wgIface, upstreamIP, timeout)
|
client, err = GetClientPrivate(u.wgIface, upstreamIP, timeout)
|
||||||
@@ -73,8 +81,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cannot use client.ExchangeContext because it overwrites our Dialer
|
return ExchangeWithFallback(ctx, client, r, upstream)
|
||||||
return ExchangeWithFallback(nil, client, r, upstream)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientPrivate returns a new DNS client bound to the local IP of the Netbird interface.
|
// GetClientPrivate returns a new DNS client bound to the local IP of the Netbird interface.
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -73,7 +74,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()))
|
servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
resolver.upstreamServers = servers
|
resolver.addRace(servers)
|
||||||
resolver.upstreamTimeout = testCase.timeout
|
resolver.upstreamTimeout = testCase.timeout
|
||||||
if testCase.cancelCTX {
|
if testCase.cancelCTX {
|
||||||
cancel()
|
cancel()
|
||||||
@@ -132,20 +133,10 @@ func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
|
|||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockUpstreamResolver struct {
|
|
||||||
r *dns.Msg
|
|
||||||
rtt time.Duration
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
// exchange mock implementation of exchange from upstreamResolver
|
|
||||||
func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
|
|
||||||
return c.r, c.rtt, c.err
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockUpstreamResponse struct {
|
type mockUpstreamResponse struct {
|
||||||
msg *dns.Msg
|
msg *dns.Msg
|
||||||
err error
|
err error
|
||||||
|
delay time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockUpstreamResolverPerServer struct {
|
type mockUpstreamResolverPerServer struct {
|
||||||
@@ -153,63 +144,19 @@ type mockUpstreamResolverPerServer struct {
|
|||||||
rtt time.Duration
|
rtt time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
|
func (c mockUpstreamResolverPerServer) exchange(ctx context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||||
if r, ok := c.responses[upstream]; ok {
|
r, ok := c.responses[upstream]
|
||||||
return r.msg, c.rtt, r.err
|
if !ok {
|
||||||
|
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
|
||||||
}
|
}
|
||||||
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
|
if r.delay > 0 {
|
||||||
}
|
select {
|
||||||
|
case <-time.After(r.delay):
|
||||||
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
case <-ctx.Done():
|
||||||
mockClient := &mockUpstreamResolver{
|
return nil, c.rtt, ctx.Err()
|
||||||
err: dns.ErrTime,
|
}
|
||||||
r: new(dns.Msg),
|
|
||||||
rtt: time.Millisecond,
|
|
||||||
}
|
|
||||||
|
|
||||||
resolver := &upstreamResolverBase{
|
|
||||||
ctx: context.TODO(),
|
|
||||||
upstreamClient: mockClient,
|
|
||||||
upstreamTimeout: UpstreamTimeout,
|
|
||||||
reactivatePeriod: time.Microsecond * 100,
|
|
||||||
}
|
|
||||||
addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection
|
|
||||||
resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())}
|
|
||||||
|
|
||||||
failed := false
|
|
||||||
resolver.deactivate = func(error) {
|
|
||||||
failed = true
|
|
||||||
// After deactivation, make the mock client work again
|
|
||||||
mockClient.err = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
reactivated := false
|
|
||||||
resolver.reactivate = func() {
|
|
||||||
reactivated = true
|
|
||||||
}
|
|
||||||
|
|
||||||
resolver.ProbeAvailability(context.TODO())
|
|
||||||
|
|
||||||
if !failed {
|
|
||||||
t.Errorf("expected that resolving was deactivated")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !resolver.disabled {
|
|
||||||
t.Errorf("resolver should be Disabled")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(time.Millisecond * 200)
|
|
||||||
|
|
||||||
if !reactivated {
|
|
||||||
t.Errorf("expected that resolving was reactivated")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resolver.disabled {
|
|
||||||
t.Errorf("should be enabled")
|
|
||||||
}
|
}
|
||||||
|
return r.msg, c.rtt, r.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpstreamResolver_Failover(t *testing.T) {
|
func TestUpstreamResolver_Failover(t *testing.T) {
|
||||||
@@ -339,9 +286,9 @@ func TestUpstreamResolver_Failover(t *testing.T) {
|
|||||||
resolver := &upstreamResolverBase{
|
resolver := &upstreamResolverBase{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
upstreamClient: trackingClient,
|
upstreamClient: trackingClient,
|
||||||
upstreamServers: []netip.AddrPort{upstream1, upstream2},
|
|
||||||
upstreamTimeout: UpstreamTimeout,
|
upstreamTimeout: UpstreamTimeout,
|
||||||
}
|
}
|
||||||
|
resolver.addRace([]netip.AddrPort{upstream1, upstream2})
|
||||||
|
|
||||||
var responseMSG *dns.Msg
|
var responseMSG *dns.Msg
|
||||||
responseWriter := &test.MockResponseWriter{
|
responseWriter := &test.MockResponseWriter{
|
||||||
@@ -421,9 +368,9 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
|
|||||||
resolver := &upstreamResolverBase{
|
resolver := &upstreamResolverBase{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
upstreamClient: mockClient,
|
upstreamClient: mockClient,
|
||||||
upstreamServers: []netip.AddrPort{upstream},
|
|
||||||
upstreamTimeout: UpstreamTimeout,
|
upstreamTimeout: UpstreamTimeout,
|
||||||
}
|
}
|
||||||
|
resolver.addRace([]netip.AddrPort{upstream})
|
||||||
|
|
||||||
var responseMSG *dns.Msg
|
var responseMSG *dns.Msg
|
||||||
responseWriter := &test.MockResponseWriter{
|
responseWriter := &test.MockResponseWriter{
|
||||||
@@ -440,6 +387,136 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
|
|||||||
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL")
|
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestUpstreamResolver_RaceAcrossGroups covers two nameserver groups
|
||||||
|
// configured for the same domain, with one broken group. The merge+race
|
||||||
|
// path should answer as fast as the working group and not pay the timeout
|
||||||
|
// of the broken one on every query.
|
||||||
|
func TestUpstreamResolver_RaceAcrossGroups(t *testing.T) {
|
||||||
|
broken := netip.MustParseAddrPort("192.0.2.1:53")
|
||||||
|
working := netip.MustParseAddrPort("192.0.2.2:53")
|
||||||
|
successAnswer := "192.0.2.100"
|
||||||
|
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
|
||||||
|
|
||||||
|
mockClient := &mockUpstreamResolverPerServer{
|
||||||
|
responses: map[string]mockUpstreamResponse{
|
||||||
|
// Force the broken upstream to only unblock via timeout /
|
||||||
|
// cancellation so the assertion below can't pass if races
|
||||||
|
// were run serially.
|
||||||
|
broken.String(): {err: timeoutErr, delay: 500 * time.Millisecond},
|
||||||
|
working.String(): {msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||||
|
},
|
||||||
|
rtt: time.Millisecond,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resolver := &upstreamResolverBase{
|
||||||
|
ctx: ctx,
|
||||||
|
upstreamClient: mockClient,
|
||||||
|
upstreamTimeout: 250 * time.Millisecond,
|
||||||
|
}
|
||||||
|
resolver.addRace([]netip.AddrPort{broken})
|
||||||
|
resolver.addRace([]netip.AddrPort{working})
|
||||||
|
|
||||||
|
var responseMSG *dns.Msg
|
||||||
|
responseWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
responseMSG = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||||
|
start := time.Now()
|
||||||
|
resolver.ServeDNS(responseWriter, inputMSG)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.NotNil(t, responseMSG, "should write a response")
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode)
|
||||||
|
require.NotEmpty(t, responseMSG.Answer)
|
||||||
|
assert.Contains(t, responseMSG.Answer[0].String(), successAnswer)
|
||||||
|
// Working group answers in a single RTT; the broken group's
|
||||||
|
// timeout (100ms) must not block the response.
|
||||||
|
assert.Less(t, elapsed, 100*time.Millisecond, "race must not wait for broken group's timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUpstreamResolver_AllGroupsFail checks that when every group fails the
|
||||||
|
// resolver returns SERVFAIL rather than leaking a partial response.
|
||||||
|
func TestUpstreamResolver_AllGroupsFail(t *testing.T) {
|
||||||
|
a := netip.MustParseAddrPort("192.0.2.1:53")
|
||||||
|
b := netip.MustParseAddrPort("192.0.2.2:53")
|
||||||
|
|
||||||
|
mockClient := &mockUpstreamResolverPerServer{
|
||||||
|
responses: map[string]mockUpstreamResponse{
|
||||||
|
a.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||||
|
b.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||||
|
},
|
||||||
|
rtt: time.Millisecond,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resolver := &upstreamResolverBase{
|
||||||
|
ctx: ctx,
|
||||||
|
upstreamClient: mockClient,
|
||||||
|
upstreamTimeout: UpstreamTimeout,
|
||||||
|
}
|
||||||
|
resolver.addRace([]netip.AddrPort{a})
|
||||||
|
resolver.addRace([]netip.AddrPort{b})
|
||||||
|
|
||||||
|
var responseMSG *dns.Msg
|
||||||
|
responseWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
responseMSG = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
|
||||||
|
require.NotNil(t, responseMSG)
|
||||||
|
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUpstreamResolver_HealthTracking verifies that query-path results are
|
||||||
|
// recorded into per-upstream health, which is what projects back to
|
||||||
|
// NSGroupState for status reporting.
|
||||||
|
func TestUpstreamResolver_HealthTracking(t *testing.T) {
|
||||||
|
ok := netip.MustParseAddrPort("192.0.2.10:53")
|
||||||
|
bad := netip.MustParseAddrPort("192.0.2.11:53")
|
||||||
|
|
||||||
|
mockClient := &mockUpstreamResolverPerServer{
|
||||||
|
responses: map[string]mockUpstreamResponse{
|
||||||
|
ok.String(): {msg: buildMockResponse(dns.RcodeSuccess, "192.0.2.100")},
|
||||||
|
bad.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||||
|
},
|
||||||
|
rtt: time.Millisecond,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resolver := &upstreamResolverBase{
|
||||||
|
ctx: ctx,
|
||||||
|
upstreamClient: mockClient,
|
||||||
|
upstreamTimeout: UpstreamTimeout,
|
||||||
|
}
|
||||||
|
resolver.addRace([]netip.AddrPort{ok, bad})
|
||||||
|
|
||||||
|
responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}
|
||||||
|
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
|
||||||
|
|
||||||
|
health := resolver.UpstreamHealth()
|
||||||
|
require.Contains(t, health, ok)
|
||||||
|
assert.False(t, health[ok].LastOk.IsZero(), "ok upstream should have LastOk set")
|
||||||
|
assert.Empty(t, health[ok].LastErr)
|
||||||
|
|
||||||
|
// bad upstream was never tried because ok answered first; its health
|
||||||
|
// should remain unset.
|
||||||
|
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
|
||||||
|
}
|
||||||
|
|
||||||
func TestFormatFailures(t *testing.T) {
|
func TestFormatFailures(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -665,10 +742,10 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
|
|||||||
// Verify that a client EDNS0 larger than our MTU-derived limit gets
|
// Verify that a client EDNS0 larger than our MTU-derived limit gets
|
||||||
// capped in the outgoing request so the upstream doesn't send a
|
// capped in the outgoing request so the upstream doesn't send a
|
||||||
// response larger than our read buffer.
|
// response larger than our read buffer.
|
||||||
var receivedUDPSize uint16
|
var receivedUDPSize atomic.Uint32
|
||||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
if opt := r.IsEdns0(); opt != nil {
|
if opt := r.IsEdns0(); opt != nil {
|
||||||
receivedUDPSize = opt.UDPSize()
|
receivedUDPSize.Store(uint32(opt.UDPSize()))
|
||||||
}
|
}
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetReply(r)
|
m.SetReply(r)
|
||||||
@@ -699,7 +776,7 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
|
|||||||
require.NotNil(t, rm)
|
require.NotNil(t, rm)
|
||||||
|
|
||||||
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
|
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
|
||||||
assert.Equal(t, expectedMax, receivedUDPSize,
|
assert.Equal(t, expectedMax, uint16(receivedUDPSize.Load()),
|
||||||
"upstream should see capped EDNS0, not the client's 4096")
|
"upstream should see capped EDNS0, not the client's 4096")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -874,7 +951,7 @@ func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
|
|||||||
resolver := &upstreamResolverBase{
|
resolver := &upstreamResolverBase{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
upstreamClient: tracking,
|
upstreamClient: tracking,
|
||||||
upstreamServers: []netip.AddrPort{upstream1, upstream2},
|
upstreamServers: []upstreamRace{{upstream1, upstream2}},
|
||||||
upstreamTimeout: UpstreamTimeout,
|
upstreamTimeout: UpstreamTimeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -61,9 +61,11 @@ import (
|
|||||||
cProto "github.com/netbirdio/netbird/client/proto"
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
types "github.com/netbirdio/netbird/shared/management/types"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
|
nbnetworkmap "github.com/netbirdio/netbird/shared/management/networkmap"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/shared/netiputil"
|
"github.com/netbirdio/netbird/shared/netiputil"
|
||||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||||
@@ -202,6 +204,13 @@ type Engine struct {
|
|||||||
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
||||||
networkSerial uint64
|
networkSerial uint64
|
||||||
|
|
||||||
|
// latestComponents is the most-recent NetworkMapComponents decoded from
|
||||||
|
// a NetworkMapEnvelope (capability=3 peers only). Held alongside the
|
||||||
|
// NetworkMap that Calculate() produced from it so future incremental
|
||||||
|
// updates have a base to apply changes against. nil for legacy-format
|
||||||
|
// peers. Guarded by syncMsgMux.
|
||||||
|
latestComponents *types.NetworkMapComponents
|
||||||
|
|
||||||
networkMonitor *networkmonitor.NetworkMonitor
|
networkMonitor *networkmonitor.NetworkMonitor
|
||||||
|
|
||||||
sshServer sshServer
|
sshServer sshServer
|
||||||
@@ -512,16 +521,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
|
|
||||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||||
|
|
||||||
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
|
e.dnsServer.SetRouteSources(e.routeManager.GetSelectedClientRoutes, e.routeManager.GetActiveClientRoutes)
|
||||||
for _, routes := range e.routeManager.GetSelectedClientRoutes() {
|
|
||||||
for _, r := range routes {
|
|
||||||
if r.Network.Contains(ip) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
})
|
|
||||||
|
|
||||||
if err = e.wgInterfaceCreate(); err != nil {
|
if err = e.wgInterfaceCreate(); err != nil {
|
||||||
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||||
@@ -874,8 +874,12 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
return e.ctx.Err()
|
return e.ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
// Envelope sync responses carry PeerConfig at the top level; legacy
|
||||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
|
// NetworkMap syncs carry it under NetworkMap.PeerConfig.
|
||||||
|
if pc := update.GetPeerConfig(); pc != nil {
|
||||||
|
e.handleAutoUpdateVersion(pc.GetAutoUpdate())
|
||||||
|
} else if nm := update.GetNetworkMap(); nm != nil && nm.GetPeerConfig() != nil {
|
||||||
|
e.handleAutoUpdateVersion(nm.GetPeerConfig().GetAutoUpdate())
|
||||||
}
|
}
|
||||||
|
|
||||||
if update.GetNetbirdConfig() != nil {
|
if update.GetNetbirdConfig() != nil {
|
||||||
@@ -916,11 +920,45 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
nm := update.GetNetworkMap()
|
var (
|
||||||
|
nm *mgmProto.NetworkMap
|
||||||
|
components *types.NetworkMapComponents
|
||||||
|
)
|
||||||
|
if envelope := update.GetNetworkMapEnvelope(); envelope != nil {
|
||||||
|
// Components-format peer: decode the envelope back to typed
|
||||||
|
// components, run Calculate() locally, and convert to the wire
|
||||||
|
// NetworkMap shape the rest of the engine consumes. Components are
|
||||||
|
// retained so future incremental updates can apply deltas instead
|
||||||
|
// of doing a full reconstruction.
|
||||||
|
localKey := e.config.WgPrivateKey.PublicKey().String()
|
||||||
|
dnsName := ""
|
||||||
|
if pc := update.GetPeerConfig(); pc != nil {
|
||||||
|
// PeerConfig.Fqdn = "<dns_label>.<dns_domain>" — extract the
|
||||||
|
// shared domain by stripping the peer's own label prefix. Falls
|
||||||
|
// back to empty if the FQDN doesn't have the expected shape.
|
||||||
|
dnsName = extractDNSDomainFromFQDN(pc.GetFqdn())
|
||||||
|
}
|
||||||
|
result, err := nbnetworkmap.EnvelopeToNetworkMap(e.ctx, envelope, localKey, dnsName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("decode network map envelope: %w", err)
|
||||||
|
}
|
||||||
|
nm = result.NetworkMap
|
||||||
|
components = result.Components
|
||||||
|
} else {
|
||||||
|
nm = update.GetNetworkMap()
|
||||||
|
}
|
||||||
if nm == nil {
|
if nm == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Only retain the components view when the server sent the envelope
|
||||||
|
// path. A legacy proto.NetworkMap means components == nil; writing it
|
||||||
|
// here would clobber a previously-cached snapshot, breaking the
|
||||||
|
// incremental-delta base on a future envelope sync.
|
||||||
|
if components != nil {
|
||||||
|
e.latestComponents = components
|
||||||
|
}
|
||||||
|
|
||||||
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
|
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
|
||||||
// Read the storage-enabled flag under the syncRespMux too.
|
// Read the storage-enabled flag under the syncRespMux too.
|
||||||
e.syncRespMux.RLock()
|
e.syncRespMux.RLock()
|
||||||
@@ -946,6 +984,19 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractDNSDomainFromFQDN returns the trailing dotted domain part of the
|
||||||
|
// receiving peer's FQDN — the same value the management server fills as
|
||||||
|
// dnsName when it builds the legacy NetworkMap. "peer42.netbird.cloud" →
|
||||||
|
// "netbird.cloud". An empty string is returned for unrecognized formats.
|
||||||
|
func extractDNSDomainFromFQDN(fqdn string) string {
|
||||||
|
for i := 0; i < len(fqdn); i++ {
|
||||||
|
if fqdn[i] == '.' && i+1 < len(fqdn) {
|
||||||
|
return fqdn[i+1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
||||||
if update != nil {
|
if update != nil {
|
||||||
// when we receive token we expect valid address list too
|
// when we receive token we expect valid address list too
|
||||||
@@ -1386,9 +1437,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
|
|
||||||
e.networkSerial = serial
|
e.networkSerial = serial
|
||||||
|
|
||||||
// Test received (upstream) servers for availability right away instead of upon usage.
|
|
||||||
// If no server of a server group responds this will disable the respective handler and retry later.
|
|
||||||
go e.dnsServer.ProbeAvailability()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1932,7 +1980,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
|
|||||||
return dnsServer, nil
|
return dnsServer, nil
|
||||||
|
|
||||||
case "ios":
|
case "ios":
|
||||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS)
|
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
|
||||||
return dnsServer, nil
|
return dnsServer, nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
@@ -1979,6 +2027,29 @@ func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
|
|||||||
return e.clientMetrics
|
return e.clientMetrics
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Performance bundles runtime-adjustable tunnel pool knobs.
|
||||||
|
// See Engine.SetPerformance. Nil fields are ignored.
|
||||||
|
type Performance struct {
|
||||||
|
PreallocatedBuffersPerPool *uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPerformance applies the given tuning to this engine's live Device.
|
||||||
|
func (e *Engine) SetPerformance(t Performance) error {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
defer e.syncMsgMux.Unlock()
|
||||||
|
if e.wgInterface == nil {
|
||||||
|
return fmt.Errorf("wg interface not initialized")
|
||||||
|
}
|
||||||
|
dev := e.wgInterface.GetWGDevice()
|
||||||
|
if dev == nil {
|
||||||
|
return fmt.Errorf("wg device not initialized")
|
||||||
|
}
|
||||||
|
if t.PreallocatedBuffersPerPool != nil {
|
||||||
|
dev.SetPreallocatedBuffersPerPool(*t.PreallocatedBuffersPerPool)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||||
iface, err := net.InterfaceByName(ifaceName)
|
iface, err := net.InterfaceByName(ifaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/management/server/job"
|
"github.com/netbirdio/netbird/management/server/job"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
@@ -66,8 +66,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
|
||||||
"github.com/netbirdio/netbird/shared/netiputil"
|
"github.com/netbirdio/netbird/shared/netiputil"
|
||||||
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||||
@@ -1641,7 +1641,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
|
|||||||
switch msg.Type {
|
switch msg.Type {
|
||||||
// handle route changes
|
// handle route changes
|
||||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||||
route, err := parseRouteMessage(buf[:n])
|
route, flags, err := parseRouteMessage(buf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||||
continue
|
continue
|
||||||
@@ -66,6 +66,10 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
|
|||||||
}
|
}
|
||||||
switch msg.Type {
|
switch msg.Type {
|
||||||
case unix.RTM_ADD:
|
case unix.RTM_ADD:
|
||||||
|
if systemops.IgnoreAddedDefaultRoute(flags) {
|
||||||
|
log.Debugf("Network monitor: ignoring added default route via %s, interface %s, flags %#x", route.Gw, intf, flags)
|
||||||
|
continue
|
||||||
|
}
|
||||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||||
return nil
|
return nil
|
||||||
case unix.RTM_DELETE:
|
case unix.RTM_DELETE:
|
||||||
@@ -78,22 +82,26 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
func parseRouteMessage(buf []byte) (*systemops.Route, int, error) {
|
||||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
|
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("parse RIB: %v", err)
|
return nil, 0, fmt.Errorf("parse RIB: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(msgs) != 1 {
|
if len(msgs) != 1 {
|
||||||
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
return nil, 0, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||||
}
|
}
|
||||||
|
|
||||||
msg, ok := msgs[0].(*route.RouteMessage)
|
msg, ok := msgs[0].(*route.RouteMessage)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
return nil, 0, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
return systemops.MsgToRoute(msg)
|
r, err := systemops.MsgToRoute(msg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
return r, msg.Flags, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// waitReadable blocks until fd has data to read, or ctx is cancelled.
|
// waitReadable blocks until fd has data to read, or ctx is cancelled.
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer/id"
|
"github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/worker"
|
"github.com/netbirdio/netbird/client/internal/peer/worker"
|
||||||
"github.com/netbirdio/netbird/client/internal/portforward"
|
"github.com/netbirdio/netbird/client/internal/portforward"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
@@ -899,7 +900,7 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to deterministic key if no NetBird PSK is configured
|
// Fallback to deterministic key if no NetBird PSK is configured
|
||||||
determKey, err := conn.rosenpassDetermKey()
|
determKey, err := rosenpass.DeterministicSeedKey(conn.config.LocalKey, conn.config.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
|
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
|
||||||
return nil
|
return nil
|
||||||
@@ -908,26 +909,6 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
|
|||||||
return determKey
|
return determKey
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo: move this logic into Rosenpass package
|
|
||||||
func (conn *Conn) rosenpassDetermKey() (*wgtypes.Key, error) {
|
|
||||||
lk := []byte(conn.config.LocalKey)
|
|
||||||
rk := []byte(conn.config.Key) // remote key
|
|
||||||
var keyInput []byte
|
|
||||||
if string(lk) > string(rk) {
|
|
||||||
//nolint:gocritic
|
|
||||||
keyInput = append(lk[:16], rk[:16]...)
|
|
||||||
} else {
|
|
||||||
//nolint:gocritic
|
|
||||||
keyInput = append(rk[:16], lk[:16]...)
|
|
||||||
}
|
|
||||||
|
|
||||||
key, err := wgtypes.NewKey(keyInput)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isController(config ConnConfig) bool {
|
func isController(config ConnConfig) bool {
|
||||||
return config.LocalKey > config.Key
|
return config.LocalKey > config.Key
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -185,9 +185,12 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
|
|||||||
return s.eventsChan
|
return s.eventsChan
|
||||||
}
|
}
|
||||||
|
|
||||||
// Status holds a state of peers, signal, management connections and relays
|
// Status holds a state of peers, signal, management connections and relays.
|
||||||
|
// mux is an RWMutex so hot read paths (notably PeerStateByIP, called for
|
||||||
|
// every private-service request) don't contend against each other.
|
||||||
|
// Pure read methods take RLock; anything that mutates state takes Lock.
|
||||||
type Status struct {
|
type Status struct {
|
||||||
mux sync.Mutex
|
mux sync.RWMutex
|
||||||
peers map[string]State
|
peers map[string]State
|
||||||
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
||||||
signalState bool
|
signalState bool
|
||||||
@@ -283,8 +286,8 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string)
|
|||||||
|
|
||||||
// GetPeer adds peer to Daemon status map
|
// GetPeer adds peer to Daemon status map
|
||||||
func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
|
|
||||||
state, ok := d.peers[peerPubKey]
|
state, ok := d.peers[peerPubKey]
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -294,8 +297,8 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) PeerByIP(ip string) (string, bool) {
|
func (d *Status) PeerByIP(ip string) (string, bool) {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
|
|
||||||
for _, state := range d.peers {
|
for _, state := range d.peers {
|
||||||
if state.IP == ip {
|
if state.IP == ip {
|
||||||
@@ -305,6 +308,25 @@ func (d *Status) PeerByIP(ip string) (string, bool) {
|
|||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PeerStateByIP returns the full peer State for the given tunnel IP.
|
||||||
|
// Matches against either the IPv4 (State.IP) or IPv6 (State.IPv6) tunnel
|
||||||
|
// address so dual-stack peers are reachable on either family. Returns the
|
||||||
|
// zero State and false when no peer matches or the input is empty.
|
||||||
|
func (d *Status) PeerStateByIP(ip string) (State, bool) {
|
||||||
|
if ip == "" {
|
||||||
|
return State{}, false
|
||||||
|
}
|
||||||
|
d.mux.RLock()
|
||||||
|
defer d.mux.RUnlock()
|
||||||
|
|
||||||
|
for _, state := range d.peers {
|
||||||
|
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
|
||||||
|
return state, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return State{}, false
|
||||||
|
}
|
||||||
|
|
||||||
// RemovePeer removes peer from Daemon status map
|
// RemovePeer removes peer from Daemon status map
|
||||||
func (d *Status) RemovePeer(peerPubKey string) error {
|
func (d *Status) RemovePeer(peerPubKey string) error {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
@@ -702,8 +724,8 @@ func (d *Status) UnsubscribePeerStateChanges(subscription *StatusChangeSubscript
|
|||||||
|
|
||||||
// GetLocalPeerState returns the local peer state
|
// GetLocalPeerState returns the local peer state
|
||||||
func (d *Status) GetLocalPeerState() LocalPeerState {
|
func (d *Status) GetLocalPeerState() LocalPeerState {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
return d.localPeer.Clone()
|
return d.localPeer.Clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -909,8 +931,8 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetRosenpassState() RosenpassState {
|
func (d *Status) GetRosenpassState() RosenpassState {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
return RosenpassState{
|
return RosenpassState{
|
||||||
d.rosenpassEnabled,
|
d.rosenpassEnabled,
|
||||||
d.rosenpassPermissive,
|
d.rosenpassPermissive,
|
||||||
@@ -918,14 +940,14 @@ func (d *Status) GetRosenpassState() RosenpassState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetLazyConnection() bool {
|
func (d *Status) GetLazyConnection() bool {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
return d.lazyConnectionEnabled
|
return d.lazyConnectionEnabled
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetManagementState() ManagementState {
|
func (d *Status) GetManagementState() ManagementState {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
return ManagementState{
|
return ManagementState{
|
||||||
d.mgmAddress,
|
d.mgmAddress,
|
||||||
d.managementState,
|
d.managementState,
|
||||||
@@ -951,8 +973,8 @@ func (d *Status) UpdateLatency(pubKey string, latency time.Duration) error {
|
|||||||
|
|
||||||
// IsLoginRequired determines if a peer's login has expired.
|
// IsLoginRequired determines if a peer's login has expired.
|
||||||
func (d *Status) IsLoginRequired() bool {
|
func (d *Status) IsLoginRequired() bool {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
|
|
||||||
// if peer is connected to the management then login is not expired
|
// if peer is connected to the management then login is not expired
|
||||||
if d.managementState {
|
if d.managementState {
|
||||||
@@ -967,8 +989,8 @@ func (d *Status) IsLoginRequired() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetSignalState() SignalState {
|
func (d *Status) GetSignalState() SignalState {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
return SignalState{
|
return SignalState{
|
||||||
d.signalAddress,
|
d.signalAddress,
|
||||||
d.signalState,
|
d.signalState,
|
||||||
@@ -978,8 +1000,8 @@ func (d *Status) GetSignalState() SignalState {
|
|||||||
|
|
||||||
// GetRelayStates returns the stun/turn/permanent relay states
|
// GetRelayStates returns the stun/turn/permanent relay states
|
||||||
func (d *Status) GetRelayStates() []relay.ProbeResult {
|
func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
if d.relayMgr == nil {
|
if d.relayMgr == nil {
|
||||||
return d.relayStates
|
return d.relayStates
|
||||||
}
|
}
|
||||||
@@ -1008,8 +1030,8 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) ForwardingRules() []firewall.ForwardRule {
|
func (d *Status) ForwardingRules() []firewall.ForwardRule {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
if d.ingressGwMgr == nil {
|
if d.ingressGwMgr == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1018,16 +1040,16 @@ func (d *Status) ForwardingRules() []firewall.ForwardRule {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetDNSStates() []NSGroupState {
|
func (d *Status) GetDNSStates() []NSGroupState {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
|
|
||||||
// shallow copy is good enough, as slices fields are currently not updated
|
// shallow copy is good enough, as slices fields are currently not updated
|
||||||
return slices.Clone(d.nsGroupStates)
|
return slices.Clone(d.nsGroupStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
|
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
return maps.Clone(d.resolvedDomainsStates)
|
return maps.Clone(d.resolvedDomainsStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1043,8 +1065,8 @@ func (d *Status) GetFullStatus() FullStatus {
|
|||||||
LazyConnectionEnabled: d.GetLazyConnection(),
|
LazyConnectionEnabled: d.GetLazyConnection(),
|
||||||
}
|
}
|
||||||
|
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
|
|
||||||
fullStatus.LocalPeerState = d.localPeer
|
fullStatus.LocalPeerState = d.localPeer
|
||||||
|
|
||||||
@@ -1219,8 +1241,8 @@ func (d *Status) SetWgIface(wgInterface WGIfaceStatus) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) PeersStatus() (*configurer.Stats, error) {
|
func (d *Status) PeersStatus() (*configurer.Stats, error) {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
if d.wgIface == nil {
|
if d.wgIface == nil {
|
||||||
return nil, fmt.Errorf("wgInterface is nil, cannot retrieve peers status")
|
return nil, fmt.Errorf("wgInterface is nil, cannot retrieve peers status")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,6 +63,33 @@ func TestUpdatePeerState(t *testing.T) {
|
|||||||
assert.Equal(t, ip, state.IP, "ip should be equal")
|
assert.Equal(t, ip, state.IP, "ip should be equal")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStatus_PeerStateByIP(t *testing.T) {
|
||||||
|
status := NewRecorder("https://mgm")
|
||||||
|
req := require.New(t)
|
||||||
|
|
||||||
|
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", ""))
|
||||||
|
req.NoError(status.AddPeer("pk-2", "peer-2.netbird", "100.64.0.11", ""))
|
||||||
|
|
||||||
|
state, ok := status.PeerStateByIP("100.64.0.10")
|
||||||
|
req.True(ok, "known tunnel IP should resolve to a peer state")
|
||||||
|
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
|
||||||
|
req.Equal("peer-1.netbird", state.FQDN, "matching state must carry the right FQDN")
|
||||||
|
|
||||||
|
_, ok = status.PeerStateByIP("100.64.0.99")
|
||||||
|
req.False(ok, "unknown IP must report ok=false")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStatus_PeerStateByIP_MatchesIPv6(t *testing.T) {
|
||||||
|
status := NewRecorder("https://mgm")
|
||||||
|
req := require.New(t)
|
||||||
|
|
||||||
|
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", "fd00::1"))
|
||||||
|
|
||||||
|
state, ok := status.PeerStateByIP("fd00::1")
|
||||||
|
req.True(ok, "IPv6-only match must resolve to the peer state")
|
||||||
|
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
|
||||||
|
}
|
||||||
|
|
||||||
func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
||||||
key := "abc"
|
key := "abc"
|
||||||
fqdn := "peer-a.netbird.local"
|
fqdn := "peer-a.netbird.local"
|
||||||
|
|||||||
@@ -28,6 +28,15 @@ func hashRosenpassKey(key []byte) string {
|
|||||||
return hex.EncodeToString(hasher.Sum(nil))
|
return hex.EncodeToString(hasher.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rpServer is the subset of rp.Server used by Manager. Defined as an interface
|
||||||
|
// so tests can substitute a mock without spinning up a real UDP server.
|
||||||
|
type rpServer interface {
|
||||||
|
AddPeer(rp.PeerConfig) (rp.PeerID, error)
|
||||||
|
RemovePeer(rp.PeerID) error
|
||||||
|
Run() error
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
ifaceName string
|
ifaceName string
|
||||||
spk []byte
|
spk []byte
|
||||||
@@ -36,7 +45,7 @@ type Manager struct {
|
|||||||
preSharedKey *[32]byte
|
preSharedKey *[32]byte
|
||||||
rpPeerIDs map[string]*rp.PeerID
|
rpPeerIDs map[string]*rp.PeerID
|
||||||
rpWgHandler *NetbirdHandler
|
rpWgHandler *NetbirdHandler
|
||||||
server *rp.Server
|
server rpServer
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
port int
|
port int
|
||||||
wgIface PresharedKeySetter
|
wgIface PresharedKeySetter
|
||||||
@@ -51,7 +60,22 @@ func NewManager(preSharedKey *wgtypes.Key, wgIfaceName string) (*Manager, error)
|
|||||||
|
|
||||||
rpKeyHash := hashRosenpassKey(public)
|
rpKeyHash := hashRosenpassKey(public)
|
||||||
log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash)
|
log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash)
|
||||||
return &Manager{ifaceName: wgIfaceName, rpKeyHash: rpKeyHash, spk: public, ssk: secret, preSharedKey: (*[32]byte)(preSharedKey), rpPeerIDs: make(map[string]*rp.PeerID), lock: sync.Mutex{}}, nil
|
return &Manager{
|
||||||
|
ifaceName: wgIfaceName,
|
||||||
|
rpKeyHash: rpKeyHash,
|
||||||
|
spk: public,
|
||||||
|
ssk: secret,
|
||||||
|
preSharedKey: (*[32]byte)(preSharedKey),
|
||||||
|
rpPeerIDs: make(map[string]*rp.PeerID),
|
||||||
|
// rpWgHandler is created here (instead of only in generateConfig) so it
|
||||||
|
// is never nil between NewManager and Run(). Otherwise an early
|
||||||
|
// OnConnected call (race observed on Android, issue #4341) panics on
|
||||||
|
// nil receiver in addPeer -> m.rpWgHandler.AddPeer. generateConfig will
|
||||||
|
// replace it with a fresh handler on each Run() to clear stale peer
|
||||||
|
// state from previous engine sessions.
|
||||||
|
rpWgHandler: NewNetbirdHandler(),
|
||||||
|
lock: sync.Mutex{},
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) GetPubKey() []byte {
|
func (m *Manager) GetPubKey() []byte {
|
||||||
@@ -65,6 +89,16 @@ func (m *Manager) GetAddress() *net.UDPAddr {
|
|||||||
|
|
||||||
// addPeer adds a new peer to the Rosenpass server
|
// addPeer adds a new peer to the Rosenpass server
|
||||||
func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuardIP string, wireGuardPubKey string) error {
|
func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuardIP string, wireGuardPubKey string) error {
|
||||||
|
// Defense in depth against issue #4341 (Android crash): if Run() has not
|
||||||
|
// completed yet, m.server / m.rpWgHandler may be nil. Return an explicit
|
||||||
|
// error instead of panicking on nil-receiver dereference.
|
||||||
|
if m.server == nil {
|
||||||
|
return fmt.Errorf("rosenpass server not initialized")
|
||||||
|
}
|
||||||
|
if m.rpWgHandler == nil {
|
||||||
|
return fmt.Errorf("rosenpass wg handler not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
pcfg := rp.PeerConfig{PublicKey: rosenpassPubKey}
|
pcfg := rp.PeerConfig{PublicKey: rosenpassPubKey}
|
||||||
if m.preSharedKey != nil {
|
if m.preSharedKey != nil {
|
||||||
@@ -79,6 +113,16 @@ func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuar
|
|||||||
if pcfg.Endpoint, err = net.ResolveUDPAddr("udp", peerAddr); err != nil {
|
if pcfg.Endpoint, err = net.ResolveUDPAddr("udp", peerAddr); err != nil {
|
||||||
return fmt.Errorf("failed to resolve peer endpoint address: %w", err)
|
return fmt.Errorf("failed to resolve peer endpoint address: %w", err)
|
||||||
}
|
}
|
||||||
|
// Our local Rosenpass UDP server binds on the IPv6 wildcard ([::]) — see
|
||||||
|
// GetAddress(). The remote peer's endpoint (pcfg.Endpoint) is the destination
|
||||||
|
// our server will sendto when initiating handshakes. ResolveUDPAddr returns a
|
||||||
|
// 4-byte IPv4 for IPv4 hosts, which the kernel rejects (EDESTADDRREQ) when
|
||||||
|
// sent from an AF_INET6 socket. Normalize the remote endpoint to IPv4-mapped
|
||||||
|
// IPv6 so its address family matches our listening socket.
|
||||||
|
// TODO: maybe bind the Rosenpass UDP server to the peer wg IP addr
|
||||||
|
if v4 := pcfg.Endpoint.IP.To4(); v4 != nil {
|
||||||
|
pcfg.Endpoint.IP = v4.To16()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
peerID, err := m.server.AddPeer(pcfg)
|
peerID, err := m.server.AddPeer(pcfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -182,24 +226,31 @@ func (m *Manager) Run() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.server, err = rp.NewUDPServer(conf)
|
server, err := rp.NewUDPServer(conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.lock.Lock()
|
||||||
|
m.server = server
|
||||||
|
m.lock.Unlock()
|
||||||
|
|
||||||
log.Infof("starting rosenpass server on port %d", m.port)
|
log.Infof("starting rosenpass server on port %d", m.port)
|
||||||
|
|
||||||
return m.server.Run()
|
return server.Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the Rosenpass server
|
// Close closes the Rosenpass server
|
||||||
func (m *Manager) Close() error {
|
func (m *Manager) Close() error {
|
||||||
if m.server != nil {
|
m.lock.Lock()
|
||||||
err := m.server.Close()
|
server := m.server
|
||||||
if err != nil {
|
m.server = nil
|
||||||
log.Errorf("failed closing local rosenpass server")
|
m.lock.Unlock()
|
||||||
}
|
if server == nil {
|
||||||
m.server = nil
|
return nil
|
||||||
|
}
|
||||||
|
if err := server.Close(); err != nil {
|
||||||
|
log.Errorf("failed closing local rosenpass server: %v", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,412 @@
|
|||||||
package rosenpass
|
package rosenpass
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
rp "cunicu.li/go-rosenpass"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// --- test doubles -----------------------------------------------------------
|
||||||
|
|
||||||
|
type addPeerCall struct {
|
||||||
|
cfg rp.PeerConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
type removePeerCall struct {
|
||||||
|
id rp.PeerID
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockServer struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
addCalls []addPeerCall
|
||||||
|
removed []removePeerCall
|
||||||
|
nextID rp.PeerID
|
||||||
|
addErr error
|
||||||
|
removeErr error
|
||||||
|
closed bool
|
||||||
|
ran bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockServer) AddPeer(cfg rp.PeerConfig) (rp.PeerID, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.addCalls = append(m.addCalls, addPeerCall{cfg: cfg})
|
||||||
|
if m.addErr != nil {
|
||||||
|
return rp.PeerID{}, m.addErr
|
||||||
|
}
|
||||||
|
// Increment a byte in nextID so distinct peers get distinct IDs.
|
||||||
|
m.nextID[0]++
|
||||||
|
return m.nextID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockServer) RemovePeer(id rp.PeerID) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.removed = append(m.removed, removePeerCall{id: id})
|
||||||
|
return m.removeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockServer) Run() error { m.ran = true; return nil }
|
||||||
|
func (m *mockServer) Close() error { m.closed = true; return nil }
|
||||||
|
|
||||||
|
type setPSKCall struct {
|
||||||
|
peerKey string
|
||||||
|
psk wgtypes.Key
|
||||||
|
updateOnly bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockIface struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
calls []setPSKCall
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.calls = append(m.calls, setPSKCall{peerKey: peerKey, psk: psk, updateOnly: updateOnly})
|
||||||
|
return m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTestManager builds a Manager with deterministic spk so tie-break
|
||||||
|
// against a peer pubkey is controllable from tests. The provided spk byte
|
||||||
|
// becomes the first byte; remaining bytes are zero.
|
||||||
|
func newTestManager(spkFirstByte byte, mock *mockServer) *Manager {
|
||||||
|
spk := make([]byte, 32)
|
||||||
|
spk[0] = spkFirstByte
|
||||||
|
return &Manager{
|
||||||
|
ifaceName: "wt0",
|
||||||
|
spk: spk,
|
||||||
|
ssk: make([]byte, 32),
|
||||||
|
rpKeyHash: "test-hash",
|
||||||
|
rpPeerIDs: make(map[string]*rp.PeerID),
|
||||||
|
rpWgHandler: NewNetbirdHandler(),
|
||||||
|
server: mock,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// validWGKey returns a deterministic 32-byte wireguard public key (base64).
|
||||||
|
func validWGKey(t *testing.T, lastByte byte) string {
|
||||||
|
t.Helper()
|
||||||
|
var k wgtypes.Key
|
||||||
|
k[31] = lastByte
|
||||||
|
return k.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- pure helpers ----------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHashRosenpassKey_Deterministic(t *testing.T) {
|
||||||
|
key := []byte("hello-rosenpass")
|
||||||
|
require.Equal(t, hashRosenpassKey(key), hashRosenpassKey(key))
|
||||||
|
require.Len(t, hashRosenpassKey(key), 64) // sha256 hex
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashRosenpassKey_DifferentInputsDifferOutputs(t *testing.T) {
|
||||||
|
require.NotEqual(t, hashRosenpassKey([]byte("a")), hashRosenpassKey([]byte("b")))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetLogLevel_DefaultWhenUnset(t *testing.T) {
|
||||||
|
// Snapshot + unset to exercise the LookupEnv ok=false branch. t.Setenv
|
||||||
|
// can only set, not delete, so do it manually with restore via t.Cleanup.
|
||||||
|
prev, hadPrev := os.LookupEnv(defaultLogLevelVar)
|
||||||
|
require.NoError(t, os.Unsetenv(defaultLogLevelVar))
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if hadPrev {
|
||||||
|
_ = os.Setenv(defaultLogLevelVar, prev)
|
||||||
|
} else {
|
||||||
|
_ = os.Unsetenv(defaultLogLevelVar)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
require.Equal(t, defaultLog.String(), getLogLevel().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetLogLevel_Cases(t *testing.T) {
|
||||||
|
cases := map[string]string{
|
||||||
|
"debug": "DEBUG",
|
||||||
|
"info": "INFO",
|
||||||
|
"warn": "WARN",
|
||||||
|
"error": "ERROR",
|
||||||
|
"unknown": "INFO", // default fallback
|
||||||
|
}
|
||||||
|
for input, wantStr := range cases {
|
||||||
|
input, wantStr := input, wantStr
|
||||||
|
t.Run(input, func(t *testing.T) {
|
||||||
|
t.Setenv(defaultLogLevelVar, input)
|
||||||
|
require.Equal(t, wantStr, getLogLevel().String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFindRandomAvailableUDPPort(t *testing.T) {
|
func TestFindRandomAvailableUDPPort(t *testing.T) {
|
||||||
port, err := findRandomAvailableUDPPort()
|
port, err := findRandomAvailableUDPPort()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Greater(t, port, 0)
|
require.Greater(t, port, 0)
|
||||||
require.LessOrEqual(t, port, 65535)
|
require.LessOrEqual(t, port, 65535)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- addPeer ---------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestAddPeer_HigherLocalPubkey_SetsEndpoint(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv) // local spk lexicographically larger
|
||||||
|
|
||||||
|
remotePubKey := make([]byte, 32) // remote spk = all zeros (smaller)
|
||||||
|
err := m.addPeer(remotePubKey, "rosenpass-host:7000", "100.1.1.1", validWGKey(t, 1))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, srv.addCalls, 1)
|
||||||
|
|
||||||
|
ep := srv.addCalls[0].cfg.Endpoint
|
||||||
|
require.NotNil(t, ep, "initiator side must set Endpoint")
|
||||||
|
require.Equal(t, 7000, ep.Port)
|
||||||
|
require.Equal(t, "100.1.1.1", ep.IP.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_HigherLocalPubkey_EndpointIPIsIPv4Mapped(t *testing.T) {
|
||||||
|
// Regression guard for the EDESTADDRREQ fix: Endpoint.IP must be 16-byte
|
||||||
|
// (IPv4-mapped IPv6) so it matches the AF_INET6 listening socket family.
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ep := srv.addCalls[0].cfg.Endpoint
|
||||||
|
require.NotNil(t, ep)
|
||||||
|
require.Len(t, ep.IP, 16, "IPv4 endpoint must be normalized to 16-byte v4-mapped form")
|
||||||
|
require.True(t, ep.IP.To4() != nil, "Endpoint must still be detected as IPv4")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_LowerLocalPubkey_LeavesEndpointNil(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0x00, srv) // local spk smaller
|
||||||
|
|
||||||
|
remotePubKey := make([]byte, 32)
|
||||||
|
remotePubKey[0] = 0xFF
|
||||||
|
err := m.addPeer(remotePubKey, "rp:5000", "100.1.1.1", validWGKey(t, 2))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Nil(t, srv.addCalls[0].cfg.Endpoint, "responder side must NOT set Endpoint")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_PresharedKeyPropagated(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
psk := &wgtypes.Key{0x42}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
m.preSharedKey = (*[32]byte)(psk)
|
||||||
|
|
||||||
|
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 3))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, [32]byte(*psk), [32]byte(srv.addCalls[0].cfg.PresharedKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_InvalidRosenpassAddr_ReturnsError(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv) // initiator path → parses rosenpassAddr
|
||||||
|
|
||||||
|
err := m.addPeer(make([]byte, 32), "not-a-host-port", "100.1.1.1", validWGKey(t, 1))
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Empty(t, srv.addCalls, "server.AddPeer must not run when address parse fails")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_InvalidWireGuardPubKey_ReturnsError(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", "not-a-valid-key")
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_ServerError_Propagates(t *testing.T) {
|
||||||
|
srv := &mockServer{addErr: errors.New("boom")}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regression guard for issue #4341 (Android crash). If Run() has not completed
|
||||||
|
// before OnConnected fires, m.rpWgHandler or m.server may be nil. Without the
|
||||||
|
// nil guards, m.rpWgHandler.AddPeer panics on nil receiver.
|
||||||
|
func TestAddPeer_NilHandler_ReturnsErrorNoCrash(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
m.rpWgHandler = nil // simulate Run() not yet completed
|
||||||
|
|
||||||
|
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "wg handler not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_NilServer_ReturnsErrorNoCrash(t *testing.T) {
|
||||||
|
m := newTestManager(0xFF, nil)
|
||||||
|
m.server = nil // simulate Run() not yet completed
|
||||||
|
|
||||||
|
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "server not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager must pre-initialize rpWgHandler so the nil-receiver crash from
|
||||||
|
// issue #4341 cannot occur in the window between NewManager and Run().
|
||||||
|
func TestNewManager_PreInitializesHandler(t *testing.T) {
|
||||||
|
psk := wgtypes.Key{}
|
||||||
|
m, err := NewManager(&psk, "wt0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, m.rpWgHandler, "rpWgHandler must be initialized in NewManager")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_RecordsPeerID(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
wgKey := validWGKey(t, 5)
|
||||||
|
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Contains(t, m.rpPeerIDs, wgKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- OnConnected / OnDisconnected ------------------------------------------
|
||||||
|
|
||||||
|
func TestOnConnected_NilRemotePubKey_NoAddPeer(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
m.OnConnected(validWGKey(t, 1), nil, "100.1.1.1", "rp:5000")
|
||||||
|
require.Empty(t, srv.addCalls, "nil remote rosenpass pubkey must skip AddPeer")
|
||||||
|
require.Empty(t, m.rpPeerIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOnConnected_ValidPubKey_CallsAddPeer(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
wgKey := validWGKey(t, 1)
|
||||||
|
m.OnConnected(wgKey, make([]byte, 32), "100.1.1.1", "rp:5000")
|
||||||
|
require.Len(t, srv.addCalls, 1)
|
||||||
|
require.Contains(t, m.rpPeerIDs, wgKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOnDisconnected_UnknownPeer_NoOp(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
m.OnDisconnected(validWGKey(t, 99))
|
||||||
|
require.Empty(t, srv.removed, "unknown peer key must not call RemovePeer")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOnDisconnected_KnownPeer_CallsRemoveAndForgets(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
wgKey := validWGKey(t, 1)
|
||||||
|
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
|
||||||
|
require.Contains(t, m.rpPeerIDs, wgKey)
|
||||||
|
|
||||||
|
m.OnDisconnected(wgKey)
|
||||||
|
require.Len(t, srv.removed, 1)
|
||||||
|
require.NotContains(t, m.rpPeerIDs, wgKey, "peer must be forgotten after disconnect")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- IsPresharedKeyInitialized ---------------------------------------------
|
||||||
|
|
||||||
|
func TestIsPresharedKeyInitialized_UnknownPeer_ReturnsFalse(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
require.False(t, m.IsPresharedKeyInitialized(validWGKey(t, 1)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsPresharedKeyInitialized_AddedButNotHandshaken_ReturnsFalse(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
wgKey := validWGKey(t, 2)
|
||||||
|
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
|
||||||
|
require.False(t, m.IsPresharedKeyInitialized(wgKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- NetbirdHandler.outputKey ----------------------------------------------
|
||||||
|
|
||||||
|
func TestHandler_OutputKey_FirstCallUsesUpdateOnlyFalse(t *testing.T) {
|
||||||
|
h := NewNetbirdHandler()
|
||||||
|
iface := &mockIface{}
|
||||||
|
h.SetInterface(iface)
|
||||||
|
|
||||||
|
pid := rp.PeerID{0x01}
|
||||||
|
wgKey := wgtypes.Key{0xAA}
|
||||||
|
h.AddPeer(pid, "wt0", rp.Key(wgKey))
|
||||||
|
|
||||||
|
psk := rp.Key{0xBB}
|
||||||
|
h.HandshakeCompleted(pid, psk)
|
||||||
|
|
||||||
|
require.Len(t, iface.calls, 1)
|
||||||
|
require.False(t, iface.calls[0].updateOnly, "first PSK rotation must use updateOnly=false")
|
||||||
|
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandler_OutputKey_SubsequentCallsUseUpdateOnlyTrue(t *testing.T) {
|
||||||
|
h := NewNetbirdHandler()
|
||||||
|
iface := &mockIface{}
|
||||||
|
h.SetInterface(iface)
|
||||||
|
|
||||||
|
pid := rp.PeerID{0x02}
|
||||||
|
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xCC}))
|
||||||
|
|
||||||
|
h.HandshakeCompleted(pid, rp.Key{0x01}) // first
|
||||||
|
h.HandshakeCompleted(pid, rp.Key{0x02}) // second
|
||||||
|
|
||||||
|
require.Len(t, iface.calls, 2)
|
||||||
|
require.False(t, iface.calls[0].updateOnly)
|
||||||
|
require.True(t, iface.calls[1].updateOnly, "subsequent rotations must use updateOnly=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandler_OutputKey_NilInterface_NoCrashNoCall(t *testing.T) {
|
||||||
|
h := NewNetbirdHandler()
|
||||||
|
// no SetInterface — iface remains nil
|
||||||
|
pid := rp.PeerID{0x03}
|
||||||
|
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{}))
|
||||||
|
|
||||||
|
// Must not panic.
|
||||||
|
h.HandshakeCompleted(pid, rp.Key{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandler_OutputKey_UnknownPeer_NoCall(t *testing.T) {
|
||||||
|
h := NewNetbirdHandler()
|
||||||
|
iface := &mockIface{}
|
||||||
|
h.SetInterface(iface)
|
||||||
|
|
||||||
|
h.HandshakeCompleted(rp.PeerID{0xFF}, rp.Key{})
|
||||||
|
require.Empty(t, iface.calls, "unknown peer id must not trigger SetPresharedKey")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandler_RemovePeer_ClearsInitializedState(t *testing.T) {
|
||||||
|
h := NewNetbirdHandler()
|
||||||
|
iface := &mockIface{}
|
||||||
|
h.SetInterface(iface)
|
||||||
|
|
||||||
|
pid := rp.PeerID{0x04}
|
||||||
|
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xDD}))
|
||||||
|
h.HandshakeCompleted(pid, rp.Key{0x01})
|
||||||
|
require.True(t, h.IsPeerInitialized(pid))
|
||||||
|
|
||||||
|
h.RemovePeer(pid)
|
||||||
|
require.False(t, h.IsPeerInitialized(pid), "RemovePeer must clear initialized flag")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandler_SetInterfaceAfterAddPeer_StillReceivesKey(t *testing.T) {
|
||||||
|
h := NewNetbirdHandler()
|
||||||
|
pid := rp.PeerID{0x05}
|
||||||
|
wgKey := wgtypes.Key{0xEE}
|
||||||
|
h.AddPeer(pid, "wt0", rp.Key(wgKey))
|
||||||
|
|
||||||
|
iface := &mockIface{}
|
||||||
|
h.SetInterface(iface) // set after AddPeer
|
||||||
|
|
||||||
|
h.HandshakeCompleted(pid, rp.Key{0x42})
|
||||||
|
require.Len(t, iface.calls, 1)
|
||||||
|
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
|
||||||
|
}
|
||||||
|
|||||||
42
client/internal/rosenpass/seed.go
Normal file
42
client/internal/rosenpass/seed.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package rosenpass
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DeterministicSeedKey derives a 32-byte WireGuard preshared key from a pair
|
||||||
|
// of peer public keys. Both peers, given the same key pair, produce the same
|
||||||
|
// output regardless of which side runs the function: the inputs are ordered
|
||||||
|
// lexicographically before concatenation.
|
||||||
|
//
|
||||||
|
// NetBird uses this value as the initial Rosenpass-side preshared key when no
|
||||||
|
// explicit account-level PSK is configured, so both peers converge on the same
|
||||||
|
// PSK before the first post-quantum handshake completes.
|
||||||
|
//
|
||||||
|
// The resulting key MUST NOT be treated as quantum-safe: it is deterministic
|
||||||
|
// from public keys and exists only to seed WireGuard until Rosenpass rotates
|
||||||
|
// in a real post-quantum PSK.
|
||||||
|
func DeterministicSeedKey(localKey, remoteKey string) (*wgtypes.Key, error) {
|
||||||
|
lk := []byte(localKey)
|
||||||
|
rk := []byte(remoteKey)
|
||||||
|
if len(lk) < 16 || len(rk) < 16 {
|
||||||
|
return nil, fmt.Errorf("rosenpass: peer keys must be at least 16 bytes (got local=%d, remote=%d)", len(lk), len(rk))
|
||||||
|
}
|
||||||
|
|
||||||
|
var keyInput []byte
|
||||||
|
if localKey > remoteKey {
|
||||||
|
keyInput = append(keyInput, lk[:16]...)
|
||||||
|
keyInput = append(keyInput, rk[:16]...)
|
||||||
|
} else {
|
||||||
|
keyInput = append(keyInput, rk[:16]...)
|
||||||
|
keyInput = append(keyInput, lk[:16]...)
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := wgtypes.NewKey(keyInput)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("rosenpass: deterministic seed key: %w", err)
|
||||||
|
}
|
||||||
|
return &key, nil
|
||||||
|
}
|
||||||
44
client/internal/rosenpass/seed_test.go
Normal file
44
client/internal/rosenpass/seed_test.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package rosenpass
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeterministicSeedKey_SameForBothSides(t *testing.T) {
|
||||||
|
// Peer A and peer B must derive the same PSK regardless of which side
|
||||||
|
// computes it: the function orders inputs internally.
|
||||||
|
a := strings.Repeat("a", 32)
|
||||||
|
b := strings.Repeat("b", 32)
|
||||||
|
|
||||||
|
keyAB, err := DeterministicSeedKey(a, b)
|
||||||
|
require.NoError(t, err)
|
||||||
|
keyBA, err := DeterministicSeedKey(b, a)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, keyAB.String(), keyBA.String(), "swapping arguments must yield identical key")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeterministicSeedKey_ChangesWithKeys(t *testing.T) {
|
||||||
|
a := strings.Repeat("a", 32)
|
||||||
|
b := strings.Repeat("b", 32)
|
||||||
|
c := strings.Repeat("c", 32)
|
||||||
|
|
||||||
|
keyAB, err := DeterministicSeedKey(a, b)
|
||||||
|
require.NoError(t, err)
|
||||||
|
keyAC, err := DeterministicSeedKey(a, c)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEqual(t, keyAB.String(), keyAC.String(), "different peer pair must yield different key")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeterministicSeedKey_TooShortKey_ReturnsError(t *testing.T) {
|
||||||
|
short := "short" // < 16 bytes
|
||||||
|
long := strings.Repeat("x", 32)
|
||||||
|
|
||||||
|
_, err := DeterministicSeedKey(short, long)
|
||||||
|
require.Error(t, err)
|
||||||
|
_, err = DeterministicSeedKey(long, short)
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
@@ -53,6 +53,7 @@ type Manager interface {
|
|||||||
GetRouteSelector() *routeselector.RouteSelector
|
GetRouteSelector() *routeselector.RouteSelector
|
||||||
GetClientRoutes() route.HAMap
|
GetClientRoutes() route.HAMap
|
||||||
GetSelectedClientRoutes() route.HAMap
|
GetSelectedClientRoutes() route.HAMap
|
||||||
|
GetActiveClientRoutes() route.HAMap
|
||||||
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
||||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||||
InitialRouteRange() []string
|
InitialRouteRange() []string
|
||||||
@@ -485,6 +486,39 @@ func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap {
|
|||||||
return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
|
return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetActiveClientRoutes returns the subset of selected client routes
|
||||||
|
// that are currently reachable: the route's peer is Connected and is
|
||||||
|
// the one actively carrying the route (not just an HA sibling).
|
||||||
|
func (m *DefaultManager) GetActiveClientRoutes() route.HAMap {
|
||||||
|
m.mux.Lock()
|
||||||
|
selected := m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
|
||||||
|
recorder := m.statusRecorder
|
||||||
|
m.mux.Unlock()
|
||||||
|
|
||||||
|
if recorder == nil {
|
||||||
|
return selected
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make(route.HAMap, len(selected))
|
||||||
|
for id, routes := range selected {
|
||||||
|
for _, r := range routes {
|
||||||
|
st, err := recorder.GetPeer(r.Peer)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if st.ConnStatus != peer.StatusConnected {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, hasRoute := st.GetRoutes()[r.Network.String()]; !hasRoute {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out[id] = routes
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
||||||
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
@@ -704,7 +738,10 @@ func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeI
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) isExitNodeRoute(routes []*route.Route) bool {
|
func (m *DefaultManager) isExitNodeRoute(routes []*route.Route) bool {
|
||||||
return len(routes) > 0 && routes[0].Network.String() == vars.ExitNodeCIDR
|
if len(routes) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return route.IsV4DefaultRoute(routes[0].Network) || route.IsV6DefaultRoute(routes[0].Network)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) categorizeUserSelection(netID route.NetID, info *exitNodeInfo) {
|
func (m *DefaultManager) categorizeUserSelection(netID route.NetID, info *exitNodeInfo) {
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ type MockManager struct {
|
|||||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||||
GetClientRoutesFunc func() route.HAMap
|
GetClientRoutesFunc func() route.HAMap
|
||||||
GetSelectedClientRoutesFunc func() route.HAMap
|
GetSelectedClientRoutesFunc func() route.HAMap
|
||||||
|
GetActiveClientRoutesFunc func() route.HAMap
|
||||||
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
|
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
|
||||||
StopFunc func(manager *statemanager.Manager)
|
StopFunc func(manager *statemanager.Manager)
|
||||||
}
|
}
|
||||||
@@ -78,6 +79,14 @@ func (m *MockManager) GetSelectedClientRoutes() route.HAMap {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetActiveClientRoutes mock implementation of GetActiveClientRoutes from the Manager interface
|
||||||
|
func (m *MockManager) GetActiveClientRoutes() route.HAMap {
|
||||||
|
if m.GetActiveClientRoutesFunc != nil {
|
||||||
|
return m.GetActiveClientRoutesFunc()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
|
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
|
||||||
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||||
if m.GetClientRoutesWithNetIDFunc != nil {
|
if m.GetClientRoutesWithNetIDFunc != nil {
|
||||||
|
|||||||
@@ -0,0 +1,9 @@
|
|||||||
|
//go:build dragonfly || freebsd || netbsd || openbsd
|
||||||
|
|
||||||
|
package systemops
|
||||||
|
|
||||||
|
// IgnoreAddedDefaultRoute reports whether an RTM_ADD default route with the
|
||||||
|
// given flags should be ignored by the network monitor.
|
||||||
|
func IgnoreAddedDefaultRoute(flags int) bool {
|
||||||
|
return filterRoutesByFlags(flags)
|
||||||
|
}
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
//go:build darwin
|
||||||
|
|
||||||
|
package systemops
|
||||||
|
|
||||||
|
import "golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
// IgnoreAddedDefaultRoute reports whether an RTM_ADD default route with the
|
||||||
|
// given flags should be ignored by the network monitor. Scoped routes
|
||||||
|
// (RTF_IFSCOPE) are tied to a specific interface index and cannot replace the
|
||||||
|
// unscoped default the kernel uses for general egress, so flapping ones (e.g.
|
||||||
|
// Wi-Fi calling IMS tunnels on ipsec0, Docker bridges, scoped utun defaults)
|
||||||
|
// must not trigger an engine restart.
|
||||||
|
func IgnoreAddedDefaultRoute(flags int) bool {
|
||||||
|
if filterRoutesByFlags(flags) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if flags&unix.RTF_IFSCOPE != 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
@@ -12,10 +13,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
exitNodeCIDR = "0.0.0.0/0"
|
|
||||||
)
|
|
||||||
|
|
||||||
type RouteSelector struct {
|
type RouteSelector struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
deselectedRoutes map[route.NetID]struct{}
|
deselectedRoutes map[route.NetID]struct{}
|
||||||
@@ -124,13 +121,7 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
|||||||
rs.mu.RLock()
|
rs.mu.RLock()
|
||||||
defer rs.mu.RUnlock()
|
defer rs.mu.RUnlock()
|
||||||
|
|
||||||
if rs.deselectAll {
|
return rs.isSelectedLocked(routeID)
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
_, deselected := rs.deselectedRoutes[routeID]
|
|
||||||
isSelected := !deselected
|
|
||||||
return isSelected
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilterSelected removes unselected routes from the provided map.
|
// FilterSelected removes unselected routes from the provided map.
|
||||||
@@ -144,23 +135,22 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
|||||||
|
|
||||||
filtered := route.HAMap{}
|
filtered := route.HAMap{}
|
||||||
for id, rt := range routes {
|
for id, rt := range routes {
|
||||||
netID := id.NetID()
|
if !rs.isDeselectedLocked(id.NetID()) {
|
||||||
_, deselected := rs.deselectedRoutes[netID]
|
|
||||||
if !deselected {
|
|
||||||
filtered[id] = rt
|
filtered[id] = rt
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return filtered
|
return filtered
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this specific route
|
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this route.
|
||||||
|
// Intended for exit-node code paths: a v6 exit-node pair (e.g. "MyExit-v6") with no explicit state of
|
||||||
|
// its own inherits its v4 base's state, so legacy persisted selections that predate v6 pairing
|
||||||
|
// transparently apply to the synthesized v6 entry.
|
||||||
func (rs *RouteSelector) HasUserSelectionForRoute(routeID route.NetID) bool {
|
func (rs *RouteSelector) HasUserSelectionForRoute(routeID route.NetID) bool {
|
||||||
rs.mu.RLock()
|
rs.mu.RLock()
|
||||||
defer rs.mu.RUnlock()
|
defer rs.mu.RUnlock()
|
||||||
|
|
||||||
_, selected := rs.selectedRoutes[routeID]
|
return rs.hasUserSelectionForRouteLocked(rs.effectiveNetID(routeID))
|
||||||
_, deselected := rs.deselectedRoutes[routeID]
|
|
||||||
return selected || deselected
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap {
|
func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap {
|
||||||
@@ -174,7 +164,7 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap
|
|||||||
filtered := make(route.HAMap, len(routes))
|
filtered := make(route.HAMap, len(routes))
|
||||||
for id, rt := range routes {
|
for id, rt := range routes {
|
||||||
netID := id.NetID()
|
netID := id.NetID()
|
||||||
if rs.isDeselected(netID) {
|
if rs.isDeselectedLocked(netID) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,13 +179,48 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap
|
|||||||
return filtered
|
return filtered
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *RouteSelector) isDeselected(netID route.NetID) bool {
|
// effectiveNetID returns the v4 base for a "-v6" exit pair entry that has no explicit
|
||||||
|
// state of its own, so selections made on the v4 entry govern the v6 entry automatically.
|
||||||
|
// Only call this from exit-node-specific code paths: applying it to a non-exit "-v6" route
|
||||||
|
// would make it inherit unrelated v4 state. Must be called with rs.mu held.
|
||||||
|
func (rs *RouteSelector) effectiveNetID(id route.NetID) route.NetID {
|
||||||
|
name := string(id)
|
||||||
|
if !strings.HasSuffix(name, route.V6ExitSuffix) {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
if _, ok := rs.selectedRoutes[id]; ok {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
if _, ok := rs.deselectedRoutes[id]; ok {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
return route.NetID(strings.TrimSuffix(name, route.V6ExitSuffix))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool {
|
||||||
|
if rs.deselectAll {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, deselected := rs.deselectedRoutes[routeID]
|
||||||
|
return !deselected
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool {
|
||||||
|
if rs.deselectAll {
|
||||||
|
return true
|
||||||
|
}
|
||||||
_, deselected := rs.deselectedRoutes[netID]
|
_, deselected := rs.deselectedRoutes[netID]
|
||||||
return deselected || rs.deselectAll
|
return deselected
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool {
|
||||||
|
_, selected := rs.selectedRoutes[routeID]
|
||||||
|
_, deselected := rs.deselectedRoutes[routeID]
|
||||||
|
return selected || deselected
|
||||||
}
|
}
|
||||||
|
|
||||||
func isExitNode(rt []*route.Route) bool {
|
func isExitNode(rt []*route.Route) bool {
|
||||||
return len(rt) > 0 && rt[0].Network.String() == exitNodeCIDR
|
return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *RouteSelector) applyExitNodeFilter(
|
func (rs *RouteSelector) applyExitNodeFilter(
|
||||||
@@ -204,26 +229,23 @@ func (rs *RouteSelector) applyExitNodeFilter(
|
|||||||
rt []*route.Route,
|
rt []*route.Route,
|
||||||
out route.HAMap,
|
out route.HAMap,
|
||||||
) {
|
) {
|
||||||
|
// Exit-node path: apply the v4/v6 pair mirror so a deselect on the v4 base also
|
||||||
if rs.hasUserSelections() {
|
// drops the synthesized v6 entry that lacks its own explicit state.
|
||||||
// user made explicit selects/deselects
|
effective := rs.effectiveNetID(netID)
|
||||||
if rs.IsSelected(netID) {
|
if rs.hasUserSelectionForRouteLocked(effective) {
|
||||||
|
if rs.isSelectedLocked(effective) {
|
||||||
out[id] = rt
|
out[id] = rt
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// no explicit selections: only include routes marked !SkipAutoApply (=AutoApply)
|
// no explicit selection for this route: defer to management's SkipAutoApply flag
|
||||||
sel := collectSelected(rt)
|
sel := collectSelected(rt)
|
||||||
if len(sel) > 0 {
|
if len(sel) > 0 {
|
||||||
out[id] = sel
|
out[id] = sel
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *RouteSelector) hasUserSelections() bool {
|
|
||||||
return len(rs.selectedRoutes) > 0 || len(rs.deselectedRoutes) > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func collectSelected(rt []*route.Route) []*route.Route {
|
func collectSelected(rt []*route.Route) []*route.Route {
|
||||||
var sel []*route.Route
|
var sel []*route.Route
|
||||||
for _, r := range rt {
|
for _, r := range rt {
|
||||||
|
|||||||
@@ -330,6 +330,137 @@ func TestRouteSelector_FilterSelectedExitNodes(t *testing.T) {
|
|||||||
assert.Len(t, filtered, 0) // No routes should be selected
|
assert.Len(t, filtered, 0) // No routes should be selected
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestRouteSelector_V6ExitPairInherits covers the v4/v6 exit-node pair selection
|
||||||
|
// mirror. The mirror is scoped to exit-node code paths: HasUserSelectionForRoute
|
||||||
|
// and FilterSelectedExitNodes resolve a "-v6" entry without explicit state to its
|
||||||
|
// v4 base, so legacy persisted selections that predate v6 pairing transparently
|
||||||
|
// apply to the synthesized v6 entry. General lookups (IsSelected, FilterSelected)
|
||||||
|
// stay literal so unrelated routes named "*-v6" don't inherit unrelated state.
|
||||||
|
func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
|
||||||
|
all := []route.NetID{"exit1", "exit1-v6", "exit2", "exit2-v6", "corp", "corp-v6"}
|
||||||
|
|
||||||
|
t.Run("HasUserSelectionForRoute mirrors deselected v4 base", func(t *testing.T) {
|
||||||
|
rs := routeselector.NewRouteSelector()
|
||||||
|
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
|
||||||
|
|
||||||
|
assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 pair sees v4 base's user selection")
|
||||||
|
|
||||||
|
// unrelated v6 with no v4 base touched is unaffected
|
||||||
|
assert.False(t, rs.HasUserSelectionForRoute("exit2-v6"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("IsSelected stays literal for non-exit lookups", func(t *testing.T) {
|
||||||
|
rs := routeselector.NewRouteSelector()
|
||||||
|
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
|
||||||
|
|
||||||
|
// A non-exit route literally named "corp-v6" must not inherit "corp"'s state
|
||||||
|
// via the mirror; the mirror only applies in exit-node code paths.
|
||||||
|
assert.False(t, rs.IsSelected("corp"))
|
||||||
|
assert.True(t, rs.IsSelected("corp-v6"), "non-exit *-v6 routes must not inherit unrelated v4 state")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("explicit v6 state overrides v4 base in filter", func(t *testing.T) {
|
||||||
|
rs := routeselector.NewRouteSelector()
|
||||||
|
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
|
||||||
|
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1-v6"}, true, all))
|
||||||
|
|
||||||
|
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
|
||||||
|
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
|
||||||
|
routes := route.HAMap{
|
||||||
|
"exit1|0.0.0.0/0": {v4Route},
|
||||||
|
"exit1-v6|::/0": {v6Route},
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := rs.FilterSelectedExitNodes(routes)
|
||||||
|
assert.NotContains(t, filtered, route.HAUniqueID("exit1|0.0.0.0/0"))
|
||||||
|
assert.Contains(t, filtered, route.HAUniqueID("exit1-v6|::/0"), "explicit v6 select wins over v4 base")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-v6-suffix routes unaffected", func(t *testing.T) {
|
||||||
|
rs := routeselector.NewRouteSelector()
|
||||||
|
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
|
||||||
|
|
||||||
|
// A route literally named "exit1-something" must not pair-resolve.
|
||||||
|
assert.False(t, rs.HasUserSelectionForRoute("exit1-something"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("filter v6 paired with deselected v4 base", func(t *testing.T) {
|
||||||
|
rs := routeselector.NewRouteSelector()
|
||||||
|
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
|
||||||
|
|
||||||
|
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
|
||||||
|
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
|
||||||
|
routes := route.HAMap{
|
||||||
|
"exit1|0.0.0.0/0": {v4Route},
|
||||||
|
"exit1-v6|::/0": {v6Route},
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := rs.FilterSelectedExitNodes(routes)
|
||||||
|
assert.Empty(t, filtered, "deselecting v4 base must also drop the v6 pair")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-exit *-v6 routes pass through FilterSelectedExitNodes", func(t *testing.T) {
|
||||||
|
rs := routeselector.NewRouteSelector()
|
||||||
|
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
|
||||||
|
|
||||||
|
// A non-default-route entry named "corp-v6" is not an exit node and
|
||||||
|
// must not be skipped because its v4 base "corp" is deselected.
|
||||||
|
corpV6 := &route.Route{NetID: "corp-v6", Network: netip.MustParsePrefix("10.0.0.0/8")}
|
||||||
|
routes := route.HAMap{
|
||||||
|
"corp-v6|10.0.0.0/8": {corpV6},
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := rs.FilterSelectedExitNodes(routes)
|
||||||
|
assert.Contains(t, filtered, route.HAUniqueID("corp-v6|10.0.0.0/8"),
|
||||||
|
"non-exit *-v6 routes must not inherit unrelated v4 state in FilterSelectedExitNodes")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRouteSelector_SkipAutoApplyPerRoute verifies that management's
|
||||||
|
// SkipAutoApply flag governs each untouched route independently, even when
|
||||||
|
// the user has explicit selections on other routes.
|
||||||
|
func TestRouteSelector_SkipAutoApplyPerRoute(t *testing.T) {
|
||||||
|
autoApplied := &route.Route{
|
||||||
|
NetID: "Auto",
|
||||||
|
Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||||
|
SkipAutoApply: false,
|
||||||
|
}
|
||||||
|
skipApply := &route.Route{
|
||||||
|
NetID: "Skip",
|
||||||
|
Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||||
|
SkipAutoApply: true,
|
||||||
|
}
|
||||||
|
routes := route.HAMap{
|
||||||
|
"Auto|0.0.0.0/0": {autoApplied},
|
||||||
|
"Skip|0.0.0.0/0": {skipApply},
|
||||||
|
}
|
||||||
|
|
||||||
|
rs := routeselector.NewRouteSelector()
|
||||||
|
// User makes an unrelated explicit selection elsewhere.
|
||||||
|
require.NoError(t, rs.DeselectRoutes([]route.NetID{"Unrelated"}, []route.NetID{"Auto", "Skip", "Unrelated"}))
|
||||||
|
|
||||||
|
filtered := rs.FilterSelectedExitNodes(routes)
|
||||||
|
assert.Contains(t, filtered, route.HAUniqueID("Auto|0.0.0.0/0"), "AutoApply route should be included")
|
||||||
|
assert.NotContains(t, filtered, route.HAUniqueID("Skip|0.0.0.0/0"), "SkipAutoApply route should be excluded without explicit user selection")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRouteSelector_V6ExitIsExitNode verifies that ::/0 routes are recognized
|
||||||
|
// as exit nodes by the selector's filter path.
|
||||||
|
func TestRouteSelector_V6ExitIsExitNode(t *testing.T) {
|
||||||
|
v6Exit := &route.Route{
|
||||||
|
NetID: "V6Only",
|
||||||
|
Network: netip.MustParsePrefix("::/0"),
|
||||||
|
SkipAutoApply: true,
|
||||||
|
}
|
||||||
|
routes := route.HAMap{
|
||||||
|
"V6Only|::/0": {v6Exit},
|
||||||
|
}
|
||||||
|
|
||||||
|
rs := routeselector.NewRouteSelector()
|
||||||
|
filtered := rs.FilterSelectedExitNodes(routes)
|
||||||
|
assert.Empty(t, filtered, "::/0 should be treated as an exit node and respect SkipAutoApply")
|
||||||
|
}
|
||||||
|
|
||||||
func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
|
func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
|
||||||
initialRoutes := []route.NetID{"route1", "route2", "route3"}
|
initialRoutes := []route.NetID{"route1", "route2", "route3"}
|
||||||
newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"}
|
newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"}
|
||||||
|
|||||||
@@ -188,7 +188,9 @@ func (d *Detector) triggerCallback(event EventType, cb func(event EventType), do
|
|||||||
}
|
}
|
||||||
|
|
||||||
doneChan := make(chan struct{})
|
doneChan := make(chan struct{})
|
||||||
timeout := time.NewTimer(500 * time.Millisecond)
|
// macOS forces sleep ~30s after kIOMessageSystemWillSleep, so block long
|
||||||
|
// enough for teardown to finish while staying under that deadline.
|
||||||
|
timeout := time.NewTimer(20 * time.Second)
|
||||||
defer timeout.Stop()
|
defer timeout.Stop()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
|||||||
@@ -96,17 +96,19 @@ func (m *Manager) Stop(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
cancel := m.cancel
|
||||||
|
done := m.done
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
if m.cancel == nil {
|
if cancel == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
m.cancel()
|
cancel()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case <-m.done:
|
case <-done:
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -162,11 +162,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
|||||||
cfg.WgIface = interfaceName
|
cfg.WgIface = interfaceName
|
||||||
|
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
hostDNS := []netip.AddrPort{
|
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
|
||||||
netip.MustParseAddrPort("9.9.9.9:53"),
|
|
||||||
netip.MustParseAddrPort("149.112.112.112:53"),
|
|
||||||
}
|
|
||||||
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
@@ -315,7 +315,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
|||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -3,15 +3,14 @@
|
|||||||
package system
|
package system
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/zcalusic/sysinfo"
|
"github.com/zcalusic/sysinfo"
|
||||||
|
|
||||||
@@ -29,19 +28,11 @@ func UpdateStaticInfoAsync() {
|
|||||||
|
|
||||||
// GetInfo retrieves and parses the system information
|
// GetInfo retrieves and parses the system information
|
||||||
func GetInfo(ctx context.Context) *Info {
|
func GetInfo(ctx context.Context) *Info {
|
||||||
info := _getInfo()
|
kernelName, kernelVersion, kernelPlatform := kernelInfo()
|
||||||
for strings.Contains(info, "broken pipe") {
|
|
||||||
info = _getInfo()
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
}
|
|
||||||
|
|
||||||
osStr := strings.ReplaceAll(info, "\n", "")
|
|
||||||
osStr = strings.ReplaceAll(osStr, "\r\n", "")
|
|
||||||
osInfo := strings.Split(osStr, " ")
|
|
||||||
|
|
||||||
osName, osVersion := readOsReleaseFile()
|
osName, osVersion := readOsReleaseFile()
|
||||||
if osName == "" {
|
if osName == "" {
|
||||||
osName = osInfo[3]
|
osName = kernelName
|
||||||
}
|
}
|
||||||
|
|
||||||
systemHostname, _ := os.Hostname()
|
systemHostname, _ := os.Hostname()
|
||||||
@@ -58,8 +49,8 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
}
|
}
|
||||||
|
|
||||||
gio := &Info{
|
gio := &Info{
|
||||||
Kernel: osInfo[0],
|
Kernel: kernelName,
|
||||||
Platform: osInfo[2],
|
Platform: kernelPlatform,
|
||||||
OS: osName,
|
OS: osName,
|
||||||
OSVersion: osVersion,
|
OSVersion: osVersion,
|
||||||
Hostname: extractDeviceName(ctx, systemHostname),
|
Hostname: extractDeviceName(ctx, systemHostname),
|
||||||
@@ -67,7 +58,7 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
CPUs: runtime.NumCPU(),
|
CPUs: runtime.NumCPU(),
|
||||||
NetbirdVersion: version.NetbirdVersion(),
|
NetbirdVersion: version.NetbirdVersion(),
|
||||||
UIVersion: extractUserAgent(ctx),
|
UIVersion: extractUserAgent(ctx),
|
||||||
KernelVersion: osInfo[1],
|
KernelVersion: kernelVersion,
|
||||||
NetworkAddresses: addrs,
|
NetworkAddresses: addrs,
|
||||||
SystemSerialNumber: si.SystemSerialNumber,
|
SystemSerialNumber: si.SystemSerialNumber,
|
||||||
SystemProductName: si.SystemProductName,
|
SystemProductName: si.SystemProductName,
|
||||||
@@ -78,18 +69,12 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
return gio
|
return gio
|
||||||
}
|
}
|
||||||
|
|
||||||
func _getInfo() string {
|
func kernelInfo() (string, string, string) {
|
||||||
cmd := exec.Command("uname", "-srio")
|
var uts unix.Utsname
|
||||||
cmd.Stdin = strings.NewReader("some")
|
if err := unix.Uname(&uts); err != nil {
|
||||||
var out bytes.Buffer
|
return "", "", ""
|
||||||
var stderr bytes.Buffer
|
|
||||||
cmd.Stdout = &out
|
|
||||||
cmd.Stderr = &stderr
|
|
||||||
err := cmd.Run()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("getInfo: %s", err)
|
|
||||||
}
|
}
|
||||||
return out.String()
|
return unix.ByteSliceToString(uts.Sysname[:]), unix.ByteSliceToString(uts.Release[:]), unix.ByteSliceToString(uts.Machine[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
func sysInfo() (string, string, string) {
|
func sysInfo() (string, string, string) {
|
||||||
|
|||||||
@@ -193,7 +193,15 @@ func getOverlappingNetworks(routes []*proto.Network) []*proto.Network {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func isDefaultRoute(routeRange string) bool {
|
func isDefaultRoute(routeRange string) bool {
|
||||||
return routeRange == "0.0.0.0/0" || routeRange == "::/0"
|
// routeRange is the merged display string from the daemon, e.g. "0.0.0.0/0",
|
||||||
|
// "::/0", or "0.0.0.0/0, ::/0" when a v4 exit node has a paired v6 entry.
|
||||||
|
for _, part := range strings.Split(routeRange, ",") {
|
||||||
|
switch strings.TrimSpace(part) {
|
||||||
|
case "0.0.0.0/0", "::/0":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func getExitNodeNetworks(routes []*proto.Network) []*proto.Network {
|
func getExitNodeNetworks(routes []*proto.Network) []*proto.Network {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
"syscall/js"
|
"syscall/js"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -13,7 +14,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
certValidationTimeout = 60 * time.Second
|
certValidationTimeout = 5 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) {
|
func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) {
|
||||||
@@ -46,17 +47,31 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert
|
|||||||
|
|
||||||
promise := conn.wsHandlers.Call("onCertificateRequest", certInfo)
|
promise := conn.wsHandlers.Call("onCertificateRequest", certInfo)
|
||||||
|
|
||||||
resultChan := make(chan bool)
|
resultChan := make(chan bool, 1)
|
||||||
errorChan := make(chan error)
|
errorChan := make(chan error, 1)
|
||||||
|
|
||||||
promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
// Release from inside the callbacks so a post-timeout promise resolution
|
||||||
result := args[0].Bool()
|
// does not invoke an already-released func.
|
||||||
resultChan <- result
|
var thenFn, catchFn js.Func
|
||||||
|
var releaseOnce sync.Once
|
||||||
|
release := func() {
|
||||||
|
releaseOnce.Do(func() {
|
||||||
|
thenFn.Release()
|
||||||
|
catchFn.Release()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
thenFn = js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||||
|
defer release()
|
||||||
|
resultChan <- args[0].Bool()
|
||||||
return nil
|
return nil
|
||||||
})).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
})
|
||||||
|
catchFn = js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||||
|
defer release()
|
||||||
errorChan <- fmt.Errorf("certificate validation failed")
|
errorChan <- fmt.Errorf("certificate validation failed")
|
||||||
return nil
|
return nil
|
||||||
}))
|
})
|
||||||
|
|
||||||
|
promise.Call("then", thenFn).Call("catch", catchFn)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case result := <-resultChan:
|
case result := <-resultChan:
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"syscall/js"
|
"syscall/js"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -57,6 +58,8 @@ type RDCleanPathProxy struct {
|
|||||||
}
|
}
|
||||||
activeConnections map[string]*proxyConnection
|
activeConnections map[string]*proxyConnection
|
||||||
destinations map[string]string
|
destinations map[string]string
|
||||||
|
pendingHandlers map[string]js.Func
|
||||||
|
nextID atomic.Uint64
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,8 +69,15 @@ type proxyConnection struct {
|
|||||||
rdpConn net.Conn
|
rdpConn net.Conn
|
||||||
tlsConn *tls.Conn
|
tlsConn *tls.Conn
|
||||||
wsHandlers js.Value
|
wsHandlers js.Value
|
||||||
ctx context.Context
|
// Go-side callbacks exposed to JS. js.FuncOf pins the Go closure in a
|
||||||
cancel context.CancelFunc
|
// global handle map and MUST be released, otherwise every connection
|
||||||
|
// leaks the Go memory the closure captures.
|
||||||
|
wsHandlerFn js.Func
|
||||||
|
onMessageFn js.Func
|
||||||
|
onCloseFn js.Func
|
||||||
|
cleanupOnce sync.Once
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRDCleanPathProxy creates a new RDCleanPath proxy
|
// NewRDCleanPathProxy creates a new RDCleanPath proxy
|
||||||
@@ -80,7 +90,11 @@ func NewRDCleanPathProxy(client interface {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateProxy creates a new proxy endpoint for the given destination
|
// CreateProxy creates a new proxy endpoint for the given destination.
|
||||||
|
// The registered handler fn and its destinations/pendingHandlers entries are
|
||||||
|
// only released once a connection is established and cleanupConnection runs.
|
||||||
|
// If a caller invokes CreateProxy but never connects to the returned URL,
|
||||||
|
// those entries stay pinned for the lifetime of the page.
|
||||||
func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
||||||
destination := net.JoinHostPort(hostname, port)
|
destination := net.JoinHostPort(hostname, port)
|
||||||
|
|
||||||
@@ -88,7 +102,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
|||||||
resolve := args[0]
|
resolve := args[0]
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections))
|
proxyID := fmt.Sprintf("proxy_%d", p.nextID.Add(1))
|
||||||
|
|
||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
if p.destinations == nil {
|
if p.destinations == nil {
|
||||||
@@ -100,7 +114,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
|||||||
proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID)
|
proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID)
|
||||||
|
|
||||||
// Register the WebSocket handler for this specific proxy
|
// Register the WebSocket handler for this specific proxy
|
||||||
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any {
|
handlerFn := js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||||
if len(args) < 1 {
|
if len(args) < 1 {
|
||||||
return js.ValueOf("error: requires WebSocket argument")
|
return js.ValueOf("error: requires WebSocket argument")
|
||||||
}
|
}
|
||||||
@@ -108,7 +122,14 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
|||||||
ws := args[0]
|
ws := args[0]
|
||||||
p.HandleWebSocketConnection(ws, proxyID)
|
p.HandleWebSocketConnection(ws, proxyID)
|
||||||
return nil
|
return nil
|
||||||
}))
|
})
|
||||||
|
p.mu.Lock()
|
||||||
|
if p.pendingHandlers == nil {
|
||||||
|
p.pendingHandlers = make(map[string]js.Func)
|
||||||
|
}
|
||||||
|
p.pendingHandlers[proxyID] = handlerFn
|
||||||
|
p.mu.Unlock()
|
||||||
|
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), handlerFn)
|
||||||
|
|
||||||
log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination)
|
log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination)
|
||||||
resolve.Invoke(proxyURL)
|
resolve.Invoke(proxyURL)
|
||||||
@@ -142,6 +163,10 @@ func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string
|
|||||||
|
|
||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
p.activeConnections[proxyID] = conn
|
p.activeConnections[proxyID] = conn
|
||||||
|
if fn, ok := p.pendingHandlers[proxyID]; ok {
|
||||||
|
conn.wsHandlerFn = fn
|
||||||
|
delete(p.pendingHandlers, proxyID)
|
||||||
|
}
|
||||||
p.mu.Unlock()
|
p.mu.Unlock()
|
||||||
|
|
||||||
p.setupWebSocketHandlers(ws, conn)
|
p.setupWebSocketHandlers(ws, conn)
|
||||||
@@ -150,7 +175,7 @@ func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) {
|
func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) {
|
||||||
ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any {
|
conn.onMessageFn = js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
if len(args) < 1 {
|
if len(args) < 1 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -158,13 +183,15 @@ func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnec
|
|||||||
data := args[0]
|
data := args[0]
|
||||||
go p.handleWebSocketMessage(conn, data)
|
go p.handleWebSocketMessage(conn, data)
|
||||||
return nil
|
return nil
|
||||||
}))
|
})
|
||||||
|
ws.Set("onGoMessage", conn.onMessageFn)
|
||||||
|
|
||||||
ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any {
|
conn.onCloseFn = js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||||
log.Debug("WebSocket closed by JavaScript")
|
log.Debug("WebSocket closed by JavaScript")
|
||||||
conn.cancel()
|
conn.cancel()
|
||||||
return nil
|
return nil
|
||||||
}))
|
})
|
||||||
|
ws.Set("onGoClose", conn.onCloseFn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) {
|
func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) {
|
||||||
@@ -261,25 +288,49 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) {
|
func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) {
|
||||||
log.Debugf("Cleaning up connection %s", conn.id)
|
conn.cleanupOnce.Do(func() {
|
||||||
conn.cancel()
|
log.Debugf("Cleaning up connection %s", conn.id)
|
||||||
if conn.tlsConn != nil {
|
conn.cancel()
|
||||||
log.Debug("Closing TLS connection")
|
if conn.tlsConn != nil {
|
||||||
if err := conn.tlsConn.Close(); err != nil {
|
log.Debug("Closing TLS connection")
|
||||||
log.Debugf("Error closing TLS connection: %v", err)
|
if err := conn.tlsConn.Close(); err != nil {
|
||||||
|
log.Debugf("Error closing TLS connection: %v", err)
|
||||||
|
}
|
||||||
|
conn.tlsConn = nil
|
||||||
}
|
}
|
||||||
conn.tlsConn = nil
|
if conn.rdpConn != nil {
|
||||||
}
|
log.Debug("Closing TCP connection")
|
||||||
if conn.rdpConn != nil {
|
if err := conn.rdpConn.Close(); err != nil {
|
||||||
log.Debug("Closing TCP connection")
|
log.Debugf("Error closing TCP connection: %v", err)
|
||||||
if err := conn.rdpConn.Close(); err != nil {
|
}
|
||||||
log.Debugf("Error closing TCP connection: %v", err)
|
conn.rdpConn = nil
|
||||||
}
|
}
|
||||||
conn.rdpConn = nil
|
js.Global().Delete(fmt.Sprintf("handleRDCleanPathWebSocket_%s", conn.id))
|
||||||
}
|
|
||||||
p.mu.Lock()
|
// Detach before releasing so late JS calls surface as TypeError instead
|
||||||
delete(p.activeConnections, conn.id)
|
// of silent "call to released function".
|
||||||
p.mu.Unlock()
|
if conn.wsHandlers.Truthy() {
|
||||||
|
conn.wsHandlers.Set("onGoMessage", js.Undefined())
|
||||||
|
conn.wsHandlers.Set("onGoClose", js.Undefined())
|
||||||
|
}
|
||||||
|
|
||||||
|
// wsHandlerFn may be zero-value if the pending handler lookup missed.
|
||||||
|
if conn.wsHandlerFn.Truthy() {
|
||||||
|
conn.wsHandlerFn.Release()
|
||||||
|
}
|
||||||
|
if conn.onMessageFn.Truthy() {
|
||||||
|
conn.onMessageFn.Release()
|
||||||
|
}
|
||||||
|
if conn.onCloseFn.Truthy() {
|
||||||
|
conn.onCloseFn.Release()
|
||||||
|
}
|
||||||
|
|
||||||
|
p.mu.Lock()
|
||||||
|
delete(p.activeConnections, conn.id)
|
||||||
|
delete(p.destinations, conn.id)
|
||||||
|
delete(p.pendingHandlers, conn.id)
|
||||||
|
p.mu.Unlock()
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
|
func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
func CreateJSInterface(client *Client) js.Value {
|
func CreateJSInterface(client *Client) js.Value {
|
||||||
jsInterface := js.Global().Get("Object").Call("create", js.Null())
|
jsInterface := js.Global().Get("Object").Call("create", js.Null())
|
||||||
|
|
||||||
jsInterface.Set("write", js.FuncOf(func(this js.Value, args []js.Value) any {
|
writeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
if len(args) < 1 {
|
if len(args) < 1 {
|
||||||
return js.ValueOf(false)
|
return js.ValueOf(false)
|
||||||
}
|
}
|
||||||
@@ -32,9 +32,10 @@ func CreateJSInterface(client *Client) js.Value {
|
|||||||
|
|
||||||
_, err := client.Write(bytes)
|
_, err := client.Write(bytes)
|
||||||
return js.ValueOf(err == nil)
|
return js.ValueOf(err == nil)
|
||||||
}))
|
})
|
||||||
|
jsInterface.Set("write", writeFunc)
|
||||||
|
|
||||||
jsInterface.Set("resize", js.FuncOf(func(this js.Value, args []js.Value) any {
|
resizeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
if len(args) < 2 {
|
if len(args) < 2 {
|
||||||
return js.ValueOf(false)
|
return js.ValueOf(false)
|
||||||
}
|
}
|
||||||
@@ -42,14 +43,26 @@ func CreateJSInterface(client *Client) js.Value {
|
|||||||
rows := args[1].Int()
|
rows := args[1].Int()
|
||||||
err := client.Resize(cols, rows)
|
err := client.Resize(cols, rows)
|
||||||
return js.ValueOf(err == nil)
|
return js.ValueOf(err == nil)
|
||||||
}))
|
})
|
||||||
|
jsInterface.Set("resize", resizeFunc)
|
||||||
|
|
||||||
jsInterface.Set("close", js.FuncOf(func(this js.Value, args []js.Value) any {
|
closeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
client.Close()
|
client.Close()
|
||||||
return js.Undefined()
|
return js.Undefined()
|
||||||
}))
|
})
|
||||||
|
jsInterface.Set("close", closeFunc)
|
||||||
|
|
||||||
go readLoop(client, jsInterface)
|
go func() {
|
||||||
|
readLoop(client, jsInterface)
|
||||||
|
// Detach before releasing so late JS calls surface as TypeError instead
|
||||||
|
// of silent "call to released function".
|
||||||
|
jsInterface.Set("write", js.Undefined())
|
||||||
|
jsInterface.Set("resize", js.Undefined())
|
||||||
|
jsInterface.Set("close", js.Undefined())
|
||||||
|
writeFunc.Release()
|
||||||
|
resizeFunc.Release()
|
||||||
|
closeFunc.Release()
|
||||||
|
}()
|
||||||
|
|
||||||
return jsInterface
|
return jsInterface
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -332,7 +332,7 @@ func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) {
|
|||||||
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
|
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
|
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), s.IDPHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
|
||||||
if servers.relaySrv != nil {
|
if servers.relaySrv != nil {
|
||||||
log.Infof("Relay WebSocket handler added (path: /relay)")
|
log.Infof("Relay WebSocket handler added (path: /relay)")
|
||||||
}
|
}
|
||||||
@@ -521,7 +521,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
|
|||||||
}
|
}
|
||||||
|
|
||||||
// createCombinedHandler creates an HTTP handler that multiplexes Management, Signal (via wsproxy), and Relay WebSocket traffic
|
// createCombinedHandler creates an HTTP handler that multiplexes Management, Signal (via wsproxy), and Relay WebSocket traffic
|
||||||
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
|
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
|
||||||
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
|
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
|
||||||
|
|
||||||
var relayAcceptFn func(conn listener.Conn)
|
var relayAcceptFn func(conn listener.Conn)
|
||||||
@@ -556,6 +556,10 @@ func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, re
|
|||||||
http.Error(w, "Relay service not enabled", http.StatusNotFound)
|
http.Error(w, "Relay service not enabled", http.StatusNotFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Embedded IdP (Dex)
|
||||||
|
case idpHandler != nil && strings.HasPrefix(r.URL.Path, "/oauth2"):
|
||||||
|
idpHandler.ServeHTTP(w, r)
|
||||||
|
|
||||||
// Management HTTP API (default)
|
// Management HTTP API (default)
|
||||||
default:
|
default:
|
||||||
httpHandler.ServeHTTP(w, r)
|
httpHandler.ServeHTTP(w, r)
|
||||||
|
|||||||
@@ -53,6 +53,9 @@ type NameServerGroup struct {
|
|||||||
ID string `gorm:"primaryKey"`
|
ID string `gorm:"primaryKey"`
|
||||||
// AccountID is a reference to Account that this object belongs
|
// AccountID is a reference to Account that this object belongs
|
||||||
AccountID string `gorm:"index"`
|
AccountID string `gorm:"index"`
|
||||||
|
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||||
|
// compact wire id when sending NetworkMap components to capable peers.
|
||||||
|
AccountSeqID uint32 `json:"-" gorm:"index:idx_nameserver_groups_account_seq_id;not null;default:0"`
|
||||||
// Name group name
|
// Name group name
|
||||||
Name string
|
Name string
|
||||||
// Description group description
|
// Description group description
|
||||||
|
|||||||
12
go.mod
12
go.mod
@@ -3,7 +3,7 @@ module github.com/netbirdio/netbird
|
|||||||
go 1.25.5
|
go 1.25.5
|
||||||
|
|
||||||
require (
|
require (
|
||||||
cunicu.li/go-rosenpass v0.4.0
|
cunicu.li/go-rosenpass v0.5.42
|
||||||
github.com/cenkalti/backoff/v4 v4.3.0
|
github.com/cenkalti/backoff/v4 v4.3.0
|
||||||
github.com/cloudflare/circl v1.3.3 // indirect
|
github.com/cloudflare/circl v1.3.3 // indirect
|
||||||
github.com/golang/protobuf v1.5.4
|
github.com/golang/protobuf v1.5.4
|
||||||
@@ -19,8 +19,8 @@ require (
|
|||||||
github.com/vishvananda/netlink v1.3.1
|
github.com/vishvananda/netlink v1.3.1
|
||||||
golang.org/x/crypto v0.50.0
|
golang.org/x/crypto v0.50.0
|
||||||
golang.org/x/sys v0.43.0
|
golang.org/x/sys v0.43.0
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||||
google.golang.org/grpc v1.80.0
|
google.golang.org/grpc v1.80.0
|
||||||
google.golang.org/protobuf v1.36.11
|
google.golang.org/protobuf v1.36.11
|
||||||
@@ -38,7 +38,7 @@ require (
|
|||||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3
|
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3
|
||||||
github.com/c-robinson/iplib v1.0.3
|
github.com/c-robinson/iplib v1.0.3
|
||||||
github.com/caddyserver/certmagic v0.21.3
|
github.com/caddyserver/certmagic v0.21.3
|
||||||
github.com/cilium/ebpf v0.15.0
|
github.com/cilium/ebpf v0.19.0
|
||||||
github.com/coder/websocket v1.8.14
|
github.com/coder/websocket v1.8.14
|
||||||
github.com/coreos/go-iptables v0.7.0
|
github.com/coreos/go-iptables v0.7.0
|
||||||
github.com/coreos/go-oidc/v3 v3.18.0
|
github.com/coreos/go-oidc/v3 v3.18.0
|
||||||
@@ -60,7 +60,7 @@ require (
|
|||||||
github.com/google/go-cmp v0.7.0
|
github.com/google/go-cmp v0.7.0
|
||||||
github.com/google/gopacket v1.1.19
|
github.com/google/gopacket v1.1.19
|
||||||
github.com/google/nftables v0.3.0
|
github.com/google/nftables v0.3.0
|
||||||
github.com/gopacket/gopacket v1.1.1
|
github.com/gopacket/gopacket v1.4.0
|
||||||
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
|
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
|
||||||
github.com/hashicorp/go-multierror v1.1.1
|
github.com/hashicorp/go-multierror v1.1.1
|
||||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
||||||
@@ -335,7 +335,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
|||||||
|
|
||||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||||
|
|
||||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0
|
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f
|
||||||
|
|
||||||
replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||||
|
|
||||||
|
|||||||
26
go.sum
26
go.sum
@@ -7,8 +7,8 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdB
|
|||||||
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
|
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
|
||||||
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:b8xUw3004wk+3ipBhu0VU4RtUJsegMIiqjxSK4++lzA=
|
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:b8xUw3004wk+3ipBhu0VU4RtUJsegMIiqjxSK4++lzA=
|
||||||
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
|
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
|
||||||
cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw=
|
cunicu.li/go-rosenpass v0.5.42 h1:fRDsGwCxd7DhDgZI1Pxeo8GtNyq8BESZJ7w2/BGGJtU=
|
||||||
cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M=
|
cunicu.li/go-rosenpass v0.5.42/go.mod h1:YRBeyKOe/gWpSX2kpDUec5p9t0XOLsshTguId5gTGVg=
|
||||||
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
|
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
|
||||||
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
|
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
|
||||||
filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw=
|
filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw=
|
||||||
@@ -111,8 +111,8 @@ github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1x
|
|||||||
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
|
github.com/cilium/ebpf v0.19.0 h1:Ro/rE64RmFBeA9FGjcTc+KmCeY6jXmryu6FfnzPRIao=
|
||||||
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
|
github.com/cilium/ebpf v0.19.0/go.mod h1:fLCgMo3l8tZmAdM3B2XqdFzXBpwkcSTroaVqN08OWVY=
|
||||||
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
||||||
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
||||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||||
@@ -225,8 +225,8 @@ github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3Bum
|
|||||||
github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ=
|
github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ=
|
||||||
github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM=
|
github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM=
|
||||||
github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
|
github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
|
||||||
github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI=
|
github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6 h1:teYtXy9B7y5lHTp8V9KPxpYRAVA7dozigQcMiBust1s=
|
||||||
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
|
github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6/go.mod h1:p4lGIVX+8Wa6ZPNDvqcxq36XpUDLh42FLetFU7odllI=
|
||||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||||
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
|
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
|
||||||
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
|
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
|
||||||
@@ -307,8 +307,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA
|
|||||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
|
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
|
||||||
github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
|
github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
|
||||||
github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
|
github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
|
||||||
github.com/gopacket/gopacket v1.1.1 h1:zbx9F9d6A7sWNkFKrvMBZTfGgxFoY4NgUudFVVHMfcw=
|
github.com/gopacket/gopacket v1.4.0 h1:cr1OlFpzksCkZHNO0eLjaSSOrMQnpPXg0j6qHIY3y2U=
|
||||||
github.com/gopacket/gopacket v1.1.1/go.mod h1:HavMeONEl7W9036of9LbSWoonqhH7HA1+ZRO+rMIvFs=
|
github.com/gopacket/gopacket v1.4.0/go.mod h1:EpvsxINeehp5qj4YMKMLf2/dekdhKn2IIAO/ZOifS7o=
|
||||||
github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE=
|
github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE=
|
||||||
github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w=
|
github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w=
|
||||||
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
||||||
@@ -390,6 +390,8 @@ github.com/jonboulle/clockwork v0.5.0 h1:Hyh9A8u51kptdkR+cqRpT1EebBwTn1oK9YfGYbd
|
|||||||
github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60=
|
github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60=
|
||||||
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
||||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||||
|
github.com/jsimonetti/rtnetlink/v2 v2.0.1 h1:xda7qaHDSVOsADNouv7ukSuicKZO7GgVUCXxpaIEIlM=
|
||||||
|
github.com/jsimonetti/rtnetlink/v2 v2.0.1/go.mod h1:7MoNYNbb3UaDHtF8udiJo/RH6VsTKP1pqKLUTVCvToE=
|
||||||
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||||
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M=
|
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M=
|
||||||
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
|
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
|
||||||
@@ -499,8 +501,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
|
|||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ=
|
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f h1:ff2D57RBjWtyQ2wVwJOxOgXAXOe/J2lJWtSX0Bz/BRk=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||||
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
|
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
|
||||||
@@ -900,8 +902,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T
|
|||||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
||||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||||
|
|||||||
@@ -308,7 +308,7 @@ func (s *Storage) OpenStorage(logger *slog.Logger) (storage.Storage, error) {
|
|||||||
if file == "" {
|
if file == "" {
|
||||||
return nil, fmt.Errorf("sqlite3 storage requires 'file' config")
|
return nil, fmt.Errorf("sqlite3 storage requires 'file' config")
|
||||||
}
|
}
|
||||||
return (&sql.SQLite3{File: file}).Open(logger)
|
return newSQLite3(file).Open(logger)
|
||||||
case "postgres":
|
case "postgres":
|
||||||
dsn, _ := s.Config["dsn"].(string)
|
dsn, _ := s.Config["dsn"].(string)
|
||||||
if dsn == "" {
|
if dsn == "" {
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import (
|
|||||||
"github.com/dexidp/dex/server"
|
"github.com/dexidp/dex/server"
|
||||||
"github.com/dexidp/dex/server/signer"
|
"github.com/dexidp/dex/server/signer"
|
||||||
"github.com/dexidp/dex/storage"
|
"github.com/dexidp/dex/storage"
|
||||||
"github.com/dexidp/dex/storage/sql"
|
|
||||||
jose "github.com/go-jose/go-jose/v4"
|
jose "github.com/go-jose/go-jose/v4"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
@@ -77,7 +76,7 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
|
|||||||
|
|
||||||
// Initialize SQLite storage
|
// Initialize SQLite storage
|
||||||
dbPath := filepath.Join(config.DataDir, "oidc.db")
|
dbPath := filepath.Join(config.DataDir, "oidc.db")
|
||||||
sqliteConfig := &sql.SQLite3{File: dbPath}
|
sqliteConfig := newSQLite3(dbPath)
|
||||||
stor, err := sqliteConfig.Open(logger)
|
stor, err := sqliteConfig.Open(logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to open storage: %w", err)
|
return nil, fmt.Errorf("failed to open storage: %w", err)
|
||||||
|
|||||||
15
idp/dex/sqlite_cgo.go
Normal file
15
idp/dex/sqlite_cgo.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
//go:build cgo
|
||||||
|
|
||||||
|
package dex
|
||||||
|
|
||||||
|
import (
|
||||||
|
sql "github.com/dexidp/dex/storage/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newSQLite3 builds the dex SQLite3 config. CGO builds use the upstream
|
||||||
|
// struct that takes a File path. Non-CGO builds get an empty stub whose
|
||||||
|
// Open() returns the dex "SQLite not available" error — correct behaviour
|
||||||
|
// for binaries that can't link sqlite3 (e.g. cross-compiled ARM targets).
|
||||||
|
func newSQLite3(file string) *sql.SQLite3 {
|
||||||
|
return &sql.SQLite3{File: file}
|
||||||
|
}
|
||||||
15
idp/dex/sqlite_nocgo.go
Normal file
15
idp/dex/sqlite_nocgo.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
//go:build !cgo
|
||||||
|
|
||||||
|
package dex
|
||||||
|
|
||||||
|
import (
|
||||||
|
sql "github.com/dexidp/dex/storage/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newSQLite3 for non-CGO builds. The dex SQLite3 stub has no fields and its
|
||||||
|
// Open() returns an error documenting the missing CGO support — correct
|
||||||
|
// behaviour for cross-compiled artefacts that never actually run the
|
||||||
|
// embedded IdP. The `file` argument is ignored.
|
||||||
|
func newSQLite3(_ string) *sql.SQLite3 {
|
||||||
|
return &sql.SQLite3{}
|
||||||
|
}
|
||||||
@@ -55,6 +55,12 @@ type Controller struct {
|
|||||||
proxyController port_forwarding.Controller
|
proxyController port_forwarding.Controller
|
||||||
|
|
||||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||||
|
|
||||||
|
// componentsDisabled, when true, forces the controller to emit legacy
|
||||||
|
// proto.NetworkMap to every peer regardless of capability. Set once at
|
||||||
|
// construction and never written after — readers race-free without a
|
||||||
|
// mutex.
|
||||||
|
componentsDisabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type bufferUpdate struct {
|
type bufferUpdate struct {
|
||||||
@@ -81,12 +87,27 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
|||||||
settingsManager: settingsManager,
|
settingsManager: settingsManager,
|
||||||
dnsDomain: dnsDomain,
|
dnsDomain: dnsDomain,
|
||||||
config: config,
|
config: config,
|
||||||
|
componentsDisabled: parseBoolEnv("NB_NETWORK_MAP_COMPONENTS_DISABLE"),
|
||||||
|
|
||||||
proxyController: proxyController,
|
proxyController: proxyController,
|
||||||
EphemeralPeersManager: ephemeralPeersManager,
|
EphemeralPeersManager: ephemeralPeersManager,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PeerNeedsComponents reports whether the gRPC layer should emit the
|
||||||
|
// component-based wire format for this peer.
|
||||||
|
func (c *Controller) PeerNeedsComponents(p *nbpeer.Peer) bool {
|
||||||
|
return p != nil && p.SupportsComponentNetworkMap() && !c.componentsDisabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseBoolEnv reads an env var via strconv.ParseBool so callers accept the
|
||||||
|
// usual "1/t/T/TRUE/true/True" set instead of being strict about a single
|
||||||
|
// literal.
|
||||||
|
func parseBoolEnv(key string) bool {
|
||||||
|
v, _ := strconv.ParseBool(os.Getenv(key))
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) {
|
func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) {
|
||||||
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
|
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -112,7 +133,7 @@ func (c *Controller) CountStreams() int {
|
|||||||
return c.peersUpdateManager.CountStreams()
|
return c.peersUpdateManager.CountStreams()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
||||||
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -175,6 +196,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.accountManagerMetrics != nil {
|
||||||
|
c.accountManagerMetrics.CountNmapTriggered(string(reason.Resource), string(reason.Operation))
|
||||||
|
}
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
semaphore <- struct{}{}
|
semaphore <- struct{}{}
|
||||||
go func(p *nbpeer.Peer) {
|
go func(p *nbpeer.Peer) {
|
||||||
@@ -192,18 +217,26 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
||||||
start = time.Now()
|
start = time.Now()
|
||||||
|
|
||||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
result := account.GetPeerNetworkMapResult(ctx, p.ID, c.componentsDisabled, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
|
||||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap := proxyNetworkMaps[p.ID]
|
||||||
if ok {
|
if result.NetworkMap != nil && proxyNetworkMap != nil {
|
||||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
result.NetworkMap.Merge(proxyNetworkMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
peerGroups := account.GetPeerGroups(p.ID)
|
peerGroups := account.GetPeerGroups(p.ID)
|
||||||
start = time.Now()
|
start = time.Now()
|
||||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
var update *proto.SyncResponse
|
||||||
|
if result.IsComponents() {
|
||||||
|
// proxyNetworkMap rides the envelope as a ProxyPatch sidecar;
|
||||||
|
// the client merges it into Calculate()'s output the same
|
||||||
|
// way the legacy server did via NetworkMap.Merge.
|
||||||
|
update = grpc.ToComponentSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, result.Components, proxyNetworkMap, dnsDomain, postureChecks, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||||
|
} else {
|
||||||
|
update = grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, result.NetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||||
|
}
|
||||||
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
||||||
|
|
||||||
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
|
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
|
||||||
@@ -242,14 +275,14 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer b.mu.Unlock()
|
defer b.mu.Unlock()
|
||||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||||
if !b.update.Load() {
|
if !b.update.Load() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
b.update.Store(false)
|
b.update.Store(false)
|
||||||
if b.next == nil {
|
if b.next == nil {
|
||||||
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -265,7 +298,7 @@ func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, r
|
|||||||
if c.accountManagerMetrics != nil {
|
if c.accountManagerMetrics != nil {
|
||||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||||
}
|
}
|
||||||
return c.sendUpdateAccountPeers(ctx, accountID)
|
return c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
|
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
|
||||||
@@ -314,11 +347,11 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
result := account.GetPeerNetworkMapResult(ctx, peerId, c.componentsDisabled, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap := proxyNetworkMaps[peer.ID]
|
||||||
if ok {
|
if result.NetworkMap != nil && proxyNetworkMap != nil {
|
||||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
result.NetworkMap.Merge(proxyNetworkMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID)
|
extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID)
|
||||||
@@ -329,7 +362,12 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
peerGroups := account.GetPeerGroups(peerId)
|
peerGroups := account.GetPeerGroups(peerId)
|
||||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||||
|
|
||||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
var update *proto.SyncResponse
|
||||||
|
if result.IsComponents() {
|
||||||
|
update = grpc.ToComponentSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, result.Components, proxyNetworkMap, dnsDomain, postureChecks, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||||
|
} else {
|
||||||
|
update = grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, result.NetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||||
|
}
|
||||||
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
|
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
|
||||||
Update: update,
|
Update: update,
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
@@ -359,14 +397,14 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer b.mu.Unlock()
|
defer b.mu.Unlock()
|
||||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||||
if !b.update.Load() {
|
if !b.update.Load() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
b.update.Store(false)
|
b.update.Store(false)
|
||||||
if b.next == nil {
|
if b.next == nil {
|
||||||
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -376,6 +414,66 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetValidatedPeerWithComponents is the components-format counterpart of
|
||||||
|
// GetValidatedPeerWithMap. It returns raw NetworkMapComponents for capable
|
||||||
|
// peers along with the proxy NetworkMap fragment (BYOP / port-forwarding
|
||||||
|
// data the legacy server folds in via NetworkMap.Merge). The gRPC layer
|
||||||
|
// encodes both into the wire envelope. Callers must gate on capability
|
||||||
|
// themselves before dispatching here — this method does NOT branch on it.
|
||||||
|
func (c *Controller) GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||||
|
if isRequiresApproval {
|
||||||
|
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, 0, err
|
||||||
|
}
|
||||||
|
return peer, &types.NetworkMapComponents{Network: network.Copy()}, nil, nil, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
account.InjectProxyPolicies(ctx)
|
||||||
|
|
||||||
|
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch the proxy network map fragment for this peer alongside the
|
||||||
|
// components — same single-account-load path the streaming controller
|
||||||
|
// uses, so initial-sync delivers BYOP/forwarding patches synchronously
|
||||||
|
// instead of waiting for the next streaming push.
|
||||||
|
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||||
|
return nil, nil, nil, nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||||
|
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||||
|
|
||||||
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
|
routers := account.GetResourceRoutersMap()
|
||||||
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
|
components := account.GetPeerNetworkMapComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, groupIDToUserIDs)
|
||||||
|
|
||||||
|
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||||
|
|
||||||
|
return peer, components, proxyNetworkMaps[peer.ID], postureChecks, dnsFwdPort, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||||
if isRequiresApproval {
|
if isRequiresApproval {
|
||||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||||
|
|||||||
@@ -22,6 +22,10 @@ type Controller interface {
|
|||||||
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
|
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
|
||||||
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
|
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
|
||||||
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||||
|
GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||||
|
// PeerNeedsComponents combines the peer's advertised capability with the
|
||||||
|
// kill-switch flag — the only public predicate gRPC layers should ask.
|
||||||
|
PeerNeedsComponents(p *nbpeer.Peer) bool
|
||||||
GetDNSDomain(settings *types.Settings) string
|
GetDNSDomain(settings *types.Settings) string
|
||||||
StartWarmup(context.Context)
|
StartWarmup(context.Context)
|
||||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||||
|
|||||||
@@ -130,6 +130,39 @@ func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApp
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetValidatedPeerWithComponents mocks base method.
|
||||||
|
func (m *MockController) GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetValidatedPeerWithComponents", ctx, isRequiresApproval, accountID, p)
|
||||||
|
ret0, _ := ret[0].(*peer.Peer)
|
||||||
|
ret1, _ := ret[1].(*types.NetworkMapComponents)
|
||||||
|
ret2, _ := ret[2].(*types.NetworkMap)
|
||||||
|
ret3, _ := ret[3].([]*posture.Checks)
|
||||||
|
ret4, _ := ret[4].(int64)
|
||||||
|
ret5, _ := ret[5].(error)
|
||||||
|
return ret0, ret1, ret2, ret3, ret4, ret5
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetValidatedPeerWithComponents indicates an expected call of GetValidatedPeerWithComponents.
|
||||||
|
func (mr *MockControllerMockRecorder) GetValidatedPeerWithComponents(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithComponents", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithComponents), ctx, isRequiresApproval, accountID, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeerNeedsComponents mocks base method.
|
||||||
|
func (m *MockController) PeerNeedsComponents(p *peer.Peer) bool {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "PeerNeedsComponents", p)
|
||||||
|
ret0, _ := ret[0].(bool)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeerNeedsComponents indicates an expected call of PeerNeedsComponents.
|
||||||
|
func (mr *MockControllerMockRecorder) PeerNeedsComponents(p any) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeerNeedsComponents", reflect.TypeOf((*MockController)(nil).PeerNeedsComponents), p)
|
||||||
|
}
|
||||||
|
|
||||||
// OnPeerConnected mocks base method.
|
// OnPeerConnected mocks base method.
|
||||||
func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) {
|
func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
|||||||
found = true
|
found = true
|
||||||
select {
|
select {
|
||||||
case channel <- update:
|
case channel <- update:
|
||||||
log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID)
|
log.WithContext(ctx).Tracef("update was sent to channel for peer %s", peerID)
|
||||||
default:
|
default:
|
||||||
dropped = true
|
dropped = true
|
||||||
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))
|
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
)
|
)
|
||||||
@@ -47,6 +48,11 @@ type EphemeralManager struct {
|
|||||||
|
|
||||||
lifeTime time.Duration
|
lifeTime time.Duration
|
||||||
cleanupWindow time.Duration
|
cleanupWindow time.Duration
|
||||||
|
|
||||||
|
// metrics is nil-safe; methods on telemetry.EphemeralPeersMetrics
|
||||||
|
// no-op when the receiver is nil so deployments without an app
|
||||||
|
// metrics provider work unchanged.
|
||||||
|
metrics *telemetry.EphemeralPeersMetrics
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewEphemeralManager instantiate new EphemeralManager
|
// NewEphemeralManager instantiate new EphemeralManager
|
||||||
@@ -60,6 +66,15 @@ func NewEphemeralManager(store store.Store, peersManager peers.Manager) *Ephemer
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetMetrics attaches a metrics collector. Safe to call once before
|
||||||
|
// LoadInitialPeers; later attachment is fine but earlier loads won't be
|
||||||
|
// reflected in the gauge. Pass nil to detach.
|
||||||
|
func (e *EphemeralManager) SetMetrics(m *telemetry.EphemeralPeersMetrics) {
|
||||||
|
e.peersLock.Lock()
|
||||||
|
e.metrics = m
|
||||||
|
e.peersLock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head
|
// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head
|
||||||
// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new
|
// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new
|
||||||
// head.
|
// head.
|
||||||
@@ -97,7 +112,9 @@ func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Pee
|
|||||||
e.peersLock.Lock()
|
e.peersLock.Lock()
|
||||||
defer e.peersLock.Unlock()
|
defer e.peersLock.Unlock()
|
||||||
|
|
||||||
e.removePeer(peer.ID)
|
if e.removePeer(peer.ID) {
|
||||||
|
e.metrics.DecPending(1)
|
||||||
|
}
|
||||||
|
|
||||||
// stop the unnecessary timer
|
// stop the unnecessary timer
|
||||||
if e.headPeer == nil && e.timer != nil {
|
if e.headPeer == nil && e.timer != nil {
|
||||||
@@ -123,6 +140,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
|
|||||||
}
|
}
|
||||||
|
|
||||||
e.addPeer(peer.AccountID, peer.ID, e.newDeadLine())
|
e.addPeer(peer.AccountID, peer.ID, e.newDeadLine())
|
||||||
|
e.metrics.IncPending()
|
||||||
if e.timer == nil {
|
if e.timer == nil {
|
||||||
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
|
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
|
||||||
if delay < 0 {
|
if delay < 0 {
|
||||||
@@ -145,6 +163,7 @@ func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
|
|||||||
for _, p := range peers {
|
for _, p := range peers {
|
||||||
e.addPeer(p.AccountID, p.ID, t)
|
e.addPeer(p.AccountID, p.ID, t)
|
||||||
}
|
}
|
||||||
|
e.metrics.AddPending(int64(len(peers)))
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", len(peers))
|
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", len(peers))
|
||||||
}
|
}
|
||||||
@@ -181,6 +200,15 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
|
|||||||
|
|
||||||
e.peersLock.Unlock()
|
e.peersLock.Unlock()
|
||||||
|
|
||||||
|
// Drop the gauge by the number of entries we just took off the list,
|
||||||
|
// regardless of whether the subsequent DeletePeers call succeeds. The
|
||||||
|
// list invariant is what the gauge tracks; failed delete batches are
|
||||||
|
// counted separately via CountCleanupError so we can still see them.
|
||||||
|
if len(deletePeers) > 0 {
|
||||||
|
e.metrics.CountCleanupRun()
|
||||||
|
e.metrics.DecPending(int64(len(deletePeers)))
|
||||||
|
}
|
||||||
|
|
||||||
peerIDsPerAccount := make(map[string][]string)
|
peerIDsPerAccount := make(map[string][]string)
|
||||||
for id, p := range deletePeers {
|
for id, p := range deletePeers {
|
||||||
peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id)
|
peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id)
|
||||||
@@ -191,7 +219,10 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
|
|||||||
err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
|
err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to delete ephemeral peers: %s", err)
|
log.WithContext(ctx).Errorf("failed to delete ephemeral peers: %s", err)
|
||||||
|
e.metrics.CountCleanupError()
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
e.metrics.CountPeersCleaned(int64(len(peerIDs)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -211,9 +242,12 @@ func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline tim
|
|||||||
e.tailPeer = ep
|
e.tailPeer = ep
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *EphemeralManager) removePeer(id string) {
|
// removePeer drops the entry from the linked list. Returns true if a
|
||||||
|
// matching entry was found and removed so callers can keep the pending
|
||||||
|
// metric gauge in sync.
|
||||||
|
func (e *EphemeralManager) removePeer(id string) bool {
|
||||||
if e.headPeer == nil {
|
if e.headPeer == nil {
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.headPeer.id == id {
|
if e.headPeer.id == id {
|
||||||
@@ -221,7 +255,7 @@ func (e *EphemeralManager) removePeer(id string) {
|
|||||||
if e.tailPeer.id == id {
|
if e.tailPeer.id == id {
|
||||||
e.tailPeer = nil
|
e.tailPeer = nil
|
||||||
}
|
}
|
||||||
return
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
for p := e.headPeer; p.next != nil; p = p.next {
|
for p := e.headPeer; p.next != nil; p = p.next {
|
||||||
@@ -231,9 +265,10 @@ func (e *EphemeralManager) removePeer(id string) {
|
|||||||
e.tailPeer = p
|
e.tailPeer = p
|
||||||
}
|
}
|
||||||
p.next = p.next.next
|
p.next = p.next.next
|
||||||
return
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *EphemeralManager) isPeerOnList(id string) bool {
|
func (e *EphemeralManager) isPeerOnList(id string) bool {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ package peers
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
@@ -35,6 +36,14 @@ type Manager interface {
|
|||||||
SetAccountManager(accountManager account.Manager)
|
SetAccountManager(accountManager account.Manager)
|
||||||
GetPeerID(ctx context.Context, peerKey string) (string, error)
|
GetPeerID(ctx context.Context, peerKey string) (string, error)
|
||||||
CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error
|
CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error
|
||||||
|
// GetPeerByTunnelIP looks up a peer in accountID by its WireGuard tunnel IP.
|
||||||
|
// Returns nil with an error when no match exists. No permission check;
|
||||||
|
// callers (the proxy's ValidateTunnelPeer RPC) are trusted server components.
|
||||||
|
GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error)
|
||||||
|
// GetPeerWithGroups returns the peer and the list of *types.Group it belongs
|
||||||
|
// to. Used by the proxy's auth path to authorise a request by the calling
|
||||||
|
// peer's group memberships.
|
||||||
|
GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type managerImpl struct {
|
type managerImpl struct {
|
||||||
@@ -99,6 +108,26 @@ func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string,
|
|||||||
return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
|
return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPeerByTunnelIP delegates to the store's indexed lookup.
|
||||||
|
func (m *managerImpl) GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error) {
|
||||||
|
return m.store.GetPeerByIP(ctx, store.LockingStrengthNone, accountID, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeerWithGroups returns the peer plus its group memberships. Any store
|
||||||
|
// error returns (nil, nil, err) so callers never receive a valid peer
|
||||||
|
// alongside a non-nil error.
|
||||||
|
func (m *managerImpl) GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error) {
|
||||||
|
p, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
groups, err := m.store.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return p, groups, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ package peers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
context "context"
|
context "context"
|
||||||
|
net "net"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
@@ -13,6 +14,7 @@ import (
|
|||||||
account "github.com/netbirdio/netbird/management/server/account"
|
account "github.com/netbirdio/netbird/management/server/account"
|
||||||
integrated_validator "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
integrated_validator "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||||
peer "github.com/netbirdio/netbird/management/server/peer"
|
peer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
types "github.com/netbirdio/netbird/management/server/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockManager is a mock of Manager interface.
|
// MockManager is a mock of Manager interface.
|
||||||
@@ -38,6 +40,20 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder {
|
|||||||
return m.recorder
|
return m.recorder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateProxyPeer mocks base method.
|
||||||
|
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID, peerKey, cluster string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
|
||||||
|
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
|
||||||
|
}
|
||||||
|
|
||||||
// DeletePeers mocks base method.
|
// DeletePeers mocks base method.
|
||||||
func (m *MockManager) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
func (m *MockManager) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -97,6 +113,21 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPeerByTunnelIP mocks base method.
|
||||||
|
func (m *MockManager) GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetPeerByTunnelIP", ctx, accountID, ip)
|
||||||
|
ret0, _ := ret[0].(*peer.Peer)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeerByTunnelIP indicates an expected call of GetPeerByTunnelIP.
|
||||||
|
func (mr *MockManagerMockRecorder) GetPeerByTunnelIP(ctx, accountID, ip interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerByTunnelIP", reflect.TypeOf((*MockManager)(nil).GetPeerByTunnelIP), ctx, accountID, ip)
|
||||||
|
}
|
||||||
|
|
||||||
// GetPeerID mocks base method.
|
// GetPeerID mocks base method.
|
||||||
func (m *MockManager) GetPeerID(ctx context.Context, peerKey string) (string, error) {
|
func (m *MockManager) GetPeerID(ctx context.Context, peerKey string) (string, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -112,6 +143,22 @@ func (mr *MockManagerMockRecorder) GetPeerID(ctx, peerKey interface{}) *gomock.C
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerID", reflect.TypeOf((*MockManager)(nil).GetPeerID), ctx, peerKey)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerID", reflect.TypeOf((*MockManager)(nil).GetPeerID), ctx, peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPeerWithGroups mocks base method.
|
||||||
|
func (m *MockManager) GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetPeerWithGroups", ctx, accountID, peerID)
|
||||||
|
ret0, _ := ret[0].(*peer.Peer)
|
||||||
|
ret1, _ := ret[1].([]*types.Group)
|
||||||
|
ret2, _ := ret[2].(error)
|
||||||
|
return ret0, ret1, ret2
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeerWithGroups indicates an expected call of GetPeerWithGroups.
|
||||||
|
func (mr *MockManagerMockRecorder) GetPeerWithGroups(ctx, accountID, peerID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerWithGroups", reflect.TypeOf((*MockManager)(nil).GetPeerWithGroups), ctx, accountID, peerID)
|
||||||
|
}
|
||||||
|
|
||||||
// GetPeersByGroupIDs mocks base method.
|
// GetPeersByGroupIDs mocks base method.
|
||||||
func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
|
func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -162,17 +209,3 @@ func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController
|
|||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateProxyPeer mocks base method.
|
|
||||||
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
|
|
||||||
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ type Domain struct {
|
|||||||
// SupportsCrowdSec is populated at query time from proxy cluster capabilities.
|
// SupportsCrowdSec is populated at query time from proxy cluster capabilities.
|
||||||
// Not persisted.
|
// Not persisted.
|
||||||
SupportsCrowdSec *bool `gorm:"-"`
|
SupportsCrowdSec *bool `gorm:"-"`
|
||||||
|
// SupportsPrivate is populated at query time from proxy cluster capabilities. Not persisted.
|
||||||
|
SupportsPrivate *bool `gorm:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// EventMeta returns activity event metadata for a domain
|
// EventMeta returns activity event metadata for a domain
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
|
|||||||
SupportsCustomPorts: d.SupportsCustomPorts,
|
SupportsCustomPorts: d.SupportsCustomPorts,
|
||||||
RequireSubdomain: d.RequireSubdomain,
|
RequireSubdomain: d.RequireSubdomain,
|
||||||
SupportsCrowdsec: d.SupportsCrowdSec,
|
SupportsCrowdsec: d.SupportsCrowdSec,
|
||||||
|
SupportsPrivate: d.SupportsPrivate,
|
||||||
}
|
}
|
||||||
if d.TargetCluster != "" {
|
if d.TargetCluster != "" {
|
||||||
resp.TargetCluster = &d.TargetCluster
|
resp.TargetCluster = &d.TargetCluster
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ type proxyManager interface {
|
|||||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||||
|
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
@@ -93,6 +94,7 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
|||||||
d.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, cluster)
|
d.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, cluster)
|
||||||
d.RequireSubdomain = m.proxyManager.ClusterRequireSubdomain(ctx, cluster)
|
d.RequireSubdomain = m.proxyManager.ClusterRequireSubdomain(ctx, cluster)
|
||||||
d.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, cluster)
|
d.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, cluster)
|
||||||
|
d.SupportsPrivate = m.proxyManager.ClusterSupportsPrivate(ctx, cluster)
|
||||||
ret = append(ret, d)
|
ret = append(ret, d)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,6 +111,7 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
|||||||
if d.TargetCluster != "" {
|
if d.TargetCluster != "" {
|
||||||
cd.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, d.TargetCluster)
|
cd.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, d.TargetCluster)
|
||||||
cd.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, d.TargetCluster)
|
cd.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, d.TargetCluster)
|
||||||
|
cd.SupportsPrivate = m.proxyManager.ClusterSupportsPrivate(ctx, d.TargetCluster)
|
||||||
}
|
}
|
||||||
// Custom domains never require a subdomain by default since
|
// Custom domains never require a subdomain by default since
|
||||||
// the account owns them and should be able to use the bare domain.
|
// the account owns them and should be able to use the bare domain.
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type mockProxyManager struct {
|
type mockProxyManager struct {
|
||||||
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||||
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
|
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -40,6 +40,10 @@ func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockProxyManager) ClusterSupportsPrivate(_ context.Context, _ string) *bool {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) {
|
func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) {
|
||||||
pm := &mockProxyManager{
|
pm := &mockProxyManager{
|
||||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
|
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
|
||||||
@@ -151,4 +155,3 @@ func TestGetClusterAllowList_PublicEmpty_BYOPOnly(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []string{"byop.example.com"}, result)
|
assert.Equal(t, []string{"byop.example.com"}, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ type Manager interface {
|
|||||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||||
|
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||||
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
|
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
|
||||||
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
|
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
|
||||||
CountAccountProxies(ctx context.Context, accountID string) (int64, error)
|
CountAccountProxies(ctx context.Context, accountID string) (int64, error)
|
||||||
|
|||||||
@@ -17,10 +17,11 @@ type store interface {
|
|||||||
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
||||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||||
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||||
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
|
GetProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
|
||||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||||
|
GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||||
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||||
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
||||||
@@ -137,6 +138,11 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string
|
|||||||
return m.store.GetClusterSupportsCrowdSec(ctx, clusterAddr)
|
return m.store.GetClusterSupportsCrowdSec(ctx, clusterAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClusterSupportsPrivate reports whether any active proxy claims the private capability (nil = unreported).
|
||||||
|
func (m Manager) ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
|
||||||
|
return m.store.GetClusterSupportsPrivate(ctx, clusterAddr)
|
||||||
|
}
|
||||||
|
|
||||||
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
|
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
|
||||||
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||||
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
|
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
|
||||||
@@ -178,4 +184,3 @@ func (m *Manager) DeleteAccountCluster(ctx context.Context, clusterAddress, acco
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,16 +15,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type mockStore struct {
|
type mockStore struct {
|
||||||
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
|
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||||
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
|
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
|
||||||
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
|
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||||
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||||
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
|
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||||
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
|
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
|
||||||
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||||
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
|
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
|
||||||
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||||
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
|
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||||
@@ -57,7 +57,7 @@ func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context
|
|||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
func (m *mockStore) GetActiveProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) {
|
func (m *mockStore) GetProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error {
|
func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error {
|
||||||
@@ -99,6 +99,9 @@ func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *boo
|
|||||||
func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
|
func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (m *mockStore) GetClusterSupportsPrivate(_ context.Context, _ string) *bool {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func newTestManager(s store) *Manager {
|
func newTestManager(s store) *Manager {
|
||||||
meter := noop.NewMeterProvider().Meter("test")
|
meter := noop.NewMeterProvider().Meter("test")
|
||||||
|
|||||||
@@ -92,6 +92,20 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCrowdSec", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCrowdSec), ctx, clusterAddr)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCrowdSec", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCrowdSec), ctx, clusterAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClusterSupportsPrivate mocks base method.
|
||||||
|
func (m *MockManager) ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "ClusterSupportsPrivate", ctx, clusterAddr)
|
||||||
|
ret0, _ := ret[0].(*bool)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClusterSupportsPrivate indicates an expected call of ClusterSupportsPrivate.
|
||||||
|
func (mr *MockManagerMockRecorder) ClusterSupportsPrivate(ctx, clusterAddr interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsPrivate", reflect.TypeOf((*MockManager)(nil).ClusterSupportsPrivate), ctx, clusterAddr)
|
||||||
|
}
|
||||||
|
|
||||||
// Connect mocks base method.
|
// Connect mocks base method.
|
||||||
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) {
|
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|||||||
@@ -20,6 +20,9 @@ type Capabilities struct {
|
|||||||
RequireSubdomain *bool
|
RequireSubdomain *bool
|
||||||
// SupportsCrowdsec indicates whether this proxy has CrowdSec configured.
|
// SupportsCrowdsec indicates whether this proxy has CrowdSec configured.
|
||||||
SupportsCrowdsec *bool
|
SupportsCrowdsec *bool
|
||||||
|
// Private indicates whether this proxy supports inbound access via Wireguard
|
||||||
|
// tunnel and netbird-only authentication policies
|
||||||
|
Private *bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Proxy represents a reverse proxy instance
|
// Proxy represents a reverse proxy instance
|
||||||
@@ -42,10 +45,34 @@ func (Proxy) TableName() string {
|
|||||||
return "proxies"
|
return "proxies"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClusterType is the source of a proxy cluster.
|
||||||
|
type ClusterType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ClusterTypeAccount is a cluster operated by the account itself (BYOP) —
|
||||||
|
// at least one proxy row in the cluster carries a non-NULL account_id.
|
||||||
|
ClusterTypeAccount ClusterType = "account"
|
||||||
|
// ClusterTypeShared is a cluster operated by NetBird and shared across
|
||||||
|
// accounts — all proxy rows in the cluster have account_id IS NULL.
|
||||||
|
ClusterTypeShared ClusterType = "shared"
|
||||||
|
)
|
||||||
|
|
||||||
// Cluster represents a group of proxy nodes serving the same address.
|
// Cluster represents a group of proxy nodes serving the same address.
|
||||||
|
//
|
||||||
|
// Online and ConnectedProxies derive from the same 2-min active window
|
||||||
|
// the rest of the module uses, but Cluster rows are not gated on it —
|
||||||
|
// the cluster listing surfaces offline clusters too so operators can
|
||||||
|
// see and clean them up. The 1-hour heartbeat reaper still bounds the
|
||||||
|
// table eventually.
|
||||||
type Cluster struct {
|
type Cluster struct {
|
||||||
ID string
|
ID string
|
||||||
Address string
|
Address string
|
||||||
|
Type ClusterType
|
||||||
|
Online bool
|
||||||
ConnectedProxies int
|
ConnectedProxies int
|
||||||
SelfHosted bool
|
// *bool: nil = no proxy reported the capability; the dashboard renders that as unknown.
|
||||||
|
SupportsCustomPorts *bool
|
||||||
|
RequireSubdomain *bool
|
||||||
|
SupportsCrowdSec *bool
|
||||||
|
Private *bool
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
|
GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
|
||||||
DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error
|
DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error
|
||||||
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
|
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
|
||||||
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
|
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
|
||||||
|
|||||||
@@ -65,20 +65,6 @@ func (mr *MockManagerMockRecorder) CreateServiceFromPeer(ctx, accountID, peerID,
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateServiceFromPeer", reflect.TypeOf((*MockManager)(nil).CreateServiceFromPeer), ctx, accountID, peerID, req)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateServiceFromPeer", reflect.TypeOf((*MockManager)(nil).CreateServiceFromPeer), ctx, accountID, peerID, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteAllServices mocks base method.
|
|
||||||
func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "DeleteAllServices", ctx, accountID, userID)
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteAllServices indicates an expected call of DeleteAllServices.
|
|
||||||
func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteAccountCluster mocks base method.
|
// DeleteAccountCluster mocks base method.
|
||||||
func (m *MockManager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
|
func (m *MockManager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -93,6 +79,20 @@ func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, accountID, userID,
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, accountID, userID, clusterAddress)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, accountID, userID, clusterAddress)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteAllServices mocks base method.
|
||||||
|
func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "DeleteAllServices", ctx, accountID, userID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAllServices indicates an expected call of DeleteAllServices.
|
||||||
|
func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteService mocks base method.
|
// DeleteService mocks base method.
|
||||||
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -122,21 +122,6 @@ func (mr *MockManagerMockRecorder) GetAccountServices(ctx, accountID interface{}
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockManager)(nil).GetAccountServices), ctx, accountID)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockManager)(nil).GetAccountServices), ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetActiveClusters mocks base method.
|
|
||||||
func (m *MockManager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "GetActiveClusters", ctx, accountID, userID)
|
|
||||||
ret0, _ := ret[0].([]proxy.Cluster)
|
|
||||||
ret1, _ := ret[1].(error)
|
|
||||||
return ret0, ret1
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetActiveClusters indicates an expected call of GetActiveClusters.
|
|
||||||
func (mr *MockManagerMockRecorder) GetActiveClusters(ctx, accountID, userID interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx, accountID, userID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllServices mocks base method.
|
// GetAllServices mocks base method.
|
||||||
func (m *MockManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) {
|
func (m *MockManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -152,19 +137,19 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetServiceByDomain mocks base method.
|
// GetClusters mocks base method.
|
||||||
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
|
func (m *MockManager) GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
|
ret := m.ctrl.Call(m, "GetClusters", ctx, accountID, userID)
|
||||||
ret0, _ := ret[0].(*Service)
|
ret0, _ := ret[0].([]proxy.Cluster)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
|
// GetClusters indicates an expected call of GetClusters.
|
||||||
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
|
func (mr *MockManagerMockRecorder) GetClusters(ctx, accountID, userID interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusters", reflect.TypeOf((*MockManager)(nil).GetClusters), ctx, accountID, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetGlobalServices mocks base method.
|
// GetGlobalServices mocks base method.
|
||||||
@@ -197,6 +182,21 @@ func (mr *MockManagerMockRecorder) GetService(ctx, accountID, userID, serviceID
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetService", reflect.TypeOf((*MockManager)(nil).GetService), ctx, accountID, userID, serviceID)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetService", reflect.TypeOf((*MockManager)(nil).GetService), ctx, accountID, userID, serviceID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetServiceByDomain mocks base method.
|
||||||
|
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
|
||||||
|
ret0, _ := ret[0].(*Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
|
||||||
|
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
|
||||||
|
}
|
||||||
|
|
||||||
// GetServiceByID mocks base method.
|
// GetServiceByID mocks base method.
|
||||||
func (m *MockManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error) {
|
func (m *MockManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|||||||
@@ -187,7 +187,7 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
clusters, err := h.manager.GetActiveClusters(r.Context(), userAuth.AccountId, userAuth.UserId)
|
clusters, err := h.manager.GetClusters(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -196,10 +196,15 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
|
|||||||
apiClusters := make([]api.ProxyCluster, 0, len(clusters))
|
apiClusters := make([]api.ProxyCluster, 0, len(clusters))
|
||||||
for _, c := range clusters {
|
for _, c := range clusters {
|
||||||
apiClusters = append(apiClusters, api.ProxyCluster{
|
apiClusters = append(apiClusters, api.ProxyCluster{
|
||||||
Id: c.ID,
|
Id: c.ID,
|
||||||
Address: c.Address,
|
Address: c.Address,
|
||||||
ConnectedProxies: c.ConnectedProxies,
|
Type: api.ProxyClusterType(c.Type),
|
||||||
SelfHosted: c.SelfHosted,
|
Online: c.Online,
|
||||||
|
ConnectedProxies: c.ConnectedProxies,
|
||||||
|
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||||
|
RequireSubdomain: c.RequireSubdomain,
|
||||||
|
SupportsCrowdsec: c.SupportsCrowdSec,
|
||||||
|
Private: c.Private,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -81,6 +81,8 @@ type ClusterDeriver interface {
|
|||||||
type CapabilityProvider interface {
|
type CapabilityProvider interface {
|
||||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||||
|
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||||
|
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
@@ -112,8 +114,12 @@ func (m *Manager) StartExposeReaper(ctx context.Context) {
|
|||||||
m.exposeReaper.StartExposeReaper(ctx)
|
m.exposeReaper.StartExposeReaper(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
|
// GetClusters returns every proxy cluster visible to the account
|
||||||
func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
|
// (shared + its own BYOP), regardless of whether any proxy in the
|
||||||
|
// cluster is currently heartbeating. Each cluster is enriched with the
|
||||||
|
// capability flags reported by its active proxies so the dashboard can
|
||||||
|
// render feature support without a second round-trip.
|
||||||
|
func (m *Manager) GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
|
||||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.NewPermissionValidationError(err)
|
return nil, status.NewPermissionValidationError(err)
|
||||||
@@ -122,7 +128,19 @@ func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID strin
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.store.GetActiveProxyClusters(ctx, accountID)
|
clusters, err := m.store.GetProxyClusters(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range clusters {
|
||||||
|
clusters[i].SupportsCustomPorts = m.capabilities.ClusterSupportsCustomPorts(ctx, clusters[i].Address)
|
||||||
|
clusters[i].RequireSubdomain = m.capabilities.ClusterRequireSubdomain(ctx, clusters[i].Address)
|
||||||
|
clusters[i].SupportsCrowdSec = m.capabilities.ClusterSupportsCrowdSec(ctx, clusters[i].Address)
|
||||||
|
clusters[i].Private = m.capabilities.ClusterSupportsPrivate(ctx, clusters[i].Address)
|
||||||
|
}
|
||||||
|
|
||||||
|
return clusters, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteAccountCluster removes all proxy registrations for the given cluster address
|
// DeleteAccountCluster removes all proxy registrations for the given cluster address
|
||||||
@@ -192,6 +210,9 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
|
|||||||
target.Host = resource.Domain
|
target.Host = resource.Domain
|
||||||
case service.TargetTypeSubnet:
|
case service.TargetTypeSubnet:
|
||||||
// For subnets we do not do any lookups on the resource
|
// For subnets we do not do any lookups on the resource
|
||||||
|
case service.TargetTypeCluster:
|
||||||
|
// Cluster targets carry the upstream address on target_id; the
|
||||||
|
// proxy resolves the destination at request time.
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unknown target type: %s", target.TargetType)
|
return fmt.Errorf("unknown target type: %s", target.TargetType)
|
||||||
}
|
}
|
||||||
@@ -763,6 +784,10 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
|
|||||||
if err := validateResourceTarget(ctx, transaction, accountID, target); err != nil {
|
if err := validateResourceTarget(ctx, transaction, accountID, target); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
case service.TargetTypeCluster:
|
||||||
|
if err := validateClusterTarget(target); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId)
|
return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId)
|
||||||
}
|
}
|
||||||
@@ -770,6 +795,13 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateClusterTarget(target *service.Target) error {
|
||||||
|
if !target.Options.DirectUpstream {
|
||||||
|
return status.Errorf(status.InvalidArgument, "cluster target %s has direct upstream disabled", target.Host)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func validatePeerTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error {
|
func validatePeerTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error {
|
||||||
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||||
@@ -946,12 +978,14 @@ func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID str
|
|||||||
return fmt.Errorf("failed to get services: %w", err)
|
return fmt.Errorf("failed to get services: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
oidcCfg := m.proxyController.GetOIDCValidationConfig()
|
||||||
|
|
||||||
for _, s := range services {
|
for _, s := range services {
|
||||||
err = m.replaceHostByLookup(ctx, accountID, s)
|
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||||
}
|
}
|
||||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -1344,3 +1344,66 @@ func TestValidateSubdomainRequirement(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateTargetReferences_ClusterTargetSkipsLookup(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
accountID := "test-account"
|
||||||
|
|
||||||
|
// No peer or resource lookups must be issued for cluster targets.
|
||||||
|
targets := []*rpservice.Target{
|
||||||
|
{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: rpservice.TargetTypeCluster,
|
||||||
|
Options: rpservice.TargetOptions{DirectUpstream: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.NoError(t, validateTargetReferences(ctx, mockStore, accountID, targets), "cluster target must validate without store lookups")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValidateTargetReferences_ClusterTargetRequiresDirectUpstream pins the
|
||||||
|
// store-side check that cluster targets must opt into the host-stack dial
|
||||||
|
// path. Without DirectUpstream the proxy would route this target through
|
||||||
|
// the embedded NetBird client and fail on every request.
|
||||||
|
func TestValidateTargetReferences_ClusterTargetRequiresDirectUpstream(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
accountID := "test-account"
|
||||||
|
|
||||||
|
targets := []*rpservice.Target{
|
||||||
|
{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: rpservice.TargetTypeCluster,
|
||||||
|
Host: "backend.lan",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := validateTargetReferences(ctx, mockStore, accountID, targets)
|
||||||
|
require.Error(t, err, "cluster target without direct_upstream must be rejected")
|
||||||
|
assert.ErrorContains(t, err, "direct upstream disabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReplaceHostByLookup_SkipsClusterTarget(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
accountID := "test-account"
|
||||||
|
|
||||||
|
mgr := &Manager{store: mockStore}
|
||||||
|
|
||||||
|
svc := &rpservice.Service{
|
||||||
|
ID: "svc-1",
|
||||||
|
AccountID: accountID,
|
||||||
|
Targets: []*rpservice.Target{
|
||||||
|
{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: rpservice.TargetTypeCluster,
|
||||||
|
Host: "127.0.0.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, mgr.replaceHostByLookup(ctx, accountID, svc), "cluster target must not trigger peer/resource lookup")
|
||||||
|
assert.Equal(t, "127.0.0.1", svc.Targets[0].Host, "operator-supplied host must be preserved for cluster target")
|
||||||
|
}
|
||||||
|
|||||||
@@ -45,10 +45,11 @@ const (
|
|||||||
StatusCertificateFailed Status = "certificate_failed"
|
StatusCertificateFailed Status = "certificate_failed"
|
||||||
StatusError Status = "error"
|
StatusError Status = "error"
|
||||||
|
|
||||||
TargetTypePeer TargetType = "peer"
|
TargetTypePeer TargetType = "peer"
|
||||||
TargetTypeHost TargetType = "host"
|
TargetTypeHost TargetType = "host"
|
||||||
TargetTypeDomain TargetType = "domain"
|
TargetTypeDomain TargetType = "domain"
|
||||||
TargetTypeSubnet TargetType = "subnet"
|
TargetTypeSubnet TargetType = "subnet"
|
||||||
|
TargetTypeCluster TargetType = "cluster"
|
||||||
|
|
||||||
SourcePermanent = "permanent"
|
SourcePermanent = "permanent"
|
||||||
SourceEphemeral = "ephemeral"
|
SourceEphemeral = "ephemeral"
|
||||||
@@ -60,6 +61,11 @@ type TargetOptions struct {
|
|||||||
SessionIdleTimeout time.Duration `json:"session_idle_timeout,omitempty"`
|
SessionIdleTimeout time.Duration `json:"session_idle_timeout,omitempty"`
|
||||||
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
|
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
|
||||||
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
|
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
|
||||||
|
// DirectUpstream bypasses the proxy's embedded NetBird client and dials
|
||||||
|
// the target via the proxy host's network stack. Useful for upstreams
|
||||||
|
// reachable without WireGuard (public APIs, LAN services, localhost
|
||||||
|
// sidecars). Default false.
|
||||||
|
DirectUpstream bool `json:"direct_upstream,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Target struct {
|
type Target struct {
|
||||||
@@ -67,7 +73,7 @@ type Target struct {
|
|||||||
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
||||||
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
|
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
|
||||||
Path *string `json:"path,omitempty"`
|
Path *string `json:"path,omitempty"`
|
||||||
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
|
Host string `json:"host"`
|
||||||
Port uint16 `gorm:"index:idx_target_port" json:"port"`
|
Port uint16 `gorm:"index:idx_target_port" json:"port"`
|
||||||
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
|
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
|
||||||
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
|
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
|
||||||
@@ -200,6 +206,10 @@ type Service struct {
|
|||||||
Mode string `gorm:"default:'http'"`
|
Mode string `gorm:"default:'http'"`
|
||||||
ListenPort uint16
|
ListenPort uint16
|
||||||
PortAutoAssigned bool
|
PortAutoAssigned bool
|
||||||
|
// Private marks the service as NetBird-only: auth via ValidateTunnelPeer against AccessGroups instead of SSO. HTTP-only.
|
||||||
|
Private bool
|
||||||
|
// AccessGroups is the group ID allowlist for inbound peers on private services. Mutually exclusive with bearer SSO.
|
||||||
|
AccessGroups []string `json:"access_groups,omitempty" gorm:"serializer:json"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitNewRecord generates a new unique ID and resets metadata for a newly created
|
// InitNewRecord generates a new unique ID and resets metadata for a newly created
|
||||||
@@ -299,6 +309,12 @@ func (s *Service) ToAPIResponse() *api.Service {
|
|||||||
Mode: &mode,
|
Mode: &mode,
|
||||||
ListenPort: &listenPort,
|
ListenPort: &listenPort,
|
||||||
PortAutoAssigned: &s.PortAutoAssigned,
|
PortAutoAssigned: &s.PortAutoAssigned,
|
||||||
|
Private: &s.Private,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(s.AccessGroups) > 0 {
|
||||||
|
groups := append([]string(nil), s.AccessGroups...)
|
||||||
|
resp.AccessGroups = &groups
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.ProxyCluster != "" {
|
if s.ProxyCluster != "" {
|
||||||
@@ -308,6 +324,7 @@ func (s *Service) ToAPIResponse() *api.Service {
|
|||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ToProtoMapping converts the service into the wire format the proxy consumes.
|
||||||
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping {
|
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping {
|
||||||
pathMappings := s.buildPathMappings()
|
pathMappings := s.buildPathMappings()
|
||||||
|
|
||||||
@@ -349,6 +366,7 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
|
|||||||
RewriteRedirects: s.RewriteRedirects,
|
RewriteRedirects: s.RewriteRedirects,
|
||||||
Mode: s.Mode,
|
Mode: s.Mode,
|
||||||
ListenPort: int32(s.ListenPort), //nolint:gosec
|
ListenPort: int32(s.ListenPort), //nolint:gosec
|
||||||
|
Private: s.Private,
|
||||||
}
|
}
|
||||||
|
|
||||||
if r := restrictionsToProto(s.Restrictions); r != nil {
|
if r := restrictionsToProto(s.Restrictions); r != nil {
|
||||||
@@ -455,7 +473,8 @@ func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
|
func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
|
||||||
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 {
|
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 &&
|
||||||
|
opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 && !opts.DirectUpstream {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
apiOpts := &api.ServiceTargetOptions{}
|
apiOpts := &api.ServiceTargetOptions{}
|
||||||
@@ -477,17 +496,22 @@ func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
|
|||||||
if len(opts.CustomHeaders) > 0 {
|
if len(opts.CustomHeaders) > 0 {
|
||||||
apiOpts.CustomHeaders = &opts.CustomHeaders
|
apiOpts.CustomHeaders = &opts.CustomHeaders
|
||||||
}
|
}
|
||||||
|
if opts.DirectUpstream {
|
||||||
|
apiOpts.DirectUpstream = &opts.DirectUpstream
|
||||||
|
}
|
||||||
return apiOpts
|
return apiOpts
|
||||||
}
|
}
|
||||||
|
|
||||||
func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
|
func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
|
||||||
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && len(opts.CustomHeaders) == 0 {
|
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 &&
|
||||||
|
len(opts.CustomHeaders) == 0 && !opts.DirectUpstream {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
popts := &proto.PathTargetOptions{
|
popts := &proto.PathTargetOptions{
|
||||||
SkipTlsVerify: opts.SkipTLSVerify,
|
SkipTlsVerify: opts.SkipTLSVerify,
|
||||||
PathRewrite: pathRewriteToProto(opts.PathRewrite),
|
PathRewrite: pathRewriteToProto(opts.PathRewrite),
|
||||||
CustomHeaders: opts.CustomHeaders,
|
CustomHeaders: opts.CustomHeaders,
|
||||||
|
DirectUpstream: opts.DirectUpstream,
|
||||||
}
|
}
|
||||||
if opts.RequestTimeout != 0 {
|
if opts.RequestTimeout != 0 {
|
||||||
popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
|
popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
|
||||||
@@ -537,6 +561,9 @@ func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions,
|
|||||||
if o.CustomHeaders != nil {
|
if o.CustomHeaders != nil {
|
||||||
opts.CustomHeaders = *o.CustomHeaders
|
opts.CustomHeaders = *o.CustomHeaders
|
||||||
}
|
}
|
||||||
|
if o.DirectUpstream != nil {
|
||||||
|
opts.DirectUpstream = *o.DirectUpstream
|
||||||
|
}
|
||||||
return opts, nil
|
return opts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -551,6 +578,14 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro
|
|||||||
if req.ListenPort != nil {
|
if req.ListenPort != nil {
|
||||||
s.ListenPort = uint16(*req.ListenPort) //nolint:gosec
|
s.ListenPort = uint16(*req.ListenPort) //nolint:gosec
|
||||||
}
|
}
|
||||||
|
if req.Private != nil {
|
||||||
|
s.Private = *req.Private
|
||||||
|
}
|
||||||
|
if req.AccessGroups != nil {
|
||||||
|
s.AccessGroups = append([]string(nil), *req.AccessGroups...)
|
||||||
|
} else {
|
||||||
|
s.AccessGroups = nil
|
||||||
|
}
|
||||||
|
|
||||||
targets, err := targetsFromAPI(accountID, req.Targets)
|
targets, err := targetsFromAPI(accountID, req.Targets)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -740,6 +775,9 @@ func (s *Service) Validate() error {
|
|||||||
if err := validateAccessRestrictions(&s.Restrictions); err != nil {
|
if err := validateAccessRestrictions(&s.Restrictions); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if err := s.validatePrivateRequirements(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
switch s.Mode {
|
switch s.Mode {
|
||||||
case ModeHTTP:
|
case ModeHTTP:
|
||||||
@@ -753,6 +791,23 @@ func (s *Service) Validate() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validatePrivateRequirements enforces the private-service contract: HTTP mode, ≥1 access group, no bearer auth.
|
||||||
|
func (s *Service) validatePrivateRequirements() error {
|
||||||
|
if !s.Private {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if s.Mode != "" && s.Mode != ModeHTTP {
|
||||||
|
return fmt.Errorf("private services only support HTTP mode, got %q", s.Mode)
|
||||||
|
}
|
||||||
|
if len(s.AccessGroups) == 0 {
|
||||||
|
return errors.New("private services require at least one access group")
|
||||||
|
}
|
||||||
|
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
|
||||||
|
return errors.New("private services cannot enable bearer auth (SSO): NetBird-only access and SSO are mutually exclusive")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Service) validateHTTPMode() error {
|
func (s *Service) validateHTTPMode() error {
|
||||||
if s.Domain == "" {
|
if s.Domain == "" {
|
||||||
return errors.New("service domain is required")
|
return errors.New("service domain is required")
|
||||||
@@ -799,11 +854,21 @@ func (s *Service) validateHTTPTargets() error {
|
|||||||
for i, target := range s.Targets {
|
for i, target := range s.Targets {
|
||||||
switch target.TargetType {
|
switch target.TargetType {
|
||||||
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
||||||
// host field will be ignored
|
// Host is normally overwritten by replaceHostByLookup with the
|
||||||
|
// resolved peer IP / resource address; operator-supplied values
|
||||||
|
// are honored only when DirectUpstream is set. Validate the
|
||||||
|
// override here so misconfigured hosts fail fast at API time.
|
||||||
|
if err := validateDirectUpstreamHost(i, target); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
case TargetTypeSubnet:
|
case TargetTypeSubnet:
|
||||||
if target.Host == "" {
|
if target.Host == "" {
|
||||||
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
|
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
|
||||||
}
|
}
|
||||||
|
case TargetTypeCluster:
|
||||||
|
if err := validateClusterTarget(i, target); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
|
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
|
||||||
}
|
}
|
||||||
@@ -821,25 +886,67 @@ func (s *Service) validateHTTPTargets() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateClusterTarget cluster targets should not have empty hosts and should have direct upstream enabled.
|
||||||
|
func validateClusterTarget(idx int, target *Target) error {
|
||||||
|
host := strings.TrimSpace(target.Host)
|
||||||
|
if host == "" {
|
||||||
|
return fmt.Errorf("target %d: has empty host", idx)
|
||||||
|
}
|
||||||
|
if !target.Options.DirectUpstream {
|
||||||
|
return fmt.Errorf("target %d: %s has direct upstream disabled", idx, target.Host)
|
||||||
|
}
|
||||||
|
return validateDirectUpstreamHost(idx, target)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateDirectUpstreamHost validates the operator-supplied Host on a
|
||||||
|
// peer/host/domain target when DirectUpstream is set. Empty Host is
|
||||||
|
// allowed — the lookup fills in the default peer IP / resource address.
|
||||||
|
// Without DirectUpstream the Host value is silently overwritten by
|
||||||
|
// replaceHostByLookup, so we don't validate it (preserves the historical
|
||||||
|
// behaviour where APIs accepted any value and dropped it). Non-empty
|
||||||
|
// Host with DirectUpstream must look like a hostname or IP and must
|
||||||
|
// not carry a port (port lives on Target.Port).
|
||||||
|
func validateDirectUpstreamHost(idx int, target *Target) error {
|
||||||
|
if !target.Options.DirectUpstream {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
host := strings.TrimSpace(target.Host)
|
||||||
|
if host == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if strings.ContainsAny(host, " \t/") {
|
||||||
|
return fmt.Errorf("target %d: host %q contains invalid characters", idx, host)
|
||||||
|
}
|
||||||
|
if _, _, err := net.SplitHostPort(host); err == nil {
|
||||||
|
return fmt.Errorf("target %d: host %q must not include a port (set target.port instead)", idx, host)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Service) validateL4Target(target *Target) error {
|
func (s *Service) validateL4Target(target *Target) error {
|
||||||
// L4 services have a single target; per-target disable is meaningless
|
// L4 services have a single target; per-target disable is meaningless
|
||||||
// (use the service-level Enabled flag instead). Force it on so that
|
// (use the service-level Enabled flag instead). Force it on so that
|
||||||
// buildPathMappings always includes the target in the proto.
|
// buildPathMappings always includes the target in the proto.
|
||||||
target.Enabled = true
|
target.Enabled = true
|
||||||
|
|
||||||
if target.Port == 0 {
|
|
||||||
return errors.New("target port is required for L4 services")
|
|
||||||
}
|
|
||||||
if target.TargetId == "" {
|
if target.TargetId == "" {
|
||||||
return errors.New("target_id is required for L4 services")
|
return errors.New("target_id is required for L4 services")
|
||||||
}
|
}
|
||||||
|
if target.TargetType != TargetTypeCluster && target.Port == 0 {
|
||||||
|
return errors.New("target port is required for L4 services")
|
||||||
|
}
|
||||||
switch target.TargetType {
|
switch target.TargetType {
|
||||||
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
||||||
// OK
|
if err := validateDirectUpstreamHost(0, target); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
case TargetTypeSubnet:
|
case TargetTypeSubnet:
|
||||||
if target.Host == "" {
|
if target.Host == "" {
|
||||||
return errors.New("target host is required for subnet targets")
|
return errors.New("target host is required for subnet targets")
|
||||||
}
|
}
|
||||||
|
case TargetTypeCluster:
|
||||||
|
// target_id carries the cluster address; the proxy resolves
|
||||||
|
// the upstream at request time.
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("invalid target_type %q for L4 service", target.TargetType)
|
return fmt.Errorf("invalid target_type %q for L4 service", target.TargetType)
|
||||||
}
|
}
|
||||||
@@ -1174,6 +1281,11 @@ func (s *Service) Copy() *Service {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var accessGroups []string
|
||||||
|
if len(s.AccessGroups) > 0 {
|
||||||
|
accessGroups = append([]string(nil), s.AccessGroups...)
|
||||||
|
}
|
||||||
|
|
||||||
return &Service{
|
return &Service{
|
||||||
ID: s.ID,
|
ID: s.ID,
|
||||||
AccountID: s.AccountID,
|
AccountID: s.AccountID,
|
||||||
@@ -1195,6 +1307,8 @@ func (s *Service) Copy() *Service {
|
|||||||
Mode: s.Mode,
|
Mode: s.Mode,
|
||||||
ListenPort: s.ListenPort,
|
ListenPort: s.ListenPort,
|
||||||
PortAutoAssigned: s.PortAutoAssigned,
|
PortAutoAssigned: s.PortAutoAssigned,
|
||||||
|
Private: s.Private,
|
||||||
|
AccessGroups: accessGroups,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||||
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1116,3 +1117,191 @@ func TestValidate_HeaderAuths(t *testing.T) {
|
|||||||
assert.Contains(t, err.Error(), "exceeds maximum length")
|
assert.Contains(t, err.Error(), "exceeds maximum length")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidate_HTTPClusterTarget(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets = []*Target{{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: TargetTypeCluster,
|
||||||
|
Protocol: "http",
|
||||||
|
Host: "backend.lan",
|
||||||
|
Options: TargetOptions{DirectUpstream: true},
|
||||||
|
Enabled: true,
|
||||||
|
}}
|
||||||
|
require.NoError(t, rp.Validate(), "HTTP cluster target with target_id, host, and direct_upstream must validate")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_HTTPClusterTarget_RequiresTargetId(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets = []*Target{{
|
||||||
|
TargetType: TargetTypeCluster,
|
||||||
|
Protocol: "http",
|
||||||
|
Host: "backend.lan",
|
||||||
|
Options: TargetOptions{DirectUpstream: true},
|
||||||
|
Enabled: true,
|
||||||
|
}}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "empty target_id", "cluster target must reject empty target_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValidate_HTTPClusterTarget_RequiresHost pins the new cluster-target
|
||||||
|
// rule that operator-supplied Host is mandatory: cluster targets dial the
|
||||||
|
// upstream via the host network stack (direct_upstream is implied), so an
|
||||||
|
// empty Host leaves the proxy with nothing to dial.
|
||||||
|
func TestValidate_HTTPClusterTarget_RequiresHost(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets = []*Target{{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: TargetTypeCluster,
|
||||||
|
Protocol: "http",
|
||||||
|
Options: TargetOptions{DirectUpstream: true},
|
||||||
|
Enabled: true,
|
||||||
|
}}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "empty host", "cluster target must reject empty host")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValidate_HTTPClusterTarget_RequiresDirectUpstream pins the second
|
||||||
|
// half of the cluster-target rule: DirectUpstream must be true so the
|
||||||
|
// stdlib transport branch in MultiTransport is taken. Without it the
|
||||||
|
// embedded NetBird client would try to dial the cluster address through
|
||||||
|
// the WG tunnel, which is the wrong network for a cluster upstream.
|
||||||
|
func TestValidate_HTTPClusterTarget_RequiresDirectUpstream(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets = []*Target{{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: TargetTypeCluster,
|
||||||
|
Protocol: "http",
|
||||||
|
Host: "backend.lan",
|
||||||
|
Enabled: true,
|
||||||
|
}}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "direct upstream disabled", "cluster target must reject direct_upstream=false")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_L4ClusterTarget(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Mode = ModeTCP
|
||||||
|
rp.ListenPort = 9000
|
||||||
|
rp.Targets = []*Target{{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: TargetTypeCluster,
|
||||||
|
Protocol: "tcp",
|
||||||
|
Enabled: true,
|
||||||
|
}}
|
||||||
|
require.NoError(t, rp.Validate(), "L4 cluster target must validate without an explicit port")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestService_Copy_RoundtripsPrivate(t *testing.T) {
|
||||||
|
svc := validProxy()
|
||||||
|
svc.Private = true
|
||||||
|
svc.AccessGroups = []string{"grp-admins", "grp-ops"}
|
||||||
|
cp := svc.Copy()
|
||||||
|
require.NotNil(t, cp)
|
||||||
|
assert.True(t, cp.Private)
|
||||||
|
assert.Equal(t, []string{"grp-admins", "grp-ops"}, cp.AccessGroups)
|
||||||
|
|
||||||
|
cp.Private = false
|
||||||
|
assert.True(t, svc.Private)
|
||||||
|
|
||||||
|
cp.AccessGroups[0] = "grp-other"
|
||||||
|
assert.Equal(t, []string{"grp-admins", "grp-ops"}, svc.AccessGroups)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestService_APIRoundtrip_Private(t *testing.T) {
|
||||||
|
enabled := true
|
||||||
|
private := true
|
||||||
|
accessGroups := []string{"grp-admins"}
|
||||||
|
targets := []api.ServiceTarget{{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: api.ServiceTargetTargetType("cluster"),
|
||||||
|
Protocol: "http",
|
||||||
|
Port: 80,
|
||||||
|
Enabled: true,
|
||||||
|
}}
|
||||||
|
req := &api.ServiceRequest{
|
||||||
|
Name: "svc-private",
|
||||||
|
Domain: "myapp.eu.proxy.netbird.io",
|
||||||
|
Enabled: enabled,
|
||||||
|
Private: &private,
|
||||||
|
AccessGroups: &accessGroups,
|
||||||
|
Targets: &targets,
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &Service{}
|
||||||
|
require.NoError(t, svc.FromAPIRequest(req, "acc-1"))
|
||||||
|
assert.True(t, svc.Private)
|
||||||
|
assert.Equal(t, []string{"grp-admins"}, svc.AccessGroups)
|
||||||
|
|
||||||
|
resp := svc.ToAPIResponse()
|
||||||
|
require.NotNil(t, resp.Private)
|
||||||
|
assert.True(t, *resp.Private)
|
||||||
|
require.NotNil(t, resp.AccessGroups)
|
||||||
|
assert.Equal(t, []string{"grp-admins"}, *resp.AccessGroups)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_Private_RequiresAccessGroups(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Private = true
|
||||||
|
rp.Targets = []*Target{{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: TargetTypeCluster,
|
||||||
|
Protocol: "http",
|
||||||
|
Host: "backend.lan",
|
||||||
|
Options: TargetOptions{DirectUpstream: true},
|
||||||
|
Enabled: true,
|
||||||
|
}}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "access group")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_Private_RejectsBearerAuth(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Private = true
|
||||||
|
rp.AccessGroups = []string{"grp-admins"}
|
||||||
|
rp.Auth.BearerAuth = &BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DistributionGroups: []string{"grp-sso"},
|
||||||
|
}
|
||||||
|
rp.Targets = []*Target{{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: TargetTypeCluster,
|
||||||
|
Protocol: "http",
|
||||||
|
Host: "backend.lan",
|
||||||
|
Options: TargetOptions{DirectUpstream: true},
|
||||||
|
Enabled: true,
|
||||||
|
}}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "mutually exclusive")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_Private_AcceptsNonClusterTargets(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Private = true
|
||||||
|
rp.AccessGroups = []string{"grp-admins"}
|
||||||
|
require.NoError(t, rp.Validate())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_Private_AcceptsClusterTargetWithAccessGroups(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Private = true
|
||||||
|
rp.AccessGroups = []string{"grp-admins"}
|
||||||
|
rp.Targets = []*Target{{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: TargetTypeCluster,
|
||||||
|
Protocol: "http",
|
||||||
|
Host: "backend.lan",
|
||||||
|
Options: TargetOptions{DirectUpstream: true},
|
||||||
|
Enabled: true,
|
||||||
|
}}
|
||||||
|
require.NoError(t, rp.Validate())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_Private_RejectsNonHTTPMode(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Private = true
|
||||||
|
rp.AccessGroups = []string{"grp-admins"}
|
||||||
|
rp.Mode = ModeTCP
|
||||||
|
rp.Targets = []*Target{{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: TargetTypeCluster,
|
||||||
|
Protocol: "tcp",
|
||||||
|
Enabled: true,
|
||||||
|
}}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "HTTP")
|
||||||
|
}
|
||||||
|
|||||||
@@ -20,6 +20,20 @@ type KeyPair struct {
|
|||||||
type Claims struct {
|
type Claims struct {
|
||||||
jwt.RegisteredClaims
|
jwt.RegisteredClaims
|
||||||
Method auth.Method `json:"method"`
|
Method auth.Method `json:"method"`
|
||||||
|
// Email is the calling user's email address. Carried so the
|
||||||
|
// proxy can stamp identity on upstream requests (e.g.
|
||||||
|
// x-litellm-end-user-id) without an extra management
|
||||||
|
// round-trip on every cookie-bearing request.
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
// Groups carries the user's group IDs so the proxy can stamp them
|
||||||
|
// onto upstream requests (X-NetBird-Groups) from the cookie path
|
||||||
|
// without an extra management round-trip.
|
||||||
|
Groups []string `json:"groups,omitempty"`
|
||||||
|
// GroupNames carries the human-readable display names for the ids
|
||||||
|
// in Groups, ordered identically (positional pairing). Slice may be
|
||||||
|
// shorter than Groups for tokens minted before names were
|
||||||
|
// resolvable; the consumer falls back to ids for missing positions.
|
||||||
|
GroupNames []string `json:"group_names,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateKeyPair() (*KeyPair, error) {
|
func GenerateKeyPair() (*KeyPair, error) {
|
||||||
@@ -34,7 +48,13 @@ func GenerateKeyPair() (*KeyPair, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration time.Duration) (string, error) {
|
// SignToken mints a session JWT for the given user and domain. email,
|
||||||
|
// groups, and groupNames, when non-empty, are embedded so the proxy can
|
||||||
|
// authorise and stamp identity for policy-aware middlewares without a
|
||||||
|
// management round-trip on every cookie-bearing request. groupNames
|
||||||
|
// pairs positionally with groups; pass nil when names couldn't be
|
||||||
|
// resolved.
|
||||||
|
func SignToken(privKeyB64, userID, email, domain string, method auth.Method, groups, groupNames []string, expiration time.Duration) (string, error) {
|
||||||
privKeyBytes, err := base64.StdEncoding.DecodeString(privKeyB64)
|
privKeyBytes, err := base64.StdEncoding.DecodeString(privKeyB64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("decode private key: %w", err)
|
return "", fmt.Errorf("decode private key: %w", err)
|
||||||
@@ -56,7 +76,10 @@ func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration
|
|||||||
IssuedAt: jwt.NewNumericDate(now),
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
NotBefore: jwt.NewNumericDate(now),
|
NotBefore: jwt.NewNumericDate(now),
|
||||||
},
|
},
|
||||||
Method: method,
|
Method: method,
|
||||||
|
Email: email,
|
||||||
|
Groups: append([]string(nil), groups...),
|
||||||
|
GroupNames: append([]string(nil), groupNames...),
|
||||||
}
|
}
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
|
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
|
||||||
|
|||||||
@@ -10,8 +10,10 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
|
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
|
||||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||||
|
"github.com/rs/cors"
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
@@ -19,7 +21,6 @@ import (
|
|||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
|
|
||||||
cachestore "github.com/eko/gocache/lib/v4/store"
|
cachestore "github.com/eko/gocache/lib/v4/store"
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
"github.com/netbirdio/netbird/formatter/hook"
|
"github.com/netbirdio/netbird/formatter/hook"
|
||||||
@@ -27,16 +28,20 @@ import (
|
|||||||
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
activitystore "github.com/netbirdio/netbird/management/server/activity/store"
|
||||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/util/crypt"
|
"github.com/netbirdio/netbird/util/crypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const apiPrefix = "/api"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
kaep = keepalive.EnforcementPolicy{
|
kaep = keepalive.EnforcementPolicy{
|
||||||
MinTime: 15 * time.Second,
|
MinTime: 15 * time.Second,
|
||||||
@@ -94,12 +99,17 @@ func (s *BaseServer) Store() store.Store {
|
|||||||
|
|
||||||
func (s *BaseServer) EventStore() activity.Store {
|
func (s *BaseServer) EventStore() activity.Store {
|
||||||
return Create(s, func() activity.Store {
|
return Create(s, func() activity.Store {
|
||||||
integrationMetrics, err := integrations.InitIntegrationMetrics(context.Background(), s.Metrics())
|
var err error
|
||||||
if err != nil {
|
key := s.Config.DataStoreEncryptionKey
|
||||||
log.Fatalf("failed to initialize integration metrics: %v", err)
|
if key == "" {
|
||||||
|
log.Debugf("generate new activity store encryption key")
|
||||||
|
key, err = crypt.GenerateKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to generate event store encryption key: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
eventStore, _, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics)
|
eventStore, err := activitystore.NewSqlStore(context.Background(), s.Config.Datadir, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to initialize event store: %v", err)
|
log.Fatalf("failed to initialize event store: %v", err)
|
||||||
}
|
}
|
||||||
@@ -110,7 +120,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
|||||||
|
|
||||||
func (s *BaseServer) APIHandler() http.Handler {
|
func (s *BaseServer) APIHandler() http.Handler {
|
||||||
return Create(s, func() http.Handler {
|
return Create(s, func() http.Handler {
|
||||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter())
|
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter(), s.IsValidChildAccount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create API handler: %v", err)
|
log.Fatalf("failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
@@ -118,6 +128,22 @@ func (s *BaseServer) APIHandler() http.Handler {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IDPHandler returns the HTTP handler for the embedded IdP (Dex), or nil if
|
||||||
|
// the deployment isn't using the embedded variant.
|
||||||
|
func (s *BaseServer) IDPHandler() http.Handler {
|
||||||
|
embeddedIdP, ok := s.IdpManager().(*idp.EmbeddedIdPManager)
|
||||||
|
if !ok || embeddedIdP == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return cors.AllowAll().Handler(embeddedIdP.Handler())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) Router() *mux.Router {
|
||||||
|
return Create(s, func() *mux.Router {
|
||||||
|
return mux.NewRouter().PathPrefix(apiPrefix).Subrouter()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
|
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
|
||||||
return Create(s, func() *middleware.APIRateLimiter {
|
return Create(s, func() *middleware.APIRateLimiter {
|
||||||
cfg, enabled := middleware.RateLimiterConfigFromEnv()
|
cfg, enabled := middleware.RateLimiterConfigFromEnv()
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/auth"
|
"github.com/netbirdio/netbird/management/server/auth"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/job"
|
"github.com/netbirdio/netbird/management/server/job"
|
||||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||||
@@ -38,7 +39,7 @@ func (s *BaseServer) JobManager() *job.Manager {
|
|||||||
|
|
||||||
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
|
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
|
||||||
return Create(s, func() integrated_validator.IntegratedValidator {
|
return Create(s, func() integrated_validator.IntegratedValidator {
|
||||||
integratedPeerValidator, err := integrations.NewIntegratedValidator(
|
integratedPeerValidator, err := validator.NewIntegratedValidator(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
s.PeersManager(),
|
s.PeersManager(),
|
||||||
s.SettingsManager(),
|
s.SettingsManager(),
|
||||||
@@ -112,7 +113,11 @@ func (s *BaseServer) AuthManager() auth.Manager {
|
|||||||
|
|
||||||
func (s *BaseServer) EphemeralManager() ephemeral.Manager {
|
func (s *BaseServer) EphemeralManager() ephemeral.Manager {
|
||||||
return Create(s, func() ephemeral.Manager {
|
return Create(s, func() ephemeral.Manager {
|
||||||
return manager.NewEphemeralManager(s.Store(), s.PeersManager())
|
em := manager.NewEphemeralManager(s.Store(), s.PeersManager())
|
||||||
|
if metrics := s.Metrics(); metrics != nil {
|
||||||
|
em.SetMetrics(metrics.EphemeralPeersMetrics())
|
||||||
|
}
|
||||||
|
return em
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -57,13 +57,7 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
|
|||||||
|
|
||||||
func (s *BaseServer) PermissionsManager() permissions.Manager {
|
func (s *BaseServer) PermissionsManager() permissions.Manager {
|
||||||
return Create(s, func() permissions.Manager {
|
return Create(s, func() permissions.Manager {
|
||||||
manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter())
|
return permissions.NewManager(s.Store())
|
||||||
|
|
||||||
s.AfterInit(func(s *BaseServer) {
|
|
||||||
manager.SetAccountManager(s.AccountManager())
|
|
||||||
})
|
|
||||||
|
|
||||||
return manager
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -153,7 +147,6 @@ func (s *BaseServer) IdpManager() idp.Manager {
|
|||||||
return idpManager
|
return idpManager
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -235,3 +228,7 @@ func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
|||||||
return &m
|
return &m
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) IsValidChildAccount(_ context.Context, _, _, _ string) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -188,7 +188,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
|||||||
log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
|
log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter())
|
rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.IDPHandler(), s.Metrics().GetMeter())
|
||||||
switch {
|
switch {
|
||||||
case s.certManager != nil:
|
case s.certManager != nil:
|
||||||
// a call to certManager.Listener() always creates a new listener so we do it once
|
// a call to certManager.Listener() always creates a new listener so we do it once
|
||||||
@@ -299,7 +299,7 @@ func (s *BaseServer) SetHandlerFunc(handler http.Handler) {
|
|||||||
log.Tracef("custom handler set successfully")
|
log.Tracef("custom handler set successfully")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
|
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, meter metric.Meter) http.Handler {
|
||||||
// Check if a custom handler was set (for multiplexing additional services)
|
// Check if a custom handler was set (for multiplexing additional services)
|
||||||
if customHandler, ok := s.GetContainer("customHandler"); ok {
|
if customHandler, ok := s.GetContainer("customHandler"); ok {
|
||||||
if handler, ok := customHandler.(http.Handler); ok {
|
if handler, ok := customHandler.(http.Handler); ok {
|
||||||
@@ -318,6 +318,8 @@ func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, ht
|
|||||||
gRPCHandler.ServeHTTP(writer, request)
|
gRPCHandler.ServeHTTP(writer, request)
|
||||||
case request.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent:
|
case request.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent:
|
||||||
wsProxy.Handler().ServeHTTP(writer, request)
|
wsProxy.Handler().ServeHTTP(writer, request)
|
||||||
|
case idpHandler != nil && strings.HasPrefix(request.URL.Path, "/oauth2"):
|
||||||
|
idpHandler.ServeHTTP(writer, request)
|
||||||
default:
|
default:
|
||||||
httpHandler.ServeHTTP(writer, request)
|
httpHandler.ServeHTTP(writer, request)
|
||||||
}
|
}
|
||||||
|
|||||||
813
management/internals/shared/grpc/components_encoder.go
Normal file
813
management/internals/shared/grpc/components_encoder.go
Normal file
@@ -0,0 +1,813 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
nbroute "github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// wgKeyRawLen is the raw byte length of a WireGuard public key.
|
||||||
|
const wgKeyRawLen = 32
|
||||||
|
|
||||||
|
// ComponentsEnvelopeInput bundles the data the component-format encoder needs.
|
||||||
|
// The envelope is fully self-contained — every field needed by the client's
|
||||||
|
// local Calculate() comes from the components struct itself. The only
|
||||||
|
// externally-supplied data is the receiving peer's PeerConfig (which is
|
||||||
|
// computed alongside the components in the network_map controller and reused
|
||||||
|
// from the legacy proto path) and the dns_domain string.
|
||||||
|
type ComponentsEnvelopeInput struct {
|
||||||
|
Components *types.NetworkMapComponents
|
||||||
|
PeerConfig *proto.PeerConfig
|
||||||
|
DNSDomain string
|
||||||
|
DNSForwarderPort int64
|
||||||
|
// UserIDClaim is the OIDC claim name the client should embed in
|
||||||
|
// SshAuth.UserIDClaim when reconstructing the NetworkMap. Empty value
|
||||||
|
// is OK — client treats empty as "no SshAuth to build".
|
||||||
|
UserIDClaim string
|
||||||
|
// ProxyPatch carries pre-expanded NetworkMap fragments injected by
|
||||||
|
// external controllers (BYOP/port-forwarding). Nil when no proxy data
|
||||||
|
// is present; encoder skips the field in that case.
|
||||||
|
ProxyPatch *proto.ProxyPatch
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeNetworkMapEnvelope converts NetworkMapComponents into the component
|
||||||
|
// wire envelope. The encoder is intentionally non-deterministic: it iterates
|
||||||
|
// Go maps in their native (random) order. Indexes inside the envelope
|
||||||
|
// (peer_indexes, source_group_ids, agent_version_idx, router_peer_indexes)
|
||||||
|
// are self-consistent within a single encode, so the decoder reconstructs
|
||||||
|
// the same typed objects regardless of emit order. Tests that need to
|
||||||
|
// compare envelopes do so semantically via proto round-trip + canonicalize,
|
||||||
|
// not byte-equal.
|
||||||
|
//
|
||||||
|
// Callers must NOT concatenate or merge envelopes from different encodes —
|
||||||
|
// index spaces are local to a single envelope.
|
||||||
|
func EncodeNetworkMapEnvelope(in ComponentsEnvelopeInput) *proto.NetworkMapEnvelope {
|
||||||
|
c := in.Components
|
||||||
|
|
||||||
|
// Graceful degrade when components is nil — matches the legacy path's
|
||||||
|
// behaviour for missing/unvalidated peers (return a NetworkMap with only
|
||||||
|
// Network populated). The receiver gets an envelope it can decode
|
||||||
|
// without crashing; AccountSettings stays non-nil so client-side
|
||||||
|
// dereferences are safe.
|
||||||
|
if c == nil {
|
||||||
|
// Match legacy missing-peer minimum: a NetworkMap with only Network
|
||||||
|
// populated. The receiver gets enough to bootstrap (Network
|
||||||
|
// identifier, dns_domain, account_settings) and nothing else.
|
||||||
|
return &proto.NetworkMapEnvelope{
|
||||||
|
Payload: &proto.NetworkMapEnvelope_Full{
|
||||||
|
Full: &proto.NetworkMapComponentsFull{
|
||||||
|
PeerConfig: in.PeerConfig,
|
||||||
|
DnsDomain: in.DNSDomain,
|
||||||
|
DnsForwarderPort: in.DNSForwarderPort,
|
||||||
|
UserIdClaim: in.UserIDClaim,
|
||||||
|
AccountSettings: &proto.AccountSettingsCompact{},
|
||||||
|
ProxyPatch: in.ProxyPatch,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 1: build dedup tables. Every routing peer (in c.RouterPeers) and
|
||||||
|
// every regular peer (in c.Peers) must be indexed before any encoder
|
||||||
|
// looks up indexes via e.peerOrder — otherwise routes / routers_map for
|
||||||
|
// peers that exist only in c.RouterPeers would silently lose their
|
||||||
|
// peer_index reference.
|
||||||
|
enc := newComponentEncoder(c)
|
||||||
|
enc.indexAllPeers()
|
||||||
|
routerIdxs := enc.indexRouterPeers(c.RouterPeers)
|
||||||
|
|
||||||
|
// Phase 2: gather every policy that any consumer references (peer-pair
|
||||||
|
// policies + resource-only policies) so encodeResourcePoliciesMap can
|
||||||
|
// translate every *Policy pointer to a wire index.
|
||||||
|
allPolicies := unionPolicies(c.Policies, c.ResourcePoliciesMap)
|
||||||
|
policies, policyToIdxs := enc.encodePolicies(allPolicies)
|
||||||
|
|
||||||
|
// Phase 3: emit. Order of struct field expressions no longer matters:
|
||||||
|
// every encoder either reads from the dedup tables or works on
|
||||||
|
// independent input.
|
||||||
|
full := &proto.NetworkMapComponentsFull{
|
||||||
|
Serial: networkSerial(c.Network),
|
||||||
|
PeerConfig: in.PeerConfig,
|
||||||
|
Network: toAccountNetwork(c.Network),
|
||||||
|
AccountSettings: toAccountSettingsCompact(c.AccountSettings),
|
||||||
|
DnsForwarderPort: in.DNSForwarderPort,
|
||||||
|
UserIdClaim: in.UserIDClaim,
|
||||||
|
ProxyPatch: in.ProxyPatch,
|
||||||
|
DnsSettings: enc.encodeDNSSettings(c.DNSSettings),
|
||||||
|
DnsDomain: in.DNSDomain,
|
||||||
|
CustomZoneDomain: c.CustomZoneDomain,
|
||||||
|
AgentVersions: enc.agentVersions,
|
||||||
|
Peers: enc.peers,
|
||||||
|
RouterPeerIndexes: routerIdxs,
|
||||||
|
Policies: policies,
|
||||||
|
Groups: enc.encodeGroups(),
|
||||||
|
Routes: enc.encodeRoutes(c.Routes),
|
||||||
|
NameserverGroups: enc.encodeNameServerGroups(c.NameServerGroups),
|
||||||
|
AllDnsRecords: encodeSimpleRecords(c.AllDNSRecords),
|
||||||
|
AccountZones: encodeCustomZones(c.AccountZones),
|
||||||
|
NetworkResources: enc.encodeNetworkResources(c.NetworkResources),
|
||||||
|
RoutersMap: enc.encodeRoutersMap(c.RoutersMap),
|
||||||
|
ResourcePoliciesMap: enc.encodeResourcePoliciesMap(c.ResourcePoliciesMap, policyToIdxs),
|
||||||
|
GroupIdToUserIds: enc.encodeGroupIDToUserIDs(c.GroupIDToUserIDs),
|
||||||
|
AllowedUserIds: stringSetToSlice(c.AllowedUserIDs),
|
||||||
|
PostureFailedPeers: enc.encodePostureFailedPeers(c.PostureFailedPeers),
|
||||||
|
}
|
||||||
|
|
||||||
|
return &proto.NetworkMapEnvelope{
|
||||||
|
Payload: &proto.NetworkMapEnvelope_Full{Full: full},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// networkSerial returns c.Network.CurrentSerial() with a nil guard. The
|
||||||
|
// production path always populates c.Network, but the encoder is exported
|
||||||
|
// and a hand-built components struct may omit it.
|
||||||
|
func networkSerial(n *types.Network) uint64 {
|
||||||
|
if n == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return n.CurrentSerial()
|
||||||
|
}
|
||||||
|
|
||||||
|
type componentEncoder struct {
|
||||||
|
components *types.NetworkMapComponents
|
||||||
|
|
||||||
|
peerOrder map[string]uint32
|
||||||
|
peers []*proto.PeerCompact
|
||||||
|
|
||||||
|
agentVersionOrder map[string]uint32
|
||||||
|
agentVersions []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newComponentEncoder(c *types.NetworkMapComponents) *componentEncoder {
|
||||||
|
return &componentEncoder{
|
||||||
|
components: c,
|
||||||
|
peerOrder: make(map[string]uint32, len(c.Peers)),
|
||||||
|
peers: make([]*proto.PeerCompact, 0, len(c.Peers)),
|
||||||
|
agentVersionOrder: make(map[string]uint32),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) indexAllPeers() {
|
||||||
|
for _, p := range e.components.Peers {
|
||||||
|
if p == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
e.appendPeer(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) appendPeer(p *nbpeer.Peer) uint32 {
|
||||||
|
if idx, ok := e.peerOrder[p.ID]; ok {
|
||||||
|
return idx
|
||||||
|
}
|
||||||
|
idx := uint32(len(e.peers))
|
||||||
|
e.peerOrder[p.ID] = idx
|
||||||
|
e.peers = append(e.peers, toPeerCompact(p, e.agentVersionIndex(p.Meta.WtVersion)))
|
||||||
|
return idx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) agentVersionIndex(v string) uint32 {
|
||||||
|
if idx, ok := e.agentVersionOrder[v]; ok {
|
||||||
|
return idx
|
||||||
|
}
|
||||||
|
// Lazy-initialise the table with "" at index 0 so the empty string
|
||||||
|
// stays interchangeable with proto3's default uint32=0 — peers without
|
||||||
|
// a WtVersion don't force the table to materialise.
|
||||||
|
if v == "" {
|
||||||
|
idx := uint32(len(e.agentVersions))
|
||||||
|
if idx == 0 {
|
||||||
|
e.agentVersions = append(e.agentVersions, "")
|
||||||
|
}
|
||||||
|
e.agentVersionOrder[""] = idx
|
||||||
|
return idx
|
||||||
|
}
|
||||||
|
if len(e.agentVersions) == 0 {
|
||||||
|
e.agentVersions = append(e.agentVersions, "")
|
||||||
|
e.agentVersionOrder[""] = 0
|
||||||
|
}
|
||||||
|
idx := uint32(len(e.agentVersions))
|
||||||
|
e.agentVersionOrder[v] = idx
|
||||||
|
e.agentVersions = append(e.agentVersions, v)
|
||||||
|
return idx
|
||||||
|
}
|
||||||
|
|
||||||
|
// indexRouterPeers ensures every router peer is in the peer dedup table
|
||||||
|
// (c.RouterPeers may contain peers not in c.Peers when validation rules drop
|
||||||
|
// them) and returns their wire indexes for the RouterPeerIndexes field. Must
|
||||||
|
// run before any encoder that resolves peer ids via e.peerOrder.
|
||||||
|
func (e *componentEncoder) indexRouterPeers(routers map[string]*nbpeer.Peer) []uint32 {
|
||||||
|
if len(routers) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]uint32, 0, len(routers))
|
||||||
|
for _, p := range routers {
|
||||||
|
if p == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, e.appendPeer(p))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) encodeGroups() []*proto.GroupCompact {
|
||||||
|
if len(e.components.Groups) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]*proto.GroupCompact, 0, len(e.components.Groups))
|
||||||
|
for _, g := range e.components.Groups {
|
||||||
|
if !g.HasSeqID() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
peerIdxs := make([]uint32, 0, len(g.Peers))
|
||||||
|
for _, peerID := range g.Peers {
|
||||||
|
if idx, ok := e.peerOrder[peerID]; ok {
|
||||||
|
peerIdxs = append(peerIdxs, idx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out = append(out, &proto.GroupCompact{
|
||||||
|
Id: g.AccountSeqID,
|
||||||
|
Name: g.Name,
|
||||||
|
PeerIndexes: peerIdxs,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodePolicies flattens Policy{Rules} → []PolicyCompact. Returns the wire
|
||||||
|
// list and a map from policy pointer to the indexes of its emitted rules in
|
||||||
|
// that list — used by encodeResourcePoliciesMap to translate
|
||||||
|
// ResourcePoliciesMap[resourceID][]*Policy into wire-side indexes.
|
||||||
|
func (e *componentEncoder) encodePolicies(policies []*types.Policy) ([]*proto.PolicyCompact, map[*types.Policy][]uint32) {
|
||||||
|
if len(policies) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]*proto.PolicyCompact, 0, len(policies))
|
||||||
|
idxByPolicy := make(map[*types.Policy][]uint32, len(policies))
|
||||||
|
|
||||||
|
for _, pol := range policies {
|
||||||
|
if !pol.HasSeqID() || !pol.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, r := range pol.Rules {
|
||||||
|
if r == nil || !r.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
idxByPolicy[pol] = append(idxByPolicy[pol], uint32(len(out)))
|
||||||
|
out = append(out, e.encodePolicyRule(pol, r))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, idxByPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodePolicyRule maps a single PolicyRule under pol to a PolicyCompact entry.
|
||||||
|
func (e *componentEncoder) encodePolicyRule(pol *types.Policy, r *types.PolicyRule) *proto.PolicyCompact {
|
||||||
|
return &proto.PolicyCompact{
|
||||||
|
Id: pol.AccountSeqID,
|
||||||
|
Action: networkmap.GetProtoAction(string(r.Action)),
|
||||||
|
Protocol: networkmap.GetProtoProtocol(string(r.Protocol)),
|
||||||
|
Bidirectional: r.Bidirectional,
|
||||||
|
Ports: portsToUint32(r.Ports),
|
||||||
|
PortRanges: portRangesToProto(r.PortRanges),
|
||||||
|
SourceGroupIds: e.groupSeqIDs(r.Sources),
|
||||||
|
DestinationGroupIds: e.groupSeqIDs(r.Destinations),
|
||||||
|
AuthorizedUser: r.AuthorizedUser,
|
||||||
|
AuthorizedGroups: e.encodeAuthorizedGroups(r.AuthorizedGroups),
|
||||||
|
SourceResource: e.resourceToProto(r.SourceResource),
|
||||||
|
DestinationResource: e.resourceToProto(r.DestinationResource),
|
||||||
|
SourcePostureCheckSeqIds: e.postureCheckSeqs(pol.SourcePostureChecks),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// groupSeqIDs maps the xid group IDs in src to their per-account seq ids,
|
||||||
|
// dropping any group that has no seq id assigned.
|
||||||
|
func (e *componentEncoder) groupSeqIDs(src []string) []uint32 {
|
||||||
|
if len(src) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]uint32, 0, len(src))
|
||||||
|
for _, gid := range src {
|
||||||
|
if seq, ok := e.groupSeq(gid); ok {
|
||||||
|
out = append(out, seq)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// unionPolicies merges c.Policies with every policy referenced by
|
||||||
|
// c.ResourcePoliciesMap, deduplicating by pointer identity. Resource-only
|
||||||
|
// policies (relevant to a NetworkResource but not to peer-pair traffic)
|
||||||
|
// only live in ResourcePoliciesMap; without this union step they'd be lost
|
||||||
|
// from the wire and the client's resource-policy lookup would come back
|
||||||
|
// empty.
|
||||||
|
func unionPolicies(policies []*types.Policy, resourcePolicies map[string][]*types.Policy) []*types.Policy {
|
||||||
|
// Fast path: non-router peers have no resource-only policies, so the
|
||||||
|
// "union" is identical to `policies`. Skip the dedup map allocation.
|
||||||
|
if len(resourcePolicies) == 0 {
|
||||||
|
return policies
|
||||||
|
}
|
||||||
|
seen := make(map[*types.Policy]struct{}, len(policies))
|
||||||
|
out := make([]*types.Policy, 0, len(policies))
|
||||||
|
for _, p := range policies {
|
||||||
|
if p == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[p]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[p] = struct{}{}
|
||||||
|
out = append(out, p)
|
||||||
|
}
|
||||||
|
for _, list := range resourcePolicies {
|
||||||
|
for _, p := range list {
|
||||||
|
if p == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[p]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[p] = struct{}{}
|
||||||
|
out = append(out, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeAuthorizedGroups translates rule.AuthorizedGroups (map keyed by
|
||||||
|
// group xid → local-user names) to the wire form (map keyed by group
|
||||||
|
// account_seq_id → UserNameList). Groups without a seq id are dropped —
|
||||||
|
// matches how source/destination group references handle the same case.
|
||||||
|
func (e *componentEncoder) encodeAuthorizedGroups(m map[string][]string) map[uint32]*proto.UserNameList {
|
||||||
|
if len(m) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make(map[uint32]*proto.UserNameList, len(m))
|
||||||
|
for groupID, names := range m {
|
||||||
|
seq, ok := e.groupSeq(groupID)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out[seq] = &proto.UserNameList{Names: append([]string(nil), names...)}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) groupSeq(groupID string) (uint32, bool) {
|
||||||
|
g, ok := e.components.Groups[groupID]
|
||||||
|
if !ok || !g.HasSeqID() {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return g.AccountSeqID, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// resourceToProto translates types.Resource for the wire. For peer-typed
|
||||||
|
// resources the peer id is converted to a peer index into the envelope's
|
||||||
|
// peers array. For other resource types only the type string is shipped
|
||||||
|
// today (Calculate's resource-typed rule path consults SourceResource only
|
||||||
|
// for "peer" — other types fall through to group-based lookup).
|
||||||
|
func (e *componentEncoder) resourceToProto(r types.Resource) *proto.ResourceCompact {
|
||||||
|
if r.ID == "" && r.Type == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := &proto.ResourceCompact{Type: string(r.Type)}
|
||||||
|
if r.Type == types.ResourceTypePeer && r.ID != "" {
|
||||||
|
if idx, ok := e.peerOrder[r.ID]; ok {
|
||||||
|
out.PeerIndexSet = true
|
||||||
|
out.PeerIndex = idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// postureCheckSeqs translates a slice of posture-check xids to their
|
||||||
|
// per-account integer ids using the NetworkMapComponents.PostureCheckXIDToSeq
|
||||||
|
// lookup. Unresolvable xids are silently dropped — matches how group/peer
|
||||||
|
// references handle the same case.
|
||||||
|
func (e *componentEncoder) postureCheckSeqs(xids []string) []uint32 {
|
||||||
|
if len(xids) == 0 || len(e.components.PostureCheckXIDToSeq) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]uint32, 0, len(xids))
|
||||||
|
for _, xid := range xids {
|
||||||
|
if seq, ok := e.components.PostureCheckXIDToSeq[xid]; ok {
|
||||||
|
out = append(out, seq)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// networkSeq translates a Network xid to its per-account integer id using
|
||||||
|
// the NetworkMapComponents.NetworkXIDToSeq lookup. Returns (0,false) when
|
||||||
|
// the xid isn't known — callers decide whether to skip the parent record.
|
||||||
|
func (e *componentEncoder) networkSeq(xid string) (uint32, bool) {
|
||||||
|
if xid == "" {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
seq, ok := e.components.NetworkXIDToSeq[xid]
|
||||||
|
if !ok || seq == 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return seq, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) encodeDNSSettings(s *types.DNSSettings) *proto.DNSSettingsCompact {
|
||||||
|
if s == nil || len(s.DisabledManagementGroups) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := &proto.DNSSettingsCompact{
|
||||||
|
DisabledManagementGroupIds: make([]uint32, 0, len(s.DisabledManagementGroups)),
|
||||||
|
}
|
||||||
|
for _, gid := range s.DisabledManagementGroups {
|
||||||
|
if seq, ok := e.groupSeq(gid); ok {
|
||||||
|
out.DisabledManagementGroupIds = append(out.DisabledManagementGroupIds, seq)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) encodeRoutes(routes []*nbroute.Route) []*proto.RouteRaw {
|
||||||
|
if len(routes) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]*proto.RouteRaw, 0, len(routes))
|
||||||
|
for _, r := range routes {
|
||||||
|
if r == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rr := &proto.RouteRaw{
|
||||||
|
Id: r.AccountSeqID,
|
||||||
|
NetId: string(r.NetID),
|
||||||
|
Description: r.Description,
|
||||||
|
KeepRoute: r.KeepRoute,
|
||||||
|
NetworkType: int32(r.NetworkType),
|
||||||
|
Masquerade: r.Masquerade,
|
||||||
|
Metric: int32(r.Metric),
|
||||||
|
Enabled: r.Enabled,
|
||||||
|
SkipAutoApply: r.SkipAutoApply,
|
||||||
|
Domains: r.Domains.ToPunycodeList(),
|
||||||
|
GroupIds: e.groupIDsToSeq(r.Groups),
|
||||||
|
AccessControlGroupIds: e.groupIDsToSeq(r.AccessControlGroups),
|
||||||
|
PeerGroupIds: e.groupIDsToSeq(r.PeerGroups),
|
||||||
|
}
|
||||||
|
if r.Network.IsValid() {
|
||||||
|
rr.NetworkCidr = r.Network.String()
|
||||||
|
}
|
||||||
|
if r.Peer != "" {
|
||||||
|
if idx, ok := e.peerOrder[r.Peer]; ok {
|
||||||
|
rr.PeerIndexSet = true
|
||||||
|
rr.PeerIndex = idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out = append(out, rr)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) groupIDsToSeq(groupIDs []string) []uint32 {
|
||||||
|
if len(groupIDs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]uint32, 0, len(groupIDs))
|
||||||
|
for _, gid := range groupIDs {
|
||||||
|
if seq, ok := e.groupSeq(gid); ok {
|
||||||
|
out = append(out, seq)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) encodeNameServerGroups(nsgs []*nbdns.NameServerGroup) []*proto.NameServerGroupRaw {
|
||||||
|
if len(nsgs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]*proto.NameServerGroupRaw, 0, len(nsgs))
|
||||||
|
for _, nsg := range nsgs {
|
||||||
|
if nsg == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
entry := &proto.NameServerGroupRaw{
|
||||||
|
Id: nsg.AccountSeqID,
|
||||||
|
Name: nsg.Name,
|
||||||
|
Description: nsg.Description,
|
||||||
|
Nameservers: encodeNameServers(nsg.NameServers),
|
||||||
|
GroupIds: e.groupIDsToSeq(nsg.Groups),
|
||||||
|
Primary: nsg.Primary,
|
||||||
|
Domains: nsg.Domains,
|
||||||
|
Enabled: nsg.Enabled,
|
||||||
|
SearchDomainsEnabled: nsg.SearchDomainsEnabled,
|
||||||
|
}
|
||||||
|
out = append(out, entry)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeNameServers(servers []nbdns.NameServer) []*proto.NameServer {
|
||||||
|
if len(servers) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]*proto.NameServer, 0, len(servers))
|
||||||
|
for _, s := range servers {
|
||||||
|
out = append(out, &proto.NameServer{
|
||||||
|
IP: s.IP.String(),
|
||||||
|
NSType: int64(s.NSType),
|
||||||
|
Port: int64(s.Port),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeSimpleRecords(records []nbdns.SimpleRecord) []*proto.SimpleRecord {
|
||||||
|
if len(records) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]*proto.SimpleRecord, 0, len(records))
|
||||||
|
for _, r := range records {
|
||||||
|
out = append(out, &proto.SimpleRecord{
|
||||||
|
Name: r.Name,
|
||||||
|
Type: int64(r.Type),
|
||||||
|
Class: r.Class,
|
||||||
|
TTL: int64(r.TTL),
|
||||||
|
RData: r.RData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeCustomZones(zones []nbdns.CustomZone) []*proto.CustomZone {
|
||||||
|
if len(zones) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]*proto.CustomZone, 0, len(zones))
|
||||||
|
for _, z := range zones {
|
||||||
|
out = append(out, &proto.CustomZone{
|
||||||
|
Domain: z.Domain,
|
||||||
|
Records: encodeSimpleRecords(z.Records),
|
||||||
|
SearchDomainDisabled: z.SearchDomainDisabled,
|
||||||
|
NonAuthoritative: z.NonAuthoritative,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) encodeNetworkResources(resources []*resourceTypes.NetworkResource) []*proto.NetworkResourceRaw {
|
||||||
|
if len(resources) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]*proto.NetworkResourceRaw, 0, len(resources))
|
||||||
|
for _, r := range resources {
|
||||||
|
if r == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
entry := &proto.NetworkResourceRaw{
|
||||||
|
Id: r.AccountSeqID,
|
||||||
|
Name: r.Name,
|
||||||
|
Description: r.Description,
|
||||||
|
Type: string(r.Type),
|
||||||
|
Address: r.Address,
|
||||||
|
DomainValue: r.Domain,
|
||||||
|
Enabled: r.Enabled,
|
||||||
|
}
|
||||||
|
if seq, ok := e.networkSeq(r.NetworkID); ok {
|
||||||
|
entry.NetworkSeq = seq
|
||||||
|
}
|
||||||
|
if r.Prefix.IsValid() {
|
||||||
|
entry.PrefixCidr = r.Prefix.String()
|
||||||
|
}
|
||||||
|
out = append(out, entry)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) encodeRoutersMap(routersMap map[string]map[string]*routerTypes.NetworkRouter) map[uint32]*proto.NetworkRouterList {
|
||||||
|
if len(routersMap) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make(map[uint32]*proto.NetworkRouterList, len(routersMap))
|
||||||
|
for networkXID, routers := range routersMap {
|
||||||
|
if len(routers) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
netSeq, ok := e.networkSeq(networkXID)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
entries := make([]*proto.NetworkRouterEntry, 0, len(routers))
|
||||||
|
for peerID, r := range routers {
|
||||||
|
if r == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
entry := &proto.NetworkRouterEntry{
|
||||||
|
Id: r.AccountSeqID,
|
||||||
|
PeerGroupIds: e.groupIDsToSeq(r.PeerGroups),
|
||||||
|
Masquerade: r.Masquerade,
|
||||||
|
Metric: int32(r.Metric),
|
||||||
|
Enabled: r.Enabled,
|
||||||
|
}
|
||||||
|
if idx, ok := e.peerOrder[peerID]; ok {
|
||||||
|
entry.PeerIndexSet = true
|
||||||
|
entry.PeerIndex = idx
|
||||||
|
}
|
||||||
|
entries = append(entries, entry)
|
||||||
|
}
|
||||||
|
out[netSeq] = &proto.NetworkRouterList{Entries: entries}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) encodeResourcePoliciesMap(rpm map[string][]*types.Policy, policyToIdxs map[*types.Policy][]uint32) map[uint32]*proto.PolicyIndexes {
|
||||||
|
if len(rpm) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// resourceXIDToSeq is local to one encode — built from components.NetworkResources
|
||||||
|
// (small slice). Network resources without seq id are dropped, matching how
|
||||||
|
// other components-without-seq are silently filtered.
|
||||||
|
resourceXIDToSeq := make(map[string]uint32, len(e.components.NetworkResources))
|
||||||
|
for _, r := range e.components.NetworkResources {
|
||||||
|
if r != nil && r.AccountSeqID != 0 {
|
||||||
|
resourceXIDToSeq[r.ID] = r.AccountSeqID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out := make(map[uint32]*proto.PolicyIndexes, len(rpm))
|
||||||
|
for resourceXID, policies := range rpm {
|
||||||
|
seq, ok := resourceXIDToSeq[resourceXID]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
idxs := make([]uint32, 0, len(policies)*2)
|
||||||
|
for _, pol := range policies {
|
||||||
|
idxs = append(idxs, policyToIdxs[pol]...)
|
||||||
|
}
|
||||||
|
if len(idxs) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out[seq] = &proto.PolicyIndexes{Indexes: idxs}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) encodeGroupIDToUserIDs(m map[string][]string) map[uint32]*proto.UserIDList {
|
||||||
|
if len(m) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make(map[uint32]*proto.UserIDList, len(m))
|
||||||
|
for groupID, userIDs := range m {
|
||||||
|
seq, ok := e.groupSeq(groupID)
|
||||||
|
if !ok || len(userIDs) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out[seq] = &proto.UserIDList{UserIds: userIDs}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringSetToSlice(s map[string]struct{}) []string {
|
||||||
|
if len(s) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]string, 0, len(s))
|
||||||
|
for k := range s {
|
||||||
|
out = append(out, k)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) encodePostureFailedPeers(m map[string]map[string]struct{}) map[uint32]*proto.PeerIndexSet {
|
||||||
|
if len(m) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make(map[uint32]*proto.PeerIndexSet, len(m))
|
||||||
|
for checkXID, failedPeerIDs := range m {
|
||||||
|
seq, ok := e.components.PostureCheckXIDToSeq[checkXID]
|
||||||
|
if !ok || seq == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
idxs := make([]uint32, 0, len(failedPeerIDs))
|
||||||
|
for peerID := range failedPeerIDs {
|
||||||
|
if idx, ok := e.peerOrder[peerID]; ok {
|
||||||
|
idxs = append(idxs, idx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(idxs) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out[seq] = &proto.PeerIndexSet{PeerIndexes: idxs}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// toAccountSettingsCompact always returns a non-nil message — the client
|
||||||
|
// dereferences it unconditionally during Calculate(), so a nil here would
|
||||||
|
// crash the receiver. A missing types.AccountSettingsInfo on the server
|
||||||
|
// (which shouldn't happen in production but the encoder is exported)
|
||||||
|
// degrades to login_expiration_enabled = false, which makes
|
||||||
|
// LoginExpired() return false for every peer.
|
||||||
|
func toAccountSettingsCompact(s *types.AccountSettingsInfo) *proto.AccountSettingsCompact {
|
||||||
|
if s == nil {
|
||||||
|
return &proto.AccountSettingsCompact{}
|
||||||
|
}
|
||||||
|
return &proto.AccountSettingsCompact{
|
||||||
|
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
|
||||||
|
PeerLoginExpirationNs: int64(s.PeerLoginExpiration),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toAccountNetwork(n *types.Network) *proto.AccountNetwork {
|
||||||
|
if n == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := &proto.AccountNetwork{
|
||||||
|
Identifier: n.Identifier,
|
||||||
|
NetCidr: n.Net.String(),
|
||||||
|
Dns: n.Dns,
|
||||||
|
Serial: n.CurrentSerial(),
|
||||||
|
}
|
||||||
|
if len(n.NetV6.IP) > 0 {
|
||||||
|
out.NetV6Cidr = n.NetV6.String()
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func toPeerCompact(p *nbpeer.Peer, agentVersionIdx uint32) *proto.PeerCompact {
|
||||||
|
pc := &proto.PeerCompact{
|
||||||
|
WgPubKey: decodeWgKey(p.Key),
|
||||||
|
SshPubKey: []byte(p.SSHKey),
|
||||||
|
DnsLabel: p.DNSLabel,
|
||||||
|
AgentVersionIdx: agentVersionIdx,
|
||||||
|
AddedWithSsoLogin: p.UserID != "",
|
||||||
|
LoginExpirationEnabled: p.LoginExpirationEnabled,
|
||||||
|
SshEnabled: p.SSHEnabled,
|
||||||
|
SupportsIpv6: p.SupportsIPv6(),
|
||||||
|
SupportsSourcePrefixes: p.SupportsSourcePrefixes(),
|
||||||
|
ServerSshAllowed: p.Meta.Flags.ServerSSHAllowed,
|
||||||
|
}
|
||||||
|
if p.LastLogin != nil {
|
||||||
|
pc.LastLoginUnixNano = p.LastLogin.UnixNano()
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case !p.IP.IsValid():
|
||||||
|
// leave Ip nil
|
||||||
|
case p.IP.Is4() || p.IP.Is4In6():
|
||||||
|
ip := p.IP.Unmap().As4()
|
||||||
|
pc.Ip = ip[:]
|
||||||
|
default:
|
||||||
|
ip := p.IP.As16()
|
||||||
|
pc.Ip = ip[:]
|
||||||
|
}
|
||||||
|
if p.IPv6.IsValid() {
|
||||||
|
ip := p.IPv6.As16()
|
||||||
|
pc.Ipv6 = ip[:]
|
||||||
|
}
|
||||||
|
return pc
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeWgKey returns the raw 32 bytes of a base64-encoded WireGuard public
|
||||||
|
// key, or nil for an empty / malformed key.
|
||||||
|
func decodeWgKey(s string) []byte {
|
||||||
|
if s == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]byte, wgKeyRawLen)
|
||||||
|
n, err := base64.StdEncoding.Decode(out, []byte(s))
|
||||||
|
if err != nil || n != wgKeyRawLen {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func portsToUint32(ports []string) []uint32 {
|
||||||
|
if len(ports) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]uint32, 0, len(ports))
|
||||||
|
for _, p := range ports {
|
||||||
|
v, err := strconv.ParseUint(p, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, uint32(v))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func portRangesToProto(ranges []types.RulePortRange) []*proto.PortInfo_Range {
|
||||||
|
if len(ranges) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]*proto.PortInfo_Range, 0, len(ranges))
|
||||||
|
for _, r := range ranges {
|
||||||
|
out = append(out, &proto.PortInfo_Range{
|
||||||
|
Start: uint32(r.Start),
|
||||||
|
End: uint32(r.End),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
879
management/internals/shared/grpc/components_encoder_test.go
Normal file
879
management/internals/shared/grpc/components_encoder_test.go
Normal file
@@ -0,0 +1,879 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"cmp"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
goproto "google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
nbroute "github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
const testWgKeyA = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
|
||||||
|
const testWgKeyB = "BBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
|
||||||
|
const testWgKeyC = "CBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
|
||||||
|
|
||||||
|
// canonicalize rewrites a NetworkMapComponentsFull in place into a canonical
|
||||||
|
// form: peers reordered by wg_pub_key, with the rest of the message rewritten
|
||||||
|
// to reference the new peer indexes. Groups, policies, and router indexes are
|
||||||
|
// also sorted. After canonicalize, two envelopes built from the same logical
|
||||||
|
// input compare byte-equal via proto.Equal.
|
||||||
|
//
|
||||||
|
// This lives on the test side — the encoder itself emits in map-iteration
|
||||||
|
// order. Test-side normalization is the contract for "two encodes are
|
||||||
|
// equivalent".
|
||||||
|
func canonicalize(full *proto.NetworkMapComponentsFull) {
|
||||||
|
if full == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Canonicalize agent_versions first: sort the slice and rewrite each
|
||||||
|
// peer's AgentVersionIdx accordingly. The empty placeholder stays at
|
||||||
|
// index 0 by convention.
|
||||||
|
avRemap := make(map[uint32]uint32, len(full.AgentVersions))
|
||||||
|
if len(full.AgentVersions) > 0 {
|
||||||
|
// Pair version → original index, sort, rebuild.
|
||||||
|
type avEntry struct {
|
||||||
|
version string
|
||||||
|
oldIdx uint32
|
||||||
|
}
|
||||||
|
entries := make([]avEntry, len(full.AgentVersions))
|
||||||
|
for i, v := range full.AgentVersions {
|
||||||
|
entries[i] = avEntry{version: v, oldIdx: uint32(i)}
|
||||||
|
}
|
||||||
|
// Empty stays at 0; sort the rest by string. Tiebreaker on oldIdx
|
||||||
|
// keeps the canonicalize output stable when two entries compare
|
||||||
|
// equal (the encoder dedups, but defending against future inputs).
|
||||||
|
slices.SortFunc(entries, func(a, b avEntry) int {
|
||||||
|
if a.version == "" && b.version != "" {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if b.version == "" && a.version != "" {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if c := cmp.Compare(a.version, b.version); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
return cmp.Compare(a.oldIdx, b.oldIdx)
|
||||||
|
})
|
||||||
|
newVersions := make([]string, len(entries))
|
||||||
|
for newIdx, e := range entries {
|
||||||
|
avRemap[e.oldIdx] = uint32(newIdx)
|
||||||
|
newVersions[newIdx] = e.version
|
||||||
|
}
|
||||||
|
full.AgentVersions = newVersions
|
||||||
|
}
|
||||||
|
for _, p := range full.Peers {
|
||||||
|
if newIdx, ok := avRemap[p.AgentVersionIdx]; ok {
|
||||||
|
p.AgentVersionIdx = newIdx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type peerEntry struct {
|
||||||
|
peer *proto.PeerCompact
|
||||||
|
oldIdx uint32
|
||||||
|
}
|
||||||
|
entries := make([]peerEntry, len(full.Peers))
|
||||||
|
for i, p := range full.Peers {
|
||||||
|
entries[i] = peerEntry{peer: p, oldIdx: uint32(i)}
|
||||||
|
}
|
||||||
|
// DnsLabel is unique per peer; it tiebreaks on equal WgPubKey (e.g. both
|
||||||
|
// nil from malformed keys, or both empty for placeholders).
|
||||||
|
slices.SortFunc(entries, func(a, b peerEntry) int {
|
||||||
|
if c := bytes.Compare(a.peer.WgPubKey, b.peer.WgPubKey); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
return cmp.Compare(a.peer.DnsLabel, b.peer.DnsLabel)
|
||||||
|
})
|
||||||
|
|
||||||
|
remap := make(map[uint32]uint32, len(entries))
|
||||||
|
newPeers := make([]*proto.PeerCompact, len(entries))
|
||||||
|
for newIdx, e := range entries {
|
||||||
|
remap[e.oldIdx] = uint32(newIdx)
|
||||||
|
newPeers[newIdx] = e.peer
|
||||||
|
}
|
||||||
|
full.Peers = newPeers
|
||||||
|
|
||||||
|
full.RouterPeerIndexes = remapAndSort(full.RouterPeerIndexes, remap)
|
||||||
|
for _, g := range full.Groups {
|
||||||
|
g.PeerIndexes = remapAndSort(g.PeerIndexes, remap)
|
||||||
|
}
|
||||||
|
slices.SortFunc(full.Groups, func(a, b *proto.GroupCompact) int { return cmp.Compare(a.Id, b.Id) })
|
||||||
|
|
||||||
|
for _, r := range full.Routes {
|
||||||
|
if r.PeerIndexSet {
|
||||||
|
if newIdx, ok := remap[r.PeerIndex]; ok {
|
||||||
|
r.PeerIndex = newIdx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slices.Sort(r.GroupIds)
|
||||||
|
slices.Sort(r.AccessControlGroupIds)
|
||||||
|
slices.Sort(r.PeerGroupIds)
|
||||||
|
}
|
||||||
|
slices.SortFunc(full.Routes, func(a, b *proto.RouteRaw) int { return cmp.Compare(a.Id, b.Id) })
|
||||||
|
|
||||||
|
for _, list := range full.RoutersMap {
|
||||||
|
for _, entry := range list.Entries {
|
||||||
|
if entry.PeerIndexSet {
|
||||||
|
if newIdx, ok := remap[entry.PeerIndex]; ok {
|
||||||
|
entry.PeerIndex = newIdx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slices.Sort(entry.PeerGroupIds)
|
||||||
|
}
|
||||||
|
slices.SortFunc(list.Entries, func(a, b *proto.NetworkRouterEntry) int { return cmp.Compare(a.Id, b.Id) })
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, set := range full.PostureFailedPeers {
|
||||||
|
set.PeerIndexes = remapAndSort(set.PeerIndexes, remap)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range full.Policies {
|
||||||
|
slices.Sort(p.SourceGroupIds)
|
||||||
|
slices.Sort(p.DestinationGroupIds)
|
||||||
|
}
|
||||||
|
// Sort policies by (Id, source_group_ids, destination_group_ids) so that
|
||||||
|
// multiple PolicyCompact entries sharing the same Id (one per rule, when
|
||||||
|
// a Policy has multiple rules) still get a deterministic order. After
|
||||||
|
// sorting we remap indexes in ResourcePoliciesMap.
|
||||||
|
policyOldOrder := make(map[*proto.PolicyCompact]uint32, len(full.Policies))
|
||||||
|
for i, p := range full.Policies {
|
||||||
|
policyOldOrder[p] = uint32(i)
|
||||||
|
}
|
||||||
|
slices.SortFunc(full.Policies, func(a, b *proto.PolicyCompact) int {
|
||||||
|
if c := cmp.Compare(a.Id, b.Id); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
if c := slices.Compare(a.SourceGroupIds, b.SourceGroupIds); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
return slices.Compare(a.DestinationGroupIds, b.DestinationGroupIds)
|
||||||
|
})
|
||||||
|
policyRemap := make(map[uint32]uint32, len(full.Policies))
|
||||||
|
for newIdx, p := range full.Policies {
|
||||||
|
policyRemap[policyOldOrder[p]] = uint32(newIdx)
|
||||||
|
}
|
||||||
|
for _, idxs := range full.ResourcePoliciesMap {
|
||||||
|
idxs.Indexes = remapAndSort(idxs.Indexes, policyRemap)
|
||||||
|
}
|
||||||
|
for _, list := range full.GroupIdToUserIds {
|
||||||
|
slices.Sort(list.UserIds)
|
||||||
|
}
|
||||||
|
slices.Sort(full.AllowedUserIds)
|
||||||
|
}
|
||||||
|
|
||||||
|
func remapAndSort(idxs []uint32, remap map[uint32]uint32) []uint32 {
|
||||||
|
out := make([]uint32, 0, len(idxs))
|
||||||
|
for _, i := range idxs {
|
||||||
|
if newIdx, ok := remap[i]; ok {
|
||||||
|
out = append(out, newIdx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slices.Sort(out)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// envelopesEquivalent decodes both envelopes, canonicalizes them, and reports
|
||||||
|
// whether they're proto.Equal. Use instead of byte-comparing marshaled output:
|
||||||
|
// the encoder is intentionally non-deterministic.
|
||||||
|
func envelopesEquivalent(a, b *proto.NetworkMapEnvelope) bool {
|
||||||
|
canonicalize(a.GetFull())
|
||||||
|
canonicalize(b.GetFull())
|
||||||
|
return goproto.Equal(a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestComponents() *types.NetworkMapComponents {
|
||||||
|
peerA := &nbpeer.Peer{
|
||||||
|
ID: "peer-a",
|
||||||
|
Key: testWgKeyA,
|
||||||
|
IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}),
|
||||||
|
DNSLabel: "peera",
|
||||||
|
SSHKey: "ssh-a",
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
||||||
|
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||||
|
}
|
||||||
|
peerB := &nbpeer.Peer{
|
||||||
|
ID: "peer-b",
|
||||||
|
Key: testWgKeyB,
|
||||||
|
IP: netip.AddrFrom4([4]byte{100, 64, 0, 2}),
|
||||||
|
IPv6: netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}),
|
||||||
|
DNSLabel: "peerb",
|
||||||
|
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.25.0"},
|
||||||
|
}
|
||||||
|
peerC := &nbpeer.Peer{
|
||||||
|
ID: "peer-c",
|
||||||
|
Key: testWgKeyC,
|
||||||
|
IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}),
|
||||||
|
DNSLabel: "peerc",
|
||||||
|
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||||
|
}
|
||||||
|
|
||||||
|
return &types.NetworkMapComponents{
|
||||||
|
PeerID: "peer-a",
|
||||||
|
Network: &types.Network{
|
||||||
|
Identifier: "net-test",
|
||||||
|
Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)},
|
||||||
|
Serial: 7,
|
||||||
|
},
|
||||||
|
AccountSettings: &types.AccountSettingsInfo{
|
||||||
|
PeerLoginExpirationEnabled: true,
|
||||||
|
PeerLoginExpiration: 2 * time.Hour,
|
||||||
|
},
|
||||||
|
Peers: map[string]*nbpeer.Peer{
|
||||||
|
"peer-a": peerA,
|
||||||
|
"peer-b": peerB,
|
||||||
|
"peer-c": peerC,
|
||||||
|
},
|
||||||
|
Groups: map[string]*types.Group{
|
||||||
|
"group-src": {ID: "group-src", AccountSeqID: 1, Name: "Src", Peers: []string{"peer-a"}},
|
||||||
|
"group-dst": {ID: "group-dst", AccountSeqID: 2, Name: "Dst", Peers: []string{"peer-b", "peer-c"}},
|
||||||
|
},
|
||||||
|
Policies: []*types.Policy{
|
||||||
|
{
|
||||||
|
ID: "pol-1",
|
||||||
|
AccountSeqID: 10,
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*types.PolicyRule{{
|
||||||
|
ID: "rule-1", Enabled: true, Action: types.PolicyTrafficActionAccept,
|
||||||
|
Protocol: types.PolicyRuleProtocolTCP, Bidirectional: true,
|
||||||
|
Ports: []string{"22", "80"},
|
||||||
|
PortRanges: []types.RulePortRange{{Start: 8000, End: 8100}},
|
||||||
|
Sources: []string{"group-src"},
|
||||||
|
Destinations: []string{"group-dst"},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
RouterPeers: map[string]*nbpeer.Peer{"peer-c": peerC},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_Basic(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||||
|
Components: c,
|
||||||
|
DNSDomain: "netbird.cloud",
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NotNil(t, env)
|
||||||
|
full := env.GetFull()
|
||||||
|
require.NotNil(t, full, "envelope must contain Full payload")
|
||||||
|
|
||||||
|
assert.EqualValues(t, 7, full.Serial)
|
||||||
|
assert.Equal(t, "netbird.cloud", full.DnsDomain)
|
||||||
|
|
||||||
|
require.NotNil(t, full.Network)
|
||||||
|
assert.Equal(t, "net-test", full.Network.Identifier)
|
||||||
|
assert.Equal(t, "100.64.0.0/10", full.Network.NetCidr)
|
||||||
|
|
||||||
|
require.NotNil(t, full.AccountSettings)
|
||||||
|
assert.True(t, full.AccountSettings.PeerLoginExpirationEnabled)
|
||||||
|
assert.EqualValues(t, (2 * time.Hour).Nanoseconds(), full.AccountSettings.PeerLoginExpirationNs)
|
||||||
|
|
||||||
|
require.Len(t, full.Peers, 3)
|
||||||
|
byLabel := map[string]*proto.PeerCompact{}
|
||||||
|
for _, p := range full.Peers {
|
||||||
|
assert.Len(t, p.WgPubKey, 32, "wg key must be raw 32 bytes")
|
||||||
|
assert.Len(t, p.Ip, 4, "ipv4 must be raw 4 bytes")
|
||||||
|
byLabel[p.DnsLabel] = p
|
||||||
|
}
|
||||||
|
assert.Len(t, byLabel["peerb"].Ipv6, 16, "peer-b has ipv6 → 16 bytes")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_RepeatEncodesEquivalent(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
|
||||||
|
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||||
|
|
||||||
|
// Hammer it 100 times — Go map iteration is randomized per call, so each
|
||||||
|
// run produces different wire bytes, but the canonicalized form must
|
||||||
|
// match.
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
got := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||||
|
require.True(t, envelopesEquivalent(expected, got),
|
||||||
|
"encode #%d must be semantically equivalent to first encode", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_ConcurrentEncodesEquivalent(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
|
||||||
|
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||||
|
|
||||||
|
const goroutines = 50
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(goroutines)
|
||||||
|
results := make([]*proto.NetworkMapEnvelope, goroutines)
|
||||||
|
for i := 0; i < goroutines; i++ {
|
||||||
|
i := i
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
results[i] = EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
for i, got := range results {
|
||||||
|
require.NotNil(t, got, "goroutine %d returned nil", i)
|
||||||
|
require.True(t, envelopesEquivalent(expected, got),
|
||||||
|
"goroutine %d produced inequivalent envelope", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_GroupsByAccountSeqID(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Len(t, full.Groups, 2)
|
||||||
|
|
||||||
|
groupByID := map[uint32]*proto.GroupCompact{}
|
||||||
|
for _, g := range full.Groups {
|
||||||
|
groupByID[g.Id] = g
|
||||||
|
}
|
||||||
|
require.Contains(t, groupByID, uint32(1))
|
||||||
|
require.Contains(t, groupByID, uint32(2))
|
||||||
|
assert.Equal(t, "Src", groupByID[1].Name)
|
||||||
|
assert.Equal(t, "Dst", groupByID[2].Name)
|
||||||
|
assert.Len(t, groupByID[1].PeerIndexes, 1)
|
||||||
|
assert.Len(t, groupByID[2].PeerIndexes, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_PolicyExpansion(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Len(t, full.Policies, 1)
|
||||||
|
pc := full.Policies[0]
|
||||||
|
assert.EqualValues(t, 10, pc.Id)
|
||||||
|
assert.Equal(t, proto.RuleAction_ACCEPT, pc.Action)
|
||||||
|
assert.Equal(t, proto.RuleProtocol_TCP, pc.Protocol)
|
||||||
|
assert.True(t, pc.Bidirectional)
|
||||||
|
assert.Equal(t, []uint32{22, 80}, pc.Ports)
|
||||||
|
require.Len(t, pc.PortRanges, 1)
|
||||||
|
assert.EqualValues(t, 8000, pc.PortRanges[0].Start)
|
||||||
|
assert.EqualValues(t, 8100, pc.PortRanges[0].End)
|
||||||
|
assert.Equal(t, []uint32{1}, pc.SourceGroupIds)
|
||||||
|
assert.Equal(t, []uint32{2}, pc.DestinationGroupIds)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_RouterIndexes(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Len(t, full.RouterPeerIndexes, 1)
|
||||||
|
idx := full.RouterPeerIndexes[0]
|
||||||
|
require.Less(t, int(idx), len(full.Peers))
|
||||||
|
assert.Equal(t, "peerc", full.Peers[idx].DnsLabel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_AgentVersionDedup(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Len(t, full.AgentVersions, 3, "empty placeholder + 2 distinct versions")
|
||||||
|
assert.Equal(t, "", full.AgentVersions[0], "index 0 reserved for empty version")
|
||||||
|
assert.ElementsMatch(t, []string{"0.40.0", "0.25.0"}, full.AgentVersions[1:],
|
||||||
|
"two distinct versions, order depends on map iteration")
|
||||||
|
|
||||||
|
idxByLabel := map[string]uint32{}
|
||||||
|
for _, p := range full.Peers {
|
||||||
|
idxByLabel[p.DnsLabel] = p.AgentVersionIdx
|
||||||
|
}
|
||||||
|
assert.Equal(t, idxByLabel["peera"], idxByLabel["peerc"], "peers with the same agent version share an index")
|
||||||
|
assert.NotEqual(t, idxByLabel["peera"], idxByLabel["peerb"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_DisabledPolicySkipped(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
c.Policies[0].Enabled = false
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
assert.Empty(t, full.Policies)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_GroupZeroSeqIDSkipped(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
c.Groups["group-src"].AccountSeqID = 0
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Len(t, full.Groups, 1, "groups with AccountSeqID=0 are not yet persisted and must be skipped")
|
||||||
|
assert.EqualValues(t, 2, full.Groups[0].Id)
|
||||||
|
|
||||||
|
require.Len(t, full.Policies, 1)
|
||||||
|
pc := full.Policies[0]
|
||||||
|
assert.Empty(t, pc.SourceGroupIds, "rule references a group that was filtered out → no group id on wire")
|
||||||
|
assert.Equal(t, []uint32{2}, pc.DestinationGroupIds)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_TwoPeersSameMalformedKey(t *testing.T) {
|
||||||
|
// Both peers have nil WgPubKey after decode; canonicalize must still
|
||||||
|
// produce a stable order using DnsLabel as a tiebreaker, so 100 encodes
|
||||||
|
// canonicalize identically.
|
||||||
|
c := newTestComponents()
|
||||||
|
c.Peers["peer-a"].Key = "garbage-a-!!!"
|
||||||
|
c.Peers["peer-b"].Key = "garbage-b-!!!"
|
||||||
|
|
||||||
|
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
got := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||||
|
require.True(t, envelopesEquivalent(expected, got),
|
||||||
|
"encode #%d with two same-key peers must canonicalize equivalently", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_MalformedWgKey(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
c.Peers["peer-a"].Key = "not-base64-!!!"
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Len(t, full.Peers, 3)
|
||||||
|
|
||||||
|
var byLabel = map[string]*proto.PeerCompact{}
|
||||||
|
for _, p := range full.Peers {
|
||||||
|
byLabel[p.DnsLabel] = p
|
||||||
|
}
|
||||||
|
assert.Nil(t, byLabel["peera"].WgPubKey, "peer with malformed key encodes nil WgPubKey")
|
||||||
|
assert.Len(t, byLabel["peerb"].WgPubKey, 32, "other peers retain their key")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_IPv6OnlyPeer(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
v6Only := &nbpeer.Peer{
|
||||||
|
ID: "peer-v6",
|
||||||
|
Key: testWgKeyA,
|
||||||
|
IPv6: netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9}),
|
||||||
|
DNSLabel: "peerv6",
|
||||||
|
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||||
|
}
|
||||||
|
c.Peers["peer-v6"] = v6Only
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
var found *proto.PeerCompact
|
||||||
|
for _, p := range full.Peers {
|
||||||
|
if p.DnsLabel == "peerv6" {
|
||||||
|
found = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotNil(t, found, "ipv6-only peer must be present")
|
||||||
|
assert.Empty(t, found.Ip, "no IPv4 address → empty Ip")
|
||||||
|
assert.Len(t, found.Ipv6, 16)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_PeerWithoutIP(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
c.Peers["peer-noip"] = &nbpeer.Peer{
|
||||||
|
ID: "peer-noip",
|
||||||
|
Key: testWgKeyA,
|
||||||
|
DNSLabel: "peernoip",
|
||||||
|
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||||
|
}
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
var found *proto.PeerCompact
|
||||||
|
for _, p := range full.Peers {
|
||||||
|
if p.DnsLabel == "peernoip" {
|
||||||
|
found = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotNil(t, found)
|
||||||
|
assert.Empty(t, found.Ip)
|
||||||
|
assert.Empty(t, found.Ipv6)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_EmptyInput(t *testing.T) {
|
||||||
|
c := &types.NetworkMapComponents{
|
||||||
|
Network: &types.Network{Identifier: "x", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}},
|
||||||
|
}
|
||||||
|
|
||||||
|
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||||
|
|
||||||
|
full := env.GetFull()
|
||||||
|
require.NotNil(t, full)
|
||||||
|
assert.Empty(t, full.Peers)
|
||||||
|
assert.Empty(t, full.Groups)
|
||||||
|
assert.Empty(t, full.Policies)
|
||||||
|
assert.Empty(t, full.RouterPeerIndexes)
|
||||||
|
require.NotNil(t, full.AccountSettings, "AccountSettingsCompact must always be emitted (client dereferences it unconditionally)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_PeerLoginExpirationFields(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
now := time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC)
|
||||||
|
c.Peers["peer-a"].UserID = "user-1"
|
||||||
|
c.Peers["peer-a"].LoginExpirationEnabled = true
|
||||||
|
c.Peers["peer-a"].LastLogin = &now
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
var pa *proto.PeerCompact
|
||||||
|
for _, p := range full.Peers {
|
||||||
|
if p.DnsLabel == "peera" {
|
||||||
|
pa = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotNil(t, pa)
|
||||||
|
assert.True(t, pa.AddedWithSsoLogin)
|
||||||
|
assert.True(t, pa.LoginExpirationEnabled)
|
||||||
|
assert.Equal(t, now.UnixNano(), pa.LastLoginUnixNano)
|
||||||
|
|
||||||
|
// peer-b has no UserID and no LastLogin → all fields zero-value.
|
||||||
|
var pb *proto.PeerCompact
|
||||||
|
for _, p := range full.Peers {
|
||||||
|
if p.DnsLabel == "peerb" {
|
||||||
|
pb = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotNil(t, pb)
|
||||||
|
assert.False(t, pb.AddedWithSsoLogin)
|
||||||
|
assert.False(t, pb.LoginExpirationEnabled)
|
||||||
|
assert.Zero(t, pb.LastLoginUnixNano)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_RoutesRoundTrip(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
c.Routes = []*nbroute.Route{
|
||||||
|
{
|
||||||
|
ID: "route-peer",
|
||||||
|
AccountSeqID: 100,
|
||||||
|
NetID: "net-A",
|
||||||
|
Description: "via peer-c",
|
||||||
|
Network: netip.MustParsePrefix("10.0.0.0/16"),
|
||||||
|
Peer: "peer-c", // peer ID, not WG key
|
||||||
|
Groups: []string{"group-src"},
|
||||||
|
AccessControlGroups: []string{"group-dst"},
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "route-peergroup",
|
||||||
|
AccountSeqID: 101,
|
||||||
|
NetID: "net-B",
|
||||||
|
Network: netip.MustParsePrefix("10.1.0.0/16"),
|
||||||
|
PeerGroups: []string{"group-src", "group-dst"},
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "route-no-seq",
|
||||||
|
AccountSeqID: 0, // unset — should still ship (no group seq filter on routes)
|
||||||
|
Network: netip.MustParsePrefix("10.2.0.0/16"),
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Len(t, full.Routes, 3)
|
||||||
|
byNetID := map[string]*proto.RouteRaw{}
|
||||||
|
for _, r := range full.Routes {
|
||||||
|
byNetID[r.NetId] = r
|
||||||
|
}
|
||||||
|
|
||||||
|
r1 := byNetID["net-A"]
|
||||||
|
require.NotNil(t, r1)
|
||||||
|
assert.True(t, r1.PeerIndexSet, "route with peer must set peer_index_set")
|
||||||
|
require.Less(t, int(r1.PeerIndex), len(full.Peers))
|
||||||
|
assert.Equal(t, "peerc", full.Peers[r1.PeerIndex].DnsLabel)
|
||||||
|
assert.Equal(t, []uint32{1}, r1.GroupIds, "group-src has AccountSeqID 1")
|
||||||
|
assert.Equal(t, []uint32{2}, r1.AccessControlGroupIds, "group-dst has AccountSeqID 2")
|
||||||
|
assert.Empty(t, r1.PeerGroupIds)
|
||||||
|
|
||||||
|
r2 := byNetID["net-B"]
|
||||||
|
require.NotNil(t, r2)
|
||||||
|
assert.False(t, r2.PeerIndexSet, "route with peer_groups must NOT set peer_index_set")
|
||||||
|
assert.ElementsMatch(t, []uint32{1, 2}, r2.PeerGroupIds)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_RouteWithMissingPeerLeavesIndexUnset(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
c.Routes = []*nbroute.Route{{
|
||||||
|
ID: "route-x",
|
||||||
|
AccountSeqID: 100,
|
||||||
|
Peer: "peer-not-in-components",
|
||||||
|
Network: netip.MustParsePrefix("10.0.0.0/16"),
|
||||||
|
Enabled: true,
|
||||||
|
}}
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Len(t, full.Routes, 1)
|
||||||
|
assert.False(t, full.Routes[0].PeerIndexSet,
|
||||||
|
"missing peer reference must not pretend to point at peer index 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_ResourceOnlyPolicyShippedAndIndexed(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
// Policy that exists ONLY in ResourcePoliciesMap, not in c.Policies. This
|
||||||
|
// is the I1 case — without unionPolicies the encoder would silently
|
||||||
|
// drop it from the wire.
|
||||||
|
resourceOnlyPolicy := &types.Policy{
|
||||||
|
ID: "pol-resource", AccountSeqID: 99, Enabled: true,
|
||||||
|
Rules: []*types.PolicyRule{{
|
||||||
|
ID: "rule-r", Enabled: true, Action: types.PolicyTrafficActionAccept,
|
||||||
|
Protocol: types.PolicyRuleProtocolTCP,
|
||||||
|
Sources: []string{"group-src"},
|
||||||
|
Destinations: []string{"group-dst"},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
c.ResourcePoliciesMap = map[string][]*types.Policy{
|
||||||
|
"resource-x": {c.Policies[0], resourceOnlyPolicy}, // shared + resource-only
|
||||||
|
}
|
||||||
|
// Resource must appear in components.NetworkResources with a seq id —
|
||||||
|
// encoder uses that to translate the xid map key to uint32.
|
||||||
|
c.NetworkResources = []*resourceTypes.NetworkResource{
|
||||||
|
{ID: "resource-x", AccountSeqID: 77, Name: "res-x", Enabled: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Len(t, full.Policies, 2, "encoded policies must include both peer-traffic and resource-only")
|
||||||
|
|
||||||
|
policyByID := map[uint32]*proto.PolicyCompact{}
|
||||||
|
policyIdxByID := map[uint32]uint32{}
|
||||||
|
for i, p := range full.Policies {
|
||||||
|
policyByID[p.Id] = p
|
||||||
|
policyIdxByID[p.Id] = uint32(i)
|
||||||
|
}
|
||||||
|
require.Contains(t, policyByID, uint32(10), "original peer-traffic policy id 10")
|
||||||
|
require.Contains(t, policyByID, uint32(99), "resource-only policy id 99")
|
||||||
|
|
||||||
|
require.Contains(t, full.ResourcePoliciesMap, uint32(77))
|
||||||
|
idxs := full.ResourcePoliciesMap[77].Indexes
|
||||||
|
require.Len(t, idxs, 2)
|
||||||
|
assert.ElementsMatch(t, []uint32{policyIdxByID[10], policyIdxByID[99]}, idxs,
|
||||||
|
"resource policies map must reference both wire policy indexes")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_NameServerGroups(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
c.NameServerGroups = []*nbdns.NameServerGroup{{
|
||||||
|
ID: "nsg-1", AccountSeqID: 50, Name: "Main", Description: "primary",
|
||||||
|
NameServers: []nbdns.NameServer{{
|
||||||
|
IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53,
|
||||||
|
}},
|
||||||
|
Groups: []string{"group-src", "group-not-persisted"},
|
||||||
|
Primary: true, Enabled: true,
|
||||||
|
Domains: []string{"corp.example"},
|
||||||
|
}}
|
||||||
|
c.Groups["group-not-persisted"] = &types.Group{ID: "group-not-persisted", AccountSeqID: 0, Peers: []string{}}
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Len(t, full.NameserverGroups, 1)
|
||||||
|
nsg := full.NameserverGroups[0]
|
||||||
|
assert.EqualValues(t, 50, nsg.Id)
|
||||||
|
assert.Equal(t, "Main", nsg.Name)
|
||||||
|
assert.True(t, nsg.Primary)
|
||||||
|
require.Len(t, nsg.Nameservers, 1)
|
||||||
|
assert.Equal(t, "8.8.8.8", nsg.Nameservers[0].IP)
|
||||||
|
assert.Equal(t, []uint32{1}, nsg.GroupIds, "group-not-persisted is filtered out (AccountSeqID=0)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_PostureFailedPeers(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
c.PostureCheckXIDToSeq = map[string]uint32{"check-1": 33}
|
||||||
|
c.PostureFailedPeers = map[string]map[string]struct{}{
|
||||||
|
"check-1": {
|
||||||
|
"peer-a": {},
|
||||||
|
"peer-b": {},
|
||||||
|
"peer-not-in-account": {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Contains(t, full.PostureFailedPeers, uint32(33))
|
||||||
|
idxs := full.PostureFailedPeers[33].PeerIndexes
|
||||||
|
assert.Len(t, idxs, 2, "missing peer is silently dropped (filterPostureFailedPeers guarantees presence in real data)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_RoutersMap(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
c.NetworkXIDToSeq = map[string]uint32{"net-1": 5}
|
||||||
|
c.RoutersMap = map[string]map[string]*routerTypes.NetworkRouter{
|
||||||
|
"net-1": {
|
||||||
|
"peer-c": {
|
||||||
|
ID: "router-1", AccountSeqID: 200,
|
||||||
|
Peer: "peer-c", Masquerade: true, Metric: 10, Enabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Contains(t, full.RoutersMap, uint32(5))
|
||||||
|
entries := full.RoutersMap[5].Entries
|
||||||
|
require.Len(t, entries, 1)
|
||||||
|
e := entries[0]
|
||||||
|
assert.EqualValues(t, 200, e.Id)
|
||||||
|
assert.True(t, e.PeerIndexSet)
|
||||||
|
require.Less(t, int(e.PeerIndex), len(full.Peers))
|
||||||
|
assert.Equal(t, "peerc", full.Peers[e.PeerIndex].DnsLabel)
|
||||||
|
assert.True(t, e.Masquerade)
|
||||||
|
assert.EqualValues(t, 10, e.Metric)
|
||||||
|
assert.True(t, e.Enabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_RouterPeerNotInComponentsPeers(t *testing.T) {
|
||||||
|
// Router peer in c.RouterPeers but NOT in c.Peers (validation may have
|
||||||
|
// filtered it). indexRouterPeers runs before encodeRoutersMap, so the
|
||||||
|
// peer_index reference must still resolve.
|
||||||
|
c := newTestComponents()
|
||||||
|
delete(c.Peers, "peer-c")
|
||||||
|
routerPeer := &nbpeer.Peer{
|
||||||
|
ID: "peer-c", Key: testWgKeyC, IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}),
|
||||||
|
DNSLabel: "peerc", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||||
|
}
|
||||||
|
c.RouterPeers = map[string]*nbpeer.Peer{"peer-c": routerPeer}
|
||||||
|
c.NetworkXIDToSeq = map[string]uint32{"net-1": 5}
|
||||||
|
c.RoutersMap = map[string]map[string]*routerTypes.NetworkRouter{
|
||||||
|
"net-1": {"peer-c": {ID: "r-1", AccountSeqID: 1, Peer: "peer-c", Enabled: true}},
|
||||||
|
}
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Contains(t, full.RoutersMap, uint32(5))
|
||||||
|
require.Len(t, full.RoutersMap[5].Entries, 1)
|
||||||
|
e := full.RoutersMap[5].Entries[0]
|
||||||
|
assert.True(t, e.PeerIndexSet, "router peer must be indexed even when not in c.Peers")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_DNSSettingsFiltersUnpersistedGroups(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
c.DNSSettings = &types.DNSSettings{
|
||||||
|
DisabledManagementGroups: []string{"group-src", "group-missing", "group-no-seq"},
|
||||||
|
}
|
||||||
|
c.Groups["group-no-seq"] = &types.Group{ID: "group-no-seq", AccountSeqID: 0}
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.NotNil(t, full.DnsSettings)
|
||||||
|
assert.Equal(t, []uint32{1}, full.DnsSettings.DisabledManagementGroupIds,
|
||||||
|
"only group-src (AccountSeqID=1) survives — missing and unpersisted are dropped")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_GroupIDToUserIDs(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
c.GroupIDToUserIDs = map[string][]string{
|
||||||
|
"group-src": {"user-1", "user-2"},
|
||||||
|
"group-no-seq": {"user-3"}, // group not persisted → drop
|
||||||
|
"group-missing": {"user-4"}, // group not in components → drop
|
||||||
|
}
|
||||||
|
c.Groups["group-no-seq"] = &types.Group{ID: "group-no-seq", AccountSeqID: 0}
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Len(t, full.GroupIdToUserIds, 1, "only persisted+present groups survive")
|
||||||
|
require.Contains(t, full.GroupIdToUserIds, uint32(1))
|
||||||
|
assert.ElementsMatch(t, []string{"user-1", "user-2"}, full.GroupIdToUserIds[1].UserIds)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProxyPatch_EmptyInputReturnsNil(t *testing.T) {
|
||||||
|
assert.Nil(t, toProxyPatch(nil, "netbird.cloud", false, false))
|
||||||
|
assert.Nil(t, toProxyPatch(&types.NetworkMap{}, "netbird.cloud", false, false),
|
||||||
|
"empty NetworkMap (no peers, rules, routes etc) → nil patch so proto3 omits the field")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProxyPatch_PopulatesAllFields(t *testing.T) {
|
||||||
|
nm := &types.NetworkMap{
|
||||||
|
Peers: []*nbpeer.Peer{{
|
||||||
|
ID: "ext-peer", Key: testWgKeyA, IP: netip.AddrFrom4([4]byte{100, 64, 0, 9}),
|
||||||
|
DNSLabel: "extpeer", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||||
|
}},
|
||||||
|
FirewallRules: []*types.FirewallRule{{
|
||||||
|
PeerIP: "100.64.0.9", Action: "accept", Direction: 0, Protocol: "tcp",
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
patch := toProxyPatch(nm, "netbird.cloud", false, false)
|
||||||
|
|
||||||
|
require.NotNil(t, patch)
|
||||||
|
assert.Len(t, patch.Peers, 1)
|
||||||
|
assert.Len(t, patch.FirewallRules, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEncodeNetworkMapEnvelope_ProxyPatchPropagated covers the ProxyPatch
|
||||||
|
// pass-through in both encoder branches (normal path + nil-Components
|
||||||
|
// graceful-degrade). Guards against a regression that drops `ProxyPatch:`
|
||||||
|
// from one of the envelope struct literals.
|
||||||
|
func TestEncodeNetworkMapEnvelope_ProxyPatchPropagated(t *testing.T) {
|
||||||
|
patch := &proto.ProxyPatch{
|
||||||
|
ForwardingRules: []*proto.ForwardingRule{{
|
||||||
|
Protocol: proto.RuleProtocol_TCP,
|
||||||
|
DestinationPort: &proto.PortInfo{PortSelection: &proto.PortInfo_Port{Port: 80}},
|
||||||
|
TranslatedAddress: net.IPv4(10, 0, 0, 1).To4(),
|
||||||
|
TranslatedPort: &proto.PortInfo{PortSelection: &proto.PortInfo_Port{Port: 8080}},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("normal_path", func(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||||
|
Components: c,
|
||||||
|
ProxyPatch: patch,
|
||||||
|
}).GetFull()
|
||||||
|
|
||||||
|
require.NotNil(t, full.ProxyPatch, "ProxyPatch must propagate through the normal encode path")
|
||||||
|
assert.Len(t, full.ProxyPatch.ForwardingRules, 1)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil_components_graceful_degrade", func(t *testing.T) {
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||||
|
Components: nil,
|
||||||
|
ProxyPatch: patch,
|
||||||
|
}).GetFull()
|
||||||
|
|
||||||
|
require.NotNil(t, full.ProxyPatch, "ProxyPatch must propagate through the nil-Components branch too")
|
||||||
|
assert.Len(t, full.ProxyPatch.ForwardingRules, 1)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_NilComponentsGracefulDegrade(t *testing.T) {
|
||||||
|
// nil Components → minimal envelope, no crash. Matches the legacy
|
||||||
|
// behaviour for missing/unvalidated peers.
|
||||||
|
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||||
|
Components: nil,
|
||||||
|
DNSDomain: "netbird.cloud",
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NotNil(t, env)
|
||||||
|
full := env.GetFull()
|
||||||
|
require.NotNil(t, full)
|
||||||
|
require.NotNil(t, full.AccountSettings, "AccountSettings must always be non-nil")
|
||||||
|
assert.Equal(t, "netbird.cloud", full.DnsDomain)
|
||||||
|
assert.Empty(t, full.Peers)
|
||||||
|
assert.Empty(t, full.Policies)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_AccountSettingsAlwaysEmitted(t *testing.T) {
|
||||||
|
c := &types.NetworkMapComponents{
|
||||||
|
Network: &types.Network{Identifier: "x", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}},
|
||||||
|
// AccountSettings deliberately nil
|
||||||
|
}
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.NotNil(t, full.AccountSettings, "client dereferences AccountSettings unconditionally during Calculate(); a nil here would crash the receiver")
|
||||||
|
assert.False(t, full.AccountSettings.PeerLoginExpirationEnabled)
|
||||||
|
assert.Zero(t, full.AccountSettings.PeerLoginExpirationNs)
|
||||||
|
}
|
||||||
192
management/internals/shared/grpc/components_envelope_response.go
Normal file
192
management/internals/shared/grpc/components_envelope_response.go
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ToComponentSyncResponse builds a SyncResponse carrying the compact
|
||||||
|
// NetworkMapEnvelope for capability-aware peers. The legacy proto.NetworkMap
|
||||||
|
// field is intentionally left empty — capable peers ignore it and the
|
||||||
|
// envelope alone is the authoritative wire shape.
|
||||||
|
//
|
||||||
|
// PeerConfig is computed once server-side using the receiving peer's own
|
||||||
|
// account-level network metadata. EnableSSH inside PeerConfig is left at
|
||||||
|
// peer.SSHEnabled (the peer's local setting); account-policy-driven SSH is
|
||||||
|
// computed by the client from the envelope's GroupIDToUserIDs / AllowedUserIDs
|
||||||
|
// inside Calculate(), so the SshConfig.SshEnabled bit may flip true on the
|
||||||
|
// client even though the server-side PeerConfig reports false.
|
||||||
|
func ToComponentSyncResponse(
|
||||||
|
ctx context.Context,
|
||||||
|
config *nbconfig.Config,
|
||||||
|
httpConfig *nbconfig.HttpServerConfig,
|
||||||
|
deviceFlowConfig *nbconfig.DeviceAuthorizationFlow,
|
||||||
|
peer *nbpeer.Peer,
|
||||||
|
turnCredentials *Token,
|
||||||
|
relayCredentials *Token,
|
||||||
|
components *types.NetworkMapComponents,
|
||||||
|
proxyPatch *types.NetworkMap,
|
||||||
|
dnsName string,
|
||||||
|
checks []*posture.Checks,
|
||||||
|
settings *types.Settings,
|
||||||
|
extraSettings *types.ExtraSettings,
|
||||||
|
peerGroups []string,
|
||||||
|
dnsFwdPort int64,
|
||||||
|
) *proto.SyncResponse {
|
||||||
|
network := networkOrZero(components)
|
||||||
|
enableSSH := computeSSHEnabledForPeer(components, peer)
|
||||||
|
peerConfig := toPeerConfig(peer, network, dnsName, settings, httpConfig, deviceFlowConfig, enableSSH)
|
||||||
|
|
||||||
|
includeIPv6 := peer.SupportsIPv6() && peer.IPv6.IsValid()
|
||||||
|
useSourcePrefixes := peer.SupportsSourcePrefixes()
|
||||||
|
|
||||||
|
userIDClaim := auth.DefaultUserIDClaim
|
||||||
|
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
||||||
|
userIDClaim = httpConfig.AuthUserIDClaim
|
||||||
|
}
|
||||||
|
|
||||||
|
envelope := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||||
|
Components: components,
|
||||||
|
PeerConfig: peerConfig,
|
||||||
|
DNSDomain: dnsName,
|
||||||
|
DNSForwarderPort: dnsFwdPort,
|
||||||
|
UserIDClaim: userIDClaim,
|
||||||
|
ProxyPatch: toProxyPatch(proxyPatch, dnsName, includeIPv6, useSourcePrefixes),
|
||||||
|
})
|
||||||
|
|
||||||
|
resp := &proto.SyncResponse{
|
||||||
|
PeerConfig: peerConfig,
|
||||||
|
NetworkMapEnvelope: envelope,
|
||||||
|
Checks: toProtocolChecks(ctx, checks),
|
||||||
|
}
|
||||||
|
|
||||||
|
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
|
||||||
|
resp.NetbirdConfig = integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
// networkOrZero returns components.Network or a zero Network — toPeerConfig
|
||||||
|
// dereferences network.Net which would panic on nil.
|
||||||
|
func networkOrZero(c *types.NetworkMapComponents) *types.Network {
|
||||||
|
if c == nil || c.Network == nil {
|
||||||
|
return &types.Network{}
|
||||||
|
}
|
||||||
|
return c.Network
|
||||||
|
}
|
||||||
|
|
||||||
|
// toProxyPatch converts a proxy-injected *types.NetworkMap into the wire
|
||||||
|
// patch the components envelope ships alongside. Returns nil when there are
|
||||||
|
// no fragments to merge — proto3 omits a nil message field, so the receiver
|
||||||
|
// sees no patch and skips the merge step entirely.
|
||||||
|
//
|
||||||
|
// We reuse the legacy proto-conversion helpers (toProtocolRoutes,
|
||||||
|
// toProtocolFirewallRules, toProtocolRoutesFirewallRules,
|
||||||
|
// appendRemotePeerConfig, ForwardingRule.ToProto) because the proxy
|
||||||
|
// delivers fragments pre-expanded — there's no raw component shape to
|
||||||
|
// derive them from. Components purity isn't violated: proxy data isn't
|
||||||
|
// policy-graph-derived, it's externally injected post-Calculate, so the
|
||||||
|
// client merges it on top of its locally-computed NetworkMap.
|
||||||
|
func toProxyPatch(nm *types.NetworkMap, dnsName string, includeIPv6, useSourcePrefixes bool) *proto.ProxyPatch {
|
||||||
|
if nm == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(nm.Peers) == 0 && len(nm.OfflinePeers) == 0 && len(nm.FirewallRules) == 0 &&
|
||||||
|
len(nm.Routes) == 0 && len(nm.RoutesFirewallRules) == 0 && len(nm.ForwardingRules) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
patch := &proto.ProxyPatch{
|
||||||
|
Peers: networkmap.AppendRemotePeerConfig(nil, nm.Peers, dnsName, includeIPv6),
|
||||||
|
OfflinePeers: networkmap.AppendRemotePeerConfig(nil, nm.OfflinePeers, dnsName, includeIPv6),
|
||||||
|
FirewallRules: networkmap.ToProtocolFirewallRules(nm.FirewallRules, includeIPv6, useSourcePrefixes),
|
||||||
|
Routes: networkmap.ToProtocolRoutes(nm.Routes),
|
||||||
|
RouteFirewallRules: networkmap.ToProtocolRoutesFirewallRules(nm.RoutesFirewallRules),
|
||||||
|
}
|
||||||
|
if len(nm.ForwardingRules) > 0 {
|
||||||
|
patch.ForwardingRules = make([]*proto.ForwardingRule, 0, len(nm.ForwardingRules))
|
||||||
|
for _, r := range nm.ForwardingRules {
|
||||||
|
patch.ForwardingRules = append(patch.ForwardingRules, r.ToProto())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return patch
|
||||||
|
}
|
||||||
|
|
||||||
|
// computeSSHEnabledForPeer mirrors the SSH-server-activation bit that
|
||||||
|
// Calculate() folds into NetworkMap.EnableSSH. Components-format peers
|
||||||
|
// receive a freshly-computed PeerConfig.SshConfig.SshEnabled at sync time;
|
||||||
|
// without this helper the field would be incorrectly false for any peer
|
||||||
|
// that's the destination of an SSH-enabling policy without having
|
||||||
|
// peer.SSHEnabled set locally.
|
||||||
|
//
|
||||||
|
// Mirrors the two activation paths Calculate() uses:
|
||||||
|
// 1. Explicit: rule.Protocol == NetbirdSSH and peer is in the rule's
|
||||||
|
// destinations.
|
||||||
|
// 2. Legacy implicit: rule covers TCP/22 or TCP/22022 (or ALL), peer is in
|
||||||
|
// destinations, AND the peer has SSHEnabled set locally — this is the
|
||||||
|
// "allow-all/TCP-22 implies SSH activation for SSH-capable peers" path.
|
||||||
|
//
|
||||||
|
// The full SSH AuthorizedUsers map is still produced by the client when it
|
||||||
|
// runs Calculate() over the envelope.
|
||||||
|
func computeSSHEnabledForPeer(c *types.NetworkMapComponents, peer *nbpeer.Peer) bool {
|
||||||
|
if c == nil || peer == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Mirror Calculate's `getAllPeersFromGroups` invariant: target peer must
|
||||||
|
// exist in c.Peers, otherwise no rule applies to it.
|
||||||
|
if _, ok := c.Peers[peer.ID]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, policy := range c.Policies {
|
||||||
|
if policy == nil || !policy.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, rule := range policy.Rules {
|
||||||
|
if ruleEnablesSSHForPeer(c, rule, peer) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ruleEnablesSSHForPeer returns true when rule is active, targets peer, and
|
||||||
|
// either explicitly authorises SSH or covers the legacy TCP/22 path while the
|
||||||
|
// peer itself has SSH enabled locally.
|
||||||
|
func ruleEnablesSSHForPeer(c *types.NetworkMapComponents, rule *types.PolicyRule, peer *nbpeer.Peer) bool {
|
||||||
|
if rule == nil || !rule.Enabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !peerInDestinations(c, rule, peer.ID) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if rule.Protocol == types.PolicyRuleProtocolNetbirdSSH {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return peer.SSHEnabled && types.PolicyRuleImpliesLegacySSH(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
// peerInDestinations reports whether peerID is in any of rule.Destinations'
|
||||||
|
// groups (or matches DestinationResource if it's a peer-typed resource —
|
||||||
|
// for non-peer types Calculate falls through to group lookup, so we mirror
|
||||||
|
// that exactly to avoid silent divergence).
|
||||||
|
func peerInDestinations(c *types.NetworkMapComponents, rule *types.PolicyRule, peerID string) bool {
|
||||||
|
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||||
|
return rule.DestinationResource.ID == peerID
|
||||||
|
}
|
||||||
|
for _, groupID := range rule.Destinations {
|
||||||
|
if c.IsPeerInGroup(peerID, groupID) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -0,0 +1,184 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestComputeSSHEnabledForPeer covers both Calculate-mirroring branches:
|
||||||
|
// explicit NetbirdSSH protocol, and the legacy implicit case where a
|
||||||
|
// TCP/22 (or 22022 / ALL / port-range-covering-22) rule activates SSH when
|
||||||
|
// the destination peer has SSHEnabled=true locally.
|
||||||
|
func TestComputeSSHEnabledForPeer(t *testing.T) {
|
||||||
|
const targetPeerID = "target"
|
||||||
|
const targetGroupID = "g_dst"
|
||||||
|
|
||||||
|
mkComponents := func(rule *types.PolicyRule, sshEnabled bool) (*types.NetworkMapComponents, *nbpeer.Peer) {
|
||||||
|
peer := &nbpeer.Peer{ID: targetPeerID, SSHEnabled: sshEnabled}
|
||||||
|
group := &types.Group{ID: targetGroupID, Name: "dst", Peers: []string{targetPeerID}}
|
||||||
|
return &types.NetworkMapComponents{
|
||||||
|
Peers: map[string]*nbpeer.Peer{targetPeerID: peer},
|
||||||
|
Groups: map[string]*types.Group{targetGroupID: group},
|
||||||
|
Policies: []*types.Policy{{
|
||||||
|
ID: "p",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*types.PolicyRule{rule},
|
||||||
|
}},
|
||||||
|
}, peer
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
peerSSH bool
|
||||||
|
rule types.PolicyRule
|
||||||
|
wantEnabled bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "explicit-netbird-ssh-activates-regardless-of-peer-ssh",
|
||||||
|
peerSSH: false,
|
||||||
|
rule: types.PolicyRule{
|
||||||
|
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||||
|
Destinations: []string{targetGroupID},
|
||||||
|
},
|
||||||
|
wantEnabled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "implicit-tcp-22-with-peer-ssh",
|
||||||
|
peerSSH: true,
|
||||||
|
rule: types.PolicyRule{
|
||||||
|
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
|
||||||
|
Destinations: []string{targetGroupID},
|
||||||
|
},
|
||||||
|
wantEnabled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "implicit-tcp-22-without-peer-ssh-disabled",
|
||||||
|
peerSSH: false,
|
||||||
|
rule: types.PolicyRule{
|
||||||
|
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
|
||||||
|
Destinations: []string{targetGroupID},
|
||||||
|
},
|
||||||
|
wantEnabled: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "implicit-tcp-22022-with-peer-ssh",
|
||||||
|
peerSSH: true,
|
||||||
|
rule: types.PolicyRule{
|
||||||
|
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22022"},
|
||||||
|
Destinations: []string{targetGroupID},
|
||||||
|
},
|
||||||
|
wantEnabled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "implicit-all-protocol-with-peer-ssh",
|
||||||
|
peerSSH: true,
|
||||||
|
rule: types.PolicyRule{
|
||||||
|
Enabled: true, Protocol: types.PolicyRuleProtocolALL,
|
||||||
|
Destinations: []string{targetGroupID},
|
||||||
|
},
|
||||||
|
wantEnabled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "implicit-port-range-covers-22",
|
||||||
|
peerSSH: true,
|
||||||
|
rule: types.PolicyRule{
|
||||||
|
Enabled: true,
|
||||||
|
Protocol: types.PolicyRuleProtocolTCP,
|
||||||
|
PortRanges: []types.RulePortRange{{Start: 20, End: 30}},
|
||||||
|
Destinations: []string{targetGroupID},
|
||||||
|
},
|
||||||
|
wantEnabled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tcp-80-no-ssh",
|
||||||
|
peerSSH: true,
|
||||||
|
rule: types.PolicyRule{
|
||||||
|
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"80"},
|
||||||
|
Destinations: []string{targetGroupID},
|
||||||
|
},
|
||||||
|
wantEnabled: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disabled-rule-skipped",
|
||||||
|
peerSSH: true,
|
||||||
|
rule: types.PolicyRule{
|
||||||
|
Enabled: false, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||||
|
Destinations: []string{targetGroupID},
|
||||||
|
},
|
||||||
|
wantEnabled: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "peer-not-in-destinations",
|
||||||
|
peerSSH: true,
|
||||||
|
rule: types.PolicyRule{
|
||||||
|
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||||
|
Destinations: []string{"g_other"}, // target not in this group
|
||||||
|
},
|
||||||
|
wantEnabled: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "peer-typed-destination-resource-matches",
|
||||||
|
peerSSH: false,
|
||||||
|
rule: types.PolicyRule{
|
||||||
|
Enabled: true,
|
||||||
|
Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||||
|
DestinationResource: types.Resource{ID: targetPeerID, Type: types.ResourceTypePeer},
|
||||||
|
},
|
||||||
|
wantEnabled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-peer-destination-resource-falls-through-to-groups",
|
||||||
|
peerSSH: false,
|
||||||
|
rule: types.PolicyRule{
|
||||||
|
Enabled: true,
|
||||||
|
Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||||
|
DestinationResource: types.Resource{ID: targetPeerID, Type: "host"}, // wrong type
|
||||||
|
Destinations: []string{targetGroupID}, // saved by group fallback
|
||||||
|
},
|
||||||
|
wantEnabled: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
c, peer := mkComponents(&tc.rule, tc.peerSSH)
|
||||||
|
got := computeSSHEnabledForPeer(c, peer)
|
||||||
|
assert.Equal(t, tc.wantEnabled, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestComputeSSHEnabledForPeer_TargetMissingFromComponents covers the
|
||||||
|
// belt-and-suspenders presence guard mirroring Calculate's
|
||||||
|
// getAllPeersFromGroups invariant.
|
||||||
|
func TestComputeSSHEnabledForPeer_TargetMissingFromComponents(t *testing.T) {
|
||||||
|
peer := &nbpeer.Peer{ID: "missing", SSHEnabled: true}
|
||||||
|
c := &types.NetworkMapComponents{
|
||||||
|
Peers: map[string]*nbpeer.Peer{}, // target peer NOT present
|
||||||
|
Groups: map[string]*types.Group{
|
||||||
|
"g": {ID: "g", Peers: []string{"missing"}},
|
||||||
|
},
|
||||||
|
Policies: []*types.Policy{{
|
||||||
|
ID: "p", Enabled: true,
|
||||||
|
Rules: []*types.PolicyRule{{
|
||||||
|
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||||
|
Destinations: []string{"g"},
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
assert.False(t, computeSSHEnabledForPeer(c, peer),
|
||||||
|
"missing target peer must short-circuit to false, not consult policies")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestComputeSSHEnabledForPeer_NilInputs guards the cheap nil-checks at
|
||||||
|
// function entry — Calculate doesn't accept nil either, but the helper is
|
||||||
|
// exported indirectly via ToComponentSyncResponse and may receive nil
|
||||||
|
// components on graceful-degrade paths.
|
||||||
|
func TestComputeSSHEnabledForPeer_NilInputs(t *testing.T) {
|
||||||
|
assert.False(t, computeSSHEnabledForPeer(nil, &nbpeer.Peer{ID: "x"}))
|
||||||
|
assert.False(t, computeSSHEnabledForPeer(&types.NetworkMapComponents{}, nil))
|
||||||
|
}
|
||||||
@@ -6,24 +6,22 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
goproto "google.golang.org/protobuf/proto"
|
|
||||||
|
|
||||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
nbroute "github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/shared/netiputil"
|
"github.com/netbirdio/netbird/shared/netiputil"
|
||||||
"github.com/netbirdio/netbird/shared/sshauth"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
||||||
@@ -138,8 +136,8 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
|||||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||||
NetworkMap: &proto.NetworkMap{
|
NetworkMap: &proto.NetworkMap{
|
||||||
Serial: networkMap.Network.CurrentSerial(),
|
Serial: networkMap.Network.CurrentSerial(),
|
||||||
Routes: toProtocolRoutes(networkMap.Routes),
|
Routes: networkmap.ToProtocolRoutes(networkMap.Routes),
|
||||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
DNSConfig: networkmap.ToProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||||
},
|
},
|
||||||
Checks: toProtocolChecks(ctx, checks),
|
Checks: toProtocolChecks(ctx, checks),
|
||||||
@@ -152,19 +150,19 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
|||||||
response.NetworkMap.PeerConfig = response.PeerConfig
|
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||||
|
|
||||||
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
||||||
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
|
remotePeers = networkmap.AppendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
|
||||||
response.RemotePeers = remotePeers
|
response.RemotePeers = remotePeers
|
||||||
response.NetworkMap.RemotePeers = remotePeers
|
response.NetworkMap.RemotePeers = remotePeers
|
||||||
response.RemotePeersIsEmpty = len(remotePeers) == 0
|
response.RemotePeersIsEmpty = len(remotePeers) == 0
|
||||||
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
|
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
|
||||||
|
|
||||||
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6)
|
response.NetworkMap.OfflinePeers = networkmap.AppendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6)
|
||||||
|
|
||||||
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes)
|
firewallRules := networkmap.ToProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes)
|
||||||
response.NetworkMap.FirewallRules = firewallRules
|
response.NetworkMap.FirewallRules = firewallRules
|
||||||
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
|
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
|
||||||
|
|
||||||
routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
|
routesFirewallRules := networkmap.ToProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
|
||||||
response.NetworkMap.RoutesFirewallRules = routesFirewallRules
|
response.NetworkMap.RoutesFirewallRules = routesFirewallRules
|
||||||
response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
|
response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
|
||||||
|
|
||||||
@@ -177,7 +175,7 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
|||||||
}
|
}
|
||||||
|
|
||||||
if networkMap.AuthorizedUsers != nil {
|
if networkMap.AuthorizedUsers != nil {
|
||||||
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
|
hashedUsers, machineUsers := networkmap.BuildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
|
||||||
userIDClaim := auth.DefaultUserIDClaim
|
userIDClaim := auth.DefaultUserIDClaim
|
||||||
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
||||||
userIDClaim = httpConfig.AuthUserIDClaim
|
userIDClaim = httpConfig.AuthUserIDClaim
|
||||||
@@ -185,79 +183,36 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
|||||||
response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
|
response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// settings == nil → field stays nil → "no info in this snapshot", client
|
||||||
|
// preserves the deadline it already had. settings non-nil → emit either a
|
||||||
|
// valid deadline or the explicit-zero "disabled" sentinel via
|
||||||
|
// encodeSessionExpiresAt.
|
||||||
|
if settings != nil {
|
||||||
|
response.SessionExpiresAt = encodeSessionExpiresAt(
|
||||||
|
peer.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
return response
|
return response
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
|
// encodeSessionExpiresAt encodes a server-side deadline into the 3-state wire
|
||||||
userIDToIndex := make(map[string]uint32)
|
// representation used on LoginResponse, SyncResponse and
|
||||||
var hashedUsers [][]byte
|
// ExtendAuthSessionResponse. See the proto comments on those messages.
|
||||||
machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers))
|
//
|
||||||
|
// - deadline.IsZero() → returns &Timestamp{} (seconds=0, nanos=0): the
|
||||||
for machineUser, users := range authorizedUsers {
|
// "expiry disabled or peer is not SSO-tracked" sentinel; the client clears
|
||||||
indexes := make([]uint32, 0, len(users))
|
// its anchor.
|
||||||
for userID := range users {
|
// - deadline non-zero → returns timestamppb.New(deadline): the new absolute
|
||||||
idx, exists := userIDToIndex[userID]
|
// UTC deadline.
|
||||||
if !exists {
|
//
|
||||||
hash, err := sshauth.HashUserID(userID)
|
// Returning nil ("no info, preserve client's anchor") is the caller's job —
|
||||||
if err != nil {
|
// only meaningful on Sync builds where settings were not resolved.
|
||||||
log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err)
|
func encodeSessionExpiresAt(deadline time.Time) *timestamppb.Timestamp {
|
||||||
continue
|
if deadline.IsZero() {
|
||||||
}
|
return ×tamppb.Timestamp{}
|
||||||
idx = uint32(len(hashedUsers))
|
|
||||||
userIDToIndex[userID] = idx
|
|
||||||
hashedUsers = append(hashedUsers, hash[:])
|
|
||||||
}
|
|
||||||
indexes = append(indexes, idx)
|
|
||||||
}
|
|
||||||
machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
|
|
||||||
}
|
}
|
||||||
|
return timestamppb.New(deadline)
|
||||||
return hashedUsers, machineUsers
|
|
||||||
}
|
|
||||||
|
|
||||||
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig {
|
|
||||||
for _, rPeer := range peers {
|
|
||||||
allowedIPs := []string{rPeer.IP.String() + "/32"}
|
|
||||||
if includeIPv6 && rPeer.IPv6.IsValid() {
|
|
||||||
allowedIPs = append(allowedIPs, rPeer.IPv6.String()+"/128")
|
|
||||||
}
|
|
||||||
dst = append(dst, &proto.RemotePeerConfig{
|
|
||||||
WgPubKey: rPeer.Key,
|
|
||||||
AllowedIps: allowedIPs,
|
|
||||||
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
|
||||||
Fqdn: rPeer.FQDN(dnsName),
|
|
||||||
AgentVersion: rPeer.Meta.WtVersion,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
|
||||||
func toProtocolDNSConfig(update nbdns.Config, cache *cache.DNSConfigCache, forwardPort int64) *proto.DNSConfig {
|
|
||||||
protoUpdate := &proto.DNSConfig{
|
|
||||||
ServiceEnable: update.ServiceEnable,
|
|
||||||
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
|
|
||||||
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
|
|
||||||
ForwarderPort: forwardPort,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, zone := range update.CustomZones {
|
|
||||||
protoZone := convertToProtoCustomZone(zone)
|
|
||||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, nsGroup := range update.NameServerGroups {
|
|
||||||
cacheKey := nsGroup.ID
|
|
||||||
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
|
|
||||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
|
|
||||||
} else {
|
|
||||||
protoGroup := convertToProtoNameServerGroup(nsGroup)
|
|
||||||
cache.SetNameServerGroup(cacheKey, protoGroup)
|
|
||||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return protoUpdate
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
|
func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
|
||||||
@@ -277,204 +232,6 @@ func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toProtocolRoutes(routes []*nbroute.Route) []*proto.Route {
|
|
||||||
protoRoutes := make([]*proto.Route, 0, len(routes))
|
|
||||||
for _, r := range routes {
|
|
||||||
protoRoutes = append(protoRoutes, toProtocolRoute(r))
|
|
||||||
}
|
|
||||||
return protoRoutes
|
|
||||||
}
|
|
||||||
|
|
||||||
func toProtocolRoute(route *nbroute.Route) *proto.Route {
|
|
||||||
return &proto.Route{
|
|
||||||
ID: string(route.ID),
|
|
||||||
NetID: string(route.NetID),
|
|
||||||
Network: route.Network.String(),
|
|
||||||
Domains: route.Domains.ToPunycodeList(),
|
|
||||||
NetworkType: int64(route.NetworkType),
|
|
||||||
Peer: route.Peer,
|
|
||||||
Metric: int64(route.Metric),
|
|
||||||
Masquerade: route.Masquerade,
|
|
||||||
KeepRoute: route.KeepRoute,
|
|
||||||
SkipAutoApply: route.SkipAutoApply,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
|
|
||||||
// When useSourcePrefixes is true, the compact SourcePrefixes field is populated
|
|
||||||
// alongside the deprecated PeerIP for forward compatibility.
|
|
||||||
// Wildcard rules ("0.0.0.0") are expanded into separate v4 and v6 SourcePrefixes
|
|
||||||
// when includeIPv6 is true.
|
|
||||||
func toProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSourcePrefixes bool) []*proto.FirewallRule {
|
|
||||||
result := make([]*proto.FirewallRule, 0, len(rules))
|
|
||||||
for i := range rules {
|
|
||||||
rule := rules[i]
|
|
||||||
|
|
||||||
fwRule := &proto.FirewallRule{
|
|
||||||
PolicyID: []byte(rule.PolicyID),
|
|
||||||
PeerIP: rule.PeerIP, //nolint:staticcheck // populated for backward compatibility
|
|
||||||
Direction: getProtoDirection(rule.Direction),
|
|
||||||
Action: getProtoAction(rule.Action),
|
|
||||||
Protocol: getProtoProtocol(rule.Protocol),
|
|
||||||
Port: rule.Port,
|
|
||||||
}
|
|
||||||
|
|
||||||
if useSourcePrefixes && rule.PeerIP != "" {
|
|
||||||
result = append(result, populateSourcePrefixes(fwRule, rule, includeIPv6)...)
|
|
||||||
}
|
|
||||||
|
|
||||||
if shouldUsePortRange(fwRule) {
|
|
||||||
fwRule.PortInfo = rule.PortRange.ToProto()
|
|
||||||
}
|
|
||||||
|
|
||||||
result = append(result, fwRule)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
|
|
||||||
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is unspecified).
|
|
||||||
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {
|
|
||||||
addr, err := netip.ParseAddr(rule.PeerIP)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if !addr.IsUnspecified() {
|
|
||||||
fwRule.SourcePrefixes = [][]byte{netiputil.EncodeAddr(addr.Unmap())}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IPv4Unspecified/0 is always valid, error is impossible.
|
|
||||||
v4Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv4Unspecified(), 0))
|
|
||||||
fwRule.SourcePrefixes = [][]byte{v4Wildcard}
|
|
||||||
|
|
||||||
if !includeIPv6 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
v6Rule := goproto.Clone(fwRule).(*proto.FirewallRule)
|
|
||||||
v6Rule.PeerIP = "::" //nolint:staticcheck // populated for backward compatibility
|
|
||||||
// IPv6Unspecified/0 is always valid, error is impossible.
|
|
||||||
v6Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv6Unspecified(), 0))
|
|
||||||
v6Rule.SourcePrefixes = [][]byte{v6Wildcard}
|
|
||||||
if shouldUsePortRange(v6Rule) {
|
|
||||||
v6Rule.PortInfo = rule.PortRange.ToProto()
|
|
||||||
}
|
|
||||||
return []*proto.FirewallRule{v6Rule}
|
|
||||||
}
|
|
||||||
|
|
||||||
// getProtoDirection converts the direction to proto.RuleDirection.
|
|
||||||
func getProtoDirection(direction int) proto.RuleDirection {
|
|
||||||
if direction == types.FirewallRuleDirectionOUT {
|
|
||||||
return proto.RuleDirection_OUT
|
|
||||||
}
|
|
||||||
return proto.RuleDirection_IN
|
|
||||||
}
|
|
||||||
|
|
||||||
func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
|
|
||||||
result := make([]*proto.RouteFirewallRule, len(rules))
|
|
||||||
for i := range rules {
|
|
||||||
rule := rules[i]
|
|
||||||
result[i] = &proto.RouteFirewallRule{
|
|
||||||
SourceRanges: rule.SourceRanges,
|
|
||||||
Action: getProtoAction(rule.Action),
|
|
||||||
Destination: rule.Destination,
|
|
||||||
Protocol: getProtoProtocol(rule.Protocol),
|
|
||||||
PortInfo: getProtoPortInfo(rule),
|
|
||||||
IsDynamic: rule.IsDynamic,
|
|
||||||
Domains: rule.Domains.ToPunycodeList(),
|
|
||||||
PolicyID: []byte(rule.PolicyID),
|
|
||||||
RouteID: string(rule.RouteID),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// getProtoAction converts the action to proto.RuleAction.
|
|
||||||
func getProtoAction(action string) proto.RuleAction {
|
|
||||||
if action == string(types.PolicyTrafficActionDrop) {
|
|
||||||
return proto.RuleAction_DROP
|
|
||||||
}
|
|
||||||
return proto.RuleAction_ACCEPT
|
|
||||||
}
|
|
||||||
|
|
||||||
// getProtoProtocol converts the protocol to proto.RuleProtocol.
|
|
||||||
func getProtoProtocol(protocol string) proto.RuleProtocol {
|
|
||||||
switch types.PolicyRuleProtocolType(protocol) {
|
|
||||||
case types.PolicyRuleProtocolALL:
|
|
||||||
return proto.RuleProtocol_ALL
|
|
||||||
case types.PolicyRuleProtocolTCP:
|
|
||||||
return proto.RuleProtocol_TCP
|
|
||||||
case types.PolicyRuleProtocolUDP:
|
|
||||||
return proto.RuleProtocol_UDP
|
|
||||||
case types.PolicyRuleProtocolICMP:
|
|
||||||
return proto.RuleProtocol_ICMP
|
|
||||||
default:
|
|
||||||
return proto.RuleProtocol_UNKNOWN
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// getProtoPortInfo converts the port info to proto.PortInfo.
|
|
||||||
func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
|
|
||||||
var portInfo proto.PortInfo
|
|
||||||
if rule.Port != 0 {
|
|
||||||
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
|
|
||||||
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
|
|
||||||
portInfo.PortSelection = &proto.PortInfo_Range_{
|
|
||||||
Range: &proto.PortInfo_Range{
|
|
||||||
Start: uint32(portRange.Start),
|
|
||||||
End: uint32(portRange.End),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &portInfo
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldUsePortRange(rule *proto.FirewallRule) bool {
|
|
||||||
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper function to convert nbdns.CustomZone to proto.CustomZone
|
|
||||||
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
|
|
||||||
protoZone := &proto.CustomZone{
|
|
||||||
Domain: zone.Domain,
|
|
||||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
|
||||||
SearchDomainDisabled: zone.SearchDomainDisabled,
|
|
||||||
NonAuthoritative: zone.NonAuthoritative,
|
|
||||||
}
|
|
||||||
for _, record := range zone.Records {
|
|
||||||
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
|
||||||
Name: record.Name,
|
|
||||||
Type: int64(record.Type),
|
|
||||||
Class: record.Class,
|
|
||||||
TTL: int64(record.TTL),
|
|
||||||
RData: record.RData,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return protoZone
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
|
|
||||||
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
|
|
||||||
protoGroup := &proto.NameServerGroup{
|
|
||||||
Primary: nsGroup.Primary,
|
|
||||||
Domains: nsGroup.Domains,
|
|
||||||
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
|
|
||||||
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
|
|
||||||
}
|
|
||||||
for _, ns := range nsGroup.NameServers {
|
|
||||||
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
|
|
||||||
IP: ns.IP.String(),
|
|
||||||
Port: int64(ns.Port),
|
|
||||||
NSType: int64(ns.NSType),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return protoGroup
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildJWTConfig constructs JWT configuration for SSH servers from management server config
|
// buildJWTConfig constructs JWT configuration for SSH servers from management server config
|
||||||
func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.JWTConfig {
|
func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.JWTConfig {
|
||||||
if config == nil || config.AuthAudience == "" {
|
if config == nil || config.AuthAudience == "" {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
@@ -12,6 +13,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||||
@@ -61,13 +63,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// First run with config1
|
// First run with config1
|
||||||
result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
result1 := networkmap.ToProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||||
|
|
||||||
// Second run with config2
|
// Second run with config2
|
||||||
result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
|
result2 := networkmap.ToProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
|
||||||
|
|
||||||
// Third run with config1 again
|
// Third run with config1 again
|
||||||
result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
result3 := networkmap.ToProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||||
|
|
||||||
// Verify that result1 and result3 are identical
|
// Verify that result1 and result3 are identical
|
||||||
if !reflect.DeepEqual(result1, result3) {
|
if !reflect.DeepEqual(result1, result3) {
|
||||||
@@ -99,7 +101,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
networkmap.ToProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -107,7 +109,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
cache := &cache.DNSConfigCache{}
|
cache := &cache.DNSConfigCache{}
|
||||||
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
networkmap.ToProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -200,3 +202,29 @@ func TestBuildJWTConfig_Audiences(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestEncodeSessionExpiresAt pins the wire encoding the client's
|
||||||
|
// applySessionDeadline depends on:
|
||||||
|
//
|
||||||
|
// - zero deadline → &Timestamp{} (seconds=0, nanos=0): the explicit
|
||||||
|
// "expiry disabled or peer is not SSO-tracked" sentinel.
|
||||||
|
// - non-zero → timestamppb.New(deadline): the absolute UTC deadline.
|
||||||
|
//
|
||||||
|
// The third state (nil pointer = "no info in this snapshot") is the caller's
|
||||||
|
// responsibility on the Sync path when settings could not be resolved; the
|
||||||
|
// helper itself never returns nil.
|
||||||
|
func TestEncodeSessionExpiresAt(t *testing.T) {
|
||||||
|
t.Run("zero deadline encodes as explicit-zero sentinel", func(t *testing.T) {
|
||||||
|
got := encodeSessionExpiresAt(time.Time{})
|
||||||
|
assert.NotNil(t, got, "must not return nil; nil means 'no info', not 'disabled'")
|
||||||
|
assert.Equal(t, int64(0), got.GetSeconds())
|
||||||
|
assert.Equal(t, int32(0), got.GetNanos())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-zero deadline round-trips", func(t *testing.T) {
|
||||||
|
deadline := time.Date(2030, 1, 2, 3, 4, 5, 0, time.UTC)
|
||||||
|
got := encodeSessionExpiresAt(deadline)
|
||||||
|
assert.NotNil(t, got)
|
||||||
|
assert.True(t, got.AsTime().Equal(deadline))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -136,9 +137,12 @@ type proxyConnection struct {
|
|||||||
tokenID string
|
tokenID string
|
||||||
capabilities *proto.ProxyCapabilities
|
capabilities *proto.ProxyCapabilities
|
||||||
stream proto.ProxyService_GetMappingUpdateServer
|
stream proto.ProxyService_GetMappingUpdateServer
|
||||||
sendChan chan *proto.GetMappingUpdateResponse
|
// syncStream is set when the proxy connected via SyncMappings.
|
||||||
ctx context.Context
|
// When non-nil, the sender goroutine uses this instead of stream.
|
||||||
cancel context.CancelFunc
|
syncStream proto.ProxyService_SyncMappingsServer
|
||||||
|
sendChan chan *proto.GetMappingUpdateResponse
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func enforceAccountScope(ctx context.Context, requestAccountID string) error {
|
func enforceAccountScope(ctx context.Context, requestAccountID string) error {
|
||||||
@@ -206,145 +210,323 @@ func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller
|
|||||||
s.proxyController = proxyController
|
s.proxyController = proxyController
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// proxyConnectParams holds the validated parameters extracted from either
|
||||||
|
// a GetMappingUpdateRequest or a SyncMappingsInit message.
|
||||||
|
type proxyConnectParams struct {
|
||||||
|
proxyID string
|
||||||
|
address string
|
||||||
|
capabilities *proto.ProxyCapabilities
|
||||||
|
}
|
||||||
|
|
||||||
// GetMappingUpdate handles the control stream with proxy clients
|
// GetMappingUpdate handles the control stream with proxy clients
|
||||||
func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error {
|
func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error {
|
||||||
ctx := stream.Context()
|
params, err := s.validateProxyConnect(req.GetProxyId(), req.GetAddress(), stream.Context())
|
||||||
|
|
||||||
peerInfo := PeerIPFromContext(ctx)
|
|
||||||
log.Infof("New proxy connection from %s", peerInfo)
|
|
||||||
|
|
||||||
proxyID := req.GetProxyId()
|
|
||||||
if proxyID == "" {
|
|
||||||
return status.Errorf(codes.InvalidArgument, "proxy_id is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyAddress := req.GetAddress()
|
|
||||||
if !isProxyAddressValid(proxyAddress) {
|
|
||||||
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
|
||||||
}
|
|
||||||
|
|
||||||
var accountID *string
|
|
||||||
token := GetProxyTokenFromContext(ctx)
|
|
||||||
if token != nil && token.AccountID != nil {
|
|
||||||
accountID = token.AccountID
|
|
||||||
|
|
||||||
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID)
|
|
||||||
if err != nil {
|
|
||||||
return status.Errorf(codes.Internal, "check cluster address: %v", err)
|
|
||||||
}
|
|
||||||
if !available {
|
|
||||||
return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var tokenID string
|
|
||||||
if token != nil {
|
|
||||||
tokenID = token.ID
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionID := uuid.NewString()
|
|
||||||
|
|
||||||
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
|
|
||||||
oldConn := old.(*proxyConnection)
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"proxy_id": proxyID,
|
|
||||||
"old_session_id": oldConn.sessionID,
|
|
||||||
"new_session_id": sessionID,
|
|
||||||
}).Info("Superseding existing proxy connection")
|
|
||||||
oldConn.cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
connCtx, cancel := context.WithCancel(ctx)
|
|
||||||
conn := &proxyConnection{
|
|
||||||
proxyID: proxyID,
|
|
||||||
sessionID: sessionID,
|
|
||||||
address: proxyAddress,
|
|
||||||
accountID: accountID,
|
|
||||||
tokenID: tokenID,
|
|
||||||
capabilities: req.GetCapabilities(),
|
|
||||||
stream: stream,
|
|
||||||
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
|
|
||||||
ctx: connCtx,
|
|
||||||
cancel: cancel,
|
|
||||||
}
|
|
||||||
|
|
||||||
var caps *proxy.Capabilities
|
|
||||||
if c := req.GetCapabilities(); c != nil {
|
|
||||||
caps = &proxy.Capabilities{
|
|
||||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
|
||||||
RequireSubdomain: c.RequireSubdomain,
|
|
||||||
SupportsCrowdsec: c.SupportsCrowdsec,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, accountID, caps)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cancel()
|
return err
|
||||||
if accountID != nil {
|
}
|
||||||
return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
|
params.capabilities = req.GetCapabilities()
|
||||||
}
|
|
||||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
conn, proxyRecord, err := s.registerProxyConnection(stream.Context(), params, &proxyConnection{
|
||||||
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
stream: stream,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.connectedProxies.Store(proxyID, conn)
|
if err := s.sendSnapshot(stream.Context(), conn); err != nil {
|
||||||
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
s.cleanupFailedSnapshot(stream.Context(), conn)
|
||||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
return fmt.Errorf("send snapshot to proxy %s: %w", params.proxyID, err)
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.sendSnapshot(ctx, conn); err != nil {
|
|
||||||
if s.connectedProxies.CompareAndDelete(proxyID, conn) {
|
|
||||||
if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil {
|
|
||||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cancel()
|
|
||||||
if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil {
|
|
||||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
errChan := make(chan error, 2)
|
errChan := make(chan error, 2)
|
||||||
go s.sender(conn, errChan)
|
go s.sender(conn, errChan)
|
||||||
|
|
||||||
log.WithFields(log.Fields{
|
return s.serveProxyConnection(conn, proxyRecord, errChan, false)
|
||||||
"proxy_id": proxyID,
|
}
|
||||||
"session_id": sessionID,
|
|
||||||
"address": proxyAddress,
|
// SyncMappings implements the bidirectional SyncMappings RPC.
|
||||||
"cluster_addr": proxyAddress,
|
// It mirrors GetMappingUpdate but provides application-level back-pressure:
|
||||||
"account_id": accountID,
|
// management waits for an ack from the proxy before sending the next batch.
|
||||||
"total_proxies": len(s.GetConnectedProxies()),
|
func (s *ProxyServiceServer) SyncMappings(stream proto.ProxyService_SyncMappingsServer) error {
|
||||||
}).Info("Proxy registered in cluster")
|
init, err := recvSyncInit(stream)
|
||||||
defer func() {
|
if err != nil {
|
||||||
if !s.connectedProxies.CompareAndDelete(proxyID, conn) {
|
return err
|
||||||
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", proxyID, sessionID)
|
}
|
||||||
cancel()
|
|
||||||
|
params, err := s.validateProxyConnect(init.GetProxyId(), init.GetAddress(), stream.Context())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
params.capabilities = init.GetCapabilities()
|
||||||
|
|
||||||
|
conn, proxyRecord, err := s.registerProxyConnection(stream.Context(), params, &proxyConnection{
|
||||||
|
syncStream: stream,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.sendSnapshotSync(stream.Context(), conn, stream); err != nil {
|
||||||
|
s.cleanupFailedSnapshot(stream.Context(), conn)
|
||||||
|
return fmt.Errorf("send snapshot to proxy %s: %w", params.proxyID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
errChan := make(chan error, 2)
|
||||||
|
go s.sender(conn, errChan)
|
||||||
|
go s.drainRecv(stream, errChan)
|
||||||
|
|
||||||
|
return s.serveProxyConnection(conn, proxyRecord, errChan, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// recvSyncInit receives and validates the first message on a SyncMappings stream.
|
||||||
|
func recvSyncInit(stream proto.ProxyService_SyncMappingsServer) (*proto.SyncMappingsInit, error) {
|
||||||
|
firstMsg, err := stream.Recv()
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "receive init: %v", err)
|
||||||
|
}
|
||||||
|
init := firstMsg.GetInit()
|
||||||
|
if init == nil {
|
||||||
|
return nil, status.Errorf(codes.InvalidArgument, "first message must be init")
|
||||||
|
}
|
||||||
|
return init, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateProxyConnect validates the proxy ID and address, and checks cluster
|
||||||
|
// address availability for account-scoped tokens.
|
||||||
|
func (s *ProxyServiceServer) validateProxyConnect(proxyID, address string, ctx context.Context) (proxyConnectParams, error) {
|
||||||
|
if proxyID == "" {
|
||||||
|
return proxyConnectParams{}, status.Errorf(codes.InvalidArgument, "proxy_id is required")
|
||||||
|
}
|
||||||
|
if !isProxyAddressValid(address) {
|
||||||
|
return proxyConnectParams{}, status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
token := GetProxyTokenFromContext(ctx)
|
||||||
|
if token != nil && token.AccountID != nil {
|
||||||
|
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, address, *token.AccountID)
|
||||||
|
if err != nil {
|
||||||
|
return proxyConnectParams{}, status.Errorf(codes.Internal, "check cluster address: %v", err)
|
||||||
|
}
|
||||||
|
if !available {
|
||||||
|
return proxyConnectParams{}, status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", address)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return proxyConnectParams{proxyID: proxyID, address: address}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerProxyConnection creates a proxyConnection, registers it with the
|
||||||
|
// proxy manager and cluster, and stores it in connectedProxies. The caller
|
||||||
|
// provides a partially initialised connSeed with stream-specific fields set;
|
||||||
|
// the remaining fields are filled in here.
|
||||||
|
func (s *ProxyServiceServer) registerProxyConnection(ctx context.Context, params proxyConnectParams, connSeed *proxyConnection) (*proxyConnection, *proxy.Proxy, error) {
|
||||||
|
peerInfo := PeerIPFromContext(ctx)
|
||||||
|
|
||||||
|
var accountID *string
|
||||||
|
var tokenID string
|
||||||
|
if token := GetProxyTokenFromContext(ctx); token != nil {
|
||||||
|
if token.AccountID != nil {
|
||||||
|
accountID = token.AccountID
|
||||||
|
}
|
||||||
|
tokenID = token.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionID := uuid.NewString()
|
||||||
|
s.supersedePriorConnection(params.proxyID, sessionID)
|
||||||
|
|
||||||
|
connCtx, cancel := context.WithCancel(ctx)
|
||||||
|
connSeed.proxyID = params.proxyID
|
||||||
|
connSeed.sessionID = sessionID
|
||||||
|
connSeed.address = params.address
|
||||||
|
connSeed.accountID = accountID
|
||||||
|
connSeed.tokenID = tokenID
|
||||||
|
connSeed.capabilities = params.capabilities
|
||||||
|
connSeed.sendChan = make(chan *proto.GetMappingUpdateResponse, 100)
|
||||||
|
connSeed.ctx = connCtx
|
||||||
|
connSeed.cancel = cancel
|
||||||
|
|
||||||
|
var caps *proxy.Capabilities
|
||||||
|
if c := params.capabilities; c != nil {
|
||||||
|
caps = &proxy.Capabilities{
|
||||||
|
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||||
|
RequireSubdomain: c.RequireSubdomain,
|
||||||
|
SupportsCrowdsec: c.SupportsCrowdsec,
|
||||||
|
Private: c.Private,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyRecord, err := s.proxyManager.Connect(ctx, params.proxyID, sessionID, params.address, peerInfo, accountID, caps)
|
||||||
|
if err != nil {
|
||||||
|
cancel()
|
||||||
|
if accountID != nil {
|
||||||
|
return nil, nil, status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", params.proxyID, err)
|
||||||
|
return nil, nil, status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.connectedProxies.Store(params.proxyID, connSeed)
|
||||||
|
if err := s.proxyController.RegisterProxyToCluster(ctx, params.address, params.proxyID); err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", params.proxyID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return connSeed, proxyRecord, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// supersedePriorConnection cancels any existing connection for the given proxy.
|
||||||
|
func (s *ProxyServiceServer) supersedePriorConnection(proxyID, newSessionID string) {
|
||||||
|
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
|
||||||
|
oldConn := old.(*proxyConnection)
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"proxy_id": proxyID,
|
||||||
|
"old_session_id": oldConn.sessionID,
|
||||||
|
"new_session_id": newSessionID,
|
||||||
|
}).Info("Superseding existing proxy connection")
|
||||||
|
oldConn.cancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupFailedSnapshot removes the connection from the cluster and store
|
||||||
|
// after a snapshot send failure.
|
||||||
|
func (s *ProxyServiceServer) cleanupFailedSnapshot(ctx context.Context, conn *proxyConnection) {
|
||||||
|
if s.connectedProxies.CompareAndDelete(conn.proxyID, conn) {
|
||||||
|
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", conn.proxyID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
conn.cancel()
|
||||||
|
if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", conn.proxyID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// drainRecv consumes and discards messages from a bidirectional stream.
|
||||||
|
// The proxy sends an ack for every incremental update; we don't need them
|
||||||
|
// after the snapshot phase. Recv errors are forwarded to errChan.
|
||||||
|
func (s *ProxyServiceServer) drainRecv(stream proto.ProxyService_SyncMappingsServer, errChan chan<- error) {
|
||||||
|
for {
|
||||||
|
if _, err := stream.Recv(); err != nil {
|
||||||
|
errChan <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
|
// serveProxyConnection runs the post-snapshot lifecycle: heartbeat, sender,
|
||||||
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
|
// and wait for termination. When bidi is true, normal stream closure (EOF,
|
||||||
}
|
// canceled) is treated as a clean disconnect rather than an error.
|
||||||
if err := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); err != nil {
|
func (s *ProxyServiceServer) serveProxyConnection(conn *proxyConnection, proxyRecord *proxy.Proxy, errChan <-chan error, bidi bool) error {
|
||||||
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
|
log.WithFields(log.Fields{
|
||||||
}
|
"proxy_id": conn.proxyID,
|
||||||
|
"session_id": conn.sessionID,
|
||||||
|
"address": conn.address,
|
||||||
|
"cluster_addr": conn.address,
|
||||||
|
"account_id": conn.accountID,
|
||||||
|
"total_proxies": len(s.GetConnectedProxies()),
|
||||||
|
}).Info("Proxy registered in cluster")
|
||||||
|
|
||||||
cancel()
|
defer s.disconnectProxy(conn)
|
||||||
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
|
go s.heartbeat(conn.ctx, conn, proxyRecord)
|
||||||
}()
|
|
||||||
|
|
||||||
go s.heartbeat(connCtx, conn, proxyRecord)
|
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
log.WithContext(ctx).Warnf("Failed to send update: %v", err)
|
if bidi && isStreamClosed(err) {
|
||||||
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
|
log.Infof("Proxy %s stream closed", conn.proxyID)
|
||||||
case <-connCtx.Done():
|
return nil
|
||||||
log.WithContext(ctx).Infof("Proxy %s context canceled", proxyID)
|
}
|
||||||
return connCtx.Err()
|
log.Warnf("Failed to send update: %v", err)
|
||||||
|
return fmt.Errorf("send update to proxy %s: %w", conn.proxyID, err)
|
||||||
|
case <-conn.ctx.Done():
|
||||||
|
log.Infof("Proxy %s context canceled", conn.proxyID)
|
||||||
|
return conn.ctx.Err()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// disconnectProxy removes the connection from cluster and store, unless it
|
||||||
|
// has already been superseded by a newer connection.
|
||||||
|
func (s *ProxyServiceServer) disconnectProxy(conn *proxyConnection) {
|
||||||
|
if !s.connectedProxies.CompareAndDelete(conn.proxyID, conn) {
|
||||||
|
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", conn.proxyID, conn.sessionID)
|
||||||
|
conn.cancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil {
|
||||||
|
log.Warnf("Failed to unregister proxy %s from cluster: %v", conn.proxyID, err)
|
||||||
|
}
|
||||||
|
if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil {
|
||||||
|
log.Warnf("Failed to mark proxy %s as disconnected: %v", conn.proxyID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.cancel()
|
||||||
|
log.Infof("Proxy %s session %s disconnected", conn.proxyID, conn.sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendSnapshotSync sends the initial snapshot with back-pressure: it sends
|
||||||
|
// one batch, then waits for the proxy to ack before sending the next.
|
||||||
|
func (s *ProxyServiceServer) sendSnapshotSync(ctx context.Context, conn *proxyConnection, stream proto.ProxyService_SyncMappingsServer) error {
|
||||||
|
if !isProxyAddressValid(conn.address) {
|
||||||
|
return fmt.Errorf("proxy address is invalid")
|
||||||
|
}
|
||||||
|
if s.snapshotBatchSize <= 0 {
|
||||||
|
return fmt.Errorf("invalid snapshot batch size: %d", s.snapshotBatchSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
mappings, err := s.snapshotServiceMappings(ctx, conn)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(mappings); i += s.snapshotBatchSize {
|
||||||
|
end := i + s.snapshotBatchSize
|
||||||
|
if end > len(mappings) {
|
||||||
|
end = len(mappings)
|
||||||
|
}
|
||||||
|
for _, m := range mappings[i:end] {
|
||||||
|
token, err := s.tokenStore.GenerateToken(m.AccountId, m.Id, s.proxyTokenTTL())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate auth token for service %s: %w", m.Id, err)
|
||||||
|
}
|
||||||
|
m.AuthToken = token
|
||||||
|
}
|
||||||
|
if err := stream.Send(&proto.SyncMappingsResponse{
|
||||||
|
Mapping: mappings[i:end],
|
||||||
|
InitialSyncComplete: end == len(mappings),
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("send snapshot batch: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := waitForAck(stream); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(mappings) == 0 {
|
||||||
|
if err := stream.Send(&proto.SyncMappingsResponse{
|
||||||
|
InitialSyncComplete: true,
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("send snapshot completion: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := waitForAck(stream); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForAck(stream proto.ProxyService_SyncMappingsServer) error {
|
||||||
|
msg, err := stream.Recv()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("receive ack: %w", err)
|
||||||
|
}
|
||||||
|
if msg.GetAck() == nil {
|
||||||
|
return fmt.Errorf("expected ack, got %T", msg.GetMsg())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// heartbeat updates the proxy's last_seen timestamp every minute and
|
// heartbeat updates the proxy's last_seen timestamp every minute and
|
||||||
// disconnects the proxy if its access token has been revoked.
|
// disconnects the proxy if its access token has been revoked.
|
||||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, p *proxy.Proxy) {
|
func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, p *proxy.Proxy) {
|
||||||
@@ -381,6 +563,9 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
|||||||
if !isProxyAddressValid(conn.address) {
|
if !isProxyAddressValid(conn.address) {
|
||||||
return fmt.Errorf("proxy address is invalid")
|
return fmt.Errorf("proxy address is invalid")
|
||||||
}
|
}
|
||||||
|
if s.snapshotBatchSize <= 0 {
|
||||||
|
return fmt.Errorf("invalid snapshot batch size: %d", s.snapshotBatchSize)
|
||||||
|
}
|
||||||
|
|
||||||
mappings, err := s.snapshotServiceMappings(ctx, conn)
|
mappings, err := s.snapshotServiceMappings(ctx, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -394,6 +579,13 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
|||||||
if end > len(mappings) {
|
if end > len(mappings) {
|
||||||
end = len(mappings)
|
end = len(mappings)
|
||||||
}
|
}
|
||||||
|
for _, m := range mappings[i:end] {
|
||||||
|
token, err := s.tokenStore.GenerateToken(m.AccountId, m.Id, s.proxyTokenTTL())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate auth token for service %s: %w", m.Id, err)
|
||||||
|
}
|
||||||
|
m.AuthToken = token
|
||||||
|
}
|
||||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||||
Mapping: mappings[i:end],
|
Mapping: mappings[i:end],
|
||||||
InitialSyncComplete: end == len(mappings),
|
InitialSyncComplete: end == len(mappings),
|
||||||
@@ -425,18 +617,14 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
|
|||||||
return nil, fmt.Errorf("get services from store: %w", err)
|
return nil, fmt.Errorf("get services from store: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
oidcCfg := s.GetOIDCValidationConfig()
|
||||||
var mappings []*proto.ProxyMapping
|
var mappings []*proto.ProxyMapping
|
||||||
for _, service := range services {
|
for _, service := range services {
|
||||||
if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address {
|
if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL())
|
m := service.ToProtoMapping(rpservice.Create, "", oidcCfg)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
|
|
||||||
if !proxyAcceptsMapping(conn, m) {
|
if !proxyAcceptsMapping(conn, m) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -457,12 +645,26 @@ func isProxyAddressValid(addr string) bool {
|
|||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// sender handles sending messages to proxy
|
// isStreamClosed returns true for errors that indicate normal stream
|
||||||
|
// termination: io.EOF, context cancellation, or gRPC Canceled.
|
||||||
|
func isStreamClosed(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return status.Code(err) == codes.Canceled
|
||||||
|
}
|
||||||
|
|
||||||
|
// sender handles sending messages to proxy.
|
||||||
|
// When conn.syncStream is set the message is sent as SyncMappingsResponse;
|
||||||
|
// otherwise the legacy GetMappingUpdateResponse stream is used.
|
||||||
func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) {
|
func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case resp := <-conn.sendChan:
|
case resp := <-conn.sendChan:
|
||||||
if err := conn.stream.Send(resp); err != nil {
|
if err := conn.sendResponse(resp); err != nil {
|
||||||
errChan <- err
|
errChan <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -472,6 +674,17 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sendResponse sends a mapping update on whichever stream the proxy connected with.
|
||||||
|
func (conn *proxyConnection) sendResponse(resp *proto.GetMappingUpdateResponse) error {
|
||||||
|
if conn.syncStream != nil {
|
||||||
|
return conn.syncStream.Send(&proto.SyncMappingsResponse{
|
||||||
|
Mapping: resp.Mapping,
|
||||||
|
InitialSyncComplete: resp.InitialSyncComplete,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return conn.stream.Send(resp)
|
||||||
|
}
|
||||||
|
|
||||||
// SendAccessLog processes access log from proxy
|
// SendAccessLog processes access log from proxy
|
||||||
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
|
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
|
||||||
accessLog := req.GetLog()
|
accessLog := req.GetLog()
|
||||||
@@ -538,10 +751,15 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
connUpdate = &proto.GetMappingUpdateResponse{
|
connUpdate = &proto.GetMappingUpdateResponse{
|
||||||
Mapping: filtered,
|
Mapping: filtered,
|
||||||
InitialSyncComplete: update.InitialSyncComplete,
|
InitialSyncComplete: update.InitialSyncComplete,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Drop mappings the proxy lacks capability for (e.g. private without SupportsPrivateService).
|
||||||
|
connUpdate = filterMappingsForProxy(conn, connUpdate)
|
||||||
|
if connUpdate == nil || len(connUpdate.Mapping) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
resp := s.perProxyMessage(connUpdate, conn.proxyID)
|
resp := s.perProxyMessage(connUpdate, conn.proxyID)
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
|
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
|
||||||
@@ -670,16 +888,20 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// proxyAcceptsMapping returns whether the proxy should receive this mapping.
|
// proxyAcceptsMapping returns whether the proxy can receive this mapping.
|
||||||
// Old proxies that never reported capabilities are skipped for non-TLS L4
|
// Private mappings require SupportsPrivateService; custom-port L4 mappings
|
||||||
// mappings with a custom listen port, since they don't understand the
|
// require SupportsCustomPorts. Remove operations always pass so proxies can
|
||||||
// protocol. Proxies that report capabilities (even SupportsCustomPorts=false)
|
// clean up.
|
||||||
// are new enough to handle the mapping. TLS uses SNI routing and works on
|
|
||||||
// any proxy. Delete operations are always sent so proxies can clean up.
|
|
||||||
func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool {
|
func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool {
|
||||||
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
|
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
if mapping.GetPrivate() {
|
||||||
|
caps := conn.capabilities
|
||||||
|
if caps == nil || caps.SupportsPrivateService == nil || !*caps.SupportsPrivateService {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
if mapping.ListenPort == 0 || mapping.Mode == "tls" {
|
if mapping.ListenPort == 0 || mapping.Mode == "tls" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -688,6 +910,29 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo
|
|||||||
return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil
|
return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// filterMappingsForProxy drops mappings the proxy cannot safely receive
|
||||||
|
// (e.g. private mappings to a proxy without SupportsPrivateService).
|
||||||
|
// Returns the input unchanged when no filtering is needed.
|
||||||
|
func filterMappingsForProxy(conn *proxyConnection, update *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse {
|
||||||
|
if update == nil || len(update.Mapping) == 0 {
|
||||||
|
return update
|
||||||
|
}
|
||||||
|
kept := make([]*proto.ProxyMapping, 0, len(update.Mapping))
|
||||||
|
for _, m := range update.Mapping {
|
||||||
|
if !proxyAcceptsMapping(conn, m) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
kept = append(kept, m)
|
||||||
|
}
|
||||||
|
if len(kept) == len(update.Mapping) {
|
||||||
|
return update
|
||||||
|
}
|
||||||
|
return &proto.GetMappingUpdateResponse{
|
||||||
|
Mapping: kept,
|
||||||
|
InitialSyncComplete: update.InitialSyncComplete,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// perProxyMessage returns a copy of update with a fresh one-time token for
|
// perProxyMessage returns a copy of update with a fresh one-time token for
|
||||||
// create/update operations. For delete operations the original mapping is
|
// create/update operations. For delete operations the original mapping is
|
||||||
// used unchanged because proxies do not need to authenticate for removal.
|
// used unchanged because proxies do not need to authenticate for removal.
|
||||||
@@ -749,7 +994,10 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
|
|||||||
|
|
||||||
authenticated, userId, method := s.authenticateRequest(ctx, req, service)
|
authenticated, userId, method := s.authenticateRequest(ctx, req, service)
|
||||||
|
|
||||||
token, err := s.generateSessionToken(ctx, authenticated, service, userId, method)
|
// Non-OIDC schemes (PIN/Password/Header) authenticate against per-service
|
||||||
|
// secrets and have no user-level group context, so groups stay nil. Email
|
||||||
|
// is also empty — these schemes don't resolve a user record at sign time.
|
||||||
|
token, err := s.generateSessionToken(ctx, authenticated, service, userId, "", method, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -838,7 +1086,7 @@ func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err err
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId string, method proxyauth.Method) (string, error) {
|
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId, userEmail string, method proxyauth.Method, groupIDs, groupNames []string) (string, error) {
|
||||||
if !authenticated || service.SessionPrivateKey == "" {
|
if !authenticated || service.SessionPrivateKey == "" {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
@@ -846,8 +1094,11 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
|
|||||||
token, err := sessionkey.SignToken(
|
token, err := sessionkey.SignToken(
|
||||||
service.SessionPrivateKey,
|
service.SessionPrivateKey,
|
||||||
userId,
|
userId,
|
||||||
|
userEmail,
|
||||||
service.Domain,
|
service.Domain,
|
||||||
method,
|
method,
|
||||||
|
groupIDs,
|
||||||
|
groupNames,
|
||||||
proxyauth.DefaultSessionExpiry,
|
proxyauth.DefaultSessionExpiry,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -858,6 +1109,26 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
|
|||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// pairGroupIDsAndNames splits a slice of resolved *types.Group records
|
||||||
|
// into parallel id and name slices. ids[i] and names[i] always pair to
|
||||||
|
// the same group. nil entries (orphan ids the manager couldn't resolve)
|
||||||
|
// are skipped so the consumer can rely on positional pairing.
|
||||||
|
func pairGroupIDsAndNames(groups []*types.Group) (ids, names []string) {
|
||||||
|
if len(groups) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
ids = make([]string, 0, len(groups))
|
||||||
|
names = make([]string, 0, len(groups))
|
||||||
|
for _, g := range groups {
|
||||||
|
if g == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ids = append(ids, g.ID)
|
||||||
|
names = append(names, g.Name)
|
||||||
|
}
|
||||||
|
return ids, names
|
||||||
|
}
|
||||||
|
|
||||||
// SendStatusUpdate handles status updates from proxy clients.
|
// SendStatusUpdate handles status updates from proxy clients.
|
||||||
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
|
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
|
||||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||||
@@ -1122,7 +1393,9 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
|
|||||||
return verifier, redirectURL, nil
|
return verifier, redirectURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateSessionToken creates a signed session JWT for the given domain and user.
|
// GenerateSessionToken creates a signed session JWT for the given domain and
|
||||||
|
// user. The user's group memberships are embedded in the token so policy-aware
|
||||||
|
// middlewares on the proxy can authorise without an extra management round-trip.
|
||||||
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
|
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
|
||||||
service, err := s.getServiceByDomain(ctx, domain)
|
service, err := s.getServiceByDomain(ctx, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1133,11 +1406,29 @@ func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, u
|
|||||||
return "", fmt.Errorf("no session key configured for domain: %s", domain)
|
return "", fmt.Errorf("no session key configured for domain: %s", domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
email string
|
||||||
|
groupIDs []string
|
||||||
|
groupNames []string
|
||||||
|
)
|
||||||
|
if s.usersManager != nil {
|
||||||
|
user, userGroups, uerr := s.usersManager.GetUserWithGroups(ctx, userID)
|
||||||
|
if uerr != nil {
|
||||||
|
log.WithContext(ctx).Debugf("session token mint: lookup user %s: %v", userID, uerr)
|
||||||
|
} else if user != nil {
|
||||||
|
email = user.Email
|
||||||
|
groupIDs, groupNames = pairGroupIDsAndNames(userGroups)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return sessionkey.SignToken(
|
return sessionkey.SignToken(
|
||||||
service.SessionPrivateKey,
|
service.SessionPrivateKey,
|
||||||
userID,
|
userID,
|
||||||
|
email,
|
||||||
domain,
|
domain,
|
||||||
method,
|
method,
|
||||||
|
groupIDs,
|
||||||
|
groupNames,
|
||||||
proxyauth.DefaultSessionExpiry,
|
proxyauth.DefaultSessionExpiry,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -1241,7 +1532,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes)
|
userID, _, _, _, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"domain": domain,
|
"domain": domain,
|
||||||
@@ -1254,7 +1545,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := s.usersManager.GetUser(ctx, userID)
|
user, userGroups, err := s.usersManager.GetUserWithGroups(ctx, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"domain": domain,
|
"domain": domain,
|
||||||
@@ -1288,12 +1579,15 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
|||||||
"user_id": userID,
|
"user_id": userID,
|
||||||
"error": err.Error(),
|
"error": err.Error(),
|
||||||
}).Debug("ValidateSession: access denied")
|
}).Debug("ValidateSession: access denied")
|
||||||
|
groupIDs, groupNames := pairGroupIDsAndNames(userGroups)
|
||||||
//nolint:nilerr
|
//nolint:nilerr
|
||||||
return &proto.ValidateSessionResponse{
|
return &proto.ValidateSessionResponse{
|
||||||
Valid: false,
|
Valid: false,
|
||||||
UserId: user.Id,
|
UserId: user.Id,
|
||||||
UserEmail: user.Email,
|
UserEmail: user.Email,
|
||||||
DeniedReason: "not_in_group",
|
DeniedReason: "not_in_group",
|
||||||
|
PeerGroupIds: groupIDs,
|
||||||
|
PeerGroupNames: groupNames,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1303,10 +1597,13 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
|||||||
"email": user.Email,
|
"email": user.Email,
|
||||||
}).Debug("ValidateSession: access granted")
|
}).Debug("ValidateSession: access granted")
|
||||||
|
|
||||||
|
groupIDs, groupNames := pairGroupIDsAndNames(userGroups)
|
||||||
return &proto.ValidateSessionResponse{
|
return &proto.ValidateSessionResponse{
|
||||||
Valid: true,
|
Valid: true,
|
||||||
UserId: user.Id,
|
UserId: user.Id,
|
||||||
UserEmail: user.Email,
|
UserEmail: user.Email,
|
||||||
|
PeerGroupIds: groupIDs,
|
||||||
|
PeerGroupNames: groupNames,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1339,3 +1636,154 @@ func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ptr[T any](v T) *T { return &v }
|
func ptr[T any](v T) *T { return &v }
|
||||||
|
|
||||||
|
// ValidateTunnelPeer resolves an inbound peer by its WireGuard tunnel IP and
|
||||||
|
// checks the peer's group membership against the service's access groups.
|
||||||
|
// Peers without a user (machine agents, automation workloads) are first-class
|
||||||
|
// callers; authorisation runs off peer-group memberships rather than the
|
||||||
|
// optional owning user's auto-groups. On success a session JWT is minted so
|
||||||
|
// the proxy can install a cookie and skip subsequent management round-trips.
|
||||||
|
func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
|
||||||
|
domain := req.GetDomain()
|
||||||
|
tunnelIPStr := req.GetTunnelIp()
|
||||||
|
|
||||||
|
if domain == "" || tunnelIPStr == "" {
|
||||||
|
return &proto.ValidateTunnelPeerResponse{
|
||||||
|
Valid: false,
|
||||||
|
DeniedReason: "missing domain or tunnel_ip",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tunnelIP := net.ParseIP(tunnelIPStr)
|
||||||
|
if tunnelIP == nil {
|
||||||
|
return &proto.ValidateTunnelPeerResponse{
|
||||||
|
Valid: false,
|
||||||
|
DeniedReason: "invalid_tunnel_ip",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
service, err := s.getServiceByDomain(ctx, domain)
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{"domain": domain, "error": err.Error()}).Debug("ValidateTunnelPeer: service not found")
|
||||||
|
//nolint:nilerr
|
||||||
|
return &proto.ValidateTunnelPeerResponse{
|
||||||
|
Valid: false,
|
||||||
|
DeniedReason: "service_not_found",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mirror ValidateSession: account-scoped (BYOP) proxy tokens may only
|
||||||
|
// validate and mint session cookies for their own account's domains.
|
||||||
|
if err := enforceAccountScope(ctx, service.AccountID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
peer, err := s.peersManager.GetPeerByTunnelIP(ctx, service.AccountID, tunnelIP)
|
||||||
|
if err != nil || peer == nil {
|
||||||
|
log.WithFields(log.Fields{"domain": domain, "tunnel_ip": tunnelIPStr}).Debug("ValidateTunnelPeer: peer not found")
|
||||||
|
//nolint:nilerr
|
||||||
|
return &proto.ValidateTunnelPeerResponse{
|
||||||
|
Valid: false,
|
||||||
|
DeniedReason: "peer_not_found",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, peerGroups, err := s.peersManager.GetPeerWithGroups(ctx, service.AccountID, peer.ID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: peer groups lookup failed")
|
||||||
|
//nolint:nilerr
|
||||||
|
return &proto.ValidateTunnelPeerResponse{
|
||||||
|
Valid: false,
|
||||||
|
DeniedReason: "peer_not_found",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
|
||||||
|
|
||||||
|
// Resolve the principal: when the peer is linked to a user, the human
|
||||||
|
// is the principal so multiple peers owned by the same user share a
|
||||||
|
// single identity. Unlinked peers (machine agents) are their own
|
||||||
|
// principal keyed on peer.ID. displayIdentity is what upstream gateways
|
||||||
|
// tag spend with — user.Email when linked, peer.Name when not.
|
||||||
|
principalID := peer.ID
|
||||||
|
displayIdentity := peer.Name
|
||||||
|
if peer.UserID != "" {
|
||||||
|
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
|
||||||
|
principalID = user.Id
|
||||||
|
if user.Email != "" {
|
||||||
|
displayIdentity = user.Email
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
|
||||||
|
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
|
||||||
|
//nolint:nilerr
|
||||||
|
return &proto.ValidateTunnelPeerResponse{
|
||||||
|
Valid: false,
|
||||||
|
UserId: principalID,
|
||||||
|
UserEmail: displayIdentity,
|
||||||
|
DeniedReason: "not_in_group",
|
||||||
|
PeerGroupIds: groupIDs,
|
||||||
|
PeerGroupNames: groupNames,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := s.generateSessionToken(ctx, true, service, principalID, displayIdentity, proxyauth.MethodOIDC, groupIDs, groupNames)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"tunnel_ip": tunnelIPStr,
|
||||||
|
"peer_id": peer.ID,
|
||||||
|
"principal_id": principalID,
|
||||||
|
}).Debug("ValidateTunnelPeer: access granted")
|
||||||
|
|
||||||
|
return &proto.ValidateTunnelPeerResponse{
|
||||||
|
Valid: true,
|
||||||
|
UserId: principalID,
|
||||||
|
UserEmail: displayIdentity,
|
||||||
|
SessionToken: token,
|
||||||
|
PeerGroupIds: groupIDs,
|
||||||
|
PeerGroupNames: groupNames,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
|
||||||
|
// groups. Private services authorise against AccessGroups (empty list fails
|
||||||
|
// closed — Validate() rejects that at save time but the RPC is the security
|
||||||
|
// boundary and must not trust upstream state). Bearer-auth services authorise
|
||||||
|
// against DistributionGroups when populated. Non-private non-bearer services
|
||||||
|
// are open.
|
||||||
|
func checkPeerGroupAccess(service *rpservice.Service, peerGroupIDs []string) error {
|
||||||
|
if service.Private {
|
||||||
|
if len(service.AccessGroups) == 0 {
|
||||||
|
return fmt.Errorf("private service has no access groups")
|
||||||
|
}
|
||||||
|
return matchAnyGroup(service.AccessGroups, peerGroupIDs)
|
||||||
|
}
|
||||||
|
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled && len(service.Auth.BearerAuth.DistributionGroups) > 0 {
|
||||||
|
return matchAnyGroup(service.Auth.BearerAuth.DistributionGroups, peerGroupIDs)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchAnyGroup returns nil when peerGroupIDs intersects allowedGroups,
|
||||||
|
// else a non-nil error.
|
||||||
|
func matchAnyGroup(allowedGroups, peerGroupIDs []string) error {
|
||||||
|
if len(allowedGroups) == 0 {
|
||||||
|
return fmt.Errorf("no allowed groups configured")
|
||||||
|
}
|
||||||
|
allowed := make(map[string]struct{}, len(allowedGroups))
|
||||||
|
for _, g := range allowedGroups {
|
||||||
|
allowed[g] = struct{}{}
|
||||||
|
}
|
||||||
|
for _, g := range peerGroupIDs {
|
||||||
|
if _, ok := allowed[g]; ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("peer not in allowed groups")
|
||||||
|
}
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain s
|
|||||||
return nil, errors.New("service not found for domain: " + domain)
|
return nil, errors.New("service not found for domain: " + domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
func (m *mockReverseProxyManager) GetClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -129,6 +129,14 @@ func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.U
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockUsersManager) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) {
|
||||||
|
user, err := m.GetUser(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return user, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestValidateUserGroupAccess(t *testing.T) {
|
func TestValidateUserGroupAccess(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -420,3 +428,46 @@ func TestGetAccountProxyByDomain(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCheckPeerGroupAccess(t *testing.T) {
|
||||||
|
t.Run("private with empty AccessGroups denies", func(t *testing.T) {
|
||||||
|
svc := &service.Service{Private: true, AccessGroups: nil}
|
||||||
|
err := checkPeerGroupAccess(svc, []string{"grp-admins"})
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no access groups")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("private with peer in AccessGroups allows", func(t *testing.T) {
|
||||||
|
svc := &service.Service{Private: true, AccessGroups: []string{"grp-admins", "grp-ops"}}
|
||||||
|
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-other", "grp-ops"}))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("private with peer outside AccessGroups denies", func(t *testing.T) {
|
||||||
|
svc := &service.Service{Private: true, AccessGroups: []string{"grp-admins"}}
|
||||||
|
assert.Error(t, checkPeerGroupAccess(svc, []string{"grp-other"}))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bearer enabled with empty DistributionGroups allows", func(t *testing.T) {
|
||||||
|
svc := &service.Service{
|
||||||
|
Auth: service.AuthConfig{BearerAuth: &service.BearerAuthConfig{Enabled: true}},
|
||||||
|
}
|
||||||
|
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-anyone"}))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bearer enabled gates on DistributionGroups", func(t *testing.T) {
|
||||||
|
svc := &service.Service{
|
||||||
|
Auth: service.AuthConfig{
|
||||||
|
BearerAuth: &service.BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DistributionGroups: []string{"grp-allowed"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-allowed"}))
|
||||||
|
assert.Error(t, checkPeerGroupAccess(svc, []string{"grp-other"}))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-private non-bearer is open", func(t *testing.T) {
|
||||||
|
assert.NoError(t, checkPeerGroupAccess(&service.Service{}, nil))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -172,3 +173,55 @@ func TestSendSnapshot_EmptySnapshot(t *testing.T) {
|
|||||||
assert.Empty(t, stream.messages[0].Mapping)
|
assert.Empty(t, stream.messages[0].Mapping)
|
||||||
assert.True(t, stream.messages[0].InitialSyncComplete)
|
assert.True(t, stream.messages[0].InitialSyncComplete)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type hookingStream struct {
|
||||||
|
grpc.ServerStream
|
||||||
|
onSend func(*proto.GetMappingUpdateResponse)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *hookingStream) Send(m *proto.GetMappingUpdateResponse) error {
|
||||||
|
if s.onSend != nil {
|
||||||
|
s.onSend(m)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *hookingStream) Context() context.Context { return context.Background() }
|
||||||
|
func (s *hookingStream) SetHeader(metadata.MD) error { return nil }
|
||||||
|
func (s *hookingStream) SendHeader(metadata.MD) error { return nil }
|
||||||
|
func (s *hookingStream) SetTrailer(metadata.MD) {}
|
||||||
|
func (s *hookingStream) SendMsg(any) error { return nil }
|
||||||
|
func (s *hookingStream) RecvMsg(any) error { return nil }
|
||||||
|
|
||||||
|
func TestSendSnapshot_TokensRemainValidUnderSlowSend(t *testing.T) {
|
||||||
|
const cluster = "cluster.example.com"
|
||||||
|
const batchSize = 2
|
||||||
|
const totalServices = 6
|
||||||
|
const ttl = 100 * time.Millisecond
|
||||||
|
const sendDelay = 200 * time.Millisecond
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mgr := rpservice.NewMockManager(ctrl)
|
||||||
|
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||||
|
|
||||||
|
s := newSnapshotTestServer(t, batchSize)
|
||||||
|
s.serviceManager = mgr
|
||||||
|
s.tokenTTL = ttl
|
||||||
|
|
||||||
|
var validateErrs []error
|
||||||
|
stream := &hookingStream{
|
||||||
|
onSend: func(resp *proto.GetMappingUpdateResponse) {
|
||||||
|
for _, m := range resp.Mapping {
|
||||||
|
if err := s.tokenStore.ValidateAndConsume(m.AuthToken, m.AccountId, m.Id); err != nil {
|
||||||
|
validateErrs = append(validateErrs, fmt.Errorf("svc %s: %w", m.Id, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
time.Sleep(sendDelay)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||||
|
|
||||||
|
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||||
|
require.Empty(t, validateErrs,
|
||||||
|
"tokens must remain valid even when batches are sent slowly: lazy per-batch generation guarantees freshness")
|
||||||
|
}
|
||||||
|
|||||||
@@ -437,7 +437,7 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
log.WithContext(ctx).Tracef("received an update for peer %s", peerKey.String())
|
||||||
if debouncer.ProcessUpdate(update) {
|
if debouncer.ProcessUpdate(update) {
|
||||||
// Send immediately (first update or after quiet period)
|
// Send immediately (first update or after quiet period)
|
||||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
|
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
|
||||||
@@ -492,7 +492,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
|
|||||||
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||||
return status.Errorf(codes.Internal, "failed sending update message")
|
return status.Errorf(codes.Internal, "failed sending update message")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
log.WithContext(ctx).Tracef("sent an update to peer %s", peerKey.String())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -522,10 +522,11 @@ func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Even
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
|
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
|
||||||
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
uncanceledCTX := context.WithoutCancel(ctx)
|
||||||
|
unlock := s.acquirePeerLockByUID(uncanceledCTX, peer.Key)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime)
|
s.cancelPeerRoutinesWithoutLock(uncanceledCTX, accountID, peer, streamStartTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
|
func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
|
||||||
@@ -820,6 +821,80 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExtendAuthSession refreshes the peer's SSO session expiry deadline using a
|
||||||
|
// fresh JWT. The same JWT validation pipeline as Login is used. The tunnel
|
||||||
|
// stays up; no network map sync is performed. The new deadline is returned
|
||||||
|
// in ExtendAuthSessionResponse.SessionExpiresAt.
|
||||||
|
func (s *Server) ExtendAuthSession(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||||
|
extendReq := &proto.ExtendAuthSessionRequest{}
|
||||||
|
peerKey, err := s.parseRequest(ctx, req, extendReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint
|
||||||
|
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
||||||
|
if accountID, accErr := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()); accErr == nil {
|
||||||
|
//nolint
|
||||||
|
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
jwt := extendReq.GetJwtToken()
|
||||||
|
if jwt == "" {
|
||||||
|
return nil, status.Errorf(codes.InvalidArgument, "jwt token is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
var userID string
|
||||||
|
const attempts = 3
|
||||||
|
for i := 0; i < attempts; i++ {
|
||||||
|
userID, err = s.validateToken(ctx, peerKey.String(), jwt)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if i == attempts-1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Warnf("failed validating JWT token while extending session for peer %s: %v. Retrying (idP cache).", peerKey.String(), err)
|
||||||
|
select {
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if userID == "" {
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "jwt token did not yield a user id")
|
||||||
|
}
|
||||||
|
|
||||||
|
deadline, err := s.accountManager.ExtendPeerSession(ctx, peerKey.String(), userID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed extending session for peer %s: %v", peerKey.String(), err)
|
||||||
|
return nil, mapError(ctx, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Success path normally returns a non-zero deadline. A defensive zero
|
||||||
|
// would still encode as the explicit "disabled" sentinel rather than nil,
|
||||||
|
// so the client clears any stale anchor instead of preserving it.
|
||||||
|
resp := &proto.ExtendAuthSessionResponse{
|
||||||
|
SessionExpiresAt: encodeSessionExpiresAt(deadline),
|
||||||
|
}
|
||||||
|
|
||||||
|
wgKey, err := s.secretsManager.GetWGKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "failed processing request")
|
||||||
|
}
|
||||||
|
encrypted, err := encryption.EncryptMessage(peerKey, wgKey, resp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "failed encrypting response")
|
||||||
|
}
|
||||||
|
return &proto.EncryptedMessage{
|
||||||
|
WgPubKey: wgKey.PublicKey().String(),
|
||||||
|
Body: encrypted,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
|
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
|
||||||
var relayToken *Token
|
var relayToken *Token
|
||||||
var err error
|
var err error
|
||||||
@@ -843,6 +918,12 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
|
|||||||
Checks: toProtocolChecks(ctx, postureChecks),
|
Checks: toProtocolChecks(ctx, postureChecks),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// settings is always non-nil here, so we never emit nil — encoder returns
|
||||||
|
// either a valid deadline or the explicit-zero "disabled" sentinel.
|
||||||
|
loginResp.SessionExpiresAt = encodeSessionExpiresAt(
|
||||||
|
peer.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration),
|
||||||
|
)
|
||||||
|
|
||||||
return loginResp, nil
|
return loginResp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -931,7 +1012,31 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
|
|||||||
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
|
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
dnsName := s.networkMapController.GetDNSDomain(settings)
|
||||||
|
|
||||||
|
var plainResp *proto.SyncResponse
|
||||||
|
if s.networkMapController.PeerNeedsComponents(peer) {
|
||||||
|
// Capable peer: discard the legacy NetworkMap that SyncAndMarkPeer
|
||||||
|
// computed and recompute the raw components instead. This wastes one
|
||||||
|
// Calculate() call per initial-sync — the component-based wire
|
||||||
|
// format is what the peer actually consumes. The streaming path
|
||||||
|
// (network_map.Controller.UpdateAccountPeers) skips this duplication
|
||||||
|
// because it dispatches by capability before computing.
|
||||||
|
//
|
||||||
|
// TODO: refactor SyncPeer / SyncAndMarkPeer / their mocks + manager
|
||||||
|
// interfaces to return PeerNetworkMapResult so the initial-sync path
|
||||||
|
// stops doing duplicate work. Deferred until the client-side
|
||||||
|
// decoder lands and there's a real deployment of capability=3 peers
|
||||||
|
// worth optimizing for.
|
||||||
|
_, components, proxyPatch, _, _, err := s.networkMapController.GetValidatedPeerWithComponents(ctx, false, peer.AccountID, peer)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to build components for peer %s on initial sync: %v", peer.ID, err)
|
||||||
|
return status.Errorf(codes.Internal, "failed to build initial sync envelope")
|
||||||
|
}
|
||||||
|
plainResp = ToComponentSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, components, proxyPatch, dnsName, postureChecks, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||||
|
} else {
|
||||||
|
plainResp = ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, dnsName, postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||||
|
}
|
||||||
|
|
||||||
key, err := s.secretsManager.GetWGKey()
|
key, err := s.secretsManager.GetWGKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user