mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-27 02:59:54 +00:00
Compare commits
25 Commits
windows-dn
...
chore/adju
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3a8380623d | ||
|
|
14af179556 | ||
|
|
1fbb5e6d5d | ||
|
|
6771e35d57 | ||
|
|
e89b1e0596 | ||
|
|
d542c60e21 | ||
|
|
4983b5cf17 | ||
|
|
b3b0feb3b8 | ||
|
|
7aebdd69dd | ||
|
|
0358be2313 | ||
|
|
37052fd5bc | ||
|
|
454ff66518 | ||
|
|
6137a1fcc5 | ||
|
|
4955c345d5 | ||
|
|
9192b4f029 | ||
|
|
c784b02550 | ||
|
|
d250f92c43 | ||
|
|
80966ab1b0 | ||
|
|
af24fd7796 | ||
|
|
13d32d274f | ||
|
|
705f87fc20 | ||
|
|
3f91f49277 | ||
|
|
347c5bf317 | ||
|
|
22e2519d71 | ||
|
|
e916f12cca |
1
.github/pull_request_template.md
vendored
1
.github/pull_request_template.md
vendored
@@ -12,6 +12,7 @@
|
||||
- [ ] Is a feature enhancement
|
||||
- [ ] It is a refactor
|
||||
- [ ] Created tests that fail without the change (if possible)
|
||||
- [ ] This change does **not** modify the public API, gRPC protocols, functionality behavior, CLI / service flags, or introduce a new feature — **OR** I have discussed it with the NetBird team beforehand (link the issue / Slack thread in the description). See [CONTRIBUTING.md](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTING.md#discuss-changes-with-the-netbird-team-first).
|
||||
|
||||
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).
|
||||
|
||||
|
||||
66
.github/workflows/proto-version-check.yml
vendored
66
.github/workflows/proto-version-check.yml
vendored
@@ -20,34 +20,66 @@ jobs:
|
||||
per_page: 100,
|
||||
});
|
||||
|
||||
const pbFiles = files.filter(f => f.filename.endsWith('.pb.go'));
|
||||
const missingPatch = pbFiles.filter(f => !f.patch).map(f => f.filename);
|
||||
if (missingPatch.length > 0) {
|
||||
core.setFailed(
|
||||
`Cannot inspect patch data for:\n` +
|
||||
missingPatch.map(f => `- ${f}`).join('\n') +
|
||||
`\nThis can happen with very large PRs. Verify proto versions manually.`
|
||||
);
|
||||
const modifiedPbFiles = files.filter(
|
||||
f => f.filename.endsWith('.pb.go') && f.status === 'modified'
|
||||
);
|
||||
if (modifiedPbFiles.length === 0) {
|
||||
console.log('No modified .pb.go files to check');
|
||||
return;
|
||||
}
|
||||
const versionPattern = /^[+-]\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
|
||||
const violations = [];
|
||||
|
||||
for (const file of pbFiles) {
|
||||
const changed = file.patch
|
||||
.split('\n')
|
||||
.filter(line => versionPattern.test(line));
|
||||
if (changed.length > 0) {
|
||||
const versionPattern = /^\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
|
||||
const baseSha = context.payload.pull_request.base.sha;
|
||||
const headSha = context.payload.pull_request.head.sha;
|
||||
|
||||
async function getVersionHeader(path, ref) {
|
||||
try {
|
||||
const res = await github.rest.repos.getContent({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
path,
|
||||
ref,
|
||||
});
|
||||
if (!res.data.content) {
|
||||
return { ok: false, reason: 'no inline content (file too large)' };
|
||||
}
|
||||
const content = Buffer.from(res.data.content, 'base64').toString('utf8');
|
||||
const lines = content
|
||||
.split('\n')
|
||||
.slice(0, 20)
|
||||
.filter(line => versionPattern.test(line));
|
||||
return { ok: true, lines };
|
||||
} catch (e) {
|
||||
return { ok: false, reason: e.message };
|
||||
}
|
||||
}
|
||||
|
||||
const violations = [];
|
||||
for (const file of modifiedPbFiles) {
|
||||
const [base, head] = await Promise.all([
|
||||
getVersionHeader(file.filename, baseSha),
|
||||
getVersionHeader(file.filename, headSha),
|
||||
]);
|
||||
if (!base.ok || !head.ok) {
|
||||
core.warning(
|
||||
`Skipping ${file.filename}: base=${base.ok ? 'ok' : base.reason}, head=${head.ok ? 'ok' : head.reason}`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
if (base.lines.join('\n') !== head.lines.join('\n')) {
|
||||
violations.push({
|
||||
file: file.filename,
|
||||
lines: changed,
|
||||
base: base.lines,
|
||||
head: head.lines,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (violations.length > 0) {
|
||||
const details = violations.map(v =>
|
||||
`${v.file}:\n${v.lines.map(l => ' ' + l).join('\n')}`
|
||||
`${v.file}:\n` +
|
||||
` base:\n${v.base.map(l => ' ' + l).join('\n') || ' (none)'}\n` +
|
||||
` head:\n${v.head.map(l => ' ' + l).join('\n') || ' (none)'}`
|
||||
).join('\n\n');
|
||||
|
||||
core.setFailed(
|
||||
|
||||
@@ -92,9 +92,6 @@ linters:
|
||||
- linters:
|
||||
- unused
|
||||
path: client/firewall/iptables/rule\.go
|
||||
- linters:
|
||||
- unused
|
||||
path: client/internal/dns/dnsfw/(types|syscall|zsyscall)_windows.*\.go
|
||||
- linters:
|
||||
- gosec
|
||||
- mirror
|
||||
|
||||
@@ -15,6 +15,7 @@ If you haven't already, join our slack workspace [here](https://docs.netbird.io/
|
||||
- [Contributing to NetBird](#contributing-to-netbird)
|
||||
- [Contents](#contents)
|
||||
- [Code of conduct](#code-of-conduct)
|
||||
- [Discuss changes with the NetBird team first](#discuss-changes-with-the-netbird-team-first)
|
||||
- [Directory structure](#directory-structure)
|
||||
- [Development setup](#development-setup)
|
||||
- [Requirements](#requirements)
|
||||
@@ -33,6 +34,14 @@ Conduct which can be found in the file [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md).
|
||||
By participating, you are expected to uphold this code. Please report
|
||||
unacceptable behavior to community@netbird.io.
|
||||
|
||||
## Discuss changes with the NetBird team first
|
||||
|
||||
Changes to the **public API**, **gRPC protocols**, **functionality behavior**, **CLI / service flags**, or **new features** should be discussed with the NetBird team before you start the work. These surfaces are part of NetBird's contract with operators, self-hosters, and downstream integrators, and changes to them have compatibility, security, and release-planning implications that benefit from an early conversation.
|
||||
|
||||
Open an issue or reach out on [Slack](https://docs.netbird.io/slack-url) to talk through what you have in mind. We'll help shape the change, flag any constraints we know about, and confirm the direction so the PR review can focus on implementation rather than design.
|
||||
|
||||
Typical bug fixes, internal refactors, documentation updates, and tests do not need pre-discussion — open the PR directly.
|
||||
|
||||
## Directory structure
|
||||
|
||||
The NetBird project monorepo is organized to maintain most of its individual dependencies code within their directories, except for a few auxiliary or shared packages.
|
||||
|
||||
153
README.md
153
README.md
@@ -1,147 +1,134 @@
|
||||
|
||||
<div align="center">
|
||||
<br/>
|
||||
<br/>
|
||||
<p align="center">
|
||||
<img width="234" src="docs/media/logo-full.png"/>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://img.shields.io/badge/license-BSD--3-blue)">
|
||||
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" />
|
||||
</a>
|
||||
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||
</a>
|
||||
<br>
|
||||
<p align="center">
|
||||
<img width="234" src="docs/media/logo-full.png" alt="NetBird logo"/>
|
||||
</p>
|
||||
<p align="center">
|
||||
<a href="https://sonarcloud.io/dashboard?id=netbirdio_netbird">
|
||||
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" alt="SonarCloud alert status"/>
|
||||
</a>
|
||||
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" alt="BSD-3 License"/>
|
||||
</a>
|
||||
<a href="https://docs.netbird.io/slack-url">
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||
</a>
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack" alt="NetBird Slack"/>
|
||||
</a>
|
||||
<a href="https://forum.netbird.io">
|
||||
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
|
||||
</a>
|
||||
<br>
|
||||
<img src="https://img.shields.io/badge/community%20forum-@netbird-red.svg?logo=discourse" alt="Community forum"/>
|
||||
</a>
|
||||
<a href="https://gurubase.io/g/netbird">
|
||||
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
|
||||
</a>
|
||||
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF" alt="Gurubase: Ask NetBird Guru"/>
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
<p align="center">
|
||||
<strong>
|
||||
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
|
||||
<strong>
|
||||
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
|
||||
</strong>
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
|
||||
<br/>
|
||||
|
||||
</strong>
|
||||
<br>
|
||||
<strong>
|
||||
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
|
||||
</strong>
|
||||
<br>
|
||||
<br>
|
||||
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
|
||||
New: NetBird terraform provider
|
||||
</a>
|
||||
<strong>
|
||||
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
|
||||
</strong>
|
||||
</p>
|
||||
|
||||
<br>
|
||||
|
||||
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
|
||||
|
||||
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
||||
|
||||
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
||||
|
||||
### Open Source Network Security in a Single Platform
|
||||
|
||||
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||
|
||||
### Self-Host NetBird (Video)
|
||||
### Self-host NetBird (video)
|
||||
|
||||
[](https://youtu.be/bZAgpT6nzaQ)
|
||||
|
||||
### Key features
|
||||
|
||||
| Connectivity | Management | Security | Automation| Platforms |
|
||||
|----|----|----|----|----|
|
||||
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
|
||||
| <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
|
||||
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
|
||||
| <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
|
||||
| <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
|
||||
||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
|
||||
||| <ul><li>- \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ui></li> || <ul><li>- \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)</ui></li> |
|
||||
||||| <ul><li>- \[x] Docker</ui></li> |
|
||||
| Connectivity | Management | Security | Automation | Platforms |
|
||||
|---|---|---|---|---|
|
||||
| ✓ [Kernel WireGuard](https://docs.netbird.io/about-netbird/why-wireguard-with-netbird) | ✓ [Admin Web UI](https://github.com/netbirdio/dashboard) | ✓ [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) | ✓ [Public API](https://docs.netbird.io/api) | ✓ [Linux](https://docs.netbird.io/get-started/install/linux) |
|
||||
| ✓ [Peer-to-peer connections](https://docs.netbird.io/about-netbird/how-netbird-works) | ✓ Auto peer discovery and configuration | ✓ [Access control: groups & rules](https://docs.netbird.io/how-to/manage-network-access) | ✓ [Setup keys for bulk provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) | ✓ [macOS](https://docs.netbird.io/get-started/install/macos) |
|
||||
| ✓ Connection relay fallback | ✓ [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) | ✓ [Activity logging](https://docs.netbird.io/how-to/audit-events-logging) | ✓ [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) | ✓ [Windows](https://docs.netbird.io/get-started/install/windows) |
|
||||
| ✓ [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) | ✓ [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) | ✓ [Traffic events](https://docs.netbird.io/manage/activity/traffic-events-logging) | ✓ [IdP groups sync with JWT](https://docs.netbird.io/manage/team/idp-sync) | ✓ [Android](https://docs.netbird.io/get-started/install/android) |
|
||||
| ✓ [Domain-based DNS routes](https://docs.netbird.io/manage/dns/dns-aliases-for-routed-networks) | ✓ [Custom DNS zones](https://docs.netbird.io/manage/dns/custom-zones) | ✓ [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) | ✓ [Terraform provider](https://registry.terraform.io/providers/netbirdio/netbird/latest) | ✓ [Android TV](https://docs.netbird.io/get-started/install/android-tv) |
|
||||
| ✓ [Exit nodes](https://docs.netbird.io/manage/network-routes/use-cases/exit-nodes) | ✓ [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) | ✓ Peer-to-peer encryption | ✓ [Ansible collection](https://github.com/netbirdio/ansible-netbird) | ✓ [iOS](https://docs.netbird.io/get-started/install/ios) |
|
||||
| ✓ [IPv6 dual-stack overlay](https://docs.netbird.io/manage/settings/ipv6) | ✓ [Multi-account profile switching](https://docs.netbird.io/client/profiles) | ✓ [SSH with central access policies](https://docs.netbird.io/manage/peers/ssh) | | ✓ [Apple TV](https://docs.netbird.io/get-started/install/tvos) |
|
||||
| ✓ [Browser SSH & RDP](https://docs.netbird.io/manage/peers/browser-client) | | ✓ [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) | | ✓ FreeBSD |
|
||||
| ✓ [Reverse proxy with auto-TLS](https://docs.netbird.io/manage/reverse-proxy) | | ✓ [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication) | | ✓ [pfSense](https://docs.netbird.io/get-started/install/pfsense) |
|
||||
| | | | | ✓ [OPNsense](https://docs.netbird.io/get-started/install/opnsense) |
|
||||
| | | | | ✓ [MikroTik RouterOS](https://docs.netbird.io/use-cases/homelab/client-on-mikrotik-router) |
|
||||
| | | | | ✓ OpenWRT |
|
||||
| | | | | ✓ [Synology](https://docs.netbird.io/get-started/install/synology) |
|
||||
| | | | | ✓ [TrueNAS](https://docs.netbird.io/get-started/install/truenas) |
|
||||
| | | | | ✓ [Proxmox](https://docs.netbird.io/get-started/install/proxmox-ve) |
|
||||
| | | | | ✓ [Raspberry Pi](https://docs.netbird.io/get-started/install/raspberrypi) |
|
||||
| | | | | ✓ [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) |
|
||||
| | | | | ✓ [Container](https://docs.netbird.io/get-started/install/docker) |
|
||||
|
||||
### Quickstart with NetBird Cloud
|
||||
|
||||
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
|
||||
- Follow the steps to sign-up with Google, Microsoft, GitHub or your email address.
|
||||
- Check NetBird [admin UI](https://app.netbird.io/).
|
||||
- Add more machines.
|
||||
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install).
|
||||
- Follow the steps to sign up with Google, Microsoft, GitHub or your email address.
|
||||
- Check the NetBird [admin UI](https://app.netbird.io/).
|
||||
|
||||
### Quickstart with self-hosted NetBird
|
||||
|
||||
> This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM.
|
||||
Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IDPs.
|
||||
This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM. Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IdPs.
|
||||
|
||||
**Infrastructure requirements:**
|
||||
- A Linux VM with at least **1CPU** and **2GB** of memory.
|
||||
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port: **3478**.
|
||||
- **Public domain** name pointing to the VM.
|
||||
- A Linux VM with at least **1 CPU** and **2 GB** of memory.
|
||||
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port **3478**.
|
||||
- A **public domain** name pointing to the VM.
|
||||
|
||||
**Software requirements:**
|
||||
- Docker installed on the VM with the docker-compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
|
||||
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
|
||||
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
|
||||
- [curl](https://curl.se/) installed.
|
||||
Usually available in the official repositories and can be installed with `sudo apt install curl` or `sudo yum install curl`
|
||||
- Docker with the Compose plugin (Compose v2 or higher). See the [Docker installation guide](https://docs.docker.com/engine/install/).
|
||||
|
||||
**Steps**
|
||||
- Download and run the installation script:
|
||||
```bash
|
||||
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started.sh | bash
|
||||
```
|
||||
- Once finished, you can manage the resources via `docker-compose`
|
||||
|
||||
### A bit on NetBird internals
|
||||
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
|
||||
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
|
||||
- NetBird agent uses WebRTC ICE implemented in [pion/ice library](https://github.com/pion/ice) to discover connection candidates when establishing a peer-to-peer connection between machines.
|
||||
- Connection candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
|
||||
- Agents negotiate a connection through [Signal Service](signal/) passing p2p encrypted messages with candidates.
|
||||
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and a p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
|
||||
|
||||
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
|
||||
- Every machine in the network runs the [NetBird agent](client/), which manages WireGuard.
|
||||
- Every agent connects to the [Management Service](management/), which holds network state, manages peer IPs, and distributes updates to agents.
|
||||
- Agents use ICE (via [pion/ice](https://github.com/pion/ice)) to discover connection candidates for peer-to-peer connections.
|
||||
- Candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
|
||||
- Agents negotiate a connection through the [Signal Service](signal/), exchanging end-to-end encrypted messages with candidates.
|
||||
- When NAT traversal fails (e.g. mobile carrier-grade NAT) and a direct p2p connection isn't possible, the system falls back to a [Relay Service](relay/) and a secure WireGuard tunnel is established through it.
|
||||
|
||||
<p float="left" align="middle">
|
||||
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700"/>
|
||||
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700" alt="NetBird high-level architecture diagram"/>
|
||||
</p>
|
||||
|
||||
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
||||
|
||||
### Community projects
|
||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
|
||||
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) — terminal UI for managing NetBird peers, routes, and settings
|
||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) - terminal UI for managing NetBird peers, routes, and settings
|
||||
- [caddy-netbird](https://github.com/lixmal/caddy-netbird) - Caddy plugin that embeds a NetBird client for proxying HTTP and TCP/UDP traffic through NetBird networks
|
||||
|
||||
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
||||
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
||||
|
||||
### Support acknowledgement
|
||||
|
||||
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by The Federal Ministry of Education and Research of The Federal Republic of Germany. Together with [CISPA Helmholtz Center for Information Security](https://cispa.de/en) NetBird brings the security best practices and simplicity to private networking.
|
||||
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by the Federal Ministry of Education and Research of the Federal Republic of Germany. Together with the [CISPA Helmholtz Center for Information Security](https://cispa.de/en), NetBird brings security best practices and simplicity to private networking.
|
||||
|
||||

|
||||
|
||||
### Testimonials
|
||||
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g., by giving a star or a contribution).
|
||||
### Acknowledgements
|
||||
We build on open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE](https://github.com/pion/ice), and [Rosenpass](https://rosenpass.eu). We greatly appreciate the work these projects are doing, and we'd love it if you could support them too (e.g., by starring or contributing).
|
||||
|
||||
### Legal
|
||||
This repository is licensed under BSD-3-Clause license that applies to all parts of the repository except for the directories management/, signal/ and relay/.
|
||||
This repository is licensed under the BSD-3-Clause license, which applies to all parts of the repository except for the directories management/, signal/ and relay/.
|
||||
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
|
||||
|
||||
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"go.opentelemetry.io/otel"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
|
||||
@@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
|
||||
iv, _ := validator.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
@@ -84,6 +85,12 @@ type Options struct {
|
||||
DisableIPv6 bool
|
||||
// BlockInbound blocks all inbound connections from peers
|
||||
BlockInbound bool
|
||||
// BlockLANAccess blocks the embedded peer from reaching the host's
|
||||
// LAN (RFC 1918, link-local, loopback) when it's used as a routing
|
||||
// peer. Mirrors profilemanager.ConfigInput.BlockLANAccess. Useful
|
||||
// when the embedded client must never act as a stepping stone into
|
||||
// the host's local network (e.g. the proxy's overlay peer).
|
||||
BlockLANAccess bool
|
||||
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
|
||||
WireguardPort *int
|
||||
// MTU is the MTU for the tunnel interface.
|
||||
@@ -94,6 +101,26 @@ type Options struct {
|
||||
MTU *uint16
|
||||
// DNSLabels defines additional DNS labels configured in the peer.
|
||||
DNSLabels []string
|
||||
// Performance configures the tunnel's buffer pool cap and batch size.
|
||||
Performance Performance
|
||||
}
|
||||
|
||||
// Performance configures the embedded client's tunnel memory/throughput knobs.
|
||||
//
|
||||
// These settings are process-global: any non-nil field also becomes the
|
||||
// default for Clients constructed by later embed.New calls in the same
|
||||
// process. Nil fields are ignored.
|
||||
type Performance struct {
|
||||
// PreallocatedBuffersPerPool caps the per-tunnel buffer pool. Zero
|
||||
// leaves the pool unbounded. Lower values trade throughput for a
|
||||
// tighter memory ceiling. May also be changed on a running Client via
|
||||
// Client.SetPerformance, provided this field was nonzero at construction.
|
||||
PreallocatedBuffersPerPool *uint32
|
||||
// MaxBatchSize overrides the number of packets the tunnel reads or
|
||||
// writes per syscall, which also bounds eager buffer allocation per
|
||||
// worker. Zero uses the platform default. Applied at construction
|
||||
// only; ignored by Client.SetPerformance.
|
||||
MaxBatchSize *uint32
|
||||
}
|
||||
|
||||
// validateCredentials checks that exactly one credential type is provided
|
||||
@@ -175,6 +202,7 @@ func New(opts Options) (*Client, error) {
|
||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||
DisableIPv6: &opts.DisableIPv6,
|
||||
BlockInbound: &opts.BlockInbound,
|
||||
BlockLANAccess: &opts.BlockLANAccess,
|
||||
WireguardPort: opts.WireguardPort,
|
||||
MTU: opts.MTU,
|
||||
DNSLabels: parsedLabels,
|
||||
@@ -192,6 +220,13 @@ func New(opts Options) (*Client, error) {
|
||||
config.PrivateKey = opts.PrivateKey
|
||||
}
|
||||
|
||||
if opts.Performance.PreallocatedBuffersPerPool != nil {
|
||||
wgdevice.SetPreallocatedBuffersPerPool(*opts.Performance.PreallocatedBuffersPerPool)
|
||||
}
|
||||
if opts.Performance.MaxBatchSize != nil {
|
||||
wgdevice.SetMaxBatchSizeOverride(*opts.Performance.MaxBatchSize)
|
||||
}
|
||||
|
||||
return &Client{
|
||||
deviceName: opts.DeviceName,
|
||||
setupKey: opts.SetupKey,
|
||||
@@ -405,6 +440,21 @@ func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// IdentityForIP looks up a remote peer by its tunnel IP using the
|
||||
// embedded client's status recorder. Returns the peer's WireGuard public
|
||||
// key and FQDN. ok=false means the IP isn't in this client's peer
|
||||
// roster — callers should treat that as "unknown peer".
|
||||
func (c *Client) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
|
||||
if !ip.IsValid() || c.recorder == nil {
|
||||
return "", "", false
|
||||
}
|
||||
state, found := c.recorder.PeerStateByIP(ip.String())
|
||||
if !found {
|
||||
return "", "", false
|
||||
}
|
||||
return state.PubKey, state.FQDN, true
|
||||
}
|
||||
|
||||
// Status returns the current status of the client.
|
||||
func (c *Client) Status() (peer.FullStatus, error) {
|
||||
c.mu.Lock()
|
||||
@@ -473,6 +523,25 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
|
||||
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
||||
}
|
||||
|
||||
// SetPerformance retunes a running Client. Only PreallocatedBuffersPerPool
|
||||
// takes effect, and only when it was nonzero at construction;
|
||||
// MaxBatchSize is construction-only and returns an error if set here.
|
||||
//
|
||||
// Returns ErrClientNotStarted / ErrEngineNotStarted if the Client is not
|
||||
// running yet.
|
||||
func (c *Client) SetPerformance(t Performance) error {
|
||||
if t.MaxBatchSize != nil {
|
||||
return errors.New("MaxBatchSize is construction-only and cannot be changed at runtime")
|
||||
}
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return engine.SetPerformance(internal.Performance{
|
||||
PreallocatedBuffersPerPool: t.PreallocatedBuffersPerPool,
|
||||
})
|
||||
}
|
||||
|
||||
// StartCapture begins capturing packets on this client's tunnel device.
|
||||
// Only one capture can be active at a time; starting a new one stops the previous.
|
||||
// Call StopCapture (or CaptureSession.Stop) to end it.
|
||||
|
||||
@@ -52,9 +52,10 @@ func (m *externalChainMonitor) start() {
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m.cancel = cancel
|
||||
m.done = make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
m.done = done
|
||||
|
||||
go m.run(ctx)
|
||||
go m.run(ctx, done)
|
||||
}
|
||||
|
||||
func (m *externalChainMonitor) stop() {
|
||||
@@ -72,8 +73,8 @@ func (m *externalChainMonitor) stop() {
|
||||
<-done
|
||||
}
|
||||
|
||||
func (m *externalChainMonitor) run(ctx context.Context) {
|
||||
defer close(m.done)
|
||||
func (m *externalChainMonitor) run(ctx context.Context, done chan struct{}) {
|
||||
defer close(done)
|
||||
|
||||
bo := &backoff.ExponentialBackOff{
|
||||
InitialInterval: externalMonitorInitInterval,
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// EnvDisable disables the DNS firewall entirely when set to a truthy value.
|
||||
EnvDisable = "NB_DISABLE_DNS_FIREWALL"
|
||||
// EnvPorts overrides the comma-separated list of remote ports to block.
|
||||
// Empty disables the firewall.
|
||||
EnvPorts = "NB_DNS_FIREWALL_PORTS"
|
||||
// EnvStrict enables strict mode: permit DNS only to the virtual DNS IP
|
||||
// and the netbird daemon. Default mode also permits anything on the
|
||||
// netbird tunnel interface, which is safer if NRPT is silently ignored
|
||||
// by Windows but lets apps reach custom DNS servers via the tunnel.
|
||||
EnvStrict = "NB_DNS_FIREWALL_STRICT"
|
||||
)
|
||||
|
||||
// defaultBlockedPorts are the well-known DNS ports we block for non-netbird
|
||||
// processes: 53 (plain DNS) and 853 (DNS-over-TLS).
|
||||
var defaultBlockedPorts = []uint16{53, 853}
|
||||
|
||||
// blockedPorts returns the effective port list, honoring env overrides.
|
||||
// A nil return means the firewall should not be installed.
|
||||
func blockedPorts() []uint16 {
|
||||
if disabled, _ := strconv.ParseBool(os.Getenv(EnvDisable)); disabled {
|
||||
log.Infof("dns firewall disabled via %s", EnvDisable)
|
||||
return nil
|
||||
}
|
||||
|
||||
override, ok := os.LookupEnv(EnvPorts)
|
||||
if !ok {
|
||||
return defaultBlockedPorts
|
||||
}
|
||||
|
||||
var ports []uint16
|
||||
for _, raw := range strings.Split(override, ",") {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
port, err := strconv.ParseUint(raw, 10, 16)
|
||||
if err != nil {
|
||||
log.Warnf("dns firewall: ignoring invalid port %q in %s: %v", raw, EnvPorts, err)
|
||||
continue
|
||||
}
|
||||
if port == 0 {
|
||||
log.Warnf("dns firewall: ignoring port 0 in %s", EnvPorts)
|
||||
continue
|
||||
}
|
||||
ports = append(ports, uint16(port))
|
||||
}
|
||||
if len(ports) == 0 {
|
||||
log.Infof("dns firewall disabled: %s yielded no valid ports", EnvPorts)
|
||||
return nil
|
||||
}
|
||||
return ports
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBlockedPorts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
disable string
|
||||
ports string
|
||||
setPorts bool
|
||||
want []uint16
|
||||
}{
|
||||
{name: "default", want: defaultBlockedPorts},
|
||||
{name: "disabled", disable: "true", want: nil},
|
||||
{name: "disabled false keeps default", disable: "false", want: defaultBlockedPorts},
|
||||
{name: "override single port", ports: "53", setPorts: true, want: []uint16{53}},
|
||||
{name: "override multi", ports: "53, 853 ,5353", setPorts: true, want: []uint16{53, 853, 5353}},
|
||||
{name: "override empty disables", ports: "", setPorts: true, want: nil},
|
||||
{name: "override invalid skipped", ports: "53,not-a-port,853", setPorts: true, want: []uint16{53, 853}},
|
||||
{name: "override zero skipped", ports: "53,0,853", setPorts: true, want: []uint16{53, 853}},
|
||||
{name: "override only invalid disables", ports: "abc", setPorts: true, want: nil},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(EnvDisable, tc.disable)
|
||||
if tc.setPorts {
|
||||
t.Setenv(EnvPorts, tc.ports)
|
||||
}
|
||||
got := blockedPorts()
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Fatalf("blockedPorts() = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
// Package dnsfw blocks DNS traffic from non-netbird processes when netbird is
|
||||
// managing the host's DNS, so that resolvers running on apps or libraries
|
||||
// outside netbird cannot bypass the configured DNS path.
|
||||
//
|
||||
// Implementation is Windows-only (uses WFP). On other platforms New returns
|
||||
// a no-op manager.
|
||||
package dnsfw
|
||||
|
||||
import "net/netip"
|
||||
|
||||
// Manager controls the per-tunnel DNS firewall. Both methods must be safe
|
||||
// to call multiple times.
|
||||
type Manager interface {
|
||||
Enable(ifaceGUID string, virtualDNSIP netip.Addr) error
|
||||
Disable() error
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package dnsfw
|
||||
|
||||
import "net/netip"
|
||||
|
||||
type noopManager struct{}
|
||||
|
||||
func (noopManager) Enable(string, netip.Addr) error { return nil }
|
||||
func (noopManager) Disable() error { return nil }
|
||||
|
||||
// New returns a no-op manager on non-Windows platforms.
|
||||
func New() Manager {
|
||||
return noopManager{}
|
||||
}
|
||||
@@ -1,144 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
modIphlpapi = windows.NewLazyDLL("iphlpapi.dll")
|
||||
procConvertInterfaceGuidToLuid = modIphlpapi.NewProc("ConvertInterfaceGuidToLuid")
|
||||
)
|
||||
|
||||
type windowsManager struct {
|
||||
mu sync.Mutex
|
||||
// session is the WFP engine handle. Zero when disabled.
|
||||
session uintptr
|
||||
}
|
||||
|
||||
// Enable installs the dns firewall. Strict mode propagates failures;
|
||||
// non-strict mode logs and returns nil so partial protection is preserved.
|
||||
func (m *windowsManager) Enable(ifaceGUID string, virtualDNSIP netip.Addr) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
ports := blockedPorts()
|
||||
if len(ports) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if m.session != 0 {
|
||||
if err := m.disableLocked(); err != nil {
|
||||
return fmt.Errorf("reset existing dns firewall session: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
strict := strictMode()
|
||||
|
||||
luid, err := luidFromGUID(ifaceGUID)
|
||||
if err != nil {
|
||||
return m.failOrLog(strict, fmt.Errorf("resolve tun luid from guid %s: %w", ifaceGUID, err))
|
||||
}
|
||||
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return m.failOrLog(strict, fmt.Errorf("resolve daemon executable path: %w", err))
|
||||
}
|
||||
|
||||
cfg := installConfig{
|
||||
tunLUID: luid,
|
||||
daemonExe: exe,
|
||||
blockedPorts: ports,
|
||||
strict: strict,
|
||||
virtualDNSIP: virtualDNSIP,
|
||||
}
|
||||
// session==0 signals a hard failure; non-zero with non-nil err is a partial install.
|
||||
session, installErr := installFilters(cfg)
|
||||
if session == 0 {
|
||||
return m.failOrLog(strict, fmt.Errorf("install dns firewall filters: %w", installErr))
|
||||
}
|
||||
|
||||
if installErr != nil && strict {
|
||||
_ = closeSession(session)
|
||||
return fmt.Errorf("strict dns firewall: partial install: %w", installErr)
|
||||
}
|
||||
|
||||
m.session = session
|
||||
log.Infof("dns firewall installed: iface=%s daemon=%s ports=%v strict=%v virtual_dns=%s",
|
||||
ifaceGUID, exe, ports, strict, virtualDNSIP)
|
||||
if installErr != nil {
|
||||
log.Warnf("dns firewall partially installed (some filters failed): %v", installErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *windowsManager) Disable() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.disableLocked()
|
||||
}
|
||||
|
||||
func (m *windowsManager) disableLocked() error {
|
||||
if m.session == 0 {
|
||||
return nil
|
||||
}
|
||||
session := m.session
|
||||
m.session = 0
|
||||
if err := closeSession(session); err != nil {
|
||||
return fmt.Errorf("close wfp session: %w", err)
|
||||
}
|
||||
log.Info("dns firewall removed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// failOrLog returns err unchanged in strict mode. In non-strict mode the
|
||||
// error is logged and nil is returned.
|
||||
func (m *windowsManager) failOrLog(strict bool, err error) error {
|
||||
if strict {
|
||||
return err
|
||||
}
|
||||
log.Errorf("dns firewall: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// New returns a Windows DNS firewall manager backed by WFP.
|
||||
func New() Manager {
|
||||
return &windowsManager{}
|
||||
}
|
||||
|
||||
// strictMode reports whether strict mode is enabled via env.
|
||||
func strictMode() bool {
|
||||
v, _ := strconv.ParseBool(os.Getenv(EnvStrict))
|
||||
return v
|
||||
}
|
||||
|
||||
// luidFromGUID converts a Windows interface GUID string to its LUID.
|
||||
func luidFromGUID(ifaceGUID string) (luid uint64, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in luidFromGUID: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
guid, err := windows.GUIDFromString(ifaceGUID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parse guid: %w", err)
|
||||
}
|
||||
rc, _, _ := procConvertInterfaceGuidToLuid.Call(
|
||||
uintptr(unsafe.Pointer(&guid)),
|
||||
uintptr(unsafe.Pointer(&luid)),
|
||||
)
|
||||
if rc != 0 {
|
||||
return 0, fmt.Errorf("ConvertInterfaceGuidToLuid returned %d", rc)
|
||||
}
|
||||
return luid, nil
|
||||
}
|
||||
@@ -1,72 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStrictMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
val string
|
||||
set bool
|
||||
want bool
|
||||
}{
|
||||
{name: "unset", want: false},
|
||||
{name: "true", val: "true", set: true, want: true},
|
||||
{name: "1", val: "1", set: true, want: true},
|
||||
{name: "false", val: "false", set: true, want: false},
|
||||
{name: "invalid is false", val: "garbage", set: true, want: false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(EnvStrict, tc.val)
|
||||
if !tc.set {
|
||||
os.Unsetenv(EnvStrict)
|
||||
}
|
||||
if got := strictMode(); got != tc.want {
|
||||
t.Fatalf("strictMode() = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWindowsManagerDisableIdempotent(t *testing.T) {
|
||||
m := &windowsManager{}
|
||||
if err := m.Disable(); err != nil {
|
||||
t.Fatalf("first Disable on fresh manager: %v", err)
|
||||
}
|
||||
if err := m.Disable(); err != nil {
|
||||
t.Fatalf("second Disable on fresh manager: %v", err)
|
||||
}
|
||||
if m.session != 0 {
|
||||
t.Fatalf("session should remain zero, got %d", m.session)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWindowsManagerEnableNoOpWhenDisabledByEnv(t *testing.T) {
|
||||
t.Setenv(EnvDisable, "true")
|
||||
|
||||
m := &windowsManager{}
|
||||
if err := m.Enable("00000000-0000-0000-0000-000000000000", netip.Addr{}); err != nil {
|
||||
t.Fatalf("Enable should be a no-op when firewall disabled by env: %v", err)
|
||||
}
|
||||
if m.session != 0 {
|
||||
t.Fatalf("session must remain zero when env disables firewall, got %d", m.session)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWindowsManagerEnableNoOpWhenPortsEmpty(t *testing.T) {
|
||||
t.Setenv(EnvPorts, "")
|
||||
|
||||
m := &windowsManager{}
|
||||
if err := m.Enable("00000000-0000-0000-0000-000000000000", netip.Addr{}); err != nil {
|
||||
t.Fatalf("Enable should be a no-op when ports list is empty: %v", err)
|
||||
}
|
||||
if m.session != 0 {
|
||||
t.Fatalf("session must remain zero when ports list is empty, got %d", m.session)
|
||||
}
|
||||
}
|
||||
@@ -1,53 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/helpers.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func createWtFwpmDisplayData0(name, description string) (*wtFwpmDisplayData0, error) {
|
||||
namePtr, err := windows.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
|
||||
descriptionPtr, err := windows.UTF16PtrFromString(description)
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
|
||||
return &wtFwpmDisplayData0{
|
||||
name: namePtr,
|
||||
description: descriptionPtr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func filterWeight(weight uint8) wtFwpValue0 {
|
||||
return wtFwpValue0{
|
||||
_type: cFWP_UINT8,
|
||||
value: uintptr(weight),
|
||||
}
|
||||
}
|
||||
|
||||
func wrapErr(err error) error {
|
||||
var errno syscall.Errno
|
||||
if !errors.As(err, &errno) {
|
||||
return err
|
||||
}
|
||||
_, file, line, ok := runtime.Caller(1)
|
||||
if !ok {
|
||||
return fmt.Errorf("wfp error at unknown location: %w", err)
|
||||
}
|
||||
return fmt.Errorf("wfp error at %s:%d: %w", file, line, err)
|
||||
}
|
||||
@@ -1,249 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2026 NetBird GmbH. All Rights Reserved.
|
||||
*
|
||||
* Filter installers adapted from wireguard-windows tunnel/firewall/rules.go.
|
||||
* The block-DNS approach (port 53 + UDP/TCP) matches what wireguard-windows
|
||||
* uses for its kill-switch DNS leak protection. We extend it with a
|
||||
* configurable port set so we also cover :853 (DoT) and any future ports.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
// Filters install at outbound ALE_AUTH_CONNECT layers only; inbound replies
|
||||
// follow the authorized outbound flow.
|
||||
|
||||
// permitTunInterface installs a permit filter for any traffic whose local
|
||||
// interface is the netbird tunnel.
|
||||
func permitTunInterface(session uintptr, base *baseObjects, weight uint8, ifLUID uint64) error {
|
||||
cond := wtFwpmFilterCondition0{
|
||||
fieldKey: cFWPM_CONDITION_IP_LOCAL_INTERFACE,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT64,
|
||||
value: uintptr(unsafe.Pointer(&ifLUID)),
|
||||
},
|
||||
}
|
||||
|
||||
filter := wtFwpmFilter0{
|
||||
providerKey: &base.provider,
|
||||
subLayerKey: base.filters,
|
||||
weight: filterWeight(weight),
|
||||
numFilterConditions: 1,
|
||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&cond)),
|
||||
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
|
||||
}
|
||||
|
||||
return addOutboundFilters(session, &filter, "Permit netbird tunnel")
|
||||
}
|
||||
|
||||
// permitDaemonByAppID installs a permit filter matching the netbird daemon
|
||||
// executable by App-ID. App-ID alone is sufficient because netbird.exe is a
|
||||
// dedicated binary.
|
||||
func permitDaemonByAppID(session uintptr, base *baseObjects, daemonExe string, weight uint8) error {
|
||||
appID, err := daemonAppID(daemonExe)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fwpmFreeMemory0(unsafe.Pointer(&appID))
|
||||
|
||||
cond := wtFwpmFilterCondition0{
|
||||
fieldKey: cFWPM_CONDITION_ALE_APP_ID,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_BYTE_BLOB_TYPE,
|
||||
value: uintptr(unsafe.Pointer(appID)),
|
||||
},
|
||||
}
|
||||
|
||||
filter := wtFwpmFilter0{
|
||||
providerKey: &base.provider,
|
||||
subLayerKey: base.filters,
|
||||
weight: filterWeight(weight),
|
||||
numFilterConditions: 1,
|
||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&cond)),
|
||||
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
|
||||
}
|
||||
|
||||
return addOutboundFilters(session, &filter, "Permit netbird daemon")
|
||||
}
|
||||
|
||||
// permitVirtualDNSIP installs a permit filter for DNS-port traffic destined
|
||||
// for the in-tunnel virtual DNS IP. Used in strict mode in lieu of
|
||||
// permitTunInterface.
|
||||
func permitVirtualDNSIP(session uintptr, base *baseObjects, ip netip.Addr, ports []uint16, weight uint8) error {
|
||||
var merr *multierror.Error
|
||||
for _, port := range ports {
|
||||
if err := permitDNSToHost(session, base, ip, port, weight); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("permit %s:%d: %w", ip, port, err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func permitDNSToHost(session uintptr, base *baseObjects, ip netip.Addr, port uint16, weight uint8) error {
|
||||
if !ip.IsValid() {
|
||||
return fmt.Errorf("invalid address")
|
||||
}
|
||||
|
||||
var addrCond wtFwpmFilterCondition0
|
||||
var layer windows.GUID
|
||||
// v6 backing must outlive fwpmFilterAdd0; keep it on this stack frame.
|
||||
var v6 wtFwpByteArray16
|
||||
|
||||
if ip.Is4() {
|
||||
v4 := ip.As4()
|
||||
addrCond = wtFwpmFilterCondition0{
|
||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT32,
|
||||
value: uintptr(binary.BigEndian.Uint32(v4[:])),
|
||||
},
|
||||
}
|
||||
layer = cFWPM_LAYER_ALE_AUTH_CONNECT_V4
|
||||
} else {
|
||||
v6 = wtFwpByteArray16{byteArray16: ip.As16()}
|
||||
addrCond = wtFwpmFilterCondition0{
|
||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_BYTE_ARRAY16_TYPE,
|
||||
value: uintptr(unsafe.Pointer(&v6)),
|
||||
},
|
||||
}
|
||||
layer = cFWPM_LAYER_ALE_AUTH_CONNECT_V6
|
||||
}
|
||||
|
||||
conditions := [2]wtFwpmFilterCondition0{
|
||||
addrCond,
|
||||
{
|
||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_PORT,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT16,
|
||||
value: uintptr(port),
|
||||
},
|
||||
},
|
||||
}
|
||||
filter := wtFwpmFilter0{
|
||||
providerKey: &base.provider,
|
||||
subLayerKey: base.filters,
|
||||
weight: filterWeight(weight),
|
||||
numFilterConditions: uint32(len(conditions)),
|
||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&conditions[0])),
|
||||
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
|
||||
}
|
||||
|
||||
display, err := createWtFwpmDisplayData0(fmt.Sprintf("Permit DNS to %s:%d", ip, port), "")
|
||||
if err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
filter.displayData = *display
|
||||
filter.layerKey = layer
|
||||
|
||||
var filterID uint64
|
||||
if err := fwpmFilterAdd0(session, &filter, 0, &filterID); err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
_ = v6
|
||||
return nil
|
||||
}
|
||||
|
||||
// blockDNSPorts installs a deny filter for outbound traffic to each of the
|
||||
// given remote ports over UDP or TCP. Per-port and per-layer failures are
|
||||
// accumulated; partial coverage is preferred over zero coverage.
|
||||
func blockDNSPorts(session uintptr, base *baseObjects, ports []uint16, weight uint8) error {
|
||||
var merr *multierror.Error
|
||||
for _, port := range ports {
|
||||
if err := blockDNSPort(session, base, port, weight); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("block port %d: %w", port, err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func blockDNSPort(session uintptr, base *baseObjects, port uint16, weight uint8) error {
|
||||
conditions := [3]wtFwpmFilterCondition0{
|
||||
{
|
||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_PORT,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT16,
|
||||
value: uintptr(port),
|
||||
},
|
||||
},
|
||||
{
|
||||
fieldKey: cFWPM_CONDITION_IP_PROTOCOL,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT8,
|
||||
value: uintptr(cIPPROTO_UDP),
|
||||
},
|
||||
},
|
||||
// Repeat the IP_PROTOCOL condition for logical OR with TCP.
|
||||
{
|
||||
fieldKey: cFWPM_CONDITION_IP_PROTOCOL,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT8,
|
||||
value: uintptr(cIPPROTO_TCP),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
filter := wtFwpmFilter0{
|
||||
providerKey: &base.provider,
|
||||
subLayerKey: base.filters,
|
||||
weight: filterWeight(weight),
|
||||
numFilterConditions: uint32(len(conditions)),
|
||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&conditions[0])),
|
||||
action: wtFwpmAction0{_type: cFWP_ACTION_BLOCK},
|
||||
}
|
||||
|
||||
return addOutboundFilters(session, &filter, fmt.Sprintf("Block DNS port %d", port))
|
||||
}
|
||||
|
||||
// addOutboundFilters installs the same filter on the v4 and v6 outbound ALE
|
||||
// connect layers. v4 and v6 are installed independently: failure on one
|
||||
// layer does not abort the other, and the accumulated errors are returned.
|
||||
// Partial coverage is preferred over zero coverage.
|
||||
func addOutboundFilters(session uintptr, filter *wtFwpmFilter0, name string) error {
|
||||
layers := [...]struct {
|
||||
layer windows.GUID
|
||||
label string
|
||||
}{
|
||||
{cFWPM_LAYER_ALE_AUTH_CONNECT_V4, name + " (IPv4)"},
|
||||
{cFWPM_LAYER_ALE_AUTH_CONNECT_V6, name + " (IPv6)"},
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, l := range layers {
|
||||
display, err := createWtFwpmDisplayData0(l.label, "")
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("%s: %w", l.label, wrapErr(err)))
|
||||
continue
|
||||
}
|
||||
filter.displayData = *display
|
||||
filter.layerKey = l.layer
|
||||
|
||||
var filterID uint64
|
||||
if err := fwpmFilterAdd0(session, filter, 0, &filterID); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("%s: %w", l.label, wrapErr(err)))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
@@ -1,177 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2026 NetBird GmbH. All Rights Reserved.
|
||||
*
|
||||
* Session lifecycle and the high-level Install/Close entry points adapted
|
||||
* from wireguard-windows tunnel/firewall.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
// installConfig is the input to installFilters.
|
||||
type installConfig struct {
|
||||
tunLUID uint64
|
||||
daemonExe string
|
||||
blockedPorts []uint16
|
||||
// strict, when true, narrows the carve-out from "anything on tun" to
|
||||
// "DNS only to virtualDNSIP". virtualDNSIP must be valid in this case.
|
||||
strict bool
|
||||
virtualDNSIP netip.Addr
|
||||
}
|
||||
|
||||
// baseObjects holds the GUIDs of the WFP provider and sublayer registered
|
||||
// for our session. Both are randomly generated per session.
|
||||
type baseObjects struct {
|
||||
provider windows.GUID
|
||||
filters windows.GUID
|
||||
}
|
||||
|
||||
// installFilters opens a dynamic WFP session and installs the netbird DNS
|
||||
// firewall filters. Returns a zero session on hard failure (session create,
|
||||
// base objects); a non-zero session with a non-nil error is a partial install
|
||||
// (some per-filter installs failed) and is safe to close.
|
||||
func installFilters(cfg installConfig) (session uintptr, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Dynamic session: kernel will clean up on process exit even
|
||||
// if we leave the handle dangling here.
|
||||
err = fmt.Errorf("panic in installFilters: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
if len(cfg.blockedPorts) == 0 {
|
||||
return 0, errors.New("dns firewall: no blocked ports configured")
|
||||
}
|
||||
if cfg.strict && !cfg.virtualDNSIP.IsValid() {
|
||||
return 0, errors.New("dns firewall: strict mode requires a valid virtual DNS IP")
|
||||
}
|
||||
|
||||
session, err = createSession()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
base, err := registerBaseObjects(session)
|
||||
if err != nil {
|
||||
_ = fwpmEngineClose0(session)
|
||||
return 0, fmt.Errorf("register base objects: %w", err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
if cfg.strict {
|
||||
if err := permitVirtualDNSIP(session, base, cfg.virtualDNSIP, cfg.blockedPorts, 15); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("permit virtual dns: %w", err))
|
||||
}
|
||||
} else {
|
||||
if err := permitTunInterface(session, base, 15, cfg.tunLUID); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("permit tun interface: %w", err))
|
||||
}
|
||||
}
|
||||
if err := permitDaemonByAppID(session, base, cfg.daemonExe, 14); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("permit netbird daemon: %w", err))
|
||||
}
|
||||
if err := blockDNSPorts(session, base, cfg.blockedPorts, 10); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("block dns ports: %w", err))
|
||||
}
|
||||
|
||||
return session, nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// closeSession tears down a WFP session previously opened by installFilters.
|
||||
// All filters owned by the session are removed.
|
||||
func closeSession(session uintptr) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in closeSession: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
if session == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := fwpmEngineClose0(session); err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func createSession() (uintptr, error) {
|
||||
displayData, err := createWtFwpmDisplayData0("NetBird DNS firewall", "NetBird DNS firewall dynamic session")
|
||||
if err != nil {
|
||||
return 0, wrapErr(err)
|
||||
}
|
||||
session := wtFwpmSession0{
|
||||
displayData: *displayData,
|
||||
flags: cFWPM_SESSION_FLAG_DYNAMIC,
|
||||
txnWaitTimeoutInMSec: windows.INFINITE,
|
||||
}
|
||||
var handle uintptr
|
||||
if err := fwpmEngineOpen0(nil, cRPC_C_AUTHN_WINNT, nil, &session, unsafe.Pointer(&handle)); err != nil {
|
||||
return 0, wrapErr(err)
|
||||
}
|
||||
return handle, nil
|
||||
}
|
||||
|
||||
func registerBaseObjects(session uintptr) (*baseObjects, error) {
|
||||
bo := &baseObjects{}
|
||||
var err error
|
||||
if bo.provider, err = windows.GenerateGUID(); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
if bo.filters, err = windows.GenerateGUID(); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
|
||||
displayData, err := createWtFwpmDisplayData0("NetBird DNS firewall", "NetBird DNS firewall provider")
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
provider := wtFwpmProvider0{
|
||||
providerKey: bo.provider,
|
||||
displayData: *displayData,
|
||||
}
|
||||
if err := fwpmProviderAdd0(session, &provider, 0); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
|
||||
subDisplay, err := createWtFwpmDisplayData0("NetBird DNS firewall filters", "Permit and block filters")
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
sublayer := wtFwpmSublayer0{
|
||||
subLayerKey: bo.filters,
|
||||
displayData: *subDisplay,
|
||||
providerKey: &bo.provider,
|
||||
weight: ^uint16(0),
|
||||
}
|
||||
if err := fwpmSubLayerAdd0(session, &sublayer, 0); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
return bo, nil
|
||||
}
|
||||
|
||||
// daemonAppID returns the WFP App-ID byte blob for the given executable path.
|
||||
func daemonAppID(path string) (*wtFwpByteBlob, error) {
|
||||
pathPtr, err := windows.UTF16PtrFromString(path)
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
var appID *wtFwpByteBlob
|
||||
if err := fwpmGetAppIdFromFileName0(pathPtr, unsafe.Pointer(&appID)); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
return appID, nil
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/syscall_windows.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmengineopen0
|
||||
//sys fwpmEngineOpen0(serverName *uint16, authnService wtRpcCAuthN, authIdentity *uintptr, session *wtFwpmSession0, engineHandle unsafe.Pointer) (err error) [failretval!=0] = fwpuclnt.FwpmEngineOpen0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmengineclose0
|
||||
//sys fwpmEngineClose0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmEngineClose0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmsublayeradd0
|
||||
//sys fwpmSubLayerAdd0(engineHandle uintptr, subLayer *wtFwpmSublayer0, sd uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmSubLayerAdd0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmgetappidfromfilename0
|
||||
//sys fwpmGetAppIdFromFileName0(fileName *uint16, appID unsafe.Pointer) (err error) [failretval!=0] = fwpuclnt.FwpmGetAppIdFromFileName0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmfreememory0
|
||||
//sys fwpmFreeMemory0(p unsafe.Pointer) = fwpuclnt.FwpmFreeMemory0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmfilteradd0
|
||||
//sys fwpmFilterAdd0(engineHandle uintptr, filter *wtFwpmFilter0, sd uintptr, id *uint64) (err error) [failretval!=0] = fwpuclnt.FwpmFilterAdd0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/Fwpmu/nf-fwpmu-fwpmtransactionbegin0
|
||||
//sys fwpmTransactionBegin0(engineHandle uintptr, flags uint32) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionBegin0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmtransactioncommit0
|
||||
//sys fwpmTransactionCommit0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionCommit0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmtransactionabort0
|
||||
//sys fwpmTransactionAbort0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionAbort0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmprovideradd0
|
||||
//sys fwpmProviderAdd0(engineHandle uintptr, provider *wtFwpmProvider0, sd uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmProviderAdd0
|
||||
@@ -1,414 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/types_windows.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
const (
|
||||
anysizeArray = 1 // ANYSIZE_ARRAY defined in winnt.h
|
||||
|
||||
wtFwpBitmapArray64_Size = 8
|
||||
|
||||
wtFwpByteArray16_Size = 16
|
||||
|
||||
wtFwpByteArray6_Size = 6
|
||||
|
||||
wtFwpmAction0_Size = 20
|
||||
wtFwpmAction0_filterType_Offset = 4
|
||||
|
||||
wtFwpV4AddrAndMask_Size = 8
|
||||
wtFwpV4AddrAndMask_mask_Offset = 4
|
||||
|
||||
wtFwpV6AddrAndMask_Size = 17
|
||||
wtFwpV6AddrAndMask_prefixLength_Offset = 16
|
||||
)
|
||||
|
||||
type wtFwpActionFlag uint32
|
||||
|
||||
const (
|
||||
cFWP_ACTION_FLAG_TERMINATING wtFwpActionFlag = 0x00001000
|
||||
cFWP_ACTION_FLAG_NON_TERMINATING wtFwpActionFlag = 0x00002000
|
||||
cFWP_ACTION_FLAG_CALLOUT wtFwpActionFlag = 0x00004000
|
||||
)
|
||||
|
||||
// FWP_ACTION_TYPE defined in fwptypes.h
|
||||
type wtFwpActionType uint32
|
||||
|
||||
const (
|
||||
cFWP_ACTION_BLOCK wtFwpActionType = wtFwpActionType(0x00000001 | cFWP_ACTION_FLAG_TERMINATING)
|
||||
cFWP_ACTION_PERMIT wtFwpActionType = wtFwpActionType(0x00000002 | cFWP_ACTION_FLAG_TERMINATING)
|
||||
cFWP_ACTION_CALLOUT_TERMINATING wtFwpActionType = wtFwpActionType(0x00000003 | cFWP_ACTION_FLAG_CALLOUT | cFWP_ACTION_FLAG_TERMINATING)
|
||||
cFWP_ACTION_CALLOUT_INSPECTION wtFwpActionType = wtFwpActionType(0x00000004 | cFWP_ACTION_FLAG_CALLOUT | cFWP_ACTION_FLAG_NON_TERMINATING)
|
||||
cFWP_ACTION_CALLOUT_UNKNOWN wtFwpActionType = wtFwpActionType(0x00000005 | cFWP_ACTION_FLAG_CALLOUT)
|
||||
cFWP_ACTION_CONTINUE wtFwpActionType = wtFwpActionType(0x00000006 | cFWP_ACTION_FLAG_NON_TERMINATING)
|
||||
cFWP_ACTION_NONE wtFwpActionType = 0x00000007
|
||||
cFWP_ACTION_NONE_NO_MATCH wtFwpActionType = 0x00000008
|
||||
cFWP_ACTION_BITMAP_INDEX_SET wtFwpActionType = 0x00000009
|
||||
)
|
||||
|
||||
// FWP_BYTE_BLOB defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_blob_)
|
||||
type wtFwpByteBlob struct {
|
||||
size uint32
|
||||
data *uint8
|
||||
}
|
||||
|
||||
// FWP_MATCH_TYPE defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ne-fwptypes-fwp_match_type_)
|
||||
type wtFwpMatchType uint32
|
||||
|
||||
const (
|
||||
cFWP_MATCH_EQUAL wtFwpMatchType = 0
|
||||
cFWP_MATCH_GREATER wtFwpMatchType = cFWP_MATCH_EQUAL + 1
|
||||
cFWP_MATCH_LESS wtFwpMatchType = cFWP_MATCH_GREATER + 1
|
||||
cFWP_MATCH_GREATER_OR_EQUAL wtFwpMatchType = cFWP_MATCH_LESS + 1
|
||||
cFWP_MATCH_LESS_OR_EQUAL wtFwpMatchType = cFWP_MATCH_GREATER_OR_EQUAL + 1
|
||||
cFWP_MATCH_RANGE wtFwpMatchType = cFWP_MATCH_LESS_OR_EQUAL + 1
|
||||
cFWP_MATCH_FLAGS_ALL_SET wtFwpMatchType = cFWP_MATCH_RANGE + 1
|
||||
cFWP_MATCH_FLAGS_ANY_SET wtFwpMatchType = cFWP_MATCH_FLAGS_ALL_SET + 1
|
||||
cFWP_MATCH_FLAGS_NONE_SET wtFwpMatchType = cFWP_MATCH_FLAGS_ANY_SET + 1
|
||||
cFWP_MATCH_EQUAL_CASE_INSENSITIVE wtFwpMatchType = cFWP_MATCH_FLAGS_NONE_SET + 1
|
||||
cFWP_MATCH_NOT_EQUAL wtFwpMatchType = cFWP_MATCH_EQUAL_CASE_INSENSITIVE + 1
|
||||
cFWP_MATCH_PREFIX wtFwpMatchType = cFWP_MATCH_NOT_EQUAL + 1
|
||||
cFWP_MATCH_NOT_PREFIX wtFwpMatchType = cFWP_MATCH_PREFIX + 1
|
||||
cFWP_MATCH_TYPE_MAX wtFwpMatchType = cFWP_MATCH_NOT_PREFIX + 1
|
||||
)
|
||||
|
||||
// FWPM_ACTION0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_action0_)
|
||||
type wtFwpmAction0 struct {
|
||||
_type wtFwpActionType
|
||||
filterType windows.GUID // Windows type: GUID
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. 4cd62a49-59c3-4969-b7f3-bda5d32890a4
|
||||
var cFWPM_CONDITION_IP_LOCAL_INTERFACE = windows.GUID{
|
||||
Data1: 0x4cd62a49,
|
||||
Data2: 0x59c3,
|
||||
Data3: 0x4969,
|
||||
Data4: [8]byte{0xb7, 0xf3, 0xbd, 0xa5, 0xd3, 0x28, 0x90, 0xa4},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. b235ae9a-1d64-49b8-a44c-5ff3d9095045
|
||||
var cFWPM_CONDITION_IP_REMOTE_ADDRESS = windows.GUID{
|
||||
Data1: 0xb235ae9a,
|
||||
Data2: 0x1d64,
|
||||
Data3: 0x49b8,
|
||||
Data4: [8]byte{0xa4, 0x4c, 0x5f, 0xf3, 0xd9, 0x09, 0x50, 0x45},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. 3971ef2b-623e-4f9a-8cb1-6e79b806b9a7
|
||||
var cFWPM_CONDITION_IP_PROTOCOL = windows.GUID{
|
||||
Data1: 0x3971ef2b,
|
||||
Data2: 0x623e,
|
||||
Data3: 0x4f9a,
|
||||
Data4: [8]byte{0x8c, 0xb1, 0x6e, 0x79, 0xb8, 0x06, 0xb9, 0xa7},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. 0c1ba1af-5765-453f-af22-a8f791ac775b
|
||||
var cFWPM_CONDITION_IP_LOCAL_PORT = windows.GUID{
|
||||
Data1: 0x0c1ba1af,
|
||||
Data2: 0x5765,
|
||||
Data3: 0x453f,
|
||||
Data4: [8]byte{0xaf, 0x22, 0xa8, 0xf7, 0x91, 0xac, 0x77, 0x5b},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. c35a604d-d22b-4e1a-91b4-68f674ee674b
|
||||
var cFWPM_CONDITION_IP_REMOTE_PORT = windows.GUID{
|
||||
Data1: 0xc35a604d,
|
||||
Data2: 0xd22b,
|
||||
Data3: 0x4e1a,
|
||||
Data4: [8]byte{0x91, 0xb4, 0x68, 0xf6, 0x74, 0xee, 0x67, 0x4b},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. d78e1e87-8644-4ea5-9437-d809ecefc971
|
||||
var cFWPM_CONDITION_ALE_APP_ID = windows.GUID{
|
||||
Data1: 0xd78e1e87,
|
||||
Data2: 0x8644,
|
||||
Data3: 0x4ea5,
|
||||
Data4: [8]byte{0x94, 0x37, 0xd8, 0x09, 0xec, 0xef, 0xc9, 0x71},
|
||||
}
|
||||
|
||||
// af043a0a-b34d-4f86-979c-c90371af6e66
|
||||
var cFWPM_CONDITION_ALE_USER_ID = windows.GUID{
|
||||
Data1: 0xaf043a0a,
|
||||
Data2: 0xb34d,
|
||||
Data3: 0x4f86,
|
||||
Data4: [8]byte{0x97, 0x9c, 0xc9, 0x03, 0x71, 0xaf, 0x6e, 0x66},
|
||||
}
|
||||
|
||||
// d9ee00de-c1ef-4617-bfe3-ffd8f5a08957
|
||||
var cFWPM_CONDITION_IP_LOCAL_ADDRESS = windows.GUID{
|
||||
Data1: 0xd9ee00de,
|
||||
Data2: 0xc1ef,
|
||||
Data3: 0x4617,
|
||||
Data4: [8]byte{0xbf, 0xe3, 0xff, 0xd8, 0xf5, 0xa0, 0x89, 0x57},
|
||||
}
|
||||
|
||||
var (
|
||||
cFWPM_CONDITION_ICMP_TYPE = cFWPM_CONDITION_IP_LOCAL_PORT
|
||||
cFWPM_CONDITION_ICMP_CODE = cFWPM_CONDITION_IP_REMOTE_PORT
|
||||
)
|
||||
|
||||
// 7bc43cbf-37ba-45f1-b74a-82ff518eeb10
|
||||
var cFWPM_CONDITION_L2_FLAGS = windows.GUID{
|
||||
Data1: 0x7bc43cbf,
|
||||
Data2: 0x37ba,
|
||||
Data3: 0x45f1,
|
||||
Data4: [8]byte{0xb7, 0x4a, 0x82, 0xff, 0x51, 0x8e, 0xeb, 0x10},
|
||||
}
|
||||
|
||||
type wtFwpmL2Flags uint32
|
||||
|
||||
const cFWP_CONDITION_L2_IS_VM2VM wtFwpmL2Flags = 0x00000010
|
||||
|
||||
var cFWPM_CONDITION_FLAGS = windows.GUID{
|
||||
Data1: 0x632ce23b,
|
||||
Data2: 0x5167,
|
||||
Data3: 0x435c,
|
||||
Data4: [8]byte{0x86, 0xd7, 0xe9, 0x03, 0x68, 0x4a, 0xa8, 0x0c},
|
||||
}
|
||||
|
||||
type wtFwpmFlags uint32
|
||||
|
||||
const cFWP_CONDITION_FLAG_IS_LOOPBACK wtFwpmFlags = 0x00000001
|
||||
|
||||
// Defined in fwpmtypes.h
|
||||
type wtFwpmFilterFlags uint32
|
||||
|
||||
const (
|
||||
cFWPM_FILTER_FLAG_NONE wtFwpmFilterFlags = 0x00000000
|
||||
cFWPM_FILTER_FLAG_PERSISTENT wtFwpmFilterFlags = 0x00000001
|
||||
cFWPM_FILTER_FLAG_BOOTTIME wtFwpmFilterFlags = 0x00000002
|
||||
cFWPM_FILTER_FLAG_HAS_PROVIDER_CONTEXT wtFwpmFilterFlags = 0x00000004
|
||||
cFWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT wtFwpmFilterFlags = 0x00000008
|
||||
cFWPM_FILTER_FLAG_PERMIT_IF_CALLOUT_UNREGISTERED wtFwpmFilterFlags = 0x00000010
|
||||
cFWPM_FILTER_FLAG_DISABLED wtFwpmFilterFlags = 0x00000020
|
||||
cFWPM_FILTER_FLAG_INDEXED wtFwpmFilterFlags = 0x00000040
|
||||
cFWPM_FILTER_FLAG_HAS_SECURITY_REALM_PROVIDER_CONTEXT wtFwpmFilterFlags = 0x00000080
|
||||
cFWPM_FILTER_FLAG_SYSTEMOS_ONLY wtFwpmFilterFlags = 0x00000100
|
||||
cFWPM_FILTER_FLAG_GAMEOS_ONLY wtFwpmFilterFlags = 0x00000200
|
||||
cFWPM_FILTER_FLAG_SILENT_MODE wtFwpmFilterFlags = 0x00000400
|
||||
cFWPM_FILTER_FLAG_IPSEC_NO_ACQUIRE_INITIATE wtFwpmFilterFlags = 0x00000800
|
||||
)
|
||||
|
||||
// FWPM_LAYER_ALE_AUTH_CONNECT_V4 (c38d57d1-05a7-4c33-904f-7fbceee60e82) defined in fwpmu.h
|
||||
var cFWPM_LAYER_ALE_AUTH_CONNECT_V4 = windows.GUID{
|
||||
Data1: 0xc38d57d1,
|
||||
Data2: 0x05a7,
|
||||
Data3: 0x4c33,
|
||||
Data4: [8]byte{0x90, 0x4f, 0x7f, 0xbc, 0xee, 0xe6, 0x0e, 0x82},
|
||||
}
|
||||
|
||||
// e1cd9fe7-f4b5-4273-96c0-592e487b8650
|
||||
var cFWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 = windows.GUID{
|
||||
Data1: 0xe1cd9fe7,
|
||||
Data2: 0xf4b5,
|
||||
Data3: 0x4273,
|
||||
Data4: [8]byte{0x96, 0xc0, 0x59, 0x2e, 0x48, 0x7b, 0x86, 0x50},
|
||||
}
|
||||
|
||||
// FWPM_LAYER_ALE_AUTH_CONNECT_V6 (4a72393b-319f-44bc-84c3-ba54dcb3b6b4) defined in fwpmu.h
|
||||
var cFWPM_LAYER_ALE_AUTH_CONNECT_V6 = windows.GUID{
|
||||
Data1: 0x4a72393b,
|
||||
Data2: 0x319f,
|
||||
Data3: 0x44bc,
|
||||
Data4: [8]byte{0x84, 0xc3, 0xba, 0x54, 0xdc, 0xb3, 0xb6, 0xb4},
|
||||
}
|
||||
|
||||
// a3b42c97-9f04-4672-b87e-cee9c483257f
|
||||
var cFWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6 = windows.GUID{
|
||||
Data1: 0xa3b42c97,
|
||||
Data2: 0x9f04,
|
||||
Data3: 0x4672,
|
||||
Data4: [8]byte{0xb8, 0x7e, 0xce, 0xe9, 0xc4, 0x83, 0x25, 0x7f},
|
||||
}
|
||||
|
||||
// 94c44912-9d6f-4ebf-b995-05ab8a088d1b
|
||||
var cFWPM_LAYER_OUTBOUND_MAC_FRAME_NATIVE = windows.GUID{
|
||||
Data1: 0x94c44912,
|
||||
Data2: 0x9d6f,
|
||||
Data3: 0x4ebf,
|
||||
Data4: [8]byte{0xb9, 0x95, 0x05, 0xab, 0x8a, 0x08, 0x8d, 0x1b},
|
||||
}
|
||||
|
||||
// d4220bd3-62ce-4f08-ae88-b56e8526df50
|
||||
var cFWPM_LAYER_INBOUND_MAC_FRAME_NATIVE = windows.GUID{
|
||||
Data1: 0xd4220bd3,
|
||||
Data2: 0x62ce,
|
||||
Data3: 0x4f08,
|
||||
Data4: [8]byte{0xae, 0x88, 0xb5, 0x6e, 0x85, 0x26, 0xdf, 0x50},
|
||||
}
|
||||
|
||||
// FWP_BITMAP_ARRAY64 defined in fwtypes.h
|
||||
type wtFwpBitmapArray64 struct {
|
||||
bitmapArray64 [8]uint8 // Windows type: [8]UINT8
|
||||
}
|
||||
|
||||
// FWP_BYTE_ARRAY6 defined in fwtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_array6_)
|
||||
type wtFwpByteArray6 struct {
|
||||
byteArray6 [6]uint8 // Windows type: [6]UINT8
|
||||
}
|
||||
|
||||
// FWP_BYTE_ARRAY16 defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_array16_)
|
||||
type wtFwpByteArray16 struct {
|
||||
byteArray16 [16]uint8 // Windows type [16]UINT8
|
||||
}
|
||||
|
||||
// FWP_CONDITION_VALUE0 defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_condition_value0).
|
||||
type wtFwpConditionValue0 wtFwpValue0
|
||||
|
||||
// FWP_DATA_TYPE defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ne-fwptypes-fwp_data_type_)
|
||||
type wtFwpDataType uint
|
||||
|
||||
const (
|
||||
cFWP_EMPTY wtFwpDataType = 0
|
||||
cFWP_UINT8 wtFwpDataType = cFWP_EMPTY + 1
|
||||
cFWP_UINT16 wtFwpDataType = cFWP_UINT8 + 1
|
||||
cFWP_UINT32 wtFwpDataType = cFWP_UINT16 + 1
|
||||
cFWP_UINT64 wtFwpDataType = cFWP_UINT32 + 1
|
||||
cFWP_INT8 wtFwpDataType = cFWP_UINT64 + 1
|
||||
cFWP_INT16 wtFwpDataType = cFWP_INT8 + 1
|
||||
cFWP_INT32 wtFwpDataType = cFWP_INT16 + 1
|
||||
cFWP_INT64 wtFwpDataType = cFWP_INT32 + 1
|
||||
cFWP_FLOAT wtFwpDataType = cFWP_INT64 + 1
|
||||
cFWP_DOUBLE wtFwpDataType = cFWP_FLOAT + 1
|
||||
cFWP_BYTE_ARRAY16_TYPE wtFwpDataType = cFWP_DOUBLE + 1
|
||||
cFWP_BYTE_BLOB_TYPE wtFwpDataType = cFWP_BYTE_ARRAY16_TYPE + 1
|
||||
cFWP_SID wtFwpDataType = cFWP_BYTE_BLOB_TYPE + 1
|
||||
cFWP_SECURITY_DESCRIPTOR_TYPE wtFwpDataType = cFWP_SID + 1
|
||||
cFWP_TOKEN_INFORMATION_TYPE wtFwpDataType = cFWP_SECURITY_DESCRIPTOR_TYPE + 1
|
||||
cFWP_TOKEN_ACCESS_INFORMATION_TYPE wtFwpDataType = cFWP_TOKEN_INFORMATION_TYPE + 1
|
||||
cFWP_UNICODE_STRING_TYPE wtFwpDataType = cFWP_TOKEN_ACCESS_INFORMATION_TYPE + 1
|
||||
cFWP_BYTE_ARRAY6_TYPE wtFwpDataType = cFWP_UNICODE_STRING_TYPE + 1
|
||||
cFWP_BITMAP_INDEX_TYPE wtFwpDataType = cFWP_BYTE_ARRAY6_TYPE + 1
|
||||
cFWP_BITMAP_ARRAY64_TYPE wtFwpDataType = cFWP_BITMAP_INDEX_TYPE + 1
|
||||
cFWP_SINGLE_DATA_TYPE_MAX wtFwpDataType = 0xff
|
||||
cFWP_V4_ADDR_MASK wtFwpDataType = cFWP_SINGLE_DATA_TYPE_MAX + 1
|
||||
cFWP_V6_ADDR_MASK wtFwpDataType = cFWP_V4_ADDR_MASK + 1
|
||||
cFWP_RANGE_TYPE wtFwpDataType = cFWP_V6_ADDR_MASK + 1
|
||||
cFWP_DATA_TYPE_MAX wtFwpDataType = cFWP_RANGE_TYPE + 1
|
||||
)
|
||||
|
||||
// FWP_V4_ADDR_AND_MASK defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_v4_addr_and_mask).
|
||||
type wtFwpV4AddrAndMask struct {
|
||||
addr uint32
|
||||
mask uint32
|
||||
}
|
||||
|
||||
// FWP_V6_ADDR_AND_MASK defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_v6_addr_and_mask).
|
||||
type wtFwpV6AddrAndMask struct {
|
||||
addr [16]uint8
|
||||
prefixLength uint8
|
||||
}
|
||||
|
||||
// FWP_VALUE0 defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_value0_)
|
||||
type wtFwpValue0 struct {
|
||||
_type wtFwpDataType
|
||||
value uintptr
|
||||
}
|
||||
|
||||
// FWPM_DISPLAY_DATA0 defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwpm_display_data0).
|
||||
type wtFwpmDisplayData0 struct {
|
||||
name *uint16 // Windows type: *wchar_t
|
||||
description *uint16 // Windows type: *wchar_t
|
||||
}
|
||||
|
||||
// FWPM_FILTER_CONDITION0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter_condition0).
|
||||
type wtFwpmFilterCondition0 struct {
|
||||
fieldKey windows.GUID // Windows type: GUID
|
||||
matchType wtFwpMatchType
|
||||
conditionValue wtFwpConditionValue0
|
||||
}
|
||||
|
||||
// FWPM_PROVIDER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_provider0_)
|
||||
type wtFwpProvider0 struct {
|
||||
providerKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags uint32
|
||||
providerData wtFwpByteBlob
|
||||
serviceName *uint16 // Windows type: *wchar_t
|
||||
}
|
||||
|
||||
type wtFwpmSessionFlagsValue uint32
|
||||
|
||||
const (
|
||||
cFWPM_SESSION_FLAG_DYNAMIC wtFwpmSessionFlagsValue = 0x00000001 // FWPM_SESSION_FLAG_DYNAMIC defined in fwpmtypes.h
|
||||
)
|
||||
|
||||
// FWPM_SESSION0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_session0).
|
||||
type wtFwpmSession0 struct {
|
||||
sessionKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags wtFwpmSessionFlagsValue // Windows type UINT32
|
||||
txnWaitTimeoutInMSec uint32
|
||||
processId uint32 // Windows type: DWORD
|
||||
sid *windows.SID
|
||||
username *uint16 // Windows type: *wchar_t
|
||||
kernelMode uint8 // Windows type: BOOL
|
||||
}
|
||||
|
||||
type wtFwpmSublayerFlags uint32
|
||||
|
||||
const (
|
||||
cFWPM_SUBLAYER_FLAG_PERSISTENT wtFwpmSublayerFlags = 0x00000001 // FWPM_SUBLAYER_FLAG_PERSISTENT defined in fwpmtypes.h
|
||||
)
|
||||
|
||||
// FWPM_SUBLAYER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_sublayer0_)
|
||||
type wtFwpmSublayer0 struct {
|
||||
subLayerKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags wtFwpmSublayerFlags
|
||||
providerKey *windows.GUID // Windows type: *GUID
|
||||
providerData wtFwpByteBlob
|
||||
weight uint16
|
||||
}
|
||||
|
||||
// Defined in rpcdce.h
|
||||
type wtRpcCAuthN uint32
|
||||
|
||||
const (
|
||||
cRPC_C_AUTHN_NONE wtRpcCAuthN = 0
|
||||
cRPC_C_AUTHN_WINNT wtRpcCAuthN = 10
|
||||
cRPC_C_AUTHN_DEFAULT wtRpcCAuthN = 0xFFFFFFFF
|
||||
)
|
||||
|
||||
// FWPM_PROVIDER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/sv-se/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_provider0).
|
||||
type wtFwpmProvider0 struct {
|
||||
providerKey windows.GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags uint32
|
||||
providerData wtFwpByteBlob
|
||||
serviceName *uint16
|
||||
}
|
||||
|
||||
type wtIPProto uint32
|
||||
|
||||
const (
|
||||
cIPPROTO_ICMP wtIPProto = 1
|
||||
cIPPROTO_ICMPV6 wtIPProto = 58
|
||||
cIPPROTO_TCP wtIPProto = 6
|
||||
cIPPROTO_UDP wtIPProto = 17
|
||||
)
|
||||
|
||||
const (
|
||||
cFWP_ACTRL_MATCH_FILTER = 1
|
||||
)
|
||||
@@ -1,92 +0,0 @@
|
||||
//go:build windows && (386 || arm)
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/types_windows_32.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
const (
|
||||
wtFwpByteBlob_Size = 8
|
||||
wtFwpByteBlob_data_Offset = 4
|
||||
|
||||
wtFwpConditionValue0_Size = 8
|
||||
wtFwpConditionValue0_uint8_Offset = 4
|
||||
|
||||
wtFwpmDisplayData0_Size = 8
|
||||
wtFwpmDisplayData0_description_Offset = 4
|
||||
|
||||
wtFwpmFilter0_Size = 152
|
||||
wtFwpmFilter0_displayData_Offset = 16
|
||||
wtFwpmFilter0_flags_Offset = 24
|
||||
wtFwpmFilter0_providerKey_Offset = 28
|
||||
wtFwpmFilter0_providerData_Offset = 32
|
||||
wtFwpmFilter0_layerKey_Offset = 40
|
||||
wtFwpmFilter0_subLayerKey_Offset = 56
|
||||
wtFwpmFilter0_weight_Offset = 72
|
||||
wtFwpmFilter0_numFilterConditions_Offset = 80
|
||||
wtFwpmFilter0_filterCondition_Offset = 84
|
||||
wtFwpmFilter0_action_Offset = 88
|
||||
wtFwpmFilter0_providerContextKey_Offset = 112
|
||||
wtFwpmFilter0_reserved_Offset = 128
|
||||
wtFwpmFilter0_filterID_Offset = 136
|
||||
wtFwpmFilter0_effectiveWeight_Offset = 144
|
||||
|
||||
wtFwpmFilterCondition0_Size = 28
|
||||
wtFwpmFilterCondition0_matchType_Offset = 16
|
||||
wtFwpmFilterCondition0_conditionValue_Offset = 20
|
||||
|
||||
wtFwpmSession0_Size = 48
|
||||
wtFwpmSession0_displayData_Offset = 16
|
||||
wtFwpmSession0_flags_Offset = 24
|
||||
wtFwpmSession0_txnWaitTimeoutInMSec_Offset = 28
|
||||
wtFwpmSession0_processId_Offset = 32
|
||||
wtFwpmSession0_sid_Offset = 36
|
||||
wtFwpmSession0_username_Offset = 40
|
||||
wtFwpmSession0_kernelMode_Offset = 44
|
||||
|
||||
wtFwpmSublayer0_Size = 44
|
||||
wtFwpmSublayer0_displayData_Offset = 16
|
||||
wtFwpmSublayer0_flags_Offset = 24
|
||||
wtFwpmSublayer0_providerKey_Offset = 28
|
||||
wtFwpmSublayer0_providerData_Offset = 32
|
||||
wtFwpmSublayer0_weight_Offset = 40
|
||||
|
||||
wtFwpProvider0_Size = 40
|
||||
wtFwpProvider0_displayData_Offset = 16
|
||||
wtFwpProvider0_flags_Offset = 24
|
||||
wtFwpProvider0_providerData_Offset = 28
|
||||
wtFwpProvider0_serviceName_Offset = 36
|
||||
|
||||
wtFwpTokenInformation_Size = 16
|
||||
|
||||
wtFwpValue0_Size = 8
|
||||
wtFwpValue0_value_Offset = 4
|
||||
)
|
||||
|
||||
// FWPM_FILTER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter0).
|
||||
type wtFwpmFilter0 struct {
|
||||
filterKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags wtFwpmFilterFlags
|
||||
providerKey *windows.GUID // Windows type: *GUID
|
||||
providerData wtFwpByteBlob
|
||||
layerKey windows.GUID // Windows type: GUID
|
||||
subLayerKey windows.GUID // Windows type: GUID
|
||||
weight wtFwpValue0
|
||||
numFilterConditions uint32
|
||||
filterCondition *wtFwpmFilterCondition0
|
||||
action wtFwpmAction0
|
||||
offset1 [4]byte // Layout correction field
|
||||
providerContextKey windows.GUID // Windows type: GUID
|
||||
reserved *windows.GUID // Windows type: *GUID
|
||||
offset2 [4]byte // Layout correction field
|
||||
filterID uint64
|
||||
effectiveWeight wtFwpValue0
|
||||
}
|
||||
@@ -1,89 +0,0 @@
|
||||
//go:build windows && (amd64 || arm64)
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/types_windows_64.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
const (
|
||||
wtFwpByteBlob_Size = 16
|
||||
wtFwpByteBlob_data_Offset = 8
|
||||
|
||||
wtFwpConditionValue0_Size = 16
|
||||
wtFwpConditionValue0_uint8_Offset = 8
|
||||
|
||||
wtFwpmDisplayData0_Size = 16
|
||||
wtFwpmDisplayData0_description_Offset = 8
|
||||
|
||||
wtFwpmFilter0_Size = 200
|
||||
wtFwpmFilter0_displayData_Offset = 16
|
||||
wtFwpmFilter0_flags_Offset = 32
|
||||
wtFwpmFilter0_providerKey_Offset = 40
|
||||
wtFwpmFilter0_providerData_Offset = 48
|
||||
wtFwpmFilter0_layerKey_Offset = 64
|
||||
wtFwpmFilter0_subLayerKey_Offset = 80
|
||||
wtFwpmFilter0_weight_Offset = 96
|
||||
wtFwpmFilter0_numFilterConditions_Offset = 112
|
||||
wtFwpmFilter0_filterCondition_Offset = 120
|
||||
wtFwpmFilter0_action_Offset = 128
|
||||
wtFwpmFilter0_providerContextKey_Offset = 152
|
||||
wtFwpmFilter0_reserved_Offset = 168
|
||||
wtFwpmFilter0_filterID_Offset = 176
|
||||
wtFwpmFilter0_effectiveWeight_Offset = 184
|
||||
|
||||
wtFwpmFilterCondition0_Size = 40
|
||||
wtFwpmFilterCondition0_matchType_Offset = 16
|
||||
wtFwpmFilterCondition0_conditionValue_Offset = 24
|
||||
|
||||
wtFwpmSession0_Size = 72
|
||||
wtFwpmSession0_displayData_Offset = 16
|
||||
wtFwpmSession0_flags_Offset = 32
|
||||
wtFwpmSession0_txnWaitTimeoutInMSec_Offset = 36
|
||||
wtFwpmSession0_processId_Offset = 40
|
||||
wtFwpmSession0_sid_Offset = 48
|
||||
wtFwpmSession0_username_Offset = 56
|
||||
wtFwpmSession0_kernelMode_Offset = 64
|
||||
|
||||
wtFwpmSublayer0_Size = 72
|
||||
wtFwpmSublayer0_displayData_Offset = 16
|
||||
wtFwpmSublayer0_flags_Offset = 32
|
||||
wtFwpmSublayer0_providerKey_Offset = 40
|
||||
wtFwpmSublayer0_providerData_Offset = 48
|
||||
wtFwpmSublayer0_weight_Offset = 64
|
||||
|
||||
wtFwpProvider0_Size = 64
|
||||
wtFwpProvider0_displayData_Offset = 16
|
||||
wtFwpProvider0_flags_Offset = 32
|
||||
wtFwpProvider0_providerData_Offset = 40
|
||||
wtFwpProvider0_serviceName_Offset = 56
|
||||
|
||||
wtFwpValue0_Size = 16
|
||||
wtFwpValue0_value_Offset = 8
|
||||
)
|
||||
|
||||
// FWPM_FILTER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter0).
|
||||
type wtFwpmFilter0 struct {
|
||||
filterKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags wtFwpmFilterFlags // Windows type: UINT32
|
||||
providerKey *windows.GUID // Windows type: *GUID
|
||||
providerData wtFwpByteBlob
|
||||
layerKey windows.GUID // Windows type: GUID
|
||||
subLayerKey windows.GUID // Windows type: GUID
|
||||
weight wtFwpValue0
|
||||
numFilterConditions uint32
|
||||
filterCondition *wtFwpmFilterCondition0
|
||||
action wtFwpmAction0
|
||||
offset1 [4]byte // Layout correction field
|
||||
providerContextKey windows.GUID // Windows type: GUID
|
||||
reserved *windows.GUID // Windows type: *GUID
|
||||
filterID uint64
|
||||
effectiveWeight wtFwpValue0
|
||||
}
|
||||
@@ -1,130 +0,0 @@
|
||||
// Code generated by 'go generate'; DO NOT EDIT.
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ unsafe.Pointer
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
errERROR_EINVAL error = syscall.EINVAL
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return errERROR_EINVAL
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
// TODO: add more here, after collecting data on the common
|
||||
// error values see on Windows. (perhaps when running
|
||||
// all.bat?)
|
||||
return e
|
||||
}
|
||||
|
||||
var (
|
||||
modfwpuclnt = windows.NewLazySystemDLL("fwpuclnt.dll")
|
||||
|
||||
procFwpmEngineClose0 = modfwpuclnt.NewProc("FwpmEngineClose0")
|
||||
procFwpmEngineOpen0 = modfwpuclnt.NewProc("FwpmEngineOpen0")
|
||||
procFwpmFilterAdd0 = modfwpuclnt.NewProc("FwpmFilterAdd0")
|
||||
procFwpmFreeMemory0 = modfwpuclnt.NewProc("FwpmFreeMemory0")
|
||||
procFwpmGetAppIdFromFileName0 = modfwpuclnt.NewProc("FwpmGetAppIdFromFileName0")
|
||||
procFwpmProviderAdd0 = modfwpuclnt.NewProc("FwpmProviderAdd0")
|
||||
procFwpmSubLayerAdd0 = modfwpuclnt.NewProc("FwpmSubLayerAdd0")
|
||||
procFwpmTransactionAbort0 = modfwpuclnt.NewProc("FwpmTransactionAbort0")
|
||||
procFwpmTransactionBegin0 = modfwpuclnt.NewProc("FwpmTransactionBegin0")
|
||||
procFwpmTransactionCommit0 = modfwpuclnt.NewProc("FwpmTransactionCommit0")
|
||||
)
|
||||
|
||||
func fwpmEngineClose0(engineHandle uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmEngineClose0.Addr(), 1, uintptr(engineHandle), 0, 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmEngineOpen0(serverName *uint16, authnService wtRpcCAuthN, authIdentity *uintptr, session *wtFwpmSession0, engineHandle unsafe.Pointer) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procFwpmEngineOpen0.Addr(), 5, uintptr(unsafe.Pointer(serverName)), uintptr(authnService), uintptr(unsafe.Pointer(authIdentity)), uintptr(unsafe.Pointer(session)), uintptr(engineHandle), 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmFilterAdd0(engineHandle uintptr, filter *wtFwpmFilter0, sd uintptr, id *uint64) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procFwpmFilterAdd0.Addr(), 4, uintptr(engineHandle), uintptr(unsafe.Pointer(filter)), uintptr(sd), uintptr(unsafe.Pointer(id)), 0, 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmFreeMemory0(p unsafe.Pointer) {
|
||||
syscall.Syscall(procFwpmFreeMemory0.Addr(), 1, uintptr(p), 0, 0)
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmGetAppIdFromFileName0(fileName *uint16, appID unsafe.Pointer) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmGetAppIdFromFileName0.Addr(), 2, uintptr(unsafe.Pointer(fileName)), uintptr(appID), 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmProviderAdd0(engineHandle uintptr, provider *wtFwpmProvider0, sd uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmProviderAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(provider)), uintptr(sd))
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmSubLayerAdd0(engineHandle uintptr, subLayer *wtFwpmSublayer0, sd uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmSubLayerAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(subLayer)), uintptr(sd))
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmTransactionAbort0(engineHandle uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmTransactionAbort0.Addr(), 1, uintptr(engineHandle), 0, 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmTransactionBegin0(engineHandle uintptr, flags uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmTransactionBegin0.Addr(), 2, uintptr(engineHandle), uintptr(flags), 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmTransactionCommit0(engineHandle uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmTransactionCommit0.Addr(), 1, uintptr(engineHandle), 0, 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -339,8 +339,7 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
case entry.Pattern == ".":
|
||||
return true
|
||||
case entry.IsWildcard:
|
||||
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
|
||||
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
|
||||
return strings.HasSuffix(qname, "."+entry.Pattern)
|
||||
default:
|
||||
// For non-wildcard patterns:
|
||||
// If handler wants subdomain matching, allow suffix match
|
||||
|
||||
@@ -164,6 +164,54 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||
matchSubdomains: true,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard label-boundary mismatch (suffix overlap)",
|
||||
handlerDomain: "*.b.test.",
|
||||
queryDomain: "x.ab.test.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard label-boundary match",
|
||||
handlerDomain: "*.b.test.",
|
||||
queryDomain: "x.b.test.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard multi-label match",
|
||||
handlerDomain: "*.b.test.",
|
||||
queryDomain: "x.y.b.test.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard no match on multi-label apex",
|
||||
handlerDomain: "*.b.test.",
|
||||
queryDomain: "b.test.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard no match on unrelated suffix containment",
|
||||
handlerDomain: "*.example.com.",
|
||||
queryDomain: "notexample.com.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard accepts pattern registered without trailing dot",
|
||||
handlerDomain: "*.b.test",
|
||||
queryDomain: "x.b.test.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -273,6 +321,19 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||
expectedCalls: 1,
|
||||
expectedHandler: 2, // highest priority matching handler should be called
|
||||
},
|
||||
{
|
||||
name: "overlapping wildcard suffixes route to correct handler",
|
||||
handlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
}{
|
||||
{pattern: "*.b.test.", priority: nbdns.PriorityDNSRoute},
|
||||
{pattern: "*.ab.test.", priority: nbdns.PriorityDNSRoute},
|
||||
},
|
||||
queryDomain: "app.ab.test.",
|
||||
expectedCalls: 1,
|
||||
expectedHandler: 1,
|
||||
},
|
||||
{
|
||||
name: "root zone with specific domain",
|
||||
handlers: []struct {
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
"golang.org/x/sys/windows/registry"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/dnsfw"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/winregistry"
|
||||
)
|
||||
@@ -75,7 +74,6 @@ type registryConfigurator struct {
|
||||
routingAll bool
|
||||
gpo bool
|
||||
nrptEntryCount int
|
||||
dnsFirewall dnsfw.Manager
|
||||
origNameservers []netip.Addr
|
||||
}
|
||||
|
||||
@@ -96,9 +94,8 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||
}
|
||||
|
||||
configurator := ®istryConfigurator{
|
||||
guid: guid,
|
||||
gpo: useGPO,
|
||||
dnsFirewall: dnsfw.New(),
|
||||
guid: guid,
|
||||
gpo: useGPO,
|
||||
}
|
||||
|
||||
origNameservers, err := configurator.captureOriginalNameservers()
|
||||
@@ -279,8 +276,16 @@ func (r *registryConfigurator) disableWINSForInterface() error {
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||
if err := r.applyRouteAll(config); err != nil {
|
||||
return err
|
||||
if config.RouteAll {
|
||||
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
|
||||
return fmt.Errorf("add dns setup: %w", err)
|
||||
}
|
||||
} else if r.routingAll {
|
||||
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil {
|
||||
return fmt.Errorf("delete interface registry key property: %w", err)
|
||||
}
|
||||
r.routingAll = false
|
||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||
}
|
||||
|
||||
r.updateState(stateManager)
|
||||
@@ -322,35 +327,6 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) applyRouteAll(config HostDNSConfig) error {
|
||||
if config.RouteAll {
|
||||
if err := r.dnsFirewall.Enable(r.guid, config.ServerIP); err != nil {
|
||||
return fmt.Errorf("dns firewall: %w", err)
|
||||
}
|
||||
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
|
||||
merr := multierror.Append(nil, fmt.Errorf("add dns setup: %w", err))
|
||||
if dErr := r.dnsFirewall.Disable(); dErr != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("rollback dns firewall: %w", dErr))
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.dnsFirewall.Disable(); err != nil {
|
||||
log.Errorf("disable dns firewall: %v", err)
|
||||
}
|
||||
if !r.routingAll {
|
||||
return nil
|
||||
}
|
||||
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil {
|
||||
return fmt.Errorf("delete interface registry key property: %w", err)
|
||||
}
|
||||
r.routingAll = false
|
||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) {
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
Guid: r.guid,
|
||||
@@ -537,10 +513,6 @@ func (r *registryConfigurator) restoreHostDNS() error {
|
||||
return fmt.Errorf("remove interface registry key: %w", err)
|
||||
}
|
||||
|
||||
if err := r.dnsFirewall.Disable(); err != nil {
|
||||
log.Errorf("disable dns firewall: %v", err)
|
||||
}
|
||||
|
||||
go r.flushDNSCache()
|
||||
|
||||
return nil
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/dnsfw"
|
||||
)
|
||||
|
||||
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
|
||||
@@ -36,9 +34,8 @@ func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
||||
}()
|
||||
|
||||
cfg := ®istryConfigurator{
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
dnsFirewall: dnsfw.New(),
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
}
|
||||
|
||||
// Create 125 domains which will result in 3 NRPT rules (50+50+25)
|
||||
@@ -137,9 +134,8 @@ func TestNRPTDomainBatching(t *testing.T) {
|
||||
}()
|
||||
|
||||
cfg := ®istryConfigurator{
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
dnsFirewall: dnsfw.New(),
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
|
||||
@@ -26,6 +26,19 @@ type resolver interface {
|
||||
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||
}
|
||||
|
||||
// PeerConnectivity reports whether a tunnel IP belongs to a peer the
|
||||
// client knows about and whether that peer is currently connected. The
|
||||
// local resolver uses this to suppress A/AAAA answers whose RDATA points
|
||||
// at a disconnected peer (typical case: a synthesized private-service
|
||||
// record pointing at an embedded proxy peer that just went offline).
|
||||
//
|
||||
// known=false means the IP isn't in the local peerstore at all — the
|
||||
// record is left alone (it points at something outside our mesh, e.g.
|
||||
// a non-peer upstream).
|
||||
type PeerConnectivity interface {
|
||||
IsConnectedByIP(ip string) (known, connected bool)
|
||||
}
|
||||
|
||||
type Resolver struct {
|
||||
mu sync.RWMutex
|
||||
records map[dns.Question][]dns.RR
|
||||
@@ -33,6 +46,11 @@ type Resolver struct {
|
||||
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
|
||||
zones map[domain.Domain]bool
|
||||
resolver resolver
|
||||
// peerConn, when non-nil, is consulted on every A/AAAA answer to
|
||||
// drop records pointing at disconnected peers. nil disables the
|
||||
// filter and preserves the legacy "return whatever is registered"
|
||||
// behaviour for callers that never wire a status source.
|
||||
peerConn PeerConnectivity
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
@@ -49,6 +67,15 @@ func NewResolver() *Resolver {
|
||||
}
|
||||
}
|
||||
|
||||
// SetPeerConnectivity wires the per-IP connectivity check used to filter
|
||||
// out A/AAAA answers pointing at disconnected peers. Pass nil to disable.
|
||||
// Safe to call multiple times; the latest value wins.
|
||||
func (d *Resolver) SetPeerConnectivity(p PeerConnectivity) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
d.peerConn = p
|
||||
}
|
||||
|
||||
func (d *Resolver) MatchSubdomains() bool {
|
||||
return true
|
||||
}
|
||||
@@ -95,6 +122,7 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
replyMessage.RecursionAvailable = true
|
||||
|
||||
result := d.lookupRecords(logger, question)
|
||||
result.records = d.filterDisconnectedPeerAnswers(logger, question, result.records)
|
||||
replyMessage.Authoritative = !result.hasExternalData
|
||||
replyMessage.Answer = result.records
|
||||
replyMessage.Rcode = d.determineRcode(question, result)
|
||||
@@ -436,6 +464,78 @@ func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16,
|
||||
}
|
||||
}
|
||||
|
||||
// filterDisconnectedPeerAnswers drops A/AAAA records whose RDATA matches
|
||||
// a known but disconnected peer. The synthesized private-service zones
|
||||
// emit one A record per connected proxy peer in a cluster; when a peer
|
||||
// goes offline, the server-side refresh removes the record from the
|
||||
// next netmap, but the client may still hold the previous netmap for a
|
||||
// short window. This filter is the local belt to that braces — even on
|
||||
// the stale netmap, the resolver hides the offline target.
|
||||
//
|
||||
// Records pointing at unknown IPs (outside the local peerstore, e.g.
|
||||
// non-mesh upstreams) are never dropped. Non-A/AAAA records pass
|
||||
// through untouched.
|
||||
//
|
||||
// Escape hatch: if filtering would leave the answer empty AND at least
|
||||
// one record was filtered, the original list is returned. Better to
|
||||
// hand the client a record that may not respond than NXDOMAIN it
|
||||
// completely when every proxy peer is offline (the upstream may still
|
||||
// be reachable some other way, or the peerstore may be stale).
|
||||
func (d *Resolver) filterDisconnectedPeerAnswers(logger *log.Entry, question dns.Question, records []dns.RR) []dns.RR {
|
||||
if len(records) == 0 {
|
||||
return records
|
||||
}
|
||||
d.mu.RLock()
|
||||
checker := d.peerConn
|
||||
d.mu.RUnlock()
|
||||
if checker == nil {
|
||||
return records
|
||||
}
|
||||
|
||||
kept := make([]dns.RR, 0, len(records))
|
||||
var dropped int
|
||||
for _, rr := range records {
|
||||
ip := extractRecordIP(rr)
|
||||
if ip == "" {
|
||||
kept = append(kept, rr)
|
||||
continue
|
||||
}
|
||||
known, connected := checker.IsConnectedByIP(ip)
|
||||
if known && !connected {
|
||||
dropped++
|
||||
continue
|
||||
}
|
||||
kept = append(kept, rr)
|
||||
}
|
||||
if dropped == 0 {
|
||||
return records
|
||||
}
|
||||
if len(kept) == 0 {
|
||||
logger.Debugf("all %d answers for %s point at disconnected peers; returning the original list", dropped, question.Name)
|
||||
return records
|
||||
}
|
||||
logger.Tracef("dropped %d disconnected-peer answer(s) for %s, returning %d", dropped, question.Name, len(kept))
|
||||
return kept
|
||||
}
|
||||
|
||||
// extractRecordIP returns the dotted-decimal / colon-hex IP carried by
|
||||
// an A or AAAA record, or "" for any other record type.
|
||||
func extractRecordIP(rr dns.RR) string {
|
||||
switch r := rr.(type) {
|
||||
case *dns.A:
|
||||
if r.A == nil {
|
||||
return ""
|
||||
}
|
||||
return r.A.String()
|
||||
case *dns.AAAA:
|
||||
if r.AAAA == nil {
|
||||
return ""
|
||||
}
|
||||
return r.AAAA.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Update replaces all zones and their records
|
||||
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
|
||||
d.mu.Lock()
|
||||
|
||||
@@ -30,6 +30,21 @@ func (m *mockResolver) LookupNetIP(ctx context.Context, network, host string) ([
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// mockPeerConnectivity returns canned (known, connected) results per IP.
|
||||
// Used by the disconnected-peer filter tests below. IPs not in the map
|
||||
// are reported as unknown so the filter leaves them alone.
|
||||
type mockPeerConnectivity struct {
|
||||
byIP map[string]struct{ known, connected bool }
|
||||
}
|
||||
|
||||
func (m mockPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) {
|
||||
v, ok := m.byIP[ip]
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
return v.known, v.connected
|
||||
}
|
||||
|
||||
func TestLocalResolver_ServeDNS(t *testing.T) {
|
||||
recordA := nbdns.SimpleRecord{
|
||||
Name: "peera.netbird.cloud.",
|
||||
@@ -2652,3 +2667,114 @@ func BenchmarkIsInManagedZone_ManyZones(b *testing.B) {
|
||||
resolver.isInManagedZone(qname)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocalResolver_FilterDisconnectedPeerAnswers verifies the
|
||||
// connectivity-aware filtering layered on top of lookupRecords:
|
||||
// when an A record's IP belongs to a known peer that's disconnected,
|
||||
// the record is dropped from the answer. Records for unknown IPs pass
|
||||
// through. If filtering would empty the answer entirely and at least
|
||||
// one record was dropped, the original list is restored (escape hatch
|
||||
// for the "all proxies offline" case).
|
||||
func TestLocalResolver_FilterDisconnectedPeerAnswers(t *testing.T) {
|
||||
zone := "svc.cluster.netbird."
|
||||
connectedRec := nbdns.SimpleRecord{
|
||||
Name: zone,
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 5,
|
||||
RData: "100.64.0.10",
|
||||
}
|
||||
disconnectedRec := nbdns.SimpleRecord{
|
||||
Name: zone,
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 5,
|
||||
RData: "100.64.0.11",
|
||||
}
|
||||
unknownRec := nbdns.SimpleRecord{
|
||||
Name: zone,
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 5,
|
||||
RData: "203.0.113.5",
|
||||
}
|
||||
|
||||
type ipState struct{ known, connected bool }
|
||||
tests := []struct {
|
||||
name string
|
||||
records []nbdns.SimpleRecord
|
||||
connByIP map[string]ipState
|
||||
wantInOrder []string
|
||||
}{
|
||||
{
|
||||
name: "drops disconnected peer, keeps connected",
|
||||
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
|
||||
connByIP: map[string]ipState{
|
||||
"100.64.0.10": {known: true, connected: true},
|
||||
"100.64.0.11": {known: true, connected: false},
|
||||
},
|
||||
wantInOrder: []string{"100.64.0.10"},
|
||||
},
|
||||
{
|
||||
name: "unknown IPs pass through untouched",
|
||||
records: []nbdns.SimpleRecord{unknownRec, disconnectedRec},
|
||||
connByIP: map[string]ipState{
|
||||
"100.64.0.11": {known: true, connected: false},
|
||||
},
|
||||
wantInOrder: []string{"203.0.113.5"},
|
||||
},
|
||||
{
|
||||
name: "all disconnected falls back to original list",
|
||||
records: []nbdns.SimpleRecord{disconnectedRec, connectedRec},
|
||||
connByIP: map[string]ipState{
|
||||
"100.64.0.10": {known: true, connected: false},
|
||||
"100.64.0.11": {known: true, connected: false},
|
||||
},
|
||||
wantInOrder: []string{"100.64.0.11", "100.64.0.10"},
|
||||
},
|
||||
{
|
||||
name: "no checker wired returns all records",
|
||||
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
|
||||
connByIP: nil,
|
||||
wantInOrder: []string{"100.64.0.10", "100.64.0.11"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
if tc.connByIP != nil {
|
||||
cm := mockPeerConnectivity{byIP: make(map[string]struct{ known, connected bool }, len(tc.connByIP))}
|
||||
for ip, st := range tc.connByIP {
|
||||
cm.byIP[ip] = struct{ known, connected bool }{st.known, st.connected}
|
||||
}
|
||||
resolver.SetPeerConnectivity(cm)
|
||||
}
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: strings.TrimSuffix(zone, "."),
|
||||
Records: tc.records,
|
||||
NonAuthoritative: true,
|
||||
}})
|
||||
|
||||
var got *dns.Msg
|
||||
writer := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
got = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
req := new(dns.Msg).SetQuestion(zone, dns.TypeA)
|
||||
resolver.ServeDNS(writer, req)
|
||||
|
||||
require.NotNil(t, got, "resolver must produce a response")
|
||||
require.Len(t, got.Answer, len(tc.wantInOrder),
|
||||
"answer count must match expected: %v", tc.wantInOrder)
|
||||
for i, want := range tc.wantInOrder {
|
||||
a, ok := got.Answer[i].(*dns.A)
|
||||
require.True(t, ok, "answer[%d] must be an A record", i)
|
||||
assert.Equal(t, want, a.A.String(),
|
||||
"answer[%d] expected %s got %s", i, want, a.A.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -301,6 +301,11 @@ func newDefaultServer(
|
||||
warningDelayBase: defaultWarningDelayBase,
|
||||
healthRefresh: make(chan struct{}, 1),
|
||||
}
|
||||
// Wire the local resolver against the peer status recorder so it can
|
||||
// suppress A/AAAA answers that point at disconnected peers (typical
|
||||
// case: synthesised private-service records pointing at an embedded
|
||||
// proxy peer that just went offline).
|
||||
defaultServer.localResolver.SetPeerConnectivity(localPeerConnectivity{statusRecorder})
|
||||
|
||||
// register with root zone, handler chain takes care of the routing
|
||||
dnsService.RegisterMux(".", handlerChain)
|
||||
@@ -1386,3 +1391,25 @@ func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// localPeerConnectivity adapts *peer.Status to local.PeerConnectivity so
|
||||
// the local resolver can ask "is this IP a known peer and is it
|
||||
// connected?" without taking on the peer package as a dependency.
|
||||
// A nil status recorder always reports known=false so the resolver
|
||||
// short-circuits to the legacy "return everything" path.
|
||||
type localPeerConnectivity struct {
|
||||
status *peer.Status
|
||||
}
|
||||
|
||||
// IsConnectedByIP looks the IP up in the peerstore and surfaces both
|
||||
// the known and connected bits. Used by Resolver.filterDisconnectedPeerAnswers.
|
||||
func (l localPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) {
|
||||
if l.status == nil {
|
||||
return false, false
|
||||
}
|
||||
state, ok := l.status.PeerStateByIP(ip)
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
return true, state.ConnStatus == peer.StatusConnected
|
||||
}
|
||||
|
||||
@@ -1967,6 +1967,29 @@ func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
|
||||
return e.clientMetrics
|
||||
}
|
||||
|
||||
// Performance bundles runtime-adjustable tunnel pool knobs.
|
||||
// See Engine.SetPerformance. Nil fields are ignored.
|
||||
type Performance struct {
|
||||
PreallocatedBuffersPerPool *uint32
|
||||
}
|
||||
|
||||
// SetPerformance applies the given tuning to this engine's live Device.
|
||||
func (e *Engine) SetPerformance(t Performance) error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
if e.wgInterface == nil {
|
||||
return fmt.Errorf("wg interface not initialized")
|
||||
}
|
||||
dev := e.wgInterface.GetWGDevice()
|
||||
if dev == nil {
|
||||
return fmt.Errorf("wg device not initialized")
|
||||
}
|
||||
if t.PreallocatedBuffersPerPool != nil {
|
||||
dev.SetPreallocatedBuffersPerPool(*t.PreallocatedBuffersPerPool)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||
iface, err := net.InterfaceByName(ifaceName)
|
||||
if err != nil {
|
||||
|
||||
@@ -27,7 +27,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
@@ -66,8 +66,8 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
@@ -1641,7 +1641,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -50,7 +50,7 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
|
||||
switch msg.Type {
|
||||
// handle route changes
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
route, err := parseRouteMessage(buf[:n])
|
||||
route, flags, err := parseRouteMessage(buf[:n])
|
||||
if err != nil {
|
||||
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||
continue
|
||||
@@ -66,6 +66,10 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
|
||||
}
|
||||
switch msg.Type {
|
||||
case unix.RTM_ADD:
|
||||
if systemops.IgnoreAddedDefaultRoute(flags) {
|
||||
log.Debugf("Network monitor: ignoring added default route via %s, interface %s, flags %#x", route.Gw, intf, flags)
|
||||
continue
|
||||
}
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
case unix.RTM_DELETE:
|
||||
@@ -78,22 +82,26 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
|
||||
}
|
||||
}
|
||||
|
||||
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
||||
func parseRouteMessage(buf []byte) (*systemops.Route, int, error) {
|
||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||
return nil, 0, fmt.Errorf("parse RIB: %v", err)
|
||||
}
|
||||
|
||||
if len(msgs) != 1 {
|
||||
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||
return nil, 0, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||
}
|
||||
|
||||
msg, ok := msgs[0].(*route.RouteMessage)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||
return nil, 0, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||
}
|
||||
|
||||
return systemops.MsgToRoute(msg)
|
||||
r, err := systemops.MsgToRoute(msg)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return r, msg.Flags, nil
|
||||
}
|
||||
|
||||
// waitReadable blocks until fd has data to read, or ctx is cancelled.
|
||||
|
||||
@@ -185,9 +185,12 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
|
||||
return s.eventsChan
|
||||
}
|
||||
|
||||
// Status holds a state of peers, signal, management connections and relays
|
||||
// Status holds a state of peers, signal, management connections and relays.
|
||||
// mux is an RWMutex so hot read paths (notably PeerStateByIP, called for
|
||||
// every private-service request) don't contend against each other.
|
||||
// Pure read methods take RLock; anything that mutates state takes Lock.
|
||||
type Status struct {
|
||||
mux sync.Mutex
|
||||
mux sync.RWMutex
|
||||
peers map[string]State
|
||||
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
||||
signalState bool
|
||||
@@ -283,8 +286,8 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string)
|
||||
|
||||
// GetPeer adds peer to Daemon status map
|
||||
func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
|
||||
state, ok := d.peers[peerPubKey]
|
||||
if !ok {
|
||||
@@ -294,8 +297,8 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
||||
}
|
||||
|
||||
func (d *Status) PeerByIP(ip string) (string, bool) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
|
||||
for _, state := range d.peers {
|
||||
if state.IP == ip {
|
||||
@@ -305,6 +308,25 @@ func (d *Status) PeerByIP(ip string) (string, bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// PeerStateByIP returns the full peer State for the given tunnel IP.
|
||||
// Matches against either the IPv4 (State.IP) or IPv6 (State.IPv6) tunnel
|
||||
// address so dual-stack peers are reachable on either family. Returns the
|
||||
// zero State and false when no peer matches or the input is empty.
|
||||
func (d *Status) PeerStateByIP(ip string) (State, bool) {
|
||||
if ip == "" {
|
||||
return State{}, false
|
||||
}
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
|
||||
for _, state := range d.peers {
|
||||
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
|
||||
return state, true
|
||||
}
|
||||
}
|
||||
return State{}, false
|
||||
}
|
||||
|
||||
// RemovePeer removes peer from Daemon status map
|
||||
func (d *Status) RemovePeer(peerPubKey string) error {
|
||||
d.mux.Lock()
|
||||
@@ -702,8 +724,8 @@ func (d *Status) UnsubscribePeerStateChanges(subscription *StatusChangeSubscript
|
||||
|
||||
// GetLocalPeerState returns the local peer state
|
||||
func (d *Status) GetLocalPeerState() LocalPeerState {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return d.localPeer.Clone()
|
||||
}
|
||||
|
||||
@@ -909,8 +931,8 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
||||
}
|
||||
|
||||
func (d *Status) GetRosenpassState() RosenpassState {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return RosenpassState{
|
||||
d.rosenpassEnabled,
|
||||
d.rosenpassPermissive,
|
||||
@@ -918,14 +940,14 @@ func (d *Status) GetRosenpassState() RosenpassState {
|
||||
}
|
||||
|
||||
func (d *Status) GetLazyConnection() bool {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return d.lazyConnectionEnabled
|
||||
}
|
||||
|
||||
func (d *Status) GetManagementState() ManagementState {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return ManagementState{
|
||||
d.mgmAddress,
|
||||
d.managementState,
|
||||
@@ -951,8 +973,8 @@ func (d *Status) UpdateLatency(pubKey string, latency time.Duration) error {
|
||||
|
||||
// IsLoginRequired determines if a peer's login has expired.
|
||||
func (d *Status) IsLoginRequired() bool {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
|
||||
// if peer is connected to the management then login is not expired
|
||||
if d.managementState {
|
||||
@@ -967,8 +989,8 @@ func (d *Status) IsLoginRequired() bool {
|
||||
}
|
||||
|
||||
func (d *Status) GetSignalState() SignalState {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return SignalState{
|
||||
d.signalAddress,
|
||||
d.signalState,
|
||||
@@ -978,8 +1000,8 @@ func (d *Status) GetSignalState() SignalState {
|
||||
|
||||
// GetRelayStates returns the stun/turn/permanent relay states
|
||||
func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
if d.relayMgr == nil {
|
||||
return d.relayStates
|
||||
}
|
||||
@@ -1008,8 +1030,8 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
}
|
||||
|
||||
func (d *Status) ForwardingRules() []firewall.ForwardRule {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
if d.ingressGwMgr == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -1018,16 +1040,16 @@ func (d *Status) ForwardingRules() []firewall.ForwardRule {
|
||||
}
|
||||
|
||||
func (d *Status) GetDNSStates() []NSGroupState {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
|
||||
// shallow copy is good enough, as slices fields are currently not updated
|
||||
return slices.Clone(d.nsGroupStates)
|
||||
}
|
||||
|
||||
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return maps.Clone(d.resolvedDomainsStates)
|
||||
}
|
||||
|
||||
@@ -1043,8 +1065,8 @@ func (d *Status) GetFullStatus() FullStatus {
|
||||
LazyConnectionEnabled: d.GetLazyConnection(),
|
||||
}
|
||||
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
|
||||
fullStatus.LocalPeerState = d.localPeer
|
||||
|
||||
@@ -1219,8 +1241,8 @@ func (d *Status) SetWgIface(wgInterface WGIfaceStatus) {
|
||||
}
|
||||
|
||||
func (d *Status) PeersStatus() (*configurer.Stats, error) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
if d.wgIface == nil {
|
||||
return nil, fmt.Errorf("wgInterface is nil, cannot retrieve peers status")
|
||||
}
|
||||
|
||||
@@ -63,6 +63,33 @@ func TestUpdatePeerState(t *testing.T) {
|
||||
assert.Equal(t, ip, state.IP, "ip should be equal")
|
||||
}
|
||||
|
||||
func TestStatus_PeerStateByIP(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
req := require.New(t)
|
||||
|
||||
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", ""))
|
||||
req.NoError(status.AddPeer("pk-2", "peer-2.netbird", "100.64.0.11", ""))
|
||||
|
||||
state, ok := status.PeerStateByIP("100.64.0.10")
|
||||
req.True(ok, "known tunnel IP should resolve to a peer state")
|
||||
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
|
||||
req.Equal("peer-1.netbird", state.FQDN, "matching state must carry the right FQDN")
|
||||
|
||||
_, ok = status.PeerStateByIP("100.64.0.99")
|
||||
req.False(ok, "unknown IP must report ok=false")
|
||||
}
|
||||
|
||||
func TestStatus_PeerStateByIP_MatchesIPv6(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
req := require.New(t)
|
||||
|
||||
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", "fd00::1"))
|
||||
|
||||
state, ok := status.PeerStateByIP("fd00::1")
|
||||
req.True(ok, "IPv6-only match must resolve to the peer state")
|
||||
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
|
||||
}
|
||||
|
||||
func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
||||
key := "abc"
|
||||
fqdn := "peer-a.netbird.local"
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
//go:build dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package systemops
|
||||
|
||||
// IgnoreAddedDefaultRoute reports whether an RTM_ADD default route with the
|
||||
// given flags should be ignored by the network monitor.
|
||||
func IgnoreAddedDefaultRoute(flags int) bool {
|
||||
return filterRoutesByFlags(flags)
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
//go:build darwin
|
||||
|
||||
package systemops
|
||||
|
||||
import "golang.org/x/sys/unix"
|
||||
|
||||
// IgnoreAddedDefaultRoute reports whether an RTM_ADD default route with the
|
||||
// given flags should be ignored by the network monitor. Scoped routes
|
||||
// (RTF_IFSCOPE) are tied to a specific interface index and cannot replace the
|
||||
// unscoped default the kernel uses for general egress, so flapping ones (e.g.
|
||||
// Wi-Fi calling IMS tunnels on ipsec0, Docker bridges, scoped utun defaults)
|
||||
// must not trigger an engine restart.
|
||||
func IgnoreAddedDefaultRoute(flags int) bool {
|
||||
if filterRoutesByFlags(flags) {
|
||||
return true
|
||||
}
|
||||
if flags&unix.RTF_IFSCOPE != 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -188,7 +188,9 @@ func (d *Detector) triggerCallback(event EventType, cb func(event EventType), do
|
||||
}
|
||||
|
||||
doneChan := make(chan struct{})
|
||||
timeout := time.NewTimer(500 * time.Millisecond)
|
||||
// macOS forces sleep ~30s after kIOMessageSystemWillSleep, so block long
|
||||
// enough for teardown to finish while staying under that deadline.
|
||||
timeout := time.NewTimer(20 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
go func() {
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
@@ -315,7 +315,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -3,15 +3,14 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/zcalusic/sysinfo"
|
||||
|
||||
@@ -29,19 +28,11 @@ func UpdateStaticInfoAsync() {
|
||||
|
||||
// GetInfo retrieves and parses the system information
|
||||
func GetInfo(ctx context.Context) *Info {
|
||||
info := _getInfo()
|
||||
for strings.Contains(info, "broken pipe") {
|
||||
info = _getInfo()
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
osStr := strings.ReplaceAll(info, "\n", "")
|
||||
osStr = strings.ReplaceAll(osStr, "\r\n", "")
|
||||
osInfo := strings.Split(osStr, " ")
|
||||
kernelName, kernelVersion, kernelPlatform := kernelInfo()
|
||||
|
||||
osName, osVersion := readOsReleaseFile()
|
||||
if osName == "" {
|
||||
osName = osInfo[3]
|
||||
osName = kernelName
|
||||
}
|
||||
|
||||
systemHostname, _ := os.Hostname()
|
||||
@@ -58,8 +49,8 @@ func GetInfo(ctx context.Context) *Info {
|
||||
}
|
||||
|
||||
gio := &Info{
|
||||
Kernel: osInfo[0],
|
||||
Platform: osInfo[2],
|
||||
Kernel: kernelName,
|
||||
Platform: kernelPlatform,
|
||||
OS: osName,
|
||||
OSVersion: osVersion,
|
||||
Hostname: extractDeviceName(ctx, systemHostname),
|
||||
@@ -67,7 +58,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
CPUs: runtime.NumCPU(),
|
||||
NetbirdVersion: version.NetbirdVersion(),
|
||||
UIVersion: extractUserAgent(ctx),
|
||||
KernelVersion: osInfo[1],
|
||||
KernelVersion: kernelVersion,
|
||||
NetworkAddresses: addrs,
|
||||
SystemSerialNumber: si.SystemSerialNumber,
|
||||
SystemProductName: si.SystemProductName,
|
||||
@@ -78,18 +69,12 @@ func GetInfo(ctx context.Context) *Info {
|
||||
return gio
|
||||
}
|
||||
|
||||
func _getInfo() string {
|
||||
cmd := exec.Command("uname", "-srio")
|
||||
cmd.Stdin = strings.NewReader("some")
|
||||
var out bytes.Buffer
|
||||
var stderr bytes.Buffer
|
||||
cmd.Stdout = &out
|
||||
cmd.Stderr = &stderr
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
log.Warnf("getInfo: %s", err)
|
||||
func kernelInfo() (string, string, string) {
|
||||
var uts unix.Utsname
|
||||
if err := unix.Uname(&uts); err != nil {
|
||||
return "", "", ""
|
||||
}
|
||||
return out.String()
|
||||
return unix.ByteSliceToString(uts.Sysname[:]), unix.ByteSliceToString(uts.Release[:]), unix.ByteSliceToString(uts.Machine[:])
|
||||
}
|
||||
|
||||
func sysInfo() (string, string, string) {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"sync"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
@@ -13,7 +14,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
certValidationTimeout = 60 * time.Second
|
||||
certValidationTimeout = 5 * time.Minute
|
||||
)
|
||||
|
||||
func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) {
|
||||
@@ -46,17 +47,31 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert
|
||||
|
||||
promise := conn.wsHandlers.Call("onCertificateRequest", certInfo)
|
||||
|
||||
resultChan := make(chan bool)
|
||||
errorChan := make(chan error)
|
||||
resultChan := make(chan bool, 1)
|
||||
errorChan := make(chan error, 1)
|
||||
|
||||
promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
result := args[0].Bool()
|
||||
resultChan <- result
|
||||
// Release from inside the callbacks so a post-timeout promise resolution
|
||||
// does not invoke an already-released func.
|
||||
var thenFn, catchFn js.Func
|
||||
var releaseOnce sync.Once
|
||||
release := func() {
|
||||
releaseOnce.Do(func() {
|
||||
thenFn.Release()
|
||||
catchFn.Release()
|
||||
})
|
||||
}
|
||||
thenFn = js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
defer release()
|
||||
resultChan <- args[0].Bool()
|
||||
return nil
|
||||
})).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
})
|
||||
catchFn = js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
defer release()
|
||||
errorChan <- fmt.Errorf("certificate validation failed")
|
||||
return nil
|
||||
}))
|
||||
})
|
||||
|
||||
promise.Call("then", thenFn).Call("catch", catchFn)
|
||||
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
@@ -57,6 +58,8 @@ type RDCleanPathProxy struct {
|
||||
}
|
||||
activeConnections map[string]*proxyConnection
|
||||
destinations map[string]string
|
||||
pendingHandlers map[string]js.Func
|
||||
nextID atomic.Uint64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
@@ -66,8 +69,15 @@ type proxyConnection struct {
|
||||
rdpConn net.Conn
|
||||
tlsConn *tls.Conn
|
||||
wsHandlers js.Value
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
// Go-side callbacks exposed to JS. js.FuncOf pins the Go closure in a
|
||||
// global handle map and MUST be released, otherwise every connection
|
||||
// leaks the Go memory the closure captures.
|
||||
wsHandlerFn js.Func
|
||||
onMessageFn js.Func
|
||||
onCloseFn js.Func
|
||||
cleanupOnce sync.Once
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewRDCleanPathProxy creates a new RDCleanPath proxy
|
||||
@@ -80,7 +90,11 @@ func NewRDCleanPathProxy(client interface {
|
||||
}
|
||||
}
|
||||
|
||||
// CreateProxy creates a new proxy endpoint for the given destination
|
||||
// CreateProxy creates a new proxy endpoint for the given destination.
|
||||
// The registered handler fn and its destinations/pendingHandlers entries are
|
||||
// only released once a connection is established and cleanupConnection runs.
|
||||
// If a caller invokes CreateProxy but never connects to the returned URL,
|
||||
// those entries stay pinned for the lifetime of the page.
|
||||
func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
||||
destination := net.JoinHostPort(hostname, port)
|
||||
|
||||
@@ -88,7 +102,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
||||
resolve := args[0]
|
||||
|
||||
go func() {
|
||||
proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections))
|
||||
proxyID := fmt.Sprintf("proxy_%d", p.nextID.Add(1))
|
||||
|
||||
p.mu.Lock()
|
||||
if p.destinations == nil {
|
||||
@@ -100,7 +114,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
||||
proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID)
|
||||
|
||||
// Register the WebSocket handler for this specific proxy
|
||||
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
handlerFn := js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf("error: requires WebSocket argument")
|
||||
}
|
||||
@@ -108,7 +122,14 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
||||
ws := args[0]
|
||||
p.HandleWebSocketConnection(ws, proxyID)
|
||||
return nil
|
||||
}))
|
||||
})
|
||||
p.mu.Lock()
|
||||
if p.pendingHandlers == nil {
|
||||
p.pendingHandlers = make(map[string]js.Func)
|
||||
}
|
||||
p.pendingHandlers[proxyID] = handlerFn
|
||||
p.mu.Unlock()
|
||||
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), handlerFn)
|
||||
|
||||
log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination)
|
||||
resolve.Invoke(proxyURL)
|
||||
@@ -142,6 +163,10 @@ func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string
|
||||
|
||||
p.mu.Lock()
|
||||
p.activeConnections[proxyID] = conn
|
||||
if fn, ok := p.pendingHandlers[proxyID]; ok {
|
||||
conn.wsHandlerFn = fn
|
||||
delete(p.pendingHandlers, proxyID)
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
p.setupWebSocketHandlers(ws, conn)
|
||||
@@ -150,7 +175,7 @@ func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) {
|
||||
ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
conn.onMessageFn = js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return nil
|
||||
}
|
||||
@@ -158,13 +183,15 @@ func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnec
|
||||
data := args[0]
|
||||
go p.handleWebSocketMessage(conn, data)
|
||||
return nil
|
||||
}))
|
||||
})
|
||||
ws.Set("onGoMessage", conn.onMessageFn)
|
||||
|
||||
ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
conn.onCloseFn = js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
log.Debug("WebSocket closed by JavaScript")
|
||||
conn.cancel()
|
||||
return nil
|
||||
}))
|
||||
})
|
||||
ws.Set("onGoClose", conn.onCloseFn)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) {
|
||||
@@ -261,25 +288,49 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) {
|
||||
log.Debugf("Cleaning up connection %s", conn.id)
|
||||
conn.cancel()
|
||||
if conn.tlsConn != nil {
|
||||
log.Debug("Closing TLS connection")
|
||||
if err := conn.tlsConn.Close(); err != nil {
|
||||
log.Debugf("Error closing TLS connection: %v", err)
|
||||
conn.cleanupOnce.Do(func() {
|
||||
log.Debugf("Cleaning up connection %s", conn.id)
|
||||
conn.cancel()
|
||||
if conn.tlsConn != nil {
|
||||
log.Debug("Closing TLS connection")
|
||||
if err := conn.tlsConn.Close(); err != nil {
|
||||
log.Debugf("Error closing TLS connection: %v", err)
|
||||
}
|
||||
conn.tlsConn = nil
|
||||
}
|
||||
conn.tlsConn = nil
|
||||
}
|
||||
if conn.rdpConn != nil {
|
||||
log.Debug("Closing TCP connection")
|
||||
if err := conn.rdpConn.Close(); err != nil {
|
||||
log.Debugf("Error closing TCP connection: %v", err)
|
||||
if conn.rdpConn != nil {
|
||||
log.Debug("Closing TCP connection")
|
||||
if err := conn.rdpConn.Close(); err != nil {
|
||||
log.Debugf("Error closing TCP connection: %v", err)
|
||||
}
|
||||
conn.rdpConn = nil
|
||||
}
|
||||
conn.rdpConn = nil
|
||||
}
|
||||
p.mu.Lock()
|
||||
delete(p.activeConnections, conn.id)
|
||||
p.mu.Unlock()
|
||||
js.Global().Delete(fmt.Sprintf("handleRDCleanPathWebSocket_%s", conn.id))
|
||||
|
||||
// Detach before releasing so late JS calls surface as TypeError instead
|
||||
// of silent "call to released function".
|
||||
if conn.wsHandlers.Truthy() {
|
||||
conn.wsHandlers.Set("onGoMessage", js.Undefined())
|
||||
conn.wsHandlers.Set("onGoClose", js.Undefined())
|
||||
}
|
||||
|
||||
// wsHandlerFn may be zero-value if the pending handler lookup missed.
|
||||
if conn.wsHandlerFn.Truthy() {
|
||||
conn.wsHandlerFn.Release()
|
||||
}
|
||||
if conn.onMessageFn.Truthy() {
|
||||
conn.onMessageFn.Release()
|
||||
}
|
||||
if conn.onCloseFn.Truthy() {
|
||||
conn.onCloseFn.Release()
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
delete(p.activeConnections, conn.id)
|
||||
delete(p.destinations, conn.id)
|
||||
delete(p.pendingHandlers, conn.id)
|
||||
p.mu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
func CreateJSInterface(client *Client) js.Value {
|
||||
jsInterface := js.Global().Get("Object").Call("create", js.Null())
|
||||
|
||||
jsInterface.Set("write", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
writeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
@@ -32,9 +32,10 @@ func CreateJSInterface(client *Client) js.Value {
|
||||
|
||||
_, err := client.Write(bytes)
|
||||
return js.ValueOf(err == nil)
|
||||
}))
|
||||
})
|
||||
jsInterface.Set("write", writeFunc)
|
||||
|
||||
jsInterface.Set("resize", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
resizeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 2 {
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
@@ -42,14 +43,26 @@ func CreateJSInterface(client *Client) js.Value {
|
||||
rows := args[1].Int()
|
||||
err := client.Resize(cols, rows)
|
||||
return js.ValueOf(err == nil)
|
||||
}))
|
||||
})
|
||||
jsInterface.Set("resize", resizeFunc)
|
||||
|
||||
jsInterface.Set("close", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
closeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
client.Close()
|
||||
return js.Undefined()
|
||||
}))
|
||||
})
|
||||
jsInterface.Set("close", closeFunc)
|
||||
|
||||
go readLoop(client, jsInterface)
|
||||
go func() {
|
||||
readLoop(client, jsInterface)
|
||||
// Detach before releasing so late JS calls surface as TypeError instead
|
||||
// of silent "call to released function".
|
||||
jsInterface.Set("write", js.Undefined())
|
||||
jsInterface.Set("resize", js.Undefined())
|
||||
jsInterface.Set("close", js.Undefined())
|
||||
writeFunc.Release()
|
||||
resizeFunc.Release()
|
||||
closeFunc.Release()
|
||||
}()
|
||||
|
||||
return jsInterface
|
||||
}
|
||||
|
||||
@@ -332,7 +332,7 @@ func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) {
|
||||
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
|
||||
}
|
||||
|
||||
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
|
||||
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), s.IDPHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
|
||||
if servers.relaySrv != nil {
|
||||
log.Infof("Relay WebSocket handler added (path: /relay)")
|
||||
}
|
||||
@@ -521,7 +521,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
|
||||
}
|
||||
|
||||
// createCombinedHandler creates an HTTP handler that multiplexes Management, Signal (via wsproxy), and Relay WebSocket traffic
|
||||
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
|
||||
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
|
||||
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
|
||||
|
||||
var relayAcceptFn func(conn listener.Conn)
|
||||
@@ -556,6 +556,10 @@ func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, re
|
||||
http.Error(w, "Relay service not enabled", http.StatusNotFound)
|
||||
}
|
||||
|
||||
// Embedded IdP (Dex)
|
||||
case idpHandler != nil && strings.HasPrefix(r.URL.Path, "/oauth2"):
|
||||
idpHandler.ServeHTTP(w, r)
|
||||
|
||||
// Management HTTP API (default)
|
||||
default:
|
||||
httpHandler.ServeHTTP(w, r)
|
||||
|
||||
2
go.mod
2
go.mod
@@ -335,7 +335,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
||||
|
||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f
|
||||
|
||||
replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||
|
||||
|
||||
4
go.sum
4
go.sum
@@ -499,8 +499,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f h1:ff2D57RBjWtyQ2wVwJOxOgXAXOe/J2lJWtSX0Bz/BRk=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
|
||||
|
||||
@@ -51,7 +51,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
||||
found = true
|
||||
select {
|
||||
case channel <- update:
|
||||
log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID)
|
||||
log.WithContext(ctx).Tracef("update was sent to channel for peer %s", peerID)
|
||||
default:
|
||||
dropped = true
|
||||
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
@@ -47,6 +48,11 @@ type EphemeralManager struct {
|
||||
|
||||
lifeTime time.Duration
|
||||
cleanupWindow time.Duration
|
||||
|
||||
// metrics is nil-safe; methods on telemetry.EphemeralPeersMetrics
|
||||
// no-op when the receiver is nil so deployments without an app
|
||||
// metrics provider work unchanged.
|
||||
metrics *telemetry.EphemeralPeersMetrics
|
||||
}
|
||||
|
||||
// NewEphemeralManager instantiate new EphemeralManager
|
||||
@@ -60,6 +66,15 @@ func NewEphemeralManager(store store.Store, peersManager peers.Manager) *Ephemer
|
||||
}
|
||||
}
|
||||
|
||||
// SetMetrics attaches a metrics collector. Safe to call once before
|
||||
// LoadInitialPeers; later attachment is fine but earlier loads won't be
|
||||
// reflected in the gauge. Pass nil to detach.
|
||||
func (e *EphemeralManager) SetMetrics(m *telemetry.EphemeralPeersMetrics) {
|
||||
e.peersLock.Lock()
|
||||
e.metrics = m
|
||||
e.peersLock.Unlock()
|
||||
}
|
||||
|
||||
// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head
|
||||
// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new
|
||||
// head.
|
||||
@@ -97,7 +112,9 @@ func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Pee
|
||||
e.peersLock.Lock()
|
||||
defer e.peersLock.Unlock()
|
||||
|
||||
e.removePeer(peer.ID)
|
||||
if e.removePeer(peer.ID) {
|
||||
e.metrics.DecPending(1)
|
||||
}
|
||||
|
||||
// stop the unnecessary timer
|
||||
if e.headPeer == nil && e.timer != nil {
|
||||
@@ -123,6 +140,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
|
||||
}
|
||||
|
||||
e.addPeer(peer.AccountID, peer.ID, e.newDeadLine())
|
||||
e.metrics.IncPending()
|
||||
if e.timer == nil {
|
||||
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
|
||||
if delay < 0 {
|
||||
@@ -145,6 +163,7 @@ func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
|
||||
for _, p := range peers {
|
||||
e.addPeer(p.AccountID, p.ID, t)
|
||||
}
|
||||
e.metrics.AddPending(int64(len(peers)))
|
||||
|
||||
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", len(peers))
|
||||
}
|
||||
@@ -181,6 +200,15 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
|
||||
|
||||
e.peersLock.Unlock()
|
||||
|
||||
// Drop the gauge by the number of entries we just took off the list,
|
||||
// regardless of whether the subsequent DeletePeers call succeeds. The
|
||||
// list invariant is what the gauge tracks; failed delete batches are
|
||||
// counted separately via CountCleanupError so we can still see them.
|
||||
if len(deletePeers) > 0 {
|
||||
e.metrics.CountCleanupRun()
|
||||
e.metrics.DecPending(int64(len(deletePeers)))
|
||||
}
|
||||
|
||||
peerIDsPerAccount := make(map[string][]string)
|
||||
for id, p := range deletePeers {
|
||||
peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id)
|
||||
@@ -191,7 +219,10 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
|
||||
err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete ephemeral peers: %s", err)
|
||||
e.metrics.CountCleanupError()
|
||||
continue
|
||||
}
|
||||
e.metrics.CountPeersCleaned(int64(len(peerIDs)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -211,9 +242,12 @@ func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline tim
|
||||
e.tailPeer = ep
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) removePeer(id string) {
|
||||
// removePeer drops the entry from the linked list. Returns true if a
|
||||
// matching entry was found and removed so callers can keep the pending
|
||||
// metric gauge in sync.
|
||||
func (e *EphemeralManager) removePeer(id string) bool {
|
||||
if e.headPeer == nil {
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
if e.headPeer.id == id {
|
||||
@@ -221,7 +255,7 @@ func (e *EphemeralManager) removePeer(id string) {
|
||||
if e.tailPeer.id == id {
|
||||
e.tailPeer = nil
|
||||
}
|
||||
return
|
||||
return true
|
||||
}
|
||||
|
||||
for p := e.headPeer; p.next != nil; p = p.next {
|
||||
@@ -231,9 +265,10 @@ func (e *EphemeralManager) removePeer(id string) {
|
||||
e.tailPeer = p
|
||||
}
|
||||
p.next = p.next.next
|
||||
return
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) isPeerOnList(id string) bool {
|
||||
|
||||
@@ -5,6 +5,7 @@ package peers
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
@@ -35,6 +36,14 @@ type Manager interface {
|
||||
SetAccountManager(accountManager account.Manager)
|
||||
GetPeerID(ctx context.Context, peerKey string) (string, error)
|
||||
CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error
|
||||
// GetPeerByTunnelIP looks up a peer in accountID by its WireGuard tunnel IP.
|
||||
// Returns nil with an error when no match exists. No permission check;
|
||||
// callers (the proxy's ValidateTunnelPeer RPC) are trusted server components.
|
||||
GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error)
|
||||
// GetPeerWithGroups returns the peer and the list of *types.Group it belongs
|
||||
// to. Used by the proxy's auth path to authorise a request by the calling
|
||||
// peer's group memberships.
|
||||
GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error)
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
@@ -99,6 +108,26 @@ func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string,
|
||||
return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
|
||||
}
|
||||
|
||||
// GetPeerByTunnelIP delegates to the store's indexed lookup.
|
||||
func (m *managerImpl) GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error) {
|
||||
return m.store.GetPeerByIP(ctx, store.LockingStrengthNone, accountID, ip)
|
||||
}
|
||||
|
||||
// GetPeerWithGroups returns the peer plus its group memberships. Any store
|
||||
// error returns (nil, nil, err) so callers never receive a valid peer
|
||||
// alongside a non-nil error.
|
||||
func (m *managerImpl) GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error) {
|
||||
p, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
groups, err := m.store.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return p, groups, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
|
||||
@@ -6,6 +6,7 @@ package peers
|
||||
|
||||
import (
|
||||
context "context"
|
||||
net "net"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
account "github.com/netbirdio/netbird/management/server/account"
|
||||
integrated_validator "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
peer "github.com/netbirdio/netbird/management/server/peer"
|
||||
types "github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// MockManager is a mock of Manager interface.
|
||||
@@ -38,6 +40,20 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// CreateProxyPeer mocks base method.
|
||||
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID, peerKey, cluster string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
|
||||
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
|
||||
}
|
||||
|
||||
// DeletePeers mocks base method.
|
||||
func (m *MockManager) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -97,6 +113,21 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID)
|
||||
}
|
||||
|
||||
// GetPeerByTunnelIP mocks base method.
|
||||
func (m *MockManager) GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPeerByTunnelIP", ctx, accountID, ip)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPeerByTunnelIP indicates an expected call of GetPeerByTunnelIP.
|
||||
func (mr *MockManagerMockRecorder) GetPeerByTunnelIP(ctx, accountID, ip interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerByTunnelIP", reflect.TypeOf((*MockManager)(nil).GetPeerByTunnelIP), ctx, accountID, ip)
|
||||
}
|
||||
|
||||
// GetPeerID mocks base method.
|
||||
func (m *MockManager) GetPeerID(ctx context.Context, peerKey string) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -112,6 +143,22 @@ func (mr *MockManagerMockRecorder) GetPeerID(ctx, peerKey interface{}) *gomock.C
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerID", reflect.TypeOf((*MockManager)(nil).GetPeerID), ctx, peerKey)
|
||||
}
|
||||
|
||||
// GetPeerWithGroups mocks base method.
|
||||
func (m *MockManager) GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPeerWithGroups", ctx, accountID, peerID)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].([]*types.Group)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// GetPeerWithGroups indicates an expected call of GetPeerWithGroups.
|
||||
func (mr *MockManagerMockRecorder) GetPeerWithGroups(ctx, accountID, peerID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerWithGroups", reflect.TypeOf((*MockManager)(nil).GetPeerWithGroups), ctx, accountID, peerID)
|
||||
}
|
||||
|
||||
// GetPeersByGroupIDs mocks base method.
|
||||
func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -162,17 +209,3 @@ func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
|
||||
}
|
||||
|
||||
// CreateProxyPeer mocks base method.
|
||||
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
|
||||
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
|
||||
}
|
||||
|
||||
@@ -23,6 +23,8 @@ type Domain struct {
|
||||
// SupportsCrowdSec is populated at query time from proxy cluster capabilities.
|
||||
// Not persisted.
|
||||
SupportsCrowdSec *bool `gorm:"-"`
|
||||
// SupportsPrivate is populated at query time from proxy cluster capabilities. Not persisted.
|
||||
SupportsPrivate *bool `gorm:"-"`
|
||||
}
|
||||
|
||||
// EventMeta returns activity event metadata for a domain
|
||||
|
||||
@@ -49,6 +49,7 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
|
||||
SupportsCustomPorts: d.SupportsCustomPorts,
|
||||
RequireSubdomain: d.RequireSubdomain,
|
||||
SupportsCrowdsec: d.SupportsCrowdSec,
|
||||
SupportsPrivate: d.SupportsPrivate,
|
||||
}
|
||||
if d.TargetCluster != "" {
|
||||
resp.TargetCluster = &d.TargetCluster
|
||||
|
||||
@@ -35,6 +35,7 @@ type proxyManager interface {
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
@@ -93,6 +94,7 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
d.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, cluster)
|
||||
d.RequireSubdomain = m.proxyManager.ClusterRequireSubdomain(ctx, cluster)
|
||||
d.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, cluster)
|
||||
d.SupportsPrivate = m.proxyManager.ClusterSupportsPrivate(ctx, cluster)
|
||||
ret = append(ret, d)
|
||||
}
|
||||
|
||||
@@ -109,6 +111,7 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
if d.TargetCluster != "" {
|
||||
cd.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, d.TargetCluster)
|
||||
cd.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, d.TargetCluster)
|
||||
cd.SupportsPrivate = m.proxyManager.ClusterSupportsPrivate(ctx, d.TargetCluster)
|
||||
}
|
||||
// Custom domains never require a subdomain by default since
|
||||
// the account owns them and should be able to use the bare domain.
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
type mockProxyManager struct {
|
||||
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||
}
|
||||
|
||||
@@ -40,6 +40,10 @@ func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProxyManager) ClusterSupportsPrivate(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
|
||||
@@ -151,4 +155,3 @@ func TestGetClusterAllowList_PublicEmpty_BYOPOnly(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"byop.example.com"}, result)
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ type Manager interface {
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
|
||||
CountAccountProxies(ctx context.Context, accountID string) (int64, error)
|
||||
|
||||
@@ -17,10 +17,11 @@ type store interface {
|
||||
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
|
||||
GetProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
||||
@@ -137,6 +138,11 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string
|
||||
return m.store.GetClusterSupportsCrowdSec(ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// ClusterSupportsPrivate reports whether any active proxy claims the private capability (nil = unreported).
|
||||
func (m Manager) ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
|
||||
return m.store.GetClusterSupportsPrivate(ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
|
||||
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
|
||||
@@ -178,4 +184,3 @@ func (m *Manager) DeleteAccountCluster(ctx context.Context, clusterAddress, acco
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -15,16 +15,16 @@ import (
|
||||
)
|
||||
|
||||
type mockStore struct {
|
||||
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
|
||||
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
|
||||
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
|
||||
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
|
||||
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
|
||||
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
|
||||
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
|
||||
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
|
||||
}
|
||||
|
||||
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||
@@ -57,7 +57,7 @@ func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) GetActiveProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) {
|
||||
func (m *mockStore) GetProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error {
|
||||
@@ -99,6 +99,9 @@ func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *boo
|
||||
func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetClusterSupportsPrivate(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTestManager(s store) *Manager {
|
||||
meter := noop.NewMeterProvider().Meter("test")
|
||||
|
||||
@@ -92,6 +92,20 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCrowdSec", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCrowdSec), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// ClusterSupportsPrivate mocks base method.
|
||||
func (m *MockManager) ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ClusterSupportsPrivate", ctx, clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ClusterSupportsPrivate indicates an expected call of ClusterSupportsPrivate.
|
||||
func (mr *MockManagerMockRecorder) ClusterSupportsPrivate(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsPrivate", reflect.TypeOf((*MockManager)(nil).ClusterSupportsPrivate), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// Connect mocks base method.
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -20,6 +20,9 @@ type Capabilities struct {
|
||||
RequireSubdomain *bool
|
||||
// SupportsCrowdsec indicates whether this proxy has CrowdSec configured.
|
||||
SupportsCrowdsec *bool
|
||||
// Private indicates whether this proxy supports inbound access via Wireguard
|
||||
// tunnel and netbird-only authentication policies
|
||||
Private *bool
|
||||
}
|
||||
|
||||
// Proxy represents a reverse proxy instance
|
||||
@@ -42,10 +45,34 @@ func (Proxy) TableName() string {
|
||||
return "proxies"
|
||||
}
|
||||
|
||||
// ClusterType is the source of a proxy cluster.
|
||||
type ClusterType string
|
||||
|
||||
const (
|
||||
// ClusterTypeAccount is a cluster operated by the account itself (BYOP) —
|
||||
// at least one proxy row in the cluster carries a non-NULL account_id.
|
||||
ClusterTypeAccount ClusterType = "account"
|
||||
// ClusterTypeShared is a cluster operated by NetBird and shared across
|
||||
// accounts — all proxy rows in the cluster have account_id IS NULL.
|
||||
ClusterTypeShared ClusterType = "shared"
|
||||
)
|
||||
|
||||
// Cluster represents a group of proxy nodes serving the same address.
|
||||
//
|
||||
// Online and ConnectedProxies derive from the same 2-min active window
|
||||
// the rest of the module uses, but Cluster rows are not gated on it —
|
||||
// the cluster listing surfaces offline clusters too so operators can
|
||||
// see and clean them up. The 1-hour heartbeat reaper still bounds the
|
||||
// table eventually.
|
||||
type Cluster struct {
|
||||
ID string
|
||||
Address string
|
||||
Type ClusterType
|
||||
Online bool
|
||||
ConnectedProxies int
|
||||
SelfHosted bool
|
||||
// *bool: nil = no proxy reported the capability; the dashboard renders that as unknown.
|
||||
SupportsCustomPorts *bool
|
||||
RequireSubdomain *bool
|
||||
SupportsCrowdSec *bool
|
||||
Private *bool
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
|
||||
GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
|
||||
DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error
|
||||
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
|
||||
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
|
||||
|
||||
@@ -65,20 +65,6 @@ func (mr *MockManagerMockRecorder) CreateServiceFromPeer(ctx, accountID, peerID,
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateServiceFromPeer", reflect.TypeOf((*MockManager)(nil).CreateServiceFromPeer), ctx, accountID, peerID, req)
|
||||
}
|
||||
|
||||
// DeleteAllServices mocks base method.
|
||||
func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAllServices", ctx, accountID, userID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAllServices indicates an expected call of DeleteAllServices.
|
||||
func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// DeleteAccountCluster mocks base method.
|
||||
func (m *MockManager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -93,6 +79,20 @@ func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, accountID, userID,
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, accountID, userID, clusterAddress)
|
||||
}
|
||||
|
||||
// DeleteAllServices mocks base method.
|
||||
func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAllServices", ctx, accountID, userID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAllServices indicates an expected call of DeleteAllServices.
|
||||
func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// DeleteService mocks base method.
|
||||
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -122,21 +122,6 @@ func (mr *MockManagerMockRecorder) GetAccountServices(ctx, accountID interface{}
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockManager)(nil).GetAccountServices), ctx, accountID)
|
||||
}
|
||||
|
||||
// GetActiveClusters mocks base method.
|
||||
func (m *MockManager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveClusters", ctx, accountID, userID)
|
||||
ret0, _ := ret[0].([]proxy.Cluster)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveClusters indicates an expected call of GetActiveClusters.
|
||||
func (mr *MockManagerMockRecorder) GetActiveClusters(ctx, accountID, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// GetAllServices mocks base method.
|
||||
func (m *MockManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -152,19 +137,19 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// GetServiceByDomain mocks base method.
|
||||
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
|
||||
// GetClusters mocks base method.
|
||||
func (m *MockManager) GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
|
||||
ret0, _ := ret[0].(*Service)
|
||||
ret := m.ctrl.Call(m, "GetClusters", ctx, accountID, userID)
|
||||
ret0, _ := ret[0].([]proxy.Cluster)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
|
||||
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
|
||||
// GetClusters indicates an expected call of GetClusters.
|
||||
func (mr *MockManagerMockRecorder) GetClusters(ctx, accountID, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusters", reflect.TypeOf((*MockManager)(nil).GetClusters), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// GetGlobalServices mocks base method.
|
||||
@@ -197,6 +182,21 @@ func (mr *MockManagerMockRecorder) GetService(ctx, accountID, userID, serviceID
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetService", reflect.TypeOf((*MockManager)(nil).GetService), ctx, accountID, userID, serviceID)
|
||||
}
|
||||
|
||||
// GetServiceByDomain mocks base method.
|
||||
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
|
||||
ret0, _ := ret[0].(*Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
|
||||
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
|
||||
}
|
||||
|
||||
// GetServiceByID mocks base method.
|
||||
func (m *MockManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -187,7 +187,7 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
clusters, err := h.manager.GetActiveClusters(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
clusters, err := h.manager.GetClusters(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -196,10 +196,15 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
|
||||
apiClusters := make([]api.ProxyCluster, 0, len(clusters))
|
||||
for _, c := range clusters {
|
||||
apiClusters = append(apiClusters, api.ProxyCluster{
|
||||
Id: c.ID,
|
||||
Address: c.Address,
|
||||
ConnectedProxies: c.ConnectedProxies,
|
||||
SelfHosted: c.SelfHosted,
|
||||
Id: c.ID,
|
||||
Address: c.Address,
|
||||
Type: api.ProxyClusterType(c.Type),
|
||||
Online: c.Online,
|
||||
ConnectedProxies: c.ConnectedProxies,
|
||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||
RequireSubdomain: c.RequireSubdomain,
|
||||
SupportsCrowdsec: c.SupportsCrowdSec,
|
||||
Private: c.Private,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -81,6 +81,8 @@ type ClusterDeriver interface {
|
||||
type CapabilityProvider interface {
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
@@ -112,8 +114,12 @@ func (m *Manager) StartExposeReaper(ctx context.Context) {
|
||||
m.exposeReaper.StartExposeReaper(ctx)
|
||||
}
|
||||
|
||||
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
|
||||
func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
|
||||
// GetClusters returns every proxy cluster visible to the account
|
||||
// (shared + its own BYOP), regardless of whether any proxy in the
|
||||
// cluster is currently heartbeating. Each cluster is enriched with the
|
||||
// capability flags reported by its active proxies so the dashboard can
|
||||
// render feature support without a second round-trip.
|
||||
func (m *Manager) GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
@@ -122,7 +128,19 @@ func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID strin
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetActiveProxyClusters(ctx, accountID)
|
||||
clusters, err := m.store.GetProxyClusters(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range clusters {
|
||||
clusters[i].SupportsCustomPorts = m.capabilities.ClusterSupportsCustomPorts(ctx, clusters[i].Address)
|
||||
clusters[i].RequireSubdomain = m.capabilities.ClusterRequireSubdomain(ctx, clusters[i].Address)
|
||||
clusters[i].SupportsCrowdSec = m.capabilities.ClusterSupportsCrowdSec(ctx, clusters[i].Address)
|
||||
clusters[i].Private = m.capabilities.ClusterSupportsPrivate(ctx, clusters[i].Address)
|
||||
}
|
||||
|
||||
return clusters, nil
|
||||
}
|
||||
|
||||
// DeleteAccountCluster removes all proxy registrations for the given cluster address
|
||||
@@ -192,6 +210,9 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
|
||||
target.Host = resource.Domain
|
||||
case service.TargetTypeSubnet:
|
||||
// For subnets we do not do any lookups on the resource
|
||||
case service.TargetTypeCluster:
|
||||
// Cluster targets carry the upstream address on target_id; the
|
||||
// proxy resolves the destination at request time.
|
||||
default:
|
||||
return fmt.Errorf("unknown target type: %s", target.TargetType)
|
||||
}
|
||||
@@ -763,6 +784,10 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
|
||||
if err := validateResourceTarget(ctx, transaction, accountID, target); err != nil {
|
||||
return err
|
||||
}
|
||||
case service.TargetTypeCluster:
|
||||
if err := validateClusterTarget(target); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId)
|
||||
}
|
||||
@@ -770,6 +795,13 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateClusterTarget(target *service.Target) error {
|
||||
if !target.Options.DirectUpstream {
|
||||
return status.Errorf(status.InvalidArgument, "cluster target %s has direct upstream disabled", target.Host)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validatePeerTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error {
|
||||
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||
@@ -946,12 +978,14 @@ func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID str
|
||||
return fmt.Errorf("failed to get services: %w", err)
|
||||
}
|
||||
|
||||
oidcCfg := m.proxyController.GetOIDCValidationConfig()
|
||||
|
||||
for _, s := range services {
|
||||
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||
}
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -1344,3 +1344,66 @@ func TestValidateSubdomainRequirement(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTargetReferences_ClusterTargetSkipsLookup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
accountID := "test-account"
|
||||
|
||||
// No peer or resource lookups must be issued for cluster targets.
|
||||
targets := []*rpservice.Target{
|
||||
{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: rpservice.TargetTypeCluster,
|
||||
Options: rpservice.TargetOptions{DirectUpstream: true},
|
||||
},
|
||||
}
|
||||
require.NoError(t, validateTargetReferences(ctx, mockStore, accountID, targets), "cluster target must validate without store lookups")
|
||||
}
|
||||
|
||||
// TestValidateTargetReferences_ClusterTargetRequiresDirectUpstream pins the
|
||||
// store-side check that cluster targets must opt into the host-stack dial
|
||||
// path. Without DirectUpstream the proxy would route this target through
|
||||
// the embedded NetBird client and fail on every request.
|
||||
func TestValidateTargetReferences_ClusterTargetRequiresDirectUpstream(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
accountID := "test-account"
|
||||
|
||||
targets := []*rpservice.Target{
|
||||
{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: rpservice.TargetTypeCluster,
|
||||
Host: "backend.lan",
|
||||
},
|
||||
}
|
||||
err := validateTargetReferences(ctx, mockStore, accountID, targets)
|
||||
require.Error(t, err, "cluster target without direct_upstream must be rejected")
|
||||
assert.ErrorContains(t, err, "direct upstream disabled")
|
||||
}
|
||||
|
||||
func TestReplaceHostByLookup_SkipsClusterTarget(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
accountID := "test-account"
|
||||
|
||||
mgr := &Manager{store: mockStore}
|
||||
|
||||
svc := &rpservice.Service{
|
||||
ID: "svc-1",
|
||||
AccountID: accountID,
|
||||
Targets: []*rpservice.Target{
|
||||
{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: rpservice.TargetTypeCluster,
|
||||
Host: "127.0.0.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
require.NoError(t, mgr.replaceHostByLookup(ctx, accountID, svc), "cluster target must not trigger peer/resource lookup")
|
||||
assert.Equal(t, "127.0.0.1", svc.Targets[0].Host, "operator-supplied host must be preserved for cluster target")
|
||||
}
|
||||
|
||||
@@ -45,10 +45,11 @@ const (
|
||||
StatusCertificateFailed Status = "certificate_failed"
|
||||
StatusError Status = "error"
|
||||
|
||||
TargetTypePeer TargetType = "peer"
|
||||
TargetTypeHost TargetType = "host"
|
||||
TargetTypeDomain TargetType = "domain"
|
||||
TargetTypeSubnet TargetType = "subnet"
|
||||
TargetTypePeer TargetType = "peer"
|
||||
TargetTypeHost TargetType = "host"
|
||||
TargetTypeDomain TargetType = "domain"
|
||||
TargetTypeSubnet TargetType = "subnet"
|
||||
TargetTypeCluster TargetType = "cluster"
|
||||
|
||||
SourcePermanent = "permanent"
|
||||
SourceEphemeral = "ephemeral"
|
||||
@@ -60,6 +61,11 @@ type TargetOptions struct {
|
||||
SessionIdleTimeout time.Duration `json:"session_idle_timeout,omitempty"`
|
||||
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
|
||||
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
|
||||
// DirectUpstream bypasses the proxy's embedded NetBird client and dials
|
||||
// the target via the proxy host's network stack. Useful for upstreams
|
||||
// reachable without WireGuard (public APIs, LAN services, localhost
|
||||
// sidecars). Default false.
|
||||
DirectUpstream bool `json:"direct_upstream,omitempty"`
|
||||
}
|
||||
|
||||
type Target struct {
|
||||
@@ -67,7 +73,7 @@ type Target struct {
|
||||
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
||||
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
|
||||
Path *string `json:"path,omitempty"`
|
||||
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
|
||||
Host string `json:"host"`
|
||||
Port uint16 `gorm:"index:idx_target_port" json:"port"`
|
||||
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
|
||||
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
|
||||
@@ -200,6 +206,10 @@ type Service struct {
|
||||
Mode string `gorm:"default:'http'"`
|
||||
ListenPort uint16
|
||||
PortAutoAssigned bool
|
||||
// Private marks the service as NetBird-only: auth via ValidateTunnelPeer against AccessGroups instead of SSO. HTTP-only.
|
||||
Private bool
|
||||
// AccessGroups is the group ID allowlist for inbound peers on private services. Mutually exclusive with bearer SSO.
|
||||
AccessGroups []string `json:"access_groups,omitempty" gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// InitNewRecord generates a new unique ID and resets metadata for a newly created
|
||||
@@ -299,6 +309,12 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
Mode: &mode,
|
||||
ListenPort: &listenPort,
|
||||
PortAutoAssigned: &s.PortAutoAssigned,
|
||||
Private: &s.Private,
|
||||
}
|
||||
|
||||
if len(s.AccessGroups) > 0 {
|
||||
groups := append([]string(nil), s.AccessGroups...)
|
||||
resp.AccessGroups = &groups
|
||||
}
|
||||
|
||||
if s.ProxyCluster != "" {
|
||||
@@ -308,6 +324,7 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
return resp
|
||||
}
|
||||
|
||||
// ToProtoMapping converts the service into the wire format the proxy consumes.
|
||||
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping {
|
||||
pathMappings := s.buildPathMappings()
|
||||
|
||||
@@ -349,6 +366,7 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
|
||||
RewriteRedirects: s.RewriteRedirects,
|
||||
Mode: s.Mode,
|
||||
ListenPort: int32(s.ListenPort), //nolint:gosec
|
||||
Private: s.Private,
|
||||
}
|
||||
|
||||
if r := restrictionsToProto(s.Restrictions); r != nil {
|
||||
@@ -455,7 +473,8 @@ func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode {
|
||||
}
|
||||
|
||||
func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
|
||||
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 {
|
||||
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 &&
|
||||
opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 && !opts.DirectUpstream {
|
||||
return nil
|
||||
}
|
||||
apiOpts := &api.ServiceTargetOptions{}
|
||||
@@ -477,17 +496,22 @@ func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
|
||||
if len(opts.CustomHeaders) > 0 {
|
||||
apiOpts.CustomHeaders = &opts.CustomHeaders
|
||||
}
|
||||
if opts.DirectUpstream {
|
||||
apiOpts.DirectUpstream = &opts.DirectUpstream
|
||||
}
|
||||
return apiOpts
|
||||
}
|
||||
|
||||
func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
|
||||
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && len(opts.CustomHeaders) == 0 {
|
||||
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 &&
|
||||
len(opts.CustomHeaders) == 0 && !opts.DirectUpstream {
|
||||
return nil
|
||||
}
|
||||
popts := &proto.PathTargetOptions{
|
||||
SkipTlsVerify: opts.SkipTLSVerify,
|
||||
PathRewrite: pathRewriteToProto(opts.PathRewrite),
|
||||
CustomHeaders: opts.CustomHeaders,
|
||||
SkipTlsVerify: opts.SkipTLSVerify,
|
||||
PathRewrite: pathRewriteToProto(opts.PathRewrite),
|
||||
CustomHeaders: opts.CustomHeaders,
|
||||
DirectUpstream: opts.DirectUpstream,
|
||||
}
|
||||
if opts.RequestTimeout != 0 {
|
||||
popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
|
||||
@@ -537,6 +561,9 @@ func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions,
|
||||
if o.CustomHeaders != nil {
|
||||
opts.CustomHeaders = *o.CustomHeaders
|
||||
}
|
||||
if o.DirectUpstream != nil {
|
||||
opts.DirectUpstream = *o.DirectUpstream
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
@@ -551,6 +578,14 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro
|
||||
if req.ListenPort != nil {
|
||||
s.ListenPort = uint16(*req.ListenPort) //nolint:gosec
|
||||
}
|
||||
if req.Private != nil {
|
||||
s.Private = *req.Private
|
||||
}
|
||||
if req.AccessGroups != nil {
|
||||
s.AccessGroups = append([]string(nil), *req.AccessGroups...)
|
||||
} else {
|
||||
s.AccessGroups = nil
|
||||
}
|
||||
|
||||
targets, err := targetsFromAPI(accountID, req.Targets)
|
||||
if err != nil {
|
||||
@@ -740,6 +775,9 @@ func (s *Service) Validate() error {
|
||||
if err := validateAccessRestrictions(&s.Restrictions); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.validatePrivateRequirements(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch s.Mode {
|
||||
case ModeHTTP:
|
||||
@@ -753,6 +791,23 @@ func (s *Service) Validate() error {
|
||||
}
|
||||
}
|
||||
|
||||
// validatePrivateRequirements enforces the private-service contract: HTTP mode, ≥1 access group, no bearer auth.
|
||||
func (s *Service) validatePrivateRequirements() error {
|
||||
if !s.Private {
|
||||
return nil
|
||||
}
|
||||
if s.Mode != "" && s.Mode != ModeHTTP {
|
||||
return fmt.Errorf("private services only support HTTP mode, got %q", s.Mode)
|
||||
}
|
||||
if len(s.AccessGroups) == 0 {
|
||||
return errors.New("private services require at least one access group")
|
||||
}
|
||||
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
|
||||
return errors.New("private services cannot enable bearer auth (SSO): NetBird-only access and SSO are mutually exclusive")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) validateHTTPMode() error {
|
||||
if s.Domain == "" {
|
||||
return errors.New("service domain is required")
|
||||
@@ -799,11 +854,21 @@ func (s *Service) validateHTTPTargets() error {
|
||||
for i, target := range s.Targets {
|
||||
switch target.TargetType {
|
||||
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
||||
// host field will be ignored
|
||||
// Host is normally overwritten by replaceHostByLookup with the
|
||||
// resolved peer IP / resource address; operator-supplied values
|
||||
// are honored only when DirectUpstream is set. Validate the
|
||||
// override here so misconfigured hosts fail fast at API time.
|
||||
if err := validateDirectUpstreamHost(i, target); err != nil {
|
||||
return err
|
||||
}
|
||||
case TargetTypeSubnet:
|
||||
if target.Host == "" {
|
||||
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
|
||||
}
|
||||
case TargetTypeCluster:
|
||||
if err := validateClusterTarget(i, target); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
|
||||
}
|
||||
@@ -821,25 +886,67 @@ func (s *Service) validateHTTPTargets() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateClusterTarget cluster targets should not have empty hosts and should have direct upstream enabled.
|
||||
func validateClusterTarget(idx int, target *Target) error {
|
||||
host := strings.TrimSpace(target.Host)
|
||||
if host == "" {
|
||||
return fmt.Errorf("target %d: has empty host", idx)
|
||||
}
|
||||
if !target.Options.DirectUpstream {
|
||||
return fmt.Errorf("target %d: %s has direct upstream disabled", idx, target.Host)
|
||||
}
|
||||
return validateDirectUpstreamHost(idx, target)
|
||||
}
|
||||
|
||||
// validateDirectUpstreamHost validates the operator-supplied Host on a
|
||||
// peer/host/domain target when DirectUpstream is set. Empty Host is
|
||||
// allowed — the lookup fills in the default peer IP / resource address.
|
||||
// Without DirectUpstream the Host value is silently overwritten by
|
||||
// replaceHostByLookup, so we don't validate it (preserves the historical
|
||||
// behaviour where APIs accepted any value and dropped it). Non-empty
|
||||
// Host with DirectUpstream must look like a hostname or IP and must
|
||||
// not carry a port (port lives on Target.Port).
|
||||
func validateDirectUpstreamHost(idx int, target *Target) error {
|
||||
if !target.Options.DirectUpstream {
|
||||
return nil
|
||||
}
|
||||
host := strings.TrimSpace(target.Host)
|
||||
if host == "" {
|
||||
return nil
|
||||
}
|
||||
if strings.ContainsAny(host, " \t/") {
|
||||
return fmt.Errorf("target %d: host %q contains invalid characters", idx, host)
|
||||
}
|
||||
if _, _, err := net.SplitHostPort(host); err == nil {
|
||||
return fmt.Errorf("target %d: host %q must not include a port (set target.port instead)", idx, host)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) validateL4Target(target *Target) error {
|
||||
// L4 services have a single target; per-target disable is meaningless
|
||||
// (use the service-level Enabled flag instead). Force it on so that
|
||||
// buildPathMappings always includes the target in the proto.
|
||||
target.Enabled = true
|
||||
|
||||
if target.Port == 0 {
|
||||
return errors.New("target port is required for L4 services")
|
||||
}
|
||||
if target.TargetId == "" {
|
||||
return errors.New("target_id is required for L4 services")
|
||||
}
|
||||
if target.TargetType != TargetTypeCluster && target.Port == 0 {
|
||||
return errors.New("target port is required for L4 services")
|
||||
}
|
||||
switch target.TargetType {
|
||||
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
||||
// OK
|
||||
if err := validateDirectUpstreamHost(0, target); err != nil {
|
||||
return err
|
||||
}
|
||||
case TargetTypeSubnet:
|
||||
if target.Host == "" {
|
||||
return errors.New("target host is required for subnet targets")
|
||||
}
|
||||
case TargetTypeCluster:
|
||||
// target_id carries the cluster address; the proxy resolves
|
||||
// the upstream at request time.
|
||||
default:
|
||||
return fmt.Errorf("invalid target_type %q for L4 service", target.TargetType)
|
||||
}
|
||||
@@ -1174,6 +1281,11 @@ func (s *Service) Copy() *Service {
|
||||
}
|
||||
}
|
||||
|
||||
var accessGroups []string
|
||||
if len(s.AccessGroups) > 0 {
|
||||
accessGroups = append([]string(nil), s.AccessGroups...)
|
||||
}
|
||||
|
||||
return &Service{
|
||||
ID: s.ID,
|
||||
AccountID: s.AccountID,
|
||||
@@ -1195,6 +1307,8 @@ func (s *Service) Copy() *Service {
|
||||
Mode: s.Mode,
|
||||
ListenPort: s.ListenPort,
|
||||
PortAutoAssigned: s.PortAutoAssigned,
|
||||
Private: s.Private,
|
||||
AccessGroups: accessGroups,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -1116,3 +1117,191 @@ func TestValidate_HeaderAuths(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "exceeds maximum length")
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidate_HTTPClusterTarget(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "backend.lan",
|
||||
Options: TargetOptions{DirectUpstream: true},
|
||||
Enabled: true,
|
||||
}}
|
||||
require.NoError(t, rp.Validate(), "HTTP cluster target with target_id, host, and direct_upstream must validate")
|
||||
}
|
||||
|
||||
func TestValidate_HTTPClusterTarget_RequiresTargetId(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets = []*Target{{
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "backend.lan",
|
||||
Options: TargetOptions{DirectUpstream: true},
|
||||
Enabled: true,
|
||||
}}
|
||||
assert.ErrorContains(t, rp.Validate(), "empty target_id", "cluster target must reject empty target_id")
|
||||
}
|
||||
|
||||
// TestValidate_HTTPClusterTarget_RequiresHost pins the new cluster-target
|
||||
// rule that operator-supplied Host is mandatory: cluster targets dial the
|
||||
// upstream via the host network stack (direct_upstream is implied), so an
|
||||
// empty Host leaves the proxy with nothing to dial.
|
||||
func TestValidate_HTTPClusterTarget_RequiresHost(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Options: TargetOptions{DirectUpstream: true},
|
||||
Enabled: true,
|
||||
}}
|
||||
assert.ErrorContains(t, rp.Validate(), "empty host", "cluster target must reject empty host")
|
||||
}
|
||||
|
||||
// TestValidate_HTTPClusterTarget_RequiresDirectUpstream pins the second
|
||||
// half of the cluster-target rule: DirectUpstream must be true so the
|
||||
// stdlib transport branch in MultiTransport is taken. Without it the
|
||||
// embedded NetBird client would try to dial the cluster address through
|
||||
// the WG tunnel, which is the wrong network for a cluster upstream.
|
||||
func TestValidate_HTTPClusterTarget_RequiresDirectUpstream(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "backend.lan",
|
||||
Enabled: true,
|
||||
}}
|
||||
assert.ErrorContains(t, rp.Validate(), "direct upstream disabled", "cluster target must reject direct_upstream=false")
|
||||
}
|
||||
|
||||
func TestValidate_L4ClusterTarget(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Mode = ModeTCP
|
||||
rp.ListenPort = 9000
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "tcp",
|
||||
Enabled: true,
|
||||
}}
|
||||
require.NoError(t, rp.Validate(), "L4 cluster target must validate without an explicit port")
|
||||
}
|
||||
|
||||
func TestService_Copy_RoundtripsPrivate(t *testing.T) {
|
||||
svc := validProxy()
|
||||
svc.Private = true
|
||||
svc.AccessGroups = []string{"grp-admins", "grp-ops"}
|
||||
cp := svc.Copy()
|
||||
require.NotNil(t, cp)
|
||||
assert.True(t, cp.Private)
|
||||
assert.Equal(t, []string{"grp-admins", "grp-ops"}, cp.AccessGroups)
|
||||
|
||||
cp.Private = false
|
||||
assert.True(t, svc.Private)
|
||||
|
||||
cp.AccessGroups[0] = "grp-other"
|
||||
assert.Equal(t, []string{"grp-admins", "grp-ops"}, svc.AccessGroups)
|
||||
}
|
||||
|
||||
func TestService_APIRoundtrip_Private(t *testing.T) {
|
||||
enabled := true
|
||||
private := true
|
||||
accessGroups := []string{"grp-admins"}
|
||||
targets := []api.ServiceTarget{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: api.ServiceTargetTargetType("cluster"),
|
||||
Protocol: "http",
|
||||
Port: 80,
|
||||
Enabled: true,
|
||||
}}
|
||||
req := &api.ServiceRequest{
|
||||
Name: "svc-private",
|
||||
Domain: "myapp.eu.proxy.netbird.io",
|
||||
Enabled: enabled,
|
||||
Private: &private,
|
||||
AccessGroups: &accessGroups,
|
||||
Targets: &targets,
|
||||
}
|
||||
|
||||
svc := &Service{}
|
||||
require.NoError(t, svc.FromAPIRequest(req, "acc-1"))
|
||||
assert.True(t, svc.Private)
|
||||
assert.Equal(t, []string{"grp-admins"}, svc.AccessGroups)
|
||||
|
||||
resp := svc.ToAPIResponse()
|
||||
require.NotNil(t, resp.Private)
|
||||
assert.True(t, *resp.Private)
|
||||
require.NotNil(t, resp.AccessGroups)
|
||||
assert.Equal(t, []string{"grp-admins"}, *resp.AccessGroups)
|
||||
}
|
||||
|
||||
func TestValidate_Private_RequiresAccessGroups(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Private = true
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "backend.lan",
|
||||
Options: TargetOptions{DirectUpstream: true},
|
||||
Enabled: true,
|
||||
}}
|
||||
assert.ErrorContains(t, rp.Validate(), "access group")
|
||||
}
|
||||
|
||||
func TestValidate_Private_RejectsBearerAuth(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Private = true
|
||||
rp.AccessGroups = []string{"grp-admins"}
|
||||
rp.Auth.BearerAuth = &BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"grp-sso"},
|
||||
}
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "backend.lan",
|
||||
Options: TargetOptions{DirectUpstream: true},
|
||||
Enabled: true,
|
||||
}}
|
||||
assert.ErrorContains(t, rp.Validate(), "mutually exclusive")
|
||||
}
|
||||
|
||||
func TestValidate_Private_AcceptsNonClusterTargets(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Private = true
|
||||
rp.AccessGroups = []string{"grp-admins"}
|
||||
require.NoError(t, rp.Validate())
|
||||
}
|
||||
|
||||
func TestValidate_Private_AcceptsClusterTargetWithAccessGroups(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Private = true
|
||||
rp.AccessGroups = []string{"grp-admins"}
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "backend.lan",
|
||||
Options: TargetOptions{DirectUpstream: true},
|
||||
Enabled: true,
|
||||
}}
|
||||
require.NoError(t, rp.Validate())
|
||||
}
|
||||
|
||||
func TestValidate_Private_RejectsNonHTTPMode(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Private = true
|
||||
rp.AccessGroups = []string{"grp-admins"}
|
||||
rp.Mode = ModeTCP
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "tcp",
|
||||
Enabled: true,
|
||||
}}
|
||||
assert.ErrorContains(t, rp.Validate(), "HTTP")
|
||||
}
|
||||
|
||||
@@ -20,6 +20,20 @@ type KeyPair struct {
|
||||
type Claims struct {
|
||||
jwt.RegisteredClaims
|
||||
Method auth.Method `json:"method"`
|
||||
// Email is the calling user's email address. Carried so the
|
||||
// proxy can stamp identity on upstream requests (e.g.
|
||||
// x-litellm-end-user-id) without an extra management
|
||||
// round-trip on every cookie-bearing request.
|
||||
Email string `json:"email,omitempty"`
|
||||
// Groups carries the user's group IDs so the proxy can stamp them
|
||||
// onto upstream requests (X-NetBird-Groups) from the cookie path
|
||||
// without an extra management round-trip.
|
||||
Groups []string `json:"groups,omitempty"`
|
||||
// GroupNames carries the human-readable display names for the ids
|
||||
// in Groups, ordered identically (positional pairing). Slice may be
|
||||
// shorter than Groups for tokens minted before names were
|
||||
// resolvable; the consumer falls back to ids for missing positions.
|
||||
GroupNames []string `json:"group_names,omitempty"`
|
||||
}
|
||||
|
||||
func GenerateKeyPair() (*KeyPair, error) {
|
||||
@@ -34,7 +48,13 @@ func GenerateKeyPair() (*KeyPair, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration time.Duration) (string, error) {
|
||||
// SignToken mints a session JWT for the given user and domain. email,
|
||||
// groups, and groupNames, when non-empty, are embedded so the proxy can
|
||||
// authorise and stamp identity for policy-aware middlewares without a
|
||||
// management round-trip on every cookie-bearing request. groupNames
|
||||
// pairs positionally with groups; pass nil when names couldn't be
|
||||
// resolved.
|
||||
func SignToken(privKeyB64, userID, email, domain string, method auth.Method, groups, groupNames []string, expiration time.Duration) (string, error) {
|
||||
privKeyBytes, err := base64.StdEncoding.DecodeString(privKeyB64)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode private key: %w", err)
|
||||
@@ -56,7 +76,10 @@ func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
Method: method,
|
||||
Method: method,
|
||||
Email: email,
|
||||
Groups: append([]string(nil), groups...),
|
||||
GroupNames: append([]string(nil), groupNames...),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
|
||||
|
||||
@@ -10,8 +10,10 @@ import (
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
|
||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||
"github.com/rs/cors"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
@@ -19,7 +21,6 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
cachestore "github.com/eko/gocache/lib/v4/store"
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
@@ -27,16 +28,20 @@ import (
|
||||
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
activitystore "github.com/netbirdio/netbird/management/server/activity/store"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
const apiPrefix = "/api"
|
||||
|
||||
var (
|
||||
kaep = keepalive.EnforcementPolicy{
|
||||
MinTime: 15 * time.Second,
|
||||
@@ -94,12 +99,17 @@ func (s *BaseServer) Store() store.Store {
|
||||
|
||||
func (s *BaseServer) EventStore() activity.Store {
|
||||
return Create(s, func() activity.Store {
|
||||
integrationMetrics, err := integrations.InitIntegrationMetrics(context.Background(), s.Metrics())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize integration metrics: %v", err)
|
||||
var err error
|
||||
key := s.Config.DataStoreEncryptionKey
|
||||
if key == "" {
|
||||
log.Debugf("generate new activity store encryption key")
|
||||
key, err = crypt.GenerateKey()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to generate event store encryption key: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
eventStore, _, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics)
|
||||
eventStore, err := activitystore.NewSqlStore(context.Background(), s.Config.Datadir, key)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize event store: %v", err)
|
||||
}
|
||||
@@ -110,7 +120,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
||||
|
||||
func (s *BaseServer) APIHandler() http.Handler {
|
||||
return Create(s, func() http.Handler {
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter())
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter(), s.IsValidChildAccount)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -118,6 +128,22 @@ func (s *BaseServer) APIHandler() http.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
// IDPHandler returns the HTTP handler for the embedded IdP (Dex), or nil if
|
||||
// the deployment isn't using the embedded variant.
|
||||
func (s *BaseServer) IDPHandler() http.Handler {
|
||||
embeddedIdP, ok := s.IdpManager().(*idp.EmbeddedIdPManager)
|
||||
if !ok || embeddedIdP == nil {
|
||||
return nil
|
||||
}
|
||||
return cors.AllowAll().Handler(embeddedIdP.Handler())
|
||||
}
|
||||
|
||||
func (s *BaseServer) Router() *mux.Router {
|
||||
return Create(s, func() *mux.Router {
|
||||
return mux.NewRouter().PathPrefix(apiPrefix).Subrouter()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
|
||||
return Create(s, func() *middleware.APIRateLimiter {
|
||||
cfg, enabled := middleware.RateLimiterConfigFromEnv()
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
@@ -38,7 +39,7 @@ func (s *BaseServer) JobManager() *job.Manager {
|
||||
|
||||
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
|
||||
return Create(s, func() integrated_validator.IntegratedValidator {
|
||||
integratedPeerValidator, err := integrations.NewIntegratedValidator(
|
||||
integratedPeerValidator, err := validator.NewIntegratedValidator(
|
||||
context.Background(),
|
||||
s.PeersManager(),
|
||||
s.SettingsManager(),
|
||||
@@ -112,7 +113,11 @@ func (s *BaseServer) AuthManager() auth.Manager {
|
||||
|
||||
func (s *BaseServer) EphemeralManager() ephemeral.Manager {
|
||||
return Create(s, func() ephemeral.Manager {
|
||||
return manager.NewEphemeralManager(s.Store(), s.PeersManager())
|
||||
em := manager.NewEphemeralManager(s.Store(), s.PeersManager())
|
||||
if metrics := s.Metrics(); metrics != nil {
|
||||
em.SetMetrics(metrics.EphemeralPeersMetrics())
|
||||
}
|
||||
return em
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -57,13 +57,7 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
|
||||
|
||||
func (s *BaseServer) PermissionsManager() permissions.Manager {
|
||||
return Create(s, func() permissions.Manager {
|
||||
manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter())
|
||||
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
manager.SetAccountManager(s.AccountManager())
|
||||
})
|
||||
|
||||
return manager
|
||||
return permissions.NewManager(s.Store())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -153,7 +147,6 @@ func (s *BaseServer) IdpManager() idp.Manager {
|
||||
return idpManager
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -235,3 +228,7 @@ func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
||||
return &m
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) IsValidChildAccount(_ context.Context, _, _, _ string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -188,7 +188,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
||||
log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
|
||||
}
|
||||
|
||||
rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter())
|
||||
rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.IDPHandler(), s.Metrics().GetMeter())
|
||||
switch {
|
||||
case s.certManager != nil:
|
||||
// a call to certManager.Listener() always creates a new listener so we do it once
|
||||
@@ -299,7 +299,7 @@ func (s *BaseServer) SetHandlerFunc(handler http.Handler) {
|
||||
log.Tracef("custom handler set successfully")
|
||||
}
|
||||
|
||||
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
|
||||
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, meter metric.Meter) http.Handler {
|
||||
// Check if a custom handler was set (for multiplexing additional services)
|
||||
if customHandler, ok := s.GetContainer("customHandler"); ok {
|
||||
if handler, ok := customHandler.(http.Handler); ok {
|
||||
@@ -318,6 +318,8 @@ func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, ht
|
||||
gRPCHandler.ServeHTTP(writer, request)
|
||||
case request.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent:
|
||||
wsProxy.Handler().ServeHTTP(writer, request)
|
||||
case idpHandler != nil && strings.HasPrefix(request.URL.Path, "/oauth2"):
|
||||
idpHandler.ServeHTTP(writer, request)
|
||||
default:
|
||||
httpHandler.ServeHTTP(writer, request)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -136,9 +137,12 @@ type proxyConnection struct {
|
||||
tokenID string
|
||||
capabilities *proto.ProxyCapabilities
|
||||
stream proto.ProxyService_GetMappingUpdateServer
|
||||
sendChan chan *proto.GetMappingUpdateResponse
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
// syncStream is set when the proxy connected via SyncMappings.
|
||||
// When non-nil, the sender goroutine uses this instead of stream.
|
||||
syncStream proto.ProxyService_SyncMappingsServer
|
||||
sendChan chan *proto.GetMappingUpdateResponse
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func enforceAccountScope(ctx context.Context, requestAccountID string) error {
|
||||
@@ -206,145 +210,323 @@ func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller
|
||||
s.proxyController = proxyController
|
||||
}
|
||||
|
||||
// proxyConnectParams holds the validated parameters extracted from either
|
||||
// a GetMappingUpdateRequest or a SyncMappingsInit message.
|
||||
type proxyConnectParams struct {
|
||||
proxyID string
|
||||
address string
|
||||
capabilities *proto.ProxyCapabilities
|
||||
}
|
||||
|
||||
// GetMappingUpdate handles the control stream with proxy clients
|
||||
func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error {
|
||||
ctx := stream.Context()
|
||||
|
||||
peerInfo := PeerIPFromContext(ctx)
|
||||
log.Infof("New proxy connection from %s", peerInfo)
|
||||
|
||||
proxyID := req.GetProxyId()
|
||||
if proxyID == "" {
|
||||
return status.Errorf(codes.InvalidArgument, "proxy_id is required")
|
||||
}
|
||||
|
||||
proxyAddress := req.GetAddress()
|
||||
if !isProxyAddressValid(proxyAddress) {
|
||||
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
||||
}
|
||||
|
||||
var accountID *string
|
||||
token := GetProxyTokenFromContext(ctx)
|
||||
if token != nil && token.AccountID != nil {
|
||||
accountID = token.AccountID
|
||||
|
||||
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "check cluster address: %v", err)
|
||||
}
|
||||
if !available {
|
||||
return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress)
|
||||
}
|
||||
}
|
||||
|
||||
var tokenID string
|
||||
if token != nil {
|
||||
tokenID = token.ID
|
||||
}
|
||||
|
||||
sessionID := uuid.NewString()
|
||||
|
||||
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
|
||||
oldConn := old.(*proxyConnection)
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"old_session_id": oldConn.sessionID,
|
||||
"new_session_id": sessionID,
|
||||
}).Info("Superseding existing proxy connection")
|
||||
oldConn.cancel()
|
||||
}
|
||||
|
||||
connCtx, cancel := context.WithCancel(ctx)
|
||||
conn := &proxyConnection{
|
||||
proxyID: proxyID,
|
||||
sessionID: sessionID,
|
||||
address: proxyAddress,
|
||||
accountID: accountID,
|
||||
tokenID: tokenID,
|
||||
capabilities: req.GetCapabilities(),
|
||||
stream: stream,
|
||||
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
|
||||
ctx: connCtx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
var caps *proxy.Capabilities
|
||||
if c := req.GetCapabilities(); c != nil {
|
||||
caps = &proxy.Capabilities{
|
||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||
RequireSubdomain: c.RequireSubdomain,
|
||||
SupportsCrowdsec: c.SupportsCrowdsec,
|
||||
}
|
||||
}
|
||||
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, accountID, caps)
|
||||
params, err := s.validateProxyConnect(req.GetProxyId(), req.GetAddress(), stream.Context())
|
||||
if err != nil {
|
||||
cancel()
|
||||
if accountID != nil {
|
||||
return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
|
||||
}
|
||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
||||
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
||||
return err
|
||||
}
|
||||
params.capabilities = req.GetCapabilities()
|
||||
|
||||
conn, proxyRecord, err := s.registerProxyConnection(stream.Context(), params, &proxyConnection{
|
||||
stream: stream,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||
}
|
||||
|
||||
if err := s.sendSnapshot(ctx, conn); err != nil {
|
||||
if s.connectedProxies.CompareAndDelete(proxyID, conn) {
|
||||
if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr)
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr)
|
||||
}
|
||||
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
|
||||
if err := s.sendSnapshot(stream.Context(), conn); err != nil {
|
||||
s.cleanupFailedSnapshot(stream.Context(), conn)
|
||||
return fmt.Errorf("send snapshot to proxy %s: %w", params.proxyID, err)
|
||||
}
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"session_id": sessionID,
|
||||
"address": proxyAddress,
|
||||
"cluster_addr": proxyAddress,
|
||||
"account_id": accountID,
|
||||
"total_proxies": len(s.GetConnectedProxies()),
|
||||
}).Info("Proxy registered in cluster")
|
||||
defer func() {
|
||||
if !s.connectedProxies.CompareAndDelete(proxyID, conn) {
|
||||
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", proxyID, sessionID)
|
||||
cancel()
|
||||
return s.serveProxyConnection(conn, proxyRecord, errChan, false)
|
||||
}
|
||||
|
||||
// SyncMappings implements the bidirectional SyncMappings RPC.
|
||||
// It mirrors GetMappingUpdate but provides application-level back-pressure:
|
||||
// management waits for an ack from the proxy before sending the next batch.
|
||||
func (s *ProxyServiceServer) SyncMappings(stream proto.ProxyService_SyncMappingsServer) error {
|
||||
init, err := recvSyncInit(stream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
params, err := s.validateProxyConnect(init.GetProxyId(), init.GetAddress(), stream.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
params.capabilities = init.GetCapabilities()
|
||||
|
||||
conn, proxyRecord, err := s.registerProxyConnection(stream.Context(), params, &proxyConnection{
|
||||
syncStream: stream,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.sendSnapshotSync(stream.Context(), conn, stream); err != nil {
|
||||
s.cleanupFailedSnapshot(stream.Context(), conn)
|
||||
return fmt.Errorf("send snapshot to proxy %s: %w", params.proxyID, err)
|
||||
}
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
go s.drainRecv(stream, errChan)
|
||||
|
||||
return s.serveProxyConnection(conn, proxyRecord, errChan, true)
|
||||
}
|
||||
|
||||
// recvSyncInit receives and validates the first message on a SyncMappings stream.
|
||||
func recvSyncInit(stream proto.ProxyService_SyncMappingsServer) (*proto.SyncMappingsInit, error) {
|
||||
firstMsg, err := stream.Recv()
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "receive init: %v", err)
|
||||
}
|
||||
init := firstMsg.GetInit()
|
||||
if init == nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "first message must be init")
|
||||
}
|
||||
return init, nil
|
||||
}
|
||||
|
||||
// validateProxyConnect validates the proxy ID and address, and checks cluster
|
||||
// address availability for account-scoped tokens.
|
||||
func (s *ProxyServiceServer) validateProxyConnect(proxyID, address string, ctx context.Context) (proxyConnectParams, error) {
|
||||
if proxyID == "" {
|
||||
return proxyConnectParams{}, status.Errorf(codes.InvalidArgument, "proxy_id is required")
|
||||
}
|
||||
if !isProxyAddressValid(address) {
|
||||
return proxyConnectParams{}, status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
||||
}
|
||||
|
||||
token := GetProxyTokenFromContext(ctx)
|
||||
if token != nil && token.AccountID != nil {
|
||||
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, address, *token.AccountID)
|
||||
if err != nil {
|
||||
return proxyConnectParams{}, status.Errorf(codes.Internal, "check cluster address: %v", err)
|
||||
}
|
||||
if !available {
|
||||
return proxyConnectParams{}, status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", address)
|
||||
}
|
||||
}
|
||||
|
||||
return proxyConnectParams{proxyID: proxyID, address: address}, nil
|
||||
}
|
||||
|
||||
// registerProxyConnection creates a proxyConnection, registers it with the
|
||||
// proxy manager and cluster, and stores it in connectedProxies. The caller
|
||||
// provides a partially initialised connSeed with stream-specific fields set;
|
||||
// the remaining fields are filled in here.
|
||||
func (s *ProxyServiceServer) registerProxyConnection(ctx context.Context, params proxyConnectParams, connSeed *proxyConnection) (*proxyConnection, *proxy.Proxy, error) {
|
||||
peerInfo := PeerIPFromContext(ctx)
|
||||
|
||||
var accountID *string
|
||||
var tokenID string
|
||||
if token := GetProxyTokenFromContext(ctx); token != nil {
|
||||
if token.AccountID != nil {
|
||||
accountID = token.AccountID
|
||||
}
|
||||
tokenID = token.ID
|
||||
}
|
||||
|
||||
sessionID := uuid.NewString()
|
||||
s.supersedePriorConnection(params.proxyID, sessionID)
|
||||
|
||||
connCtx, cancel := context.WithCancel(ctx)
|
||||
connSeed.proxyID = params.proxyID
|
||||
connSeed.sessionID = sessionID
|
||||
connSeed.address = params.address
|
||||
connSeed.accountID = accountID
|
||||
connSeed.tokenID = tokenID
|
||||
connSeed.capabilities = params.capabilities
|
||||
connSeed.sendChan = make(chan *proto.GetMappingUpdateResponse, 100)
|
||||
connSeed.ctx = connCtx
|
||||
connSeed.cancel = cancel
|
||||
|
||||
var caps *proxy.Capabilities
|
||||
if c := params.capabilities; c != nil {
|
||||
caps = &proxy.Capabilities{
|
||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||
RequireSubdomain: c.RequireSubdomain,
|
||||
SupportsCrowdsec: c.SupportsCrowdsec,
|
||||
Private: c.Private,
|
||||
}
|
||||
}
|
||||
|
||||
proxyRecord, err := s.proxyManager.Connect(ctx, params.proxyID, sessionID, params.address, peerInfo, accountID, caps)
|
||||
if err != nil {
|
||||
cancel()
|
||||
if accountID != nil {
|
||||
return nil, nil, status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
|
||||
}
|
||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", params.proxyID, err)
|
||||
return nil, nil, status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
||||
}
|
||||
|
||||
s.connectedProxies.Store(params.proxyID, connSeed)
|
||||
if err := s.proxyController.RegisterProxyToCluster(ctx, params.address, params.proxyID); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", params.proxyID, err)
|
||||
}
|
||||
|
||||
return connSeed, proxyRecord, nil
|
||||
}
|
||||
|
||||
// supersedePriorConnection cancels any existing connection for the given proxy.
|
||||
func (s *ProxyServiceServer) supersedePriorConnection(proxyID, newSessionID string) {
|
||||
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
|
||||
oldConn := old.(*proxyConnection)
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"old_session_id": oldConn.sessionID,
|
||||
"new_session_id": newSessionID,
|
||||
}).Info("Superseding existing proxy connection")
|
||||
oldConn.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupFailedSnapshot removes the connection from the cluster and store
|
||||
// after a snapshot send failure.
|
||||
func (s *ProxyServiceServer) cleanupFailedSnapshot(ctx context.Context, conn *proxyConnection) {
|
||||
if s.connectedProxies.CompareAndDelete(conn.proxyID, conn) {
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", conn.proxyID, err)
|
||||
}
|
||||
}
|
||||
conn.cancel()
|
||||
if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", conn.proxyID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// drainRecv consumes and discards messages from a bidirectional stream.
|
||||
// The proxy sends an ack for every incremental update; we don't need them
|
||||
// after the snapshot phase. Recv errors are forwarded to errChan.
|
||||
func (s *ProxyServiceServer) drainRecv(stream proto.ProxyService_SyncMappingsServer, errChan chan<- error) {
|
||||
for {
|
||||
if _, err := stream.Recv(); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
|
||||
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
|
||||
}
|
||||
if err := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); err != nil {
|
||||
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
|
||||
}
|
||||
// serveProxyConnection runs the post-snapshot lifecycle: heartbeat, sender,
|
||||
// and wait for termination. When bidi is true, normal stream closure (EOF,
|
||||
// canceled) is treated as a clean disconnect rather than an error.
|
||||
func (s *ProxyServiceServer) serveProxyConnection(conn *proxyConnection, proxyRecord *proxy.Proxy, errChan <-chan error, bidi bool) error {
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": conn.proxyID,
|
||||
"session_id": conn.sessionID,
|
||||
"address": conn.address,
|
||||
"cluster_addr": conn.address,
|
||||
"account_id": conn.accountID,
|
||||
"total_proxies": len(s.GetConnectedProxies()),
|
||||
}).Info("Proxy registered in cluster")
|
||||
|
||||
cancel()
|
||||
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
|
||||
}()
|
||||
|
||||
go s.heartbeat(connCtx, conn, proxyRecord)
|
||||
defer s.disconnectProxy(conn)
|
||||
go s.heartbeat(conn.ctx, conn, proxyRecord)
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
log.WithContext(ctx).Warnf("Failed to send update: %v", err)
|
||||
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
|
||||
case <-connCtx.Done():
|
||||
log.WithContext(ctx).Infof("Proxy %s context canceled", proxyID)
|
||||
return connCtx.Err()
|
||||
if bidi && isStreamClosed(err) {
|
||||
log.Infof("Proxy %s stream closed", conn.proxyID)
|
||||
return nil
|
||||
}
|
||||
log.Warnf("Failed to send update: %v", err)
|
||||
return fmt.Errorf("send update to proxy %s: %w", conn.proxyID, err)
|
||||
case <-conn.ctx.Done():
|
||||
log.Infof("Proxy %s context canceled", conn.proxyID)
|
||||
return conn.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// disconnectProxy removes the connection from cluster and store, unless it
|
||||
// has already been superseded by a newer connection.
|
||||
func (s *ProxyServiceServer) disconnectProxy(conn *proxyConnection) {
|
||||
if !s.connectedProxies.CompareAndDelete(conn.proxyID, conn) {
|
||||
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", conn.proxyID, conn.sessionID)
|
||||
conn.cancel()
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil {
|
||||
log.Warnf("Failed to unregister proxy %s from cluster: %v", conn.proxyID, err)
|
||||
}
|
||||
if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil {
|
||||
log.Warnf("Failed to mark proxy %s as disconnected: %v", conn.proxyID, err)
|
||||
}
|
||||
|
||||
conn.cancel()
|
||||
log.Infof("Proxy %s session %s disconnected", conn.proxyID, conn.sessionID)
|
||||
}
|
||||
|
||||
// sendSnapshotSync sends the initial snapshot with back-pressure: it sends
|
||||
// one batch, then waits for the proxy to ack before sending the next.
|
||||
func (s *ProxyServiceServer) sendSnapshotSync(ctx context.Context, conn *proxyConnection, stream proto.ProxyService_SyncMappingsServer) error {
|
||||
if !isProxyAddressValid(conn.address) {
|
||||
return fmt.Errorf("proxy address is invalid")
|
||||
}
|
||||
if s.snapshotBatchSize <= 0 {
|
||||
return fmt.Errorf("invalid snapshot batch size: %d", s.snapshotBatchSize)
|
||||
}
|
||||
|
||||
mappings, err := s.snapshotServiceMappings(ctx, conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := 0; i < len(mappings); i += s.snapshotBatchSize {
|
||||
end := i + s.snapshotBatchSize
|
||||
if end > len(mappings) {
|
||||
end = len(mappings)
|
||||
}
|
||||
for _, m := range mappings[i:end] {
|
||||
token, err := s.tokenStore.GenerateToken(m.AccountId, m.Id, s.proxyTokenTTL())
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate auth token for service %s: %w", m.Id, err)
|
||||
}
|
||||
m.AuthToken = token
|
||||
}
|
||||
if err := stream.Send(&proto.SyncMappingsResponse{
|
||||
Mapping: mappings[i:end],
|
||||
InitialSyncComplete: end == len(mappings),
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send snapshot batch: %w", err)
|
||||
}
|
||||
|
||||
if err := waitForAck(stream); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(mappings) == 0 {
|
||||
if err := stream.Send(&proto.SyncMappingsResponse{
|
||||
InitialSyncComplete: true,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send snapshot completion: %w", err)
|
||||
}
|
||||
|
||||
if err := waitForAck(stream); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func waitForAck(stream proto.ProxyService_SyncMappingsServer) error {
|
||||
msg, err := stream.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("receive ack: %w", err)
|
||||
}
|
||||
if msg.GetAck() == nil {
|
||||
return fmt.Errorf("expected ack, got %T", msg.GetMsg())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// heartbeat updates the proxy's last_seen timestamp every minute and
|
||||
// disconnects the proxy if its access token has been revoked.
|
||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, p *proxy.Proxy) {
|
||||
@@ -381,6 +563,9 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
if !isProxyAddressValid(conn.address) {
|
||||
return fmt.Errorf("proxy address is invalid")
|
||||
}
|
||||
if s.snapshotBatchSize <= 0 {
|
||||
return fmt.Errorf("invalid snapshot batch size: %d", s.snapshotBatchSize)
|
||||
}
|
||||
|
||||
mappings, err := s.snapshotServiceMappings(ctx, conn)
|
||||
if err != nil {
|
||||
@@ -394,6 +579,13 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
if end > len(mappings) {
|
||||
end = len(mappings)
|
||||
}
|
||||
for _, m := range mappings[i:end] {
|
||||
token, err := s.tokenStore.GenerateToken(m.AccountId, m.Id, s.proxyTokenTTL())
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate auth token for service %s: %w", m.Id, err)
|
||||
}
|
||||
m.AuthToken = token
|
||||
}
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
Mapping: mappings[i:end],
|
||||
InitialSyncComplete: end == len(mappings),
|
||||
@@ -425,18 +617,14 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
|
||||
return nil, fmt.Errorf("get services from store: %w", err)
|
||||
}
|
||||
|
||||
oidcCfg := s.GetOIDCValidationConfig()
|
||||
var mappings []*proto.ProxyMapping
|
||||
for _, service := range services {
|
||||
if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address {
|
||||
continue
|
||||
}
|
||||
|
||||
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
|
||||
m := service.ToProtoMapping(rpservice.Create, "", oidcCfg)
|
||||
if !proxyAcceptsMapping(conn, m) {
|
||||
continue
|
||||
}
|
||||
@@ -457,12 +645,26 @@ func isProxyAddressValid(addr string) bool {
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// sender handles sending messages to proxy
|
||||
// isStreamClosed returns true for errors that indicate normal stream
|
||||
// termination: io.EOF, context cancellation, or gRPC Canceled.
|
||||
func isStreamClosed(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
|
||||
return true
|
||||
}
|
||||
return status.Code(err) == codes.Canceled
|
||||
}
|
||||
|
||||
// sender handles sending messages to proxy.
|
||||
// When conn.syncStream is set the message is sent as SyncMappingsResponse;
|
||||
// otherwise the legacy GetMappingUpdateResponse stream is used.
|
||||
func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) {
|
||||
for {
|
||||
select {
|
||||
case resp := <-conn.sendChan:
|
||||
if err := conn.stream.Send(resp); err != nil {
|
||||
if err := conn.sendResponse(resp); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
@@ -472,6 +674,17 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
|
||||
}
|
||||
}
|
||||
|
||||
// sendResponse sends a mapping update on whichever stream the proxy connected with.
|
||||
func (conn *proxyConnection) sendResponse(resp *proto.GetMappingUpdateResponse) error {
|
||||
if conn.syncStream != nil {
|
||||
return conn.syncStream.Send(&proto.SyncMappingsResponse{
|
||||
Mapping: resp.Mapping,
|
||||
InitialSyncComplete: resp.InitialSyncComplete,
|
||||
})
|
||||
}
|
||||
return conn.stream.Send(resp)
|
||||
}
|
||||
|
||||
// SendAccessLog processes access log from proxy
|
||||
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
|
||||
accessLog := req.GetLog()
|
||||
@@ -538,10 +751,15 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
|
||||
return true
|
||||
}
|
||||
connUpdate = &proto.GetMappingUpdateResponse{
|
||||
Mapping: filtered,
|
||||
InitialSyncComplete: update.InitialSyncComplete,
|
||||
Mapping: filtered,
|
||||
InitialSyncComplete: update.InitialSyncComplete,
|
||||
}
|
||||
}
|
||||
// Drop mappings the proxy lacks capability for (e.g. private without SupportsPrivateService).
|
||||
connUpdate = filterMappingsForProxy(conn, connUpdate)
|
||||
if connUpdate == nil || len(connUpdate.Mapping) == 0 {
|
||||
return true
|
||||
}
|
||||
resp := s.perProxyMessage(connUpdate, conn.proxyID)
|
||||
if resp == nil {
|
||||
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
|
||||
@@ -670,16 +888,20 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
|
||||
}
|
||||
}
|
||||
|
||||
// proxyAcceptsMapping returns whether the proxy should receive this mapping.
|
||||
// Old proxies that never reported capabilities are skipped for non-TLS L4
|
||||
// mappings with a custom listen port, since they don't understand the
|
||||
// protocol. Proxies that report capabilities (even SupportsCustomPorts=false)
|
||||
// are new enough to handle the mapping. TLS uses SNI routing and works on
|
||||
// any proxy. Delete operations are always sent so proxies can clean up.
|
||||
// proxyAcceptsMapping returns whether the proxy can receive this mapping.
|
||||
// Private mappings require SupportsPrivateService; custom-port L4 mappings
|
||||
// require SupportsCustomPorts. Remove operations always pass so proxies can
|
||||
// clean up.
|
||||
func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool {
|
||||
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
|
||||
return true
|
||||
}
|
||||
if mapping.GetPrivate() {
|
||||
caps := conn.capabilities
|
||||
if caps == nil || caps.SupportsPrivateService == nil || !*caps.SupportsPrivateService {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if mapping.ListenPort == 0 || mapping.Mode == "tls" {
|
||||
return true
|
||||
}
|
||||
@@ -688,6 +910,29 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo
|
||||
return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil
|
||||
}
|
||||
|
||||
// filterMappingsForProxy drops mappings the proxy cannot safely receive
|
||||
// (e.g. private mappings to a proxy without SupportsPrivateService).
|
||||
// Returns the input unchanged when no filtering is needed.
|
||||
func filterMappingsForProxy(conn *proxyConnection, update *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse {
|
||||
if update == nil || len(update.Mapping) == 0 {
|
||||
return update
|
||||
}
|
||||
kept := make([]*proto.ProxyMapping, 0, len(update.Mapping))
|
||||
for _, m := range update.Mapping {
|
||||
if !proxyAcceptsMapping(conn, m) {
|
||||
continue
|
||||
}
|
||||
kept = append(kept, m)
|
||||
}
|
||||
if len(kept) == len(update.Mapping) {
|
||||
return update
|
||||
}
|
||||
return &proto.GetMappingUpdateResponse{
|
||||
Mapping: kept,
|
||||
InitialSyncComplete: update.InitialSyncComplete,
|
||||
}
|
||||
}
|
||||
|
||||
// perProxyMessage returns a copy of update with a fresh one-time token for
|
||||
// create/update operations. For delete operations the original mapping is
|
||||
// used unchanged because proxies do not need to authenticate for removal.
|
||||
@@ -749,7 +994,10 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
|
||||
|
||||
authenticated, userId, method := s.authenticateRequest(ctx, req, service)
|
||||
|
||||
token, err := s.generateSessionToken(ctx, authenticated, service, userId, method)
|
||||
// Non-OIDC schemes (PIN/Password/Header) authenticate against per-service
|
||||
// secrets and have no user-level group context, so groups stay nil. Email
|
||||
// is also empty — these schemes don't resolve a user record at sign time.
|
||||
token, err := s.generateSessionToken(ctx, authenticated, service, userId, "", method, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -838,7 +1086,7 @@ func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId string, method proxyauth.Method) (string, error) {
|
||||
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId, userEmail string, method proxyauth.Method, groupIDs, groupNames []string) (string, error) {
|
||||
if !authenticated || service.SessionPrivateKey == "" {
|
||||
return "", nil
|
||||
}
|
||||
@@ -846,8 +1094,11 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
|
||||
token, err := sessionkey.SignToken(
|
||||
service.SessionPrivateKey,
|
||||
userId,
|
||||
userEmail,
|
||||
service.Domain,
|
||||
method,
|
||||
groupIDs,
|
||||
groupNames,
|
||||
proxyauth.DefaultSessionExpiry,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -858,6 +1109,26 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// pairGroupIDsAndNames splits a slice of resolved *types.Group records
|
||||
// into parallel id and name slices. ids[i] and names[i] always pair to
|
||||
// the same group. nil entries (orphan ids the manager couldn't resolve)
|
||||
// are skipped so the consumer can rely on positional pairing.
|
||||
func pairGroupIDsAndNames(groups []*types.Group) (ids, names []string) {
|
||||
if len(groups) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
ids = make([]string, 0, len(groups))
|
||||
names = make([]string, 0, len(groups))
|
||||
for _, g := range groups {
|
||||
if g == nil {
|
||||
continue
|
||||
}
|
||||
ids = append(ids, g.ID)
|
||||
names = append(names, g.Name)
|
||||
}
|
||||
return ids, names
|
||||
}
|
||||
|
||||
// SendStatusUpdate handles status updates from proxy clients.
|
||||
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
|
||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||
@@ -1122,7 +1393,9 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
|
||||
return verifier, redirectURL, nil
|
||||
}
|
||||
|
||||
// GenerateSessionToken creates a signed session JWT for the given domain and user.
|
||||
// GenerateSessionToken creates a signed session JWT for the given domain and
|
||||
// user. The user's group memberships are embedded in the token so policy-aware
|
||||
// middlewares on the proxy can authorise without an extra management round-trip.
|
||||
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
|
||||
service, err := s.getServiceByDomain(ctx, domain)
|
||||
if err != nil {
|
||||
@@ -1133,11 +1406,29 @@ func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, u
|
||||
return "", fmt.Errorf("no session key configured for domain: %s", domain)
|
||||
}
|
||||
|
||||
var (
|
||||
email string
|
||||
groupIDs []string
|
||||
groupNames []string
|
||||
)
|
||||
if s.usersManager != nil {
|
||||
user, userGroups, uerr := s.usersManager.GetUserWithGroups(ctx, userID)
|
||||
if uerr != nil {
|
||||
log.WithContext(ctx).Debugf("session token mint: lookup user %s: %v", userID, uerr)
|
||||
} else if user != nil {
|
||||
email = user.Email
|
||||
groupIDs, groupNames = pairGroupIDsAndNames(userGroups)
|
||||
}
|
||||
}
|
||||
|
||||
return sessionkey.SignToken(
|
||||
service.SessionPrivateKey,
|
||||
userID,
|
||||
email,
|
||||
domain,
|
||||
method,
|
||||
groupIDs,
|
||||
groupNames,
|
||||
proxyauth.DefaultSessionExpiry,
|
||||
)
|
||||
}
|
||||
@@ -1241,7 +1532,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
}, nil
|
||||
}
|
||||
|
||||
userID, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes)
|
||||
userID, _, _, _, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"domain": domain,
|
||||
@@ -1254,7 +1545,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
}, nil
|
||||
}
|
||||
|
||||
user, err := s.usersManager.GetUser(ctx, userID)
|
||||
user, userGroups, err := s.usersManager.GetUserWithGroups(ctx, userID)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"domain": domain,
|
||||
@@ -1288,12 +1579,15 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
"user_id": userID,
|
||||
"error": err.Error(),
|
||||
}).Debug("ValidateSession: access denied")
|
||||
groupIDs, groupNames := pairGroupIDsAndNames(userGroups)
|
||||
//nolint:nilerr
|
||||
return &proto.ValidateSessionResponse{
|
||||
Valid: false,
|
||||
UserId: user.Id,
|
||||
UserEmail: user.Email,
|
||||
DeniedReason: "not_in_group",
|
||||
Valid: false,
|
||||
UserId: user.Id,
|
||||
UserEmail: user.Email,
|
||||
DeniedReason: "not_in_group",
|
||||
PeerGroupIds: groupIDs,
|
||||
PeerGroupNames: groupNames,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1303,10 +1597,13 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
"email": user.Email,
|
||||
}).Debug("ValidateSession: access granted")
|
||||
|
||||
groupIDs, groupNames := pairGroupIDsAndNames(userGroups)
|
||||
return &proto.ValidateSessionResponse{
|
||||
Valid: true,
|
||||
UserId: user.Id,
|
||||
UserEmail: user.Email,
|
||||
Valid: true,
|
||||
UserId: user.Id,
|
||||
UserEmail: user.Email,
|
||||
PeerGroupIds: groupIDs,
|
||||
PeerGroupNames: groupNames,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1339,3 +1636,154 @@ func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *
|
||||
}
|
||||
|
||||
func ptr[T any](v T) *T { return &v }
|
||||
|
||||
// ValidateTunnelPeer resolves an inbound peer by its WireGuard tunnel IP and
|
||||
// checks the peer's group membership against the service's access groups.
|
||||
// Peers without a user (machine agents, automation workloads) are first-class
|
||||
// callers; authorisation runs off peer-group memberships rather than the
|
||||
// optional owning user's auto-groups. On success a session JWT is minted so
|
||||
// the proxy can install a cookie and skip subsequent management round-trips.
|
||||
func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
domain := req.GetDomain()
|
||||
tunnelIPStr := req.GetTunnelIp()
|
||||
|
||||
if domain == "" || tunnelIPStr == "" {
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
DeniedReason: "missing domain or tunnel_ip",
|
||||
}, nil
|
||||
}
|
||||
|
||||
tunnelIP := net.ParseIP(tunnelIPStr)
|
||||
if tunnelIP == nil {
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
DeniedReason: "invalid_tunnel_ip",
|
||||
}, nil
|
||||
}
|
||||
|
||||
service, err := s.getServiceByDomain(ctx, domain)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"domain": domain, "error": err.Error()}).Debug("ValidateTunnelPeer: service not found")
|
||||
//nolint:nilerr
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
DeniedReason: "service_not_found",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Mirror ValidateSession: account-scoped (BYOP) proxy tokens may only
|
||||
// validate and mint session cookies for their own account's domains.
|
||||
if err := enforceAccountScope(ctx, service.AccountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peer, err := s.peersManager.GetPeerByTunnelIP(ctx, service.AccountID, tunnelIP)
|
||||
if err != nil || peer == nil {
|
||||
log.WithFields(log.Fields{"domain": domain, "tunnel_ip": tunnelIPStr}).Debug("ValidateTunnelPeer: peer not found")
|
||||
//nolint:nilerr
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
DeniedReason: "peer_not_found",
|
||||
}, nil
|
||||
}
|
||||
|
||||
_, peerGroups, err := s.peersManager.GetPeerWithGroups(ctx, service.AccountID, peer.ID)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: peer groups lookup failed")
|
||||
//nolint:nilerr
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
DeniedReason: "peer_not_found",
|
||||
}, nil
|
||||
}
|
||||
|
||||
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
|
||||
|
||||
// Resolve the principal: when the peer is linked to a user, the human
|
||||
// is the principal so multiple peers owned by the same user share a
|
||||
// single identity. Unlinked peers (machine agents) are their own
|
||||
// principal keyed on peer.ID. displayIdentity is what upstream gateways
|
||||
// tag spend with — user.Email when linked, peer.Name when not.
|
||||
principalID := peer.ID
|
||||
displayIdentity := peer.Name
|
||||
if peer.UserID != "" {
|
||||
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
|
||||
principalID = user.Id
|
||||
if user.Email != "" {
|
||||
displayIdentity = user.Email
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
|
||||
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
|
||||
//nolint:nilerr
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
UserId: principalID,
|
||||
UserEmail: displayIdentity,
|
||||
DeniedReason: "not_in_group",
|
||||
PeerGroupIds: groupIDs,
|
||||
PeerGroupNames: groupNames,
|
||||
}, nil
|
||||
}
|
||||
|
||||
token, err := s.generateSessionToken(ctx, true, service, principalID, displayIdentity, proxyauth.MethodOIDC, groupIDs, groupNames)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"domain": domain,
|
||||
"tunnel_ip": tunnelIPStr,
|
||||
"peer_id": peer.ID,
|
||||
"principal_id": principalID,
|
||||
}).Debug("ValidateTunnelPeer: access granted")
|
||||
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: true,
|
||||
UserId: principalID,
|
||||
UserEmail: displayIdentity,
|
||||
SessionToken: token,
|
||||
PeerGroupIds: groupIDs,
|
||||
PeerGroupNames: groupNames,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
|
||||
// groups. Private services authorise against AccessGroups (empty list fails
|
||||
// closed — Validate() rejects that at save time but the RPC is the security
|
||||
// boundary and must not trust upstream state). Bearer-auth services authorise
|
||||
// against DistributionGroups when populated. Non-private non-bearer services
|
||||
// are open.
|
||||
func checkPeerGroupAccess(service *rpservice.Service, peerGroupIDs []string) error {
|
||||
if service.Private {
|
||||
if len(service.AccessGroups) == 0 {
|
||||
return fmt.Errorf("private service has no access groups")
|
||||
}
|
||||
return matchAnyGroup(service.AccessGroups, peerGroupIDs)
|
||||
}
|
||||
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled && len(service.Auth.BearerAuth.DistributionGroups) > 0 {
|
||||
return matchAnyGroup(service.Auth.BearerAuth.DistributionGroups, peerGroupIDs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// matchAnyGroup returns nil when peerGroupIDs intersects allowedGroups,
|
||||
// else a non-nil error.
|
||||
func matchAnyGroup(allowedGroups, peerGroupIDs []string) error {
|
||||
if len(allowedGroups) == 0 {
|
||||
return fmt.Errorf("no allowed groups configured")
|
||||
}
|
||||
allowed := make(map[string]struct{}, len(allowedGroups))
|
||||
for _, g := range allowedGroups {
|
||||
allowed[g] = struct{}{}
|
||||
}
|
||||
for _, g := range peerGroupIDs {
|
||||
if _, ok := allowed[g]; ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("peer not in allowed groups")
|
||||
}
|
||||
|
||||
@@ -109,7 +109,7 @@ func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain s
|
||||
return nil, errors.New("service not found for domain: " + domain)
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||
func (m *mockReverseProxyManager) GetClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -129,6 +129,14 @@ func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.U
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (m *mockUsersManager) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) {
|
||||
user, err := m.GetUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return user, nil, nil
|
||||
}
|
||||
|
||||
func TestValidateUserGroupAccess(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -420,3 +428,46 @@ func TestGetAccountProxyByDomain(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPeerGroupAccess(t *testing.T) {
|
||||
t.Run("private with empty AccessGroups denies", func(t *testing.T) {
|
||||
svc := &service.Service{Private: true, AccessGroups: nil}
|
||||
err := checkPeerGroupAccess(svc, []string{"grp-admins"})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no access groups")
|
||||
})
|
||||
|
||||
t.Run("private with peer in AccessGroups allows", func(t *testing.T) {
|
||||
svc := &service.Service{Private: true, AccessGroups: []string{"grp-admins", "grp-ops"}}
|
||||
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-other", "grp-ops"}))
|
||||
})
|
||||
|
||||
t.Run("private with peer outside AccessGroups denies", func(t *testing.T) {
|
||||
svc := &service.Service{Private: true, AccessGroups: []string{"grp-admins"}}
|
||||
assert.Error(t, checkPeerGroupAccess(svc, []string{"grp-other"}))
|
||||
})
|
||||
|
||||
t.Run("bearer enabled with empty DistributionGroups allows", func(t *testing.T) {
|
||||
svc := &service.Service{
|
||||
Auth: service.AuthConfig{BearerAuth: &service.BearerAuthConfig{Enabled: true}},
|
||||
}
|
||||
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-anyone"}))
|
||||
})
|
||||
|
||||
t.Run("bearer enabled gates on DistributionGroups", func(t *testing.T) {
|
||||
svc := &service.Service{
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"grp-allowed"},
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-allowed"}))
|
||||
assert.Error(t, checkPeerGroupAccess(svc, []string{"grp-other"}))
|
||||
})
|
||||
|
||||
t.Run("non-private non-bearer is open", func(t *testing.T) {
|
||||
assert.NoError(t, checkPeerGroupAccess(&service.Service{}, nil))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -172,3 +173,55 @@ func TestSendSnapshot_EmptySnapshot(t *testing.T) {
|
||||
assert.Empty(t, stream.messages[0].Mapping)
|
||||
assert.True(t, stream.messages[0].InitialSyncComplete)
|
||||
}
|
||||
|
||||
type hookingStream struct {
|
||||
grpc.ServerStream
|
||||
onSend func(*proto.GetMappingUpdateResponse)
|
||||
}
|
||||
|
||||
func (s *hookingStream) Send(m *proto.GetMappingUpdateResponse) error {
|
||||
if s.onSend != nil {
|
||||
s.onSend(m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *hookingStream) Context() context.Context { return context.Background() }
|
||||
func (s *hookingStream) SetHeader(metadata.MD) error { return nil }
|
||||
func (s *hookingStream) SendHeader(metadata.MD) error { return nil }
|
||||
func (s *hookingStream) SetTrailer(metadata.MD) {}
|
||||
func (s *hookingStream) SendMsg(any) error { return nil }
|
||||
func (s *hookingStream) RecvMsg(any) error { return nil }
|
||||
|
||||
func TestSendSnapshot_TokensRemainValidUnderSlowSend(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 2
|
||||
const totalServices = 6
|
||||
const ttl = 100 * time.Millisecond
|
||||
const sendDelay = 200 * time.Millisecond
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
s.tokenTTL = ttl
|
||||
|
||||
var validateErrs []error
|
||||
stream := &hookingStream{
|
||||
onSend: func(resp *proto.GetMappingUpdateResponse) {
|
||||
for _, m := range resp.Mapping {
|
||||
if err := s.tokenStore.ValidateAndConsume(m.AuthToken, m.AccountId, m.Id); err != nil {
|
||||
validateErrs = append(validateErrs, fmt.Errorf("svc %s: %w", m.Id, err))
|
||||
}
|
||||
}
|
||||
time.Sleep(sendDelay)
|
||||
},
|
||||
}
|
||||
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||
|
||||
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||
require.Empty(t, validateErrs,
|
||||
"tokens must remain valid even when batches are sent slowly: lazy per-batch generation guarantees freshness")
|
||||
}
|
||||
|
||||
@@ -437,7 +437,7 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
|
||||
return nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
||||
log.WithContext(ctx).Tracef("received an update for peer %s", peerKey.String())
|
||||
if debouncer.ProcessUpdate(update) {
|
||||
// Send immediately (first update or after quiet period)
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
|
||||
@@ -492,7 +492,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
|
||||
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||
return status.Errorf(codes.Internal, "failed sending update message")
|
||||
}
|
||||
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
||||
log.WithContext(ctx).Tracef("sent an update to peer %s", peerKey.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -522,10 +522,11 @@ func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Even
|
||||
}
|
||||
|
||||
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
|
||||
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
||||
uncanceledCTX := context.WithoutCancel(ctx)
|
||||
unlock := s.acquirePeerLockByUID(uncanceledCTX, peer.Key)
|
||||
defer unlock()
|
||||
|
||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime)
|
||||
s.cancelPeerRoutinesWithoutLock(uncanceledCTX, accountID, peer, streamStartTime)
|
||||
}
|
||||
|
||||
func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
|
||||
|
||||
411
management/internals/shared/grpc/sync_mappings_test.go
Normal file
411
management/internals/shared/grpc/sync_mappings_test.go
Normal file
@@ -0,0 +1,411 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// syncRecordingStream is a mock ProxyService_SyncMappingsServer that records
|
||||
// sent messages and returns pre-loaded ack responses from Recv.
|
||||
type syncRecordingStream struct {
|
||||
grpc.ServerStream
|
||||
|
||||
mu sync.Mutex
|
||||
sent []*proto.SyncMappingsResponse
|
||||
recvMsgs []*proto.SyncMappingsRequest
|
||||
recvIdx int
|
||||
}
|
||||
|
||||
func (s *syncRecordingStream) Send(m *proto.SyncMappingsResponse) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sent = append(s.sent, m)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *syncRecordingStream) Recv() (*proto.SyncMappingsRequest, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.recvIdx >= len(s.recvMsgs) {
|
||||
return nil, fmt.Errorf("no more recv messages")
|
||||
}
|
||||
msg := s.recvMsgs[s.recvIdx]
|
||||
s.recvIdx++
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (s *syncRecordingStream) Context() context.Context { return context.Background() }
|
||||
func (s *syncRecordingStream) SetHeader(metadata.MD) error { return nil }
|
||||
func (s *syncRecordingStream) SendHeader(metadata.MD) error { return nil }
|
||||
func (s *syncRecordingStream) SetTrailer(metadata.MD) {}
|
||||
func (s *syncRecordingStream) SendMsg(any) error { return nil }
|
||||
func (s *syncRecordingStream) RecvMsg(any) error { return nil }
|
||||
|
||||
func ackMsg() *proto.SyncMappingsRequest {
|
||||
return &proto.SyncMappingsRequest{
|
||||
Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}},
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_BatchesWithAcks(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 3
|
||||
const totalServices = 7 // 3 + 3 + 1 → 3 batches, 3 acks (one per batch, including final)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &syncRecordingStream{
|
||||
recvMsgs: []*proto.SyncMappingsRequest{ackMsg(), ackMsg(), ackMsg()},
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, stream.sent, 3, "should send ceil(7/3) = 3 batches")
|
||||
|
||||
assert.Len(t, stream.sent[0].Mapping, 3)
|
||||
assert.False(t, stream.sent[0].InitialSyncComplete)
|
||||
|
||||
assert.Len(t, stream.sent[1].Mapping, 3)
|
||||
assert.False(t, stream.sent[1].InitialSyncComplete)
|
||||
|
||||
assert.Len(t, stream.sent[2].Mapping, 1)
|
||||
assert.True(t, stream.sent[2].InitialSyncComplete)
|
||||
|
||||
// All 3 acks consumed — including the final batch.
|
||||
assert.Equal(t, 3, stream.recvIdx)
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_SingleBatchWaitsForAck(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 100
|
||||
const totalServices = 5
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &syncRecordingStream{
|
||||
recvMsgs: []*proto.SyncMappingsRequest{ackMsg()},
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, stream.sent, 1)
|
||||
assert.Len(t, stream.sent[0].Mapping, totalServices)
|
||||
assert.True(t, stream.sent[0].InitialSyncComplete)
|
||||
assert.Equal(t, 1, stream.recvIdx, "final batch ack must be consumed")
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_EmptySnapshot(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil)
|
||||
|
||||
s := newSnapshotTestServer(t, 500)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &syncRecordingStream{
|
||||
recvMsgs: []*proto.SyncMappingsRequest{ackMsg()},
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, stream.sent, 1, "empty snapshot must still send sync-complete")
|
||||
assert.Empty(t, stream.sent[0].Mapping)
|
||||
assert.True(t, stream.sent[0].InitialSyncComplete)
|
||||
assert.Equal(t, 1, stream.recvIdx, "empty snapshot ack must be consumed")
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_MissingAckReturnsError(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 2
|
||||
const totalServices = 4 // 2 batches → 1 ack needed, but we provide none
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
// No acks available — Recv will return error.
|
||||
stream := &syncRecordingStream{}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "receive ack")
|
||||
// First batch should have been sent before the error.
|
||||
require.Len(t, stream.sent, 1)
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_WrongMessageInsteadOfAck(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 2
|
||||
const totalServices = 4
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
// Send an init message instead of an ack.
|
||||
stream := &syncRecordingStream{
|
||||
recvMsgs: []*proto.SyncMappingsRequest{
|
||||
{Msg: &proto.SyncMappingsRequest_Init{Init: &proto.SyncMappingsInit{ProxyId: "bad"}}},
|
||||
},
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "expected ack")
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_BackPressureOrdering(t *testing.T) {
|
||||
// Verify batches are sent strictly sequentially — batch N+1 is not sent
|
||||
// until the ack for batch N is received, including the final batch.
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 2
|
||||
const totalServices = 6 // 3 batches, 3 acks
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
var mu sync.Mutex
|
||||
var events []string
|
||||
|
||||
// Build a stream that logs send/recv events so we can verify ordering.
|
||||
ackCh := make(chan struct{}, 3)
|
||||
stream := &orderTrackingStream{
|
||||
mu: &mu,
|
||||
events: &events,
|
||||
ackCh: ackCh,
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
// Feed acks asynchronously after a short delay to simulate real proxy.
|
||||
go func() {
|
||||
for range 3 {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
ackCh <- struct{}{}
|
||||
}
|
||||
}()
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.NoError(t, err)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// Expected: send, recv-ack, send, recv-ack, send, recv-ack.
|
||||
require.Len(t, events, 6)
|
||||
assert.Equal(t, "send", events[0])
|
||||
assert.Equal(t, "recv", events[1])
|
||||
assert.Equal(t, "send", events[2])
|
||||
assert.Equal(t, "recv", events[3])
|
||||
assert.Equal(t, "send", events[4])
|
||||
assert.Equal(t, "recv", events[5])
|
||||
}
|
||||
|
||||
// orderTrackingStream logs "send" and "recv" events and blocks Recv until
|
||||
// an ack is signaled via ackCh.
|
||||
type orderTrackingStream struct {
|
||||
grpc.ServerStream
|
||||
mu *sync.Mutex
|
||||
events *[]string
|
||||
ackCh chan struct{}
|
||||
}
|
||||
|
||||
func (s *orderTrackingStream) Send(_ *proto.SyncMappingsResponse) error {
|
||||
s.mu.Lock()
|
||||
*s.events = append(*s.events, "send")
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *orderTrackingStream) Recv() (*proto.SyncMappingsRequest, error) {
|
||||
<-s.ackCh
|
||||
s.mu.Lock()
|
||||
*s.events = append(*s.events, "recv")
|
||||
s.mu.Unlock()
|
||||
return ackMsg(), nil
|
||||
}
|
||||
|
||||
func (s *orderTrackingStream) Context() context.Context { return context.Background() }
|
||||
func (s *orderTrackingStream) SetHeader(metadata.MD) error { return nil }
|
||||
func (s *orderTrackingStream) SendHeader(metadata.MD) error { return nil }
|
||||
func (s *orderTrackingStream) SetTrailer(metadata.MD) {}
|
||||
func (s *orderTrackingStream) SendMsg(any) error { return nil }
|
||||
func (s *orderTrackingStream) RecvMsg(any) error { return nil }
|
||||
|
||||
func TestSendSnapshotSync_TokensGeneratedPerBatch(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 2
|
||||
const totalServices = 4
|
||||
const ttl = 100 * time.Millisecond
|
||||
const ackDelay = 200 * time.Millisecond
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
s.tokenTTL = ttl
|
||||
|
||||
// Build a stream that validates tokens immediately on Send, then
|
||||
// delays the ack to ensure the next batch's tokens are generated fresh.
|
||||
var validateErrs []error
|
||||
ackCh := make(chan struct{}, 2)
|
||||
stream := &tokenValidatingSyncStream{
|
||||
tokenStore: s.tokenStore,
|
||||
validateErrs: &validateErrs,
|
||||
ackCh: ackCh,
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
go func() {
|
||||
// Delay first ack so that if tokens were all generated upfront they'd expire.
|
||||
time.Sleep(ackDelay)
|
||||
ackCh <- struct{}{}
|
||||
// Final batch ack — immediate.
|
||||
ackCh <- struct{}{}
|
||||
}()
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, validateErrs,
|
||||
"tokens must remain valid: per-batch generation guarantees freshness")
|
||||
}
|
||||
|
||||
type tokenValidatingSyncStream struct {
|
||||
grpc.ServerStream
|
||||
tokenStore *OneTimeTokenStore
|
||||
validateErrs *[]error
|
||||
ackCh chan struct{}
|
||||
}
|
||||
|
||||
func (s *tokenValidatingSyncStream) Send(m *proto.SyncMappingsResponse) error {
|
||||
for _, mapping := range m.Mapping {
|
||||
if err := s.tokenStore.ValidateAndConsume(mapping.AuthToken, mapping.AccountId, mapping.Id); err != nil {
|
||||
*s.validateErrs = append(*s.validateErrs, fmt.Errorf("svc %s: %w", mapping.Id, err))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *tokenValidatingSyncStream) Recv() (*proto.SyncMappingsRequest, error) {
|
||||
<-s.ackCh
|
||||
return ackMsg(), nil
|
||||
}
|
||||
|
||||
func (s *tokenValidatingSyncStream) Context() context.Context { return context.Background() }
|
||||
func (s *tokenValidatingSyncStream) SetHeader(metadata.MD) error { return nil }
|
||||
func (s *tokenValidatingSyncStream) SendHeader(metadata.MD) error { return nil }
|
||||
func (s *tokenValidatingSyncStream) SetTrailer(metadata.MD) {}
|
||||
func (s *tokenValidatingSyncStream) SendMsg(any) error { return nil }
|
||||
func (s *tokenValidatingSyncStream) RecvMsg(any) error { return nil }
|
||||
|
||||
func TestConnectionSendResponse_RoutesToSyncStream(t *testing.T) {
|
||||
stream := &syncRecordingStream{}
|
||||
conn := &proxyConnection{
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
resp := &proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{Id: "svc-1", AccountId: "acct-1", Domain: "example.com"},
|
||||
},
|
||||
InitialSyncComplete: true,
|
||||
}
|
||||
|
||||
err := conn.sendResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, stream.sent, 1)
|
||||
assert.Len(t, stream.sent[0].Mapping, 1)
|
||||
assert.Equal(t, "svc-1", stream.sent[0].Mapping[0].Id)
|
||||
assert.True(t, stream.sent[0].InitialSyncComplete)
|
||||
}
|
||||
|
||||
func TestConnectionSendResponse_RoutesToLegacyStream(t *testing.T) {
|
||||
stream := &recordingStream{}
|
||||
conn := &proxyConnection{
|
||||
stream: stream,
|
||||
}
|
||||
|
||||
resp := &proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{Id: "svc-2", AccountId: "acct-2"},
|
||||
},
|
||||
}
|
||||
|
||||
err := conn.sendResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, stream.messages, 1)
|
||||
assert.Equal(t, "svc-2", stream.messages[0].Mapping[0].Id)
|
||||
}
|
||||
@@ -102,7 +102,7 @@ func generateSessionKeyPair(t *testing.T) (string, string) {
|
||||
|
||||
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
|
||||
t.Helper()
|
||||
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, time.Hour)
|
||||
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, nil, time.Hour)
|
||||
require.NoError(t, err)
|
||||
return token
|
||||
}
|
||||
@@ -125,6 +125,7 @@ func TestValidateSession_UserAllowed(t *testing.T) {
|
||||
assert.True(t, resp.Valid, "User should be allowed access")
|
||||
assert.Equal(t, "allowedUserId", resp.UserId)
|
||||
assert.Empty(t, resp.DeniedReason)
|
||||
assert.Equal(t, []string{"allowedGroupId"}, resp.GetPeerGroupIds(), "PeerGroupIds must mirror the resolved user's group memberships")
|
||||
}
|
||||
|
||||
func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
|
||||
@@ -145,6 +146,7 @@ func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
|
||||
assert.False(t, resp.Valid, "User not in group should be denied")
|
||||
assert.Equal(t, "not_in_group", resp.DeniedReason)
|
||||
assert.Equal(t, "nonGroupUserId", resp.UserId)
|
||||
assert.Empty(t, resp.GetPeerGroupIds(), "PeerGroupIds must mirror the resolved user's actual (empty) memberships on denial")
|
||||
}
|
||||
|
||||
func TestValidateSession_UserInDifferentAccount(t *testing.T) {
|
||||
@@ -322,21 +324,29 @@ func (m *testValidateSessionServiceManager) GetServiceByDomain(ctx context.Conte
|
||||
return m.store.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||
func (m *testValidateSessionServiceManager) GetClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type testValidateSessionProxyManager struct{}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *string, _ *proxy.Capabilities) error {
|
||||
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _, _ string, _ *string, _ *proxy.Capabilities) (*proxy.Proxy, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string) error {
|
||||
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ *proxy.Proxy) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _, _, _ string) error {
|
||||
func (m *testValidateSessionProxyManager) DeleteAccountCluster(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -291,10 +291,15 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
// Canonicalize the incoming range so a caller-supplied prefix with host bits
|
||||
// (e.g. 100.64.1.1/16) compares equal to the masked form stored on network.Net.
|
||||
newSettings.NetworkRange = newSettings.NetworkRange.Masked()
|
||||
|
||||
var oldSettings *types.Settings
|
||||
var updateAccountPeers bool
|
||||
var groupChangesAffectPeers bool
|
||||
var reloadReverseProxy bool
|
||||
var effectiveOldNetworkRange netip.Prefix
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var groupsUpdated bool
|
||||
@@ -308,6 +313,16 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
// No lock: the transaction already holds Settings(Update), and network.Net is
|
||||
// only mutated by reallocateAccountPeerIPs, which is reachable only through
|
||||
// this same code path. A Share lock here would extend an unnecessary row lock
|
||||
// and complicate ordering against updatePeerIPv6InTransaction.
|
||||
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get account network: %w", err)
|
||||
}
|
||||
effectiveOldNetworkRange = prefixFromIPNet(network.Net)
|
||||
|
||||
if oldSettings.Extra != nil && newSettings.Extra != nil &&
|
||||
oldSettings.Extra.PeerApprovalEnabled && !newSettings.Extra.PeerApprovalEnabled {
|
||||
approvedCount, err := transaction.ApproveAccountPeers(ctx, accountID)
|
||||
@@ -321,7 +336,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
}
|
||||
}
|
||||
|
||||
if oldSettings.NetworkRange != newSettings.NetworkRange {
|
||||
if newSettings.NetworkRange.IsValid() && newSettings.NetworkRange != effectiveOldNetworkRange {
|
||||
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -396,9 +411,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, eventMeta)
|
||||
}
|
||||
if oldSettings.NetworkRange != newSettings.NetworkRange {
|
||||
if newSettings.NetworkRange.IsValid() && newSettings.NetworkRange != effectiveOldNetworkRange {
|
||||
eventMeta := map[string]any{
|
||||
"old_network_range": oldSettings.NetworkRange.String(),
|
||||
"old_network_range": effectiveOldNetworkRange.String(),
|
||||
"new_network_range": newSettings.NetworkRange.String(),
|
||||
}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
|
||||
@@ -443,6 +458,22 @@ func ipv6SettingsChanged(old, updated *types.Settings) bool {
|
||||
return !slices.Equal(oldGroups, newGroups)
|
||||
}
|
||||
|
||||
// prefixFromIPNet returns the overlay prefix actually allocated on the account
|
||||
// network, or an invalid prefix if none is set. Settings.NetworkRange is a
|
||||
// user-facing override that is empty on legacy accounts, so the effective
|
||||
// range must be read from network.Net to compare against an incoming update.
|
||||
func prefixFromIPNet(ipNet net.IPNet) netip.Prefix {
|
||||
if ipNet.IP == nil {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
addr, ok := netip.AddrFromSlice(ipNet.IP)
|
||||
if !ok {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
ones, _ := ipNet.Mask.Size()
|
||||
return netip.PrefixFrom(addr.Unmap(), ones)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error {
|
||||
halfYearLimit := 180 * 24 * time.Hour
|
||||
if newSettings.PeerLoginExpiration > halfYearLimit {
|
||||
@@ -1837,35 +1868,32 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
|
||||
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
||||
}
|
||||
|
||||
// SyncAndMarkPeer is the per-Sync entry point: it refreshes the peer's
|
||||
// network map and then marks the peer connected with a session token
|
||||
// derived from syncTime (the moment the gRPC stream opened). Any
|
||||
// concurrent stream that started earlier loses the optimistic-lock race
|
||||
// in MarkPeerConnected and bails without writing.
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID, syncTime)
|
||||
if err != nil {
|
||||
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano()); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
return peer, netMap, postureChecks, dnsfwdPort, nil
|
||||
}
|
||||
|
||||
// OnPeerDisconnected is invoked when a sync stream ends. It marks the
|
||||
// peer disconnected only when the stored SessionStartedAt matches the
|
||||
// nanosecond token derived from streamStartTime — i.e. only when this
|
||||
// is the stream that currently owns the peer's session. A mismatch
|
||||
// means a newer stream has already replaced us, so the disconnect is
|
||||
// dropped.
|
||||
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if peer.Status.LastSeen.After(streamStartTime) {
|
||||
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect",
|
||||
peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339))
|
||||
return nil
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, time.Now().UTC())
|
||||
if err != nil {
|
||||
if err := am.MarkPeerDisconnected(ctx, peerPubKey, accountID, streamStartTime.UnixNano()); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -61,7 +61,8 @@ type Manager interface {
|
||||
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error
|
||||
MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
|
||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
|
||||
|
||||
@@ -1305,17 +1305,31 @@ func (mr *MockManagerMockRecorder) LoginPeer(ctx, login interface{}) *gomock.Cal
|
||||
}
|
||||
|
||||
// MarkPeerConnected mocks base method.
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, connected, realIP, accountID, syncTime)
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// MarkPeerConnected indicates an expected call of MarkPeerConnected.
|
||||
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, connected, realIP, accountID, syncTime interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, connected, realIP, accountID, syncTime)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt)
|
||||
}
|
||||
|
||||
// MarkPeerDisconnected mocks base method.
|
||||
func (m *MockManager) MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerDisconnected", ctx, peerKey, accountID, sessionStartedAt)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// MarkPeerDisconnected indicates an expected call of MarkPeerDisconnected.
|
||||
func (mr *MockManagerMockRecorder) MarkPeerDisconnected(ctx, peerKey, accountID, sessionStartedAt interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnected", reflect.TypeOf((*MockManager)(nil).MarkPeerDisconnected), ctx, peerKey, accountID, sessionStartedAt)
|
||||
}
|
||||
|
||||
// OnPeerDisconnected mocks base method.
|
||||
|
||||
@@ -1813,7 +1813,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
@@ -1884,7 +1884,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1910,15 +1910,16 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) {
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
|
||||
t.Run("disconnect peer when session token matches", func(t *testing.T) {
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err, "unable to get peer")
|
||||
require.True(t, peer.Status.Connected, "peer should be connected")
|
||||
|
||||
streamStartTime := time.Now().UTC()
|
||||
require.Equal(t, streamStartTime.UnixNano(), peer.Status.SessionStartedAt,
|
||||
"SessionStartedAt should equal the token we passed in")
|
||||
|
||||
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
|
||||
require.NoError(t, err)
|
||||
@@ -1926,49 +1927,127 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.False(t, peer.Status.Connected, "peer should be disconnected")
|
||||
require.Equal(t, int64(0), peer.Status.SessionStartedAt, "SessionStartedAt should be reset to 0")
|
||||
})
|
||||
|
||||
t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) {
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
|
||||
t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) {
|
||||
// Newer stream wins on connect (sets SessionStartedAt = now ns).
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected, "peer should be connected")
|
||||
|
||||
streamStartTime := peer.Status.LastSeen.Add(-1 * time.Hour)
|
||||
// Older stream tries to mark disconnect with its own (older) session token —
|
||||
// fencing kicks in and the write is dropped.
|
||||
staleStreamStartTime := streamStartTime.Add(-1 * time.Hour)
|
||||
|
||||
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
|
||||
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, staleStreamStartTime)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected,
|
||||
"peer should remain connected because LastSeen > streamStartTime (zombie stream protection)")
|
||||
"peer should remain connected because the stored session is newer than the disconnect token")
|
||||
require.Equal(t, streamStartTime.UnixNano(), peer.Status.SessionStartedAt,
|
||||
"SessionStartedAt should still hold the winning stream's token")
|
||||
})
|
||||
|
||||
t.Run("skip stale connect when peer already has newer LastSeen (blocked goroutine protection)", func(t *testing.T) {
|
||||
t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) {
|
||||
node2SyncTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node2SyncTime)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano())
|
||||
require.NoError(t, err, "node 2 should connect peer")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected, "peer should be connected")
|
||||
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), "LastSeen should be node2SyncTime")
|
||||
require.Equal(t, node2SyncTime.UnixNano(), peer.Status.SessionStartedAt,
|
||||
"SessionStartedAt should equal node2SyncTime token")
|
||||
|
||||
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node1StaleSyncTime)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano())
|
||||
require.NoError(t, err, "stale connect should not return error")
|
||||
|
||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected, "peer should still be connected")
|
||||
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(),
|
||||
"LastSeen should NOT be overwritten by stale syncTime from blocked goroutine")
|
||||
require.Equal(t, node2SyncTime.UnixNano(), peer.Status.SessionStartedAt,
|
||||
"SessionStartedAt should NOT be overwritten by stale token from blocked goroutine")
|
||||
})
|
||||
}
|
||||
|
||||
// TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace exercises the
|
||||
// fencing protocol under contention: many goroutines race to mark the
|
||||
// same peer connected with distinct session tokens at the same time.
|
||||
// The contract is that the highest token always wins and is what remains
|
||||
// in the store, regardless of execution order.
|
||||
func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to get account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peerPubKey := key.PublicKey().String()
|
||||
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: peerPubKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "race-peer"},
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
const workers = 16
|
||||
base := time.Now().UTC().UnixNano()
|
||||
tokens := make([]int64, workers)
|
||||
for i := range tokens {
|
||||
// Spread tokens by 1ms so the comparison is unambiguous; the
|
||||
// largest is index workers-1.
|
||||
tokens[i] = base + int64(i)*int64(time.Millisecond)
|
||||
}
|
||||
expected := tokens[workers-1]
|
||||
|
||||
var ready sync.WaitGroup
|
||||
ready.Add(workers)
|
||||
var start sync.WaitGroup
|
||||
start.Add(1)
|
||||
var done sync.WaitGroup
|
||||
done.Add(workers)
|
||||
|
||||
// require.* calls t.FailNow which is documented as unsafe from
|
||||
// non-test goroutines (it calls runtime.Goexit on the wrong stack and
|
||||
// races with the WaitGroup). Collect errors here and assert from the
|
||||
// main goroutine after done.Wait().
|
||||
errs := make(chan error, workers)
|
||||
|
||||
for i := 0; i < workers; i++ {
|
||||
token := tokens[i]
|
||||
go func() {
|
||||
defer done.Done()
|
||||
ready.Done()
|
||||
start.Wait()
|
||||
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token)
|
||||
}()
|
||||
}
|
||||
|
||||
ready.Wait()
|
||||
start.Done()
|
||||
done.Wait()
|
||||
close(errs)
|
||||
for err := range errs {
|
||||
require.NoError(t, err, "MarkPeerConnected must not error under contention")
|
||||
}
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected, "peer should be connected after the race")
|
||||
require.Equal(t, expected, peer.Status.SessionStartedAt,
|
||||
"the largest token must win regardless of execution order")
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
@@ -1991,7 +2070,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
@@ -3970,6 +4049,96 @@ func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange(t *testi
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultAccountManager_UpdateAccountSettings_NetworkRangePreserved guards against
|
||||
// peer IP reallocation when a settings update carries the network range that is already
|
||||
// in use. Legacy accounts have Settings.NetworkRange unset in the DB while network.Net
|
||||
// holds the actual allocated overlay; the dashboard backfills the GET response from
|
||||
// network.Net and echoes the value back on PUT, so the diff must be against the
|
||||
// effective range to avoid renumbering every peer on an unrelated settings change.
|
||||
func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangePreserved(t *testing.T) {
|
||||
manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
settings, err := manager.Store.GetAccountSettings(ctx, store.LockingStrengthNone, account.Id)
|
||||
require.NoError(t, err)
|
||||
require.False(t, settings.NetworkRange.IsValid(), "precondition: new accounts leave Settings.NetworkRange unset")
|
||||
|
||||
network, err := manager.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, account.Id)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, network.Net.IP, "precondition: network.Net should be allocated")
|
||||
addr, ok := netip.AddrFromSlice(network.Net.IP)
|
||||
require.True(t, ok)
|
||||
ones, _ := network.Net.Mask.Size()
|
||||
effective := netip.PrefixFrom(addr.Unmap(), ones)
|
||||
require.True(t, effective.IsValid())
|
||||
|
||||
before := map[string]netip.Addr{peer1.ID: peer1.IP, peer2.ID: peer2.IP, peer3.ID: peer3.IP}
|
||||
|
||||
// Round-trip the effective range as if the dashboard echoed back the GET-backfilled value.
|
||||
_, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
|
||||
NetworkRange: effective,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
peers, err := manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peers, len(before))
|
||||
for _, p := range peers {
|
||||
assert.Equal(t, before[p.ID], p.IP, "peer %s IP should not change when range matches effective", p.ID)
|
||||
}
|
||||
|
||||
// Carrying the same range with host bits set must also be a no-op once canonicalized.
|
||||
hostBitsForm := netip.PrefixFrom(peer1.IP, ones)
|
||||
require.NotEqual(t, effective, hostBitsForm, "precondition: host-bit form should differ before masking")
|
||||
_, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
|
||||
NetworkRange: hostBitsForm,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "")
|
||||
require.NoError(t, err)
|
||||
for _, p := range peers {
|
||||
assert.Equal(t, before[p.ID], p.IP, "peer %s IP should not change for host-bit-set equivalent range", p.ID)
|
||||
}
|
||||
|
||||
// Omitting NetworkRange (invalid prefix) must also be a no-op.
|
||||
_, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "")
|
||||
require.NoError(t, err)
|
||||
for _, p := range peers {
|
||||
assert.Equal(t, before[p.ID], p.IP, "peer %s IP should not change when NetworkRange omitted", p.ID)
|
||||
}
|
||||
|
||||
// Sanity: an actually different range still triggers reallocation.
|
||||
newRange := netip.MustParsePrefix("100.99.0.0/16")
|
||||
_, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
|
||||
NetworkRange: newRange,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "")
|
||||
require.NoError(t, err)
|
||||
for _, p := range peers {
|
||||
assert.True(t, newRange.Contains(p.IP), "peer %s should be in new range %s, got %s", p.ID, newRange, p.IP)
|
||||
assert.NotEqual(t, before[p.ID], p.IP, "peer %s IP should change on real range update", p.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdateAccountSettings_IPv6EnabledGroups(t *testing.T) {
|
||||
manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -15,15 +15,13 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||
@@ -32,12 +30,10 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
|
||||
|
||||
nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
nbgroups "github.com/netbirdio/netbird/management/server/groups"
|
||||
@@ -56,17 +52,14 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const apiPrefix = "/api"
|
||||
|
||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) {
|
||||
func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter, isValidChildAccount middleware.IsValidChildAccountFunc) (http.Handler, error) {
|
||||
|
||||
// Register bypass paths for unauthenticated endpoints
|
||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||
@@ -100,25 +93,16 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
accountManager.GetUserFromUserAuth,
|
||||
rateLimiter,
|
||||
appMetrics.GetMeter(),
|
||||
isValidChildAccount,
|
||||
)
|
||||
|
||||
corsMiddleware := cors.AllowAll()
|
||||
|
||||
rootRouter := mux.NewRouter()
|
||||
metricsMiddleware := appMetrics.HTTPMiddleware()
|
||||
|
||||
prefix := apiPrefix
|
||||
router := rootRouter.PathPrefix(prefix).Subrouter()
|
||||
|
||||
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler)
|
||||
|
||||
if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, integratedValidator, appMetrics.GetMeter(), permissionsManager, peersManager, proxyController, settingsManager); err != nil {
|
||||
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
||||
}
|
||||
|
||||
// Check if embedded IdP is enabled for instance manager
|
||||
embeddedIdP, embeddedIdpEnabled := idpManager.(*idpmanager.EmbeddedIdPManager)
|
||||
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), embeddedIdP)
|
||||
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), idpManager)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create instance manager: %w", err)
|
||||
}
|
||||
@@ -154,10 +138,5 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
oauthHandler.RegisterEndpoints(router)
|
||||
}
|
||||
|
||||
// Mount embedded IdP handler at /oauth2 path if configured
|
||||
if embeddedIdpEnabled {
|
||||
rootRouter.PathPrefix("/oauth2").Handler(corsMiddleware.Handler(embeddedIdP.Handler()))
|
||||
}
|
||||
|
||||
return rootRouter, nil
|
||||
return router, nil
|
||||
}
|
||||
|
||||
@@ -444,7 +444,7 @@ func (m *testServiceManager) GetServiceByDomain(ctx context.Context, domain stri
|
||||
return m.store.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (m *testServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
|
||||
func (m *testServiceManager) GetClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -11,8 +11,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
@@ -27,6 +25,8 @@ type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) err
|
||||
|
||||
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
|
||||
type IsValidChildAccountFunc func(ctx context.Context, userID, accountID, childAccountID string) bool
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
authManager serverauth.Manager
|
||||
@@ -35,6 +35,7 @@ type AuthMiddleware struct {
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc
|
||||
rateLimiter *APIRateLimiter
|
||||
patUsageTracker *PATUsageTracker
|
||||
isValidChildAccount IsValidChildAccountFunc
|
||||
}
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
@@ -45,6 +46,7 @@ func NewAuthMiddleware(
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||
rateLimiter *APIRateLimiter,
|
||||
meter metric.Meter,
|
||||
isValidChildAccount IsValidChildAccountFunc,
|
||||
) *AuthMiddleware {
|
||||
var patUsageTracker *PATUsageTracker
|
||||
if meter != nil {
|
||||
@@ -62,6 +64,7 @@ func NewAuthMiddleware(
|
||||
getUserFromUserAuth: getUserFromUserAuth,
|
||||
rateLimiter: rateLimiter,
|
||||
patUsageTracker: patUsageTracker,
|
||||
isValidChildAccount: isValidChildAccount,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,7 +127,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
||||
}
|
||||
|
||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||
if integrations.IsValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||
if m.isValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||
userAuth.AccountId = impersonate[0]
|
||||
userAuth.IsChild = true
|
||||
}
|
||||
@@ -203,7 +206,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
}
|
||||
|
||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||
if integrations.IsValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||
if m.isValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||
userAuth.AccountId = impersonate[0]
|
||||
userAuth.IsChild = true
|
||||
}
|
||||
|
||||
@@ -211,6 +211,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
},
|
||||
disabledLimiter,
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handlerToTest := authMiddleware.Handler(nextHandler)
|
||||
@@ -270,6 +271,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -322,6 +324,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -365,6 +368,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -409,6 +413,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -473,6 +478,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -532,6 +538,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -587,6 +594,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -687,6 +695,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
},
|
||||
disabledLimiter,
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
for _, tc := range tt {
|
||||
|
||||
@@ -1319,7 +1319,7 @@ func Test_NetworkRouters_Update(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Update non-existing router creates it",
|
||||
name: "Update non-existing router returns not found",
|
||||
networkId: "testNetworkId",
|
||||
routerId: "nonExistingRouterId",
|
||||
requestBody: &api.NetworkRouterRequest{
|
||||
@@ -1328,11 +1328,7 @@ func Test_NetworkRouters_Update(t *testing.T) {
|
||||
Metric: 100,
|
||||
Enabled: true,
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, router *api.NetworkRouter) {
|
||||
t.Helper()
|
||||
assert.Equal(t, "nonExistingRouterId", router.Id)
|
||||
},
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "Update router with both peer and peer_groups",
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
@@ -135,7 +136,8 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
||||
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -264,7 +266,8 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
||||
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
@@ -138,10 +140,13 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
return nil, fmt.Errorf("invalid IdP storage config: %w", err)
|
||||
}
|
||||
|
||||
// Build CLI redirect URIs including the device callback (both relative and absolute)
|
||||
// Build CLI redirect URIs including the device callback. Dex uses the issuer-relative
|
||||
// path (for example, /oauth2/device/callback) when completing the device flow, so
|
||||
// include it explicitly in addition to the legacy bare path and absolute URL.
|
||||
cliRedirectURIs := c.CLIRedirectURIs
|
||||
cliRedirectURIs = append(cliRedirectURIs, "/device/callback")
|
||||
cliRedirectURIs = append(cliRedirectURIs, c.Issuer+"/device/callback")
|
||||
cliRedirectURIs = append(cliRedirectURIs, issuerRelativeDeviceCallback(c.Issuer))
|
||||
cliRedirectURIs = append(cliRedirectURIs, strings.TrimSuffix(c.Issuer, "/")+"/device/callback")
|
||||
|
||||
// Build dashboard redirect URIs including the OAuth callback for proxy authentication
|
||||
dashboardRedirectURIs := c.DashboardRedirectURIs
|
||||
@@ -154,6 +159,10 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
// MGMT api and the dashboard, adding baseURL means less configuration for the instance admin
|
||||
dashboardPostLogoutRedirectURIs = append(dashboardPostLogoutRedirectURIs, baseURL)
|
||||
|
||||
redirectURIs := make([]string, 0)
|
||||
redirectURIs = append(redirectURIs, cliRedirectURIs...)
|
||||
redirectURIs = append(redirectURIs, dashboardRedirectURIs...)
|
||||
|
||||
cfg := &dex.YAMLConfig{
|
||||
Issuer: c.Issuer,
|
||||
Storage: dex.Storage{
|
||||
@@ -179,14 +188,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
ID: staticClientDashboard,
|
||||
Name: "NetBird Dashboard",
|
||||
Public: true,
|
||||
RedirectURIs: dashboardRedirectURIs,
|
||||
RedirectURIs: redirectURIs,
|
||||
PostLogoutRedirectURIs: sanitizePostLogoutRedirectURIs(dashboardPostLogoutRedirectURIs),
|
||||
},
|
||||
{
|
||||
ID: staticClientCLI,
|
||||
Name: "NetBird CLI",
|
||||
Public: true,
|
||||
RedirectURIs: cliRedirectURIs,
|
||||
RedirectURIs: redirectURIs,
|
||||
},
|
||||
},
|
||||
StaticConnectors: c.StaticConnectors,
|
||||
@@ -217,6 +226,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func issuerRelativeDeviceCallback(issuer string) string {
|
||||
u, err := url.Parse(issuer)
|
||||
if err != nil || u.Path == "" {
|
||||
return "/device/callback"
|
||||
}
|
||||
return path.Join(u.Path, "/device/callback")
|
||||
}
|
||||
|
||||
// Due to how the frontend generates the logout, sometimes it appends a trailing slash
|
||||
// and because Dex only allows exact matches, we need to make sure we always have both
|
||||
// versions of each provided uri
|
||||
@@ -299,7 +316,7 @@ func resolveSessionCookieEncryptionKey(configuredKey string) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("invalid embedded IdP session cookie encryption key: %s (or sessionCookieEncryptionKey) must be 16, 24, or 32 bytes as a raw string or base64-encoded to one of those lengths; got %d raw bytes", sessionCookieEncryptionKeyEnv, len([]byte(key)))
|
||||
return "", fmt.Errorf("invalid embedded IdP session cookie encryption key:%s (or sessionCookieEncryptionKey) must be 16, 24, or 32 bytes as a raw string or base64-encoded to one of those lengths; got %d raw bytes", sessionCookieEncryptionKeyEnv, len([]byte(key)))
|
||||
}
|
||||
|
||||
func validSessionCookieEncryptionKeyLength(length int) bool {
|
||||
|
||||
@@ -314,6 +314,34 @@ func TestEmbeddedIdPManager_UpdateUserPassword(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmbeddedIdPConfig_ToYAMLConfig_IncludesDeviceCallbackRedirectURI(t *testing.T) {
|
||||
config := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "https://example.com/oauth2",
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: filepath.Join(t.TempDir(), "dex.db"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
yamlConfig, err := config.ToYAMLConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
var cliRedirectURIs []string
|
||||
for _, client := range yamlConfig.StaticClients {
|
||||
if client.ID == staticClientCLI {
|
||||
cliRedirectURIs = client.RedirectURIs
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotEmpty(t, cliRedirectURIs)
|
||||
assert.Contains(t, cliRedirectURIs, "/device/callback")
|
||||
assert.Contains(t, cliRedirectURIs, "/oauth2/device/callback")
|
||||
assert.Contains(t, cliRedirectURIs, "https://example.com/oauth2/device/callback")
|
||||
}
|
||||
|
||||
func TestEmbeddedIdPConfig_ToYAMLConfig_SessionCookieEncryptionKey(t *testing.T) {
|
||||
t.Setenv(sessionCookieEncryptionKeyEnv, "")
|
||||
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
package validator
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
cachestore "github.com/eko/gocache/lib/v4/store"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type IntegratedValidatorImpl struct{}
|
||||
|
||||
func NewIntegratedValidator(_ context.Context, _ peers.Manager, _ settings.Manager, _ activity.Store, _ cachestore.StoreInterface) (*IntegratedValidatorImpl, error) {
|
||||
return &IntegratedValidatorImpl{}, nil
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) ValidateExtraSettings(context.Context, *types.ExtraSettings, *types.ExtraSettings, string, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) ValidatePeer(_ context.Context, update *nbpeer.Peer, _ *nbpeer.Peer, _ string, _ string, _ string, _ []string, _ *types.ExtraSettings) (*nbpeer.Peer, bool, error) {
|
||||
return update, false, nil
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) PreparePeer(_ context.Context, _ string, peer *nbpeer.Peer, _ []string, _ *types.ExtraSettings, _ bool) *nbpeer.Peer {
|
||||
return peer.Copy()
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) IsNotValidPeer(_ context.Context, _ string, _ *nbpeer.Peer, _ []string, _ *types.ExtraSettings) (bool, bool, error) {
|
||||
return false, false, nil
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) GetValidatedPeers(_ context.Context, _ string, _ []*types.Group, peers []*nbpeer.Peer, _ *types.ExtraSettings) (map[string]struct{}, error) {
|
||||
validatedPeers := make(map[string]struct{})
|
||||
for _, p := range peers {
|
||||
validatedPeers[p.ID] = struct{}{}
|
||||
}
|
||||
return validatedPeers, nil
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) GetInvalidPeers(_ context.Context, _ string, _ *types.ExtraSettings) (map[string]string, error) {
|
||||
return make(map[string]string), nil
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) PeerDeleted(_ context.Context, _, _ string, _ *types.ExtraSettings) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) SetPeerInvalidationListener(_ func(accountID string, peerIDs []string)) {
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) Stop(_ context.Context) {
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) ValidateFlowResponse(_ context.Context, _ string, flowResponse *proto.PKCEAuthorizationFlow) *proto.PKCEAuthorizationFlow {
|
||||
return flowResponse
|
||||
}
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbversion "github.com/netbirdio/netbird/version"
|
||||
)
|
||||
@@ -53,6 +54,7 @@ type DataSource interface {
|
||||
GetAllAccounts(ctx context.Context) []*types.Account
|
||||
GetStoreEngine() types.Engine
|
||||
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
|
||||
GetProxyMetrics(ctx context.Context) (store.ProxyMetrics, error)
|
||||
}
|
||||
|
||||
// ConnManager peer connection manager that holds state for current active connections
|
||||
@@ -223,6 +225,12 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
servicesAuthPassword int
|
||||
servicesAuthPin int
|
||||
servicesAuthOIDC int
|
||||
// Private-service signals — track adoption of NetBird-only mode
|
||||
// (services backed by an embedded proxy peer + access groups).
|
||||
servicesPrivate int
|
||||
servicesPrivateWithGroups int
|
||||
servicesPrivateAccessGroupsSum int
|
||||
servicesWithDirectUpstream int
|
||||
)
|
||||
start := time.Now()
|
||||
metricsProperties := make(properties)
|
||||
@@ -380,9 +388,31 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled {
|
||||
servicesAuthOIDC++
|
||||
}
|
||||
|
||||
if service.Private {
|
||||
servicesPrivate++
|
||||
if len(service.AccessGroups) > 0 {
|
||||
servicesPrivateWithGroups++
|
||||
}
|
||||
servicesPrivateAccessGroupsSum += len(service.AccessGroups)
|
||||
}
|
||||
|
||||
for _, target := range service.Targets {
|
||||
if target.Options.DirectUpstream {
|
||||
servicesWithDirectUpstream++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Proxy / BYOP cluster signals come from the proxies table aggregated
|
||||
// across all accounts in a single store query; nil on FileStore.
|
||||
proxyMetrics, err := w.dataSource.GetProxyMetrics(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("collect proxy metrics: %v", err)
|
||||
}
|
||||
|
||||
minActivePeerVersion, maxActivePeerVersion := getMinMaxVersion(peerActiveVersions)
|
||||
metricsProperties["uptime"] = uptime
|
||||
metricsProperties["accounts"] = accounts
|
||||
@@ -430,6 +460,15 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
metricsProperties["services_auth_password"] = servicesAuthPassword
|
||||
metricsProperties["services_auth_pin"] = servicesAuthPin
|
||||
metricsProperties["services_auth_oidc"] = servicesAuthOIDC
|
||||
metricsProperties["services_private"] = servicesPrivate
|
||||
metricsProperties["services_private_with_access_groups"] = servicesPrivateWithGroups
|
||||
metricsProperties["services_private_access_groups_sum"] = servicesPrivateAccessGroupsSum
|
||||
metricsProperties["services_with_direct_upstream"] = servicesWithDirectUpstream
|
||||
metricsProperties["proxy_clusters"] = proxyMetrics.Clusters
|
||||
metricsProperties["proxy_clusters_byop"] = proxyMetrics.ClustersBYOP
|
||||
metricsProperties["proxy_clusters_private"] = proxyMetrics.ClustersPrivate
|
||||
metricsProperties["proxies"] = proxyMetrics.Proxies
|
||||
metricsProperties["proxies_connected"] = proxyMetrics.ProxiesConnected
|
||||
metricsProperties["custom_domains"] = customDomains
|
||||
metricsProperties["custom_domains_validated"] = customDomainsValidated
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -123,7 +124,7 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
||||
Enabled: true,
|
||||
Targets: []*rpservice.Target{
|
||||
{TargetType: "peer"},
|
||||
{TargetType: "host"},
|
||||
{TargetType: "host", Options: rpservice.TargetOptions{DirectUpstream: true}},
|
||||
},
|
||||
Auth: rpservice.AuthConfig{
|
||||
PasswordAuth: &rpservice.PasswordAuthConfig{Enabled: true},
|
||||
@@ -141,6 +142,16 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
||||
},
|
||||
Meta: rpservice.Meta{Status: string(rpservice.StatusPending)},
|
||||
},
|
||||
{
|
||||
ID: "svc3-private",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
AccessGroups: []string{"grp-eng", "grp-ops"},
|
||||
Targets: []*rpservice.Target{
|
||||
{TargetType: "cluster", Options: rpservice.TargetOptions{DirectUpstream: true}},
|
||||
},
|
||||
Meta: rpservice.Meta{Status: string(rpservice.StatusActive)},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -254,6 +265,18 @@ func (mockDatasource) GetCustomDomainsCounts(_ context.Context) (int64, int64, e
|
||||
return 3, 2, nil
|
||||
}
|
||||
|
||||
// GetProxyMetrics returns canned proxy/cluster counts so the
|
||||
// generateProperties test can assert the BYOP signals end-to-end.
|
||||
func (mockDatasource) GetProxyMetrics(_ context.Context) (store.ProxyMetrics, error) {
|
||||
return store.ProxyMetrics{
|
||||
Clusters: 3,
|
||||
ClustersBYOP: 1,
|
||||
ClustersPrivate: 1,
|
||||
Proxies: 4,
|
||||
ProxiesConnected: 2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties
|
||||
func TestGenerateProperties(t *testing.T) {
|
||||
ds := mockDatasource{}
|
||||
@@ -393,17 +416,17 @@ func TestGenerateProperties(t *testing.T) {
|
||||
t.Errorf("expected 3 embedded_idp_count, got %v", properties["embedded_idp_count"])
|
||||
}
|
||||
|
||||
if properties["services"] != 2 {
|
||||
t.Errorf("expected 2 services, got %v", properties["services"])
|
||||
if properties["services"] != 3 {
|
||||
t.Errorf("expected 3 services, got %v", properties["services"])
|
||||
}
|
||||
if properties["services_enabled"] != 1 {
|
||||
t.Errorf("expected 1 services_enabled, got %v", properties["services_enabled"])
|
||||
if properties["services_enabled"] != 2 {
|
||||
t.Errorf("expected 2 services_enabled, got %v", properties["services_enabled"])
|
||||
}
|
||||
if properties["services_targets"] != 3 {
|
||||
t.Errorf("expected 3 services_targets, got %v", properties["services_targets"])
|
||||
if properties["services_targets"] != 4 {
|
||||
t.Errorf("expected 4 services_targets, got %v", properties["services_targets"])
|
||||
}
|
||||
if properties["services_status_active"] != 1 {
|
||||
t.Errorf("expected 1 services_status_active, got %v", properties["services_status_active"])
|
||||
if properties["services_status_active"] != 2 {
|
||||
t.Errorf("expected 2 services_status_active, got %v", properties["services_status_active"])
|
||||
}
|
||||
if properties["services_status_pending"] != 1 {
|
||||
t.Errorf("expected 1 services_status_pending, got %v", properties["services_status_pending"])
|
||||
@@ -420,6 +443,9 @@ func TestGenerateProperties(t *testing.T) {
|
||||
if properties["services_target_type_domain"] != 1 {
|
||||
t.Errorf("expected 1 services_target_type_domain, got %v", properties["services_target_type_domain"])
|
||||
}
|
||||
if properties["services_target_type_cluster"] != 1 {
|
||||
t.Errorf("expected 1 services_target_type_cluster, got %v", properties["services_target_type_cluster"])
|
||||
}
|
||||
if properties["services_auth_password"] != 1 {
|
||||
t.Errorf("expected 1 services_auth_password, got %v", properties["services_auth_password"])
|
||||
}
|
||||
@@ -429,6 +455,33 @@ func TestGenerateProperties(t *testing.T) {
|
||||
if properties["services_auth_pin"] != 0 {
|
||||
t.Errorf("expected 0 services_auth_pin, got %v", properties["services_auth_pin"])
|
||||
}
|
||||
if properties["services_private"] != 1 {
|
||||
t.Errorf("expected 1 services_private, got %v", properties["services_private"])
|
||||
}
|
||||
if properties["services_private_with_access_groups"] != 1 {
|
||||
t.Errorf("expected 1 services_private_with_access_groups, got %v", properties["services_private_with_access_groups"])
|
||||
}
|
||||
if properties["services_private_access_groups_sum"] != 2 {
|
||||
t.Errorf("expected 2 services_private_access_groups_sum, got %v", properties["services_private_access_groups_sum"])
|
||||
}
|
||||
if properties["services_with_direct_upstream"] != 2 {
|
||||
t.Errorf("expected 2 services_with_direct_upstream, got %v", properties["services_with_direct_upstream"])
|
||||
}
|
||||
if properties["proxy_clusters"] != int64(3) {
|
||||
t.Errorf("expected 3 proxy_clusters, got %v", properties["proxy_clusters"])
|
||||
}
|
||||
if properties["proxy_clusters_byop"] != int64(1) {
|
||||
t.Errorf("expected 1 proxy_clusters_byop, got %v", properties["proxy_clusters_byop"])
|
||||
}
|
||||
if properties["proxy_clusters_private"] != int64(1) {
|
||||
t.Errorf("expected 1 proxy_clusters_private, got %v", properties["proxy_clusters_private"])
|
||||
}
|
||||
if properties["proxies"] != int64(4) {
|
||||
t.Errorf("expected 4 proxies, got %v", properties["proxies"])
|
||||
}
|
||||
if properties["proxies_connected"] != int64(2) {
|
||||
t.Errorf("expected 2 proxies_connected, got %v", properties["proxies_connected"])
|
||||
}
|
||||
if properties["custom_domains"] != int64(3) {
|
||||
t.Errorf("expected 3 custom_domains, got %v", properties["custom_domains"])
|
||||
}
|
||||
|
||||
@@ -198,7 +198,11 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
|
||||
require.NoError(t, err, "Failed to insert account")
|
||||
|
||||
account.PeersG = []nbpeer.Peer{
|
||||
{AccountID: "1234", Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
|
||||
{
|
||||
AccountID: "1234",
|
||||
Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}},
|
||||
Status: &nbpeer.PeerStatus{LastSeen: time.Now()},
|
||||
},
|
||||
}
|
||||
|
||||
err = db.Save(account).Error
|
||||
|
||||
@@ -38,7 +38,8 @@ type MockAccountManager struct {
|
||||
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP, syncTime time.Time) error
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error
|
||||
MarkPeerDisconnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
|
||||
@@ -227,7 +228,14 @@ func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID str
|
||||
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
|
||||
func (am *MockAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
|
||||
// Mirror DefaultAccountManager.OnPeerDisconnected: drive the fencing
|
||||
// hook so tests that inject MarkPeerDisconnectedFunc actually observe
|
||||
// disconnect events. Falls through to nil when no hook is set, which
|
||||
// is the original behaviour.
|
||||
if am.MarkPeerDisconnectedFunc != nil {
|
||||
return am.MarkPeerDisconnectedFunc(ctx, peerPubKey, accountID, streamStartTime.UnixNano())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -328,13 +336,21 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth
|
||||
}
|
||||
|
||||
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
|
||||
if am.MarkPeerConnectedFunc != nil {
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP, syncTime)
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
// MarkPeerDisconnected mock implementation of MarkPeerDisconnected from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error {
|
||||
if am.MarkPeerDisconnectedFunc != nil {
|
||||
return am.MarkPeerDisconnectedFunc(ctx, peerKey, accountID, sessionStartedAt)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerDisconnected is not implemented")
|
||||
}
|
||||
|
||||
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
|
||||
func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error {
|
||||
if am.DeleteAccountFunc != nil {
|
||||
|
||||
@@ -34,8 +34,11 @@ func Test_GetAllNetworksReturnsNetworks(t *testing.T) {
|
||||
|
||||
networks, err := manager.GetAllNetworks(ctx, accountID, userID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, networks, 1)
|
||||
require.Equal(t, "testNetworkId", networks[0].ID)
|
||||
ids := make([]string, 0, len(networks))
|
||||
for _, n := range networks {
|
||||
ids = append(ids, n.ID)
|
||||
}
|
||||
require.ElementsMatch(t, []string{"testNetworkId", "secondNetworkId"}, ids)
|
||||
}
|
||||
|
||||
func Test_GetAllNetworksReturnsPermissionDenied(t *testing.T) {
|
||||
|
||||
@@ -102,7 +102,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
router.ID = xid.New().String()
|
||||
|
||||
err = transaction.SaveNetworkRouter(ctx, router)
|
||||
err = transaction.CreateNetworkRouter(ctx, router)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create network router: %w", err)
|
||||
}
|
||||
@@ -162,11 +162,20 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
|
||||
return fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
|
||||
if network.ID != router.NetworkID {
|
||||
existing, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, router.AccountID, router.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network router: %w", err)
|
||||
}
|
||||
|
||||
if existing.AccountID != router.AccountID {
|
||||
return status.NewNetworkRouterNotFoundError(router.ID)
|
||||
}
|
||||
|
||||
if existing.NetworkID != router.NetworkID {
|
||||
return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
|
||||
}
|
||||
|
||||
err = transaction.SaveNetworkRouter(ctx, router)
|
||||
err = transaction.UpdateNetworkRouter(ctx, router)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update network router: %w", err)
|
||||
}
|
||||
|
||||
@@ -195,6 +195,7 @@ func Test_UpdateRouterSuccessfully(t *testing.T) {
|
||||
if err != nil {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
router.ID = "testRouterId"
|
||||
|
||||
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
|
||||
if err != nil {
|
||||
@@ -210,6 +211,102 @@ func Test_UpdateRouterSuccessfully(t *testing.T) {
|
||||
require.Equal(t, router.Metric, updatedRouter.Metric)
|
||||
}
|
||||
|
||||
func Test_UpdateRouterRejectsCrossAccountID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
userID := "testAdminId"
|
||||
|
||||
// Admin of testAccountId tries to update a router that belongs to otherAccountId
|
||||
// by passing the other account's router ID through the URL.
|
||||
router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 1, true)
|
||||
if err != nil {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
router.ID = "otherRouterId"
|
||||
|
||||
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
am := mock_server.MockAccountManager{}
|
||||
manager := NewManager(s, permissionsManager, &am)
|
||||
|
||||
updatedRouter, err := manager.UpdateRouter(ctx, userID, router)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, updatedRouter)
|
||||
|
||||
// The other account's router must be untouched.
|
||||
stored, err := s.GetNetworkRouterByID(ctx, store.LockingStrengthNone, "otherAccountId", "otherRouterId")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "otherAccountId", stored.AccountID)
|
||||
require.Equal(t, "otherNetworkId", stored.NetworkID)
|
||||
require.Equal(t, "otherPeer", stored.Peer)
|
||||
require.Equal(t, 1, stored.Metric)
|
||||
}
|
||||
|
||||
func Test_CreateRouterRejectsCrossAccountID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
userID := "testAdminId"
|
||||
|
||||
// Admin of testAccountId tries to create a router in otherAccountId's network.
|
||||
// The permission check is on router.AccountID (their own), but the network
|
||||
// lookup must fail because (testAccountId, otherNetworkId) does not exist.
|
||||
router, err := types.NewNetworkRouter("testAccountId", "otherNetworkId", "testPeerId", []string{}, false, 1, true)
|
||||
if err != nil {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
am := mock_server.MockAccountManager{}
|
||||
manager := NewManager(s, permissionsManager, &am)
|
||||
|
||||
createdRouter, err := manager.CreateRouter(ctx, userID, router)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, createdRouter)
|
||||
|
||||
// No router should have been created in either account's scope under otherNetworkId.
|
||||
routersInOther, err := s.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, "otherAccountId", "otherNetworkId")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, routersInOther, 1)
|
||||
require.Equal(t, "otherRouterId", routersInOther[0].ID)
|
||||
}
|
||||
|
||||
func Test_UpdateRouterRejectsNetworkMismatch(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
userID := "testAdminId"
|
||||
|
||||
// The router exists in testNetworkId, but the caller submits secondNetworkId
|
||||
// (a different network in the same account). The update must be refused.
|
||||
router, err := types.NewNetworkRouter("testAccountId", "secondNetworkId", "testPeerId", []string{}, false, 1, true)
|
||||
if err != nil {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
router.ID = "testRouterId"
|
||||
|
||||
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
am := mock_server.MockAccountManager{}
|
||||
manager := NewManager(s, permissionsManager, &am)
|
||||
|
||||
updatedRouter, err := manager.UpdateRouter(ctx, userID, router)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, updatedRouter)
|
||||
|
||||
stored, err := s.GetNetworkRouterByID(ctx, store.LockingStrengthNone, "testAccountId", "testRouterId")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "testNetworkId", stored.NetworkID)
|
||||
}
|
||||
|
||||
func Test_UpdateRouterFailsWithPermissionDenied(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
userID := "testUserId"
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
@@ -29,6 +28,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -63,98 +63,148 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
return am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
|
||||
}
|
||||
|
||||
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
||||
// syncTime is used as the LastSeen timestamp and for stale request detection
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
|
||||
var peer *nbpeer.Peer
|
||||
var settings *types.Settings
|
||||
var expired bool
|
||||
var err error
|
||||
var skipped bool
|
||||
// MarkPeerConnected marks a peer as connected with optimistic-locked
|
||||
// fencing on PeerStatus.SessionStartedAt. The sessionStartedAt argument
|
||||
// is the start time of the gRPC sync stream that owns this update,
|
||||
// expressed as Unix nanoseconds — only the call whose token is greater
|
||||
// than what's stored wins. LastSeen is written by the database itself;
|
||||
// we never pass it down.
|
||||
//
|
||||
// Disconnects use MarkPeerDisconnected and require the session to match
|
||||
// exactly; see PeerStatus.SessionStartedAt for the protocol.
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusConnect, time.Since(start))
|
||||
}()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
|
||||
if err != nil {
|
||||
outcome := telemetry.PeerStatusError
|
||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
outcome = telemetry.PeerStatusPeerNotFound
|
||||
}
|
||||
|
||||
if connected && !syncTime.After(peer.Status.LastSeen) {
|
||||
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s >= syncTime=%s), skipping connect",
|
||||
peer.ID, peer.Status.LastSeen.Format(time.RFC3339), syncTime.Format(time.RFC3339))
|
||||
skipped = true
|
||||
return nil
|
||||
}
|
||||
|
||||
expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID, syncTime)
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, outcome)
|
||||
return err
|
||||
})
|
||||
if skipped {
|
||||
}
|
||||
|
||||
updated, err := am.Store.MarkPeerConnectedIfNewerSession(ctx, accountID, peer.ID, sessionStartedAt)
|
||||
if err != nil {
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusError)
|
||||
return err
|
||||
}
|
||||
if !updated {
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusStale)
|
||||
log.WithContext(ctx).Tracef("peer %s already has a newer session in store, skipping connect", peer.ID)
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusApplied)
|
||||
|
||||
if am.geo != nil && realIP != nil {
|
||||
am.updatePeerLocationIfChanged(ctx, accountID, peer, realIP)
|
||||
}
|
||||
|
||||
expired := peer.Status != nil && peer.Status.LoginExpired
|
||||
|
||||
if peer.AddedWithSSOLogin() {
|
||||
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
|
||||
am.schedulePeerLoginExpiration(ctx, accountID)
|
||||
}
|
||||
|
||||
if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
if expired {
|
||||
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
|
||||
if err != nil {
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
|
||||
return fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// An embedded proxy peer flipping to connected is the trigger for
|
||||
// SynthesizePrivateServiceZones to emit DNS A records pointing at its
|
||||
// tunnel IP. Without an account-wide netmap recompute, user peers keep
|
||||
// the stale synth (or no synth at all on first connect) until some
|
||||
// other change pokes the controller. Fire OnPeersUpdated so the
|
||||
// buffered recompute fans the new state out to every peer.
|
||||
if peer.ProxyMeta.Embedded {
|
||||
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
|
||||
log.WithContext(ctx).Warnf("notify network map controller of embedded proxy %s connect: %v", peer.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string, syncTime time.Time) (bool, error) {
|
||||
oldStatus := peer.Status.Copy()
|
||||
newStatus := oldStatus
|
||||
newStatus.LastSeen = syncTime
|
||||
newStatus.Connected = connected
|
||||
// whenever peer got connected that means that it logged in successfully
|
||||
if newStatus.Connected {
|
||||
newStatus.LoginExpired = false
|
||||
}
|
||||
peer.Status = newStatus
|
||||
// MarkPeerDisconnected marks a peer as disconnected, but only when the
|
||||
// stored session token matches the one passed in. A mismatch means a
|
||||
// newer stream has already taken ownership of the peer — disconnects from
|
||||
// the older stream are ignored. LastSeen is written by the database.
|
||||
func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerPubKey string, accountID string, sessionStartedAt int64) error {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusDisconnect, time.Since(start))
|
||||
}()
|
||||
|
||||
if geo != nil && realIP != nil {
|
||||
location, err := geo.Lookup(realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
|
||||
} else {
|
||||
peer.Location.ConnectionIP = realIP
|
||||
peer.Location.CountryCode = location.Country.ISOCode
|
||||
peer.Location.CityName = location.City.Names.En
|
||||
peer.Location.GeoNameID = location.City.GeonameID
|
||||
err = transaction.SavePeerLocation(ctx, accountID, peer)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
|
||||
}
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
|
||||
if err != nil {
|
||||
outcome := telemetry.PeerStatusError
|
||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
outcome = telemetry.PeerStatusPeerNotFound
|
||||
}
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, outcome)
|
||||
return err
|
||||
}
|
||||
|
||||
updated, err := am.Store.MarkPeerDisconnectedIfSameSession(ctx, accountID, peer.ID, sessionStartedAt)
|
||||
if err != nil {
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusError)
|
||||
return err
|
||||
}
|
||||
if !updated {
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusStale)
|
||||
log.WithContext(ctx).Tracef("peer %s session token mismatch on disconnect (token=%d), skipping",
|
||||
peer.ID, sessionStartedAt)
|
||||
return nil
|
||||
}
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusApplied)
|
||||
|
||||
// Symmetric with MarkPeerConnected: when an embedded proxy peer goes
|
||||
// offline, drive an account-wide netmap recompute so the synthesized
|
||||
// DNS records that pointed at it are pulled. Without this the records
|
||||
// linger client-side at TTL until something else triggers a refresh.
|
||||
if peer.ProxyMeta.Embedded {
|
||||
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
|
||||
log.WithContext(ctx).Warnf("notify network map controller of embedded proxy %s disconnect: %v", peer.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("saving peer status for peer %s is connected: %t", peer.ID, connected)
|
||||
return nil
|
||||
}
|
||||
|
||||
err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus)
|
||||
if err != nil {
|
||||
return false, err
|
||||
// updatePeerLocationIfChanged refreshes the geolocation on a separate
|
||||
// row update, only when the connection IP actually changed. Geo lookups
|
||||
// are expensive so we skip same-IP reconnects.
|
||||
func (am *DefaultAccountManager) updatePeerLocationIfChanged(ctx context.Context, accountID string, peer *nbpeer.Peer, realIP net.IP) {
|
||||
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) {
|
||||
return
|
||||
}
|
||||
location, err := am.geo.Lookup(realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
|
||||
return
|
||||
}
|
||||
peer.Location.ConnectionIP = realIP
|
||||
peer.Location.CountryCode = location.Country.ISOCode
|
||||
peer.Location.CityName = location.City.Names.En
|
||||
peer.Location.GeoNameID = location.City.GeonameID
|
||||
if err := am.Store.SavePeerLocation(ctx, accountID, peer); err != nil {
|
||||
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
|
||||
}
|
||||
|
||||
return oldStatus.LoginExpired, nil
|
||||
}
|
||||
|
||||
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated.
|
||||
|
||||
@@ -74,8 +74,19 @@ type ProxyMeta struct {
|
||||
}
|
||||
|
||||
type PeerStatus struct { //nolint:revive
|
||||
// LastSeen is the last time peer was connected to the management service
|
||||
// LastSeen is the last time the peer status was updated (i.e. the last
|
||||
// time we observed the peer being alive on a sync stream). Written by
|
||||
// the database (CURRENT_TIMESTAMP) — callers do not supply it.
|
||||
LastSeen time.Time
|
||||
// SessionStartedAt records when the currently-active sync stream began,
|
||||
// stored as Unix nanoseconds. It acts as the optimistic-locking token
|
||||
// for status updates: a stream is only allowed to mutate the peer's
|
||||
// status when its own token strictly exceeds the stored token (when connecting)
|
||||
// or matches it exactly (for disconnects). Zero means "no
|
||||
// active session". Integer nanoseconds are used so equality is
|
||||
// precision-safe across drivers, and so the predicates compose to a
|
||||
// single bigint comparison.
|
||||
SessionStartedAt int64 `gorm:"not null;default:0"`
|
||||
// Connected indicates whether peer is connected to the management service or not
|
||||
Connected bool
|
||||
// LoginExpired
|
||||
@@ -375,10 +386,14 @@ func (p *Peer) EventMeta(dnsDomain string) map[string]any {
|
||||
return meta
|
||||
}
|
||||
|
||||
// Copy PeerStatus
|
||||
// Copy PeerStatus. SessionStartedAt must be propagated so clone-based
|
||||
// callers (Peer.Copy, MarkLoginExpired, UpdateLastLogin) don't silently
|
||||
// reset the fencing token to zero — that would let any subsequent
|
||||
// SavePeerStatus write reopen the optimistic-lock window.
|
||||
func (p *PeerStatus) Copy() *PeerStatus {
|
||||
return &PeerStatus{
|
||||
LastSeen: p.LastSeen,
|
||||
SessionStartedAt: p.SessionStartedAt,
|
||||
Connected: p.Connected,
|
||||
LoginExpired: p.LoginExpired,
|
||||
RequiresApproval: p.RequiresApproval,
|
||||
|
||||
@@ -2218,6 +2218,9 @@ func Test_IsUniqueConstraintError(t *testing.T) {
|
||||
ID: "test-peer-id",
|
||||
AccountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
DNSLabel: "test-peer-dns-label",
|
||||
Status: &nbpeer.PeerStatus{
|
||||
LastSeen: time.Now(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
@@ -33,9 +32,6 @@ func (n *NBVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, err
|
||||
return true, nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("peer %s NB version %s is older than minimum allowed version %s",
|
||||
peer.ID, peer.Meta.WtVersion, n.MinVersion)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user