Compare commits

..

1 Commits

Author SHA1 Message Date
mlsmaycon
5eb28acb11 [management] Account-scoped ephemeral peer cleanup
Replace the per-peer linked list with a per-account map keyed by
accountID. Each entry holds only the latest disconnect timestamp we
have observed for that account and a single timer that fires the next
sweep. Sweeps query the database for the authoritative stale set,
batch the deletes through peers.Manager.DeletePeers, then drop the
account from the tracker when lastDisc + lifeTime <= now (else
re-arm at horizon + cleanupWindow).

The drop rule is the entire termination story: an account stays
tracked only while OnPeerDisconnected keeps refreshing the
timestamp. There is no internal feedback loop that can advance
lastDisc on its own, so once disconnects stop the account drops in
at most one sweep.

A timestamp beats the ref-counter alternative because the counter
drifts positive in three real situations the cleanup loop has no
signal for: peers deleted via the API while offline, peers that
reconnect within the lifetime window, and management restarts. The
timestamp design never claims to know the size of the stale set —
it only knows the latest disconnect we observed and uses that to
bound when it is safe to drop the account.

OnPeerConnected becomes a no-op. The sweep query already filters
reconnected peers at the database level (peer_status_connected =
false in the WHERE clause), so there is nothing the in-memory
tracker needs to do on reconnect. The interface method is preserved
for call-site compatibility.

LoadInitialPeers no longer runs the catch-up query synchronously.
It schedules a deferred load via time.AfterFunc at a random delay
between 8 and 10 minutes. Without the jitter, every management
replica in a fleet-wide deploy would issue the catch-up query
simultaneously. The catch-up itself is one GROUP BY against the
peers table:
```sql
  SELECT account_id, MAX(peer_status_last_seen)
  FROM peers
  WHERE ephemeral = true AND peer_status_connected = false
  GROUP BY account_id
```
For each row the tracker seeds an entry and arms a sweep at
max(now, last_seen + lifeTime) + cleanupWindow — so accounts whose
backlog is already stale get cleaned soon after the delay elapses,
and accounts that disconnected recently wait the remaining window.
OnPeerDisconnected calls that arrive during the delay window seed
the tracker live, and the catch-up query skips accounts that are
already tracked.

Stop() cancels both the deferred initial-load timer and every
per-account sweep timer, and flips a stopped flag so subsequent
OnPeerDisconnected calls are ignored. This makes restarts and test
teardown clean.

Two new store methods:
  GetStaleEphemeralPeerIDsForAccount(ctx, accountID, olderThan)
  GetEphemeralAccountsLastDisconnect(ctx)
Both are scoped, indexable queries that the existing peers table
supports without schema changes.

The pending metric is renamed from
management.ephemeral.peers.pending to
management.ephemeral.accounts.tracked to reflect the new semantics
(it now counts accounts on the cleanup list, not peers). Method
names on the metrics type are unchanged so no production call site
has to move. No new metric labels, no per-account cardinality.

The algorithm was validated against an in-memory SQLite peers
table through an 11-scenario prototype kept under proto/, including
pathological-churn and 4-hour randomized simulations. All scenarios
terminate; max observed per-account sweep rate stays bounded near
the lifeTime + cleanupWindow cadence even under sustained
disconnect churn.

Verification: go build, go vet, race-clean tests across the
ephemeral, store, and telemetry packages, plus a clean
golangci-lint pass on the touched packages.
2026-05-19 09:50:14 +02:00
125 changed files with 2135 additions and 11397 deletions

View File

@@ -12,7 +12,6 @@
- [ ] Is a feature enhancement
- [ ] It is a refactor
- [ ] Created tests that fail without the change (if possible)
- [ ] This change does **not** modify the public API, gRPC protocols, functionality behavior, CLI / service flags, or introduce a new feature — **OR** I have discussed it with the NetBird team beforehand (link the issue / Slack thread in the description). See [CONTRIBUTING.md](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTING.md#discuss-changes-with-the-netbird-team-first).
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).

View File

@@ -20,66 +20,34 @@ jobs:
per_page: 100,
});
const modifiedPbFiles = files.filter(
f => f.filename.endsWith('.pb.go') && f.status === 'modified'
);
if (modifiedPbFiles.length === 0) {
console.log('No modified .pb.go files to check');
const pbFiles = files.filter(f => f.filename.endsWith('.pb.go'));
const missingPatch = pbFiles.filter(f => !f.patch).map(f => f.filename);
if (missingPatch.length > 0) {
core.setFailed(
`Cannot inspect patch data for:\n` +
missingPatch.map(f => `- ${f}`).join('\n') +
`\nThis can happen with very large PRs. Verify proto versions manually.`
);
return;
}
const versionPattern = /^\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
const baseSha = context.payload.pull_request.base.sha;
const headSha = context.payload.pull_request.head.sha;
async function getVersionHeader(path, ref) {
try {
const res = await github.rest.repos.getContent({
owner: context.repo.owner,
repo: context.repo.repo,
path,
ref,
});
if (!res.data.content) {
return { ok: false, reason: 'no inline content (file too large)' };
}
const content = Buffer.from(res.data.content, 'base64').toString('utf8');
const lines = content
.split('\n')
.slice(0, 20)
.filter(line => versionPattern.test(line));
return { ok: true, lines };
} catch (e) {
return { ok: false, reason: e.message };
}
}
const versionPattern = /^[+-]\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
const violations = [];
for (const file of modifiedPbFiles) {
const [base, head] = await Promise.all([
getVersionHeader(file.filename, baseSha),
getVersionHeader(file.filename, headSha),
]);
if (!base.ok || !head.ok) {
core.warning(
`Skipping ${file.filename}: base=${base.ok ? 'ok' : base.reason}, head=${head.ok ? 'ok' : head.reason}`
);
continue;
}
if (base.lines.join('\n') !== head.lines.join('\n')) {
for (const file of pbFiles) {
const changed = file.patch
.split('\n')
.filter(line => versionPattern.test(line));
if (changed.length > 0) {
violations.push({
file: file.filename,
base: base.lines,
head: head.lines,
lines: changed,
});
}
}
if (violations.length > 0) {
const details = violations.map(v =>
`${v.file}:\n` +
` base:\n${v.base.map(l => ' ' + l).join('\n') || ' (none)'}\n` +
` head:\n${v.head.map(l => ' ' + l).join('\n') || ' (none)'}`
`${v.file}:\n${v.lines.map(l => ' ' + l).join('\n')}`
).join('\n\n');
core.setFailed(

View File

@@ -15,7 +15,6 @@ If you haven't already, join our slack workspace [here](https://docs.netbird.io/
- [Contributing to NetBird](#contributing-to-netbird)
- [Contents](#contents)
- [Code of conduct](#code-of-conduct)
- [Discuss changes with the NetBird team first](#discuss-changes-with-the-netbird-team-first)
- [Directory structure](#directory-structure)
- [Development setup](#development-setup)
- [Requirements](#requirements)
@@ -34,14 +33,6 @@ Conduct which can be found in the file [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md).
By participating, you are expected to uphold this code. Please report
unacceptable behavior to community@netbird.io.
## Discuss changes with the NetBird team first
Changes to the **public API**, **gRPC protocols**, **functionality behavior**, **CLI / service flags**, or **new features** should be discussed with the NetBird team before you start the work. These surfaces are part of NetBird's contract with operators, self-hosters, and downstream integrators, and changes to them have compatibility, security, and release-planning implications that benefit from an early conversation.
Open an issue or reach out on [Slack](https://docs.netbird.io/slack-url) to talk through what you have in mind. We'll help shape the change, flag any constraints we know about, and confirm the direction so the PR review can focus on implementation rather than design.
Typical bug fixes, internal refactors, documentation updates, and tests do not need pre-discussion — open the PR directly.
## Directory structure
The NetBird project monorepo is organized to maintain most of its individual dependencies code within their directories, except for a few auxiliary or shared packages.

153
README.md
View File

@@ -1,134 +1,147 @@
<div align="center">
<p align="center">
<img width="234" src="docs/media/logo-full.png" alt="NetBird logo"/>
</p>
<p align="center">
<a href="https://sonarcloud.io/dashboard?id=netbirdio_netbird">
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" alt="SonarCloud alert status"/>
</a>
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
<img src="https://img.shields.io/badge/license-BSD--3-blue" alt="BSD-3 License"/>
</a>
<br/>
<br/>
<p align="center">
<img width="234" src="docs/media/logo-full.png"/>
</p>
<p>
<a href="https://img.shields.io/badge/license-BSD--3-blue)">
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" />
</a>
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a>
<br>
<a href="https://docs.netbird.io/slack-url">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack" alt="NetBird Slack"/>
</a>
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
<a href="https://forum.netbird.io">
<img src="https://img.shields.io/badge/community%20forum-@netbird-red.svg?logo=discourse" alt="Community forum"/>
</a>
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
</a>
<br>
<a href="https://gurubase.io/g/netbird">
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF" alt="Gurubase: Ask NetBird Guru"/>
</a>
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
</a>
</p>
</div>
<p align="center">
<strong>
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
<br/>
See <a href="https://netbird.io/docs/">Documentation</a>
<br/>
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
</strong>
<strong>
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
<br/>
See <a href="https://netbird.io/docs/">Documentation</a>
<br/>
<strong>
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
</strong>
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
<br/>
</strong>
<br>
<strong>
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
</strong>
<br>
<br>
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
New: NetBird terraform provider
</a>
</p>
<br>
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
### Open Source Network Security in a Single Platform
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
### Self-host NetBird (video)
### Self-Host NetBird (Video)
[![Watch the video](https://img.youtube.com/vi/bZAgpT6nzaQ/0.jpg)](https://youtu.be/bZAgpT6nzaQ)
### Key features
| Connectivity | Management | Security | Automation | Platforms |
|---|---|---|---|---|
| ✓ [Kernel WireGuard](https://docs.netbird.io/about-netbird/why-wireguard-with-netbird) | ✓ [Admin Web UI](https://github.com/netbirdio/dashboard) | ✓ [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) | ✓ [Public API](https://docs.netbird.io/api) | ✓ [Linux](https://docs.netbird.io/get-started/install/linux) |
| ✓ [Peer-to-peer connections](https://docs.netbird.io/about-netbird/how-netbird-works) | ✓ Auto peer discovery and configuration | ✓ [Access control: groups & rules](https://docs.netbird.io/how-to/manage-network-access) | ✓ [Setup keys for bulk provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) | ✓ [macOS](https://docs.netbird.io/get-started/install/macos) |
| Connection relay fallback | ✓ [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) | ✓ [Activity logging](https://docs.netbird.io/how-to/audit-events-logging) | ✓ [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) | ✓ [Windows](https://docs.netbird.io/get-started/install/windows) |
| [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) | ✓ [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) | ✓ [Traffic events](https://docs.netbird.io/manage/activity/traffic-events-logging) | ✓ [IdP groups sync with JWT](https://docs.netbird.io/manage/team/idp-sync) | ✓ [Android](https://docs.netbird.io/get-started/install/android) |
| ✓ [Domain-based DNS routes](https://docs.netbird.io/manage/dns/dns-aliases-for-routed-networks) | ✓ [Custom DNS zones](https://docs.netbird.io/manage/dns/custom-zones) | ✓ [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) | ✓ [Terraform provider](https://registry.terraform.io/providers/netbirdio/netbird/latest) | ✓ [Android TV](https://docs.netbird.io/get-started/install/android-tv) |
| ✓ [Exit nodes](https://docs.netbird.io/manage/network-routes/use-cases/exit-nodes) | ✓ [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) | ✓ Peer-to-peer encryption | ✓ [Ansible collection](https://github.com/netbirdio/ansible-netbird) | ✓ [iOS](https://docs.netbird.io/get-started/install/ios) |
| ✓ [IPv6 dual-stack overlay](https://docs.netbird.io/manage/settings/ipv6) | ✓ [Multi-account profile switching](https://docs.netbird.io/client/profiles) | ✓ [SSH with central access policies](https://docs.netbird.io/manage/peers/ssh) | | ✓ [Apple TV](https://docs.netbird.io/get-started/install/tvos) |
| ✓ [Browser SSH & RDP](https://docs.netbird.io/manage/peers/browser-client) | | ✓ [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) | | ✓ FreeBSD |
| ✓ [Reverse proxy with auto-TLS](https://docs.netbird.io/manage/reverse-proxy) | | ✓ [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication) | | ✓ [pfSense](https://docs.netbird.io/get-started/install/pfsense) |
| | | | | ✓ [OPNsense](https://docs.netbird.io/get-started/install/opnsense) |
| | | | | ✓ [MikroTik RouterOS](https://docs.netbird.io/use-cases/homelab/client-on-mikrotik-router) |
| | | | | ✓ OpenWRT |
| | | | | ✓ [Synology](https://docs.netbird.io/get-started/install/synology) |
| | | | | ✓ [TrueNAS](https://docs.netbird.io/get-started/install/truenas) |
| | | | | ✓ [Proxmox](https://docs.netbird.io/get-started/install/proxmox-ve) |
| | | | | ✓ [Raspberry Pi](https://docs.netbird.io/get-started/install/raspberrypi) |
| | | | | ✓ [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) |
| | | | | ✓ [Container](https://docs.netbird.io/get-started/install/docker) |
| Connectivity | Management | Security | Automation| Platforms |
|----|----|----|----|----|
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
| <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
| <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
| <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
||| <ul><li>- \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ui></li> || <ul><li>- \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)</ui></li> |
||||| <ul><li>- \[x] Docker</ui></li> |
### Quickstart with NetBird Cloud
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install).
- Follow the steps to sign up with Google, Microsoft, GitHub or your email address.
- Check the NetBird [admin UI](https://app.netbird.io/).
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
- Follow the steps to sign-up with Google, Microsoft, GitHub or your email address.
- Check NetBird [admin UI](https://app.netbird.io/).
- Add more machines.
### Quickstart with self-hosted NetBird
This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM. Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IdPs.
> This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM.
Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IDPs.
**Infrastructure requirements:**
- A Linux VM with at least **1 CPU** and **2 GB** of memory.
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port **3478**.
- A **public domain** name pointing to the VM.
- A Linux VM with at least **1CPU** and **2GB** of memory.
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port: **3478**.
- **Public domain** name pointing to the VM.
**Software requirements:**
- Docker with the Compose plugin (Compose v2 or higher). See the [Docker installation guide](https://docs.docker.com/engine/install/).
- Docker installed on the VM with the docker-compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
- [curl](https://curl.se/) installed.
Usually available in the official repositories and can be installed with `sudo apt install curl` or `sudo yum install curl`
**Steps**
- Download and run the installation script:
```bash
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started.sh | bash
```
- Once finished, you can manage the resources via `docker-compose`
### A bit on NetBird internals
- Every machine in the network runs the [NetBird agent](client/), which manages WireGuard.
- Every agent connects to the [Management Service](management/), which holds network state, manages peer IPs, and distributes updates to agents.
- Agents use ICE (via [pion/ice](https://github.com/pion/ice)) to discover connection candidates for peer-to-peer connections.
- Candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
- Agents negotiate a connection through the [Signal Service](signal/), exchanging end-to-end encrypted messages with candidates.
- When NAT traversal fails (e.g. mobile carrier-grade NAT) and a direct p2p connection isn't possible, the system falls back to a [Relay Service](relay/) and a secure WireGuard tunnel is established through it.
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
- NetBird agent uses WebRTC ICE implemented in [pion/ice library](https://github.com/pion/ice) to discover connection candidates when establishing a peer-to-peer connection between machines.
- Connection candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
- Agents negotiate a connection through [Signal Service](signal/) passing p2p encrypted messages with candidates.
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and a p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
<p float="left" align="middle">
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700" alt="NetBird high-level architecture diagram"/>
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700"/>
</p>
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
### Community projects
- [NetBird installer script](https://github.com/physk/netbird-installer)
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) - terminal UI for managing NetBird peers, routes, and settings
- [caddy-netbird](https://github.com/lixmal/caddy-netbird) - Caddy plugin that embeds a NetBird client for proxying HTTP and TCP/UDP traffic through NetBird networks
- [NetBird installer script](https://github.com/physk/netbird-installer)
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) — terminal UI for managing NetBird peers, routes, and settings
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
### Support acknowledgement
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by the Federal Ministry of Education and Research of the Federal Republic of Germany. Together with the [CISPA Helmholtz Center for Information Security](https://cispa.de/en), NetBird brings security best practices and simplicity to private networking.
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by The Federal Ministry of Education and Research of The Federal Republic of Germany. Together with [CISPA Helmholtz Center for Information Security](https://cispa.de/en) NetBird brings the security best practices and simplicity to private networking.
![CISPA_Logo_BLACK_EN_RZ_RGB (1)](https://user-images.githubusercontent.com/700848/203091324-c6d311a0-22b5-4b05-a288-91cbc6cdcc46.png)
### Acknowledgements
We build on open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE](https://github.com/pion/ice), and [Rosenpass](https://rosenpass.eu). We greatly appreciate the work these projects are doing, and we'd love it if you could support them too (e.g., by starring or contributing).
### Testimonials
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g., by giving a star or a contribution).
### Legal
This repository is licensed under the BSD-3-Clause license, which applies to all parts of the repository except for the directories management/, signal/ and relay/.
This repository is licensed under BSD-3-Clause license that applies to all parts of the repository except for the directories management/, signal/ and relay/.
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.

View File

@@ -12,13 +12,7 @@ var (
Short: "Print the NetBird's client application version",
Run: func(cmd *cobra.Command, args []string) {
cmd.SetOut(cmd.OutOrStdout())
out := version.NetbirdVersion()
if version.IsDevelopmentVersion(out) {
if commit := version.NetbirdCommit(); commit != "" {
out += "-" + commit
}
}
cmd.Println(out)
cmd.Println(version.NetbirdVersion())
},
}
)

View File

@@ -84,12 +84,6 @@ type Options struct {
DisableIPv6 bool
// BlockInbound blocks all inbound connections from peers
BlockInbound bool
// BlockLANAccess blocks the embedded peer from reaching the host's
// LAN (RFC 1918, link-local, loopback) when it's used as a routing
// peer. Mirrors profilemanager.ConfigInput.BlockLANAccess. Useful
// when the embedded client must never act as a stepping stone into
// the host's local network (e.g. the proxy's overlay peer).
BlockLANAccess bool
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
WireguardPort *int
// MTU is the MTU for the tunnel interface.
@@ -181,7 +175,6 @@ func New(opts Options) (*Client, error) {
DisableClientRoutes: &opts.DisableClientRoutes,
DisableIPv6: &opts.DisableIPv6,
BlockInbound: &opts.BlockInbound,
BlockLANAccess: &opts.BlockLANAccess,
WireguardPort: opts.WireguardPort,
MTU: opts.MTU,
DNSLabels: parsedLabels,
@@ -412,21 +405,6 @@ func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession,
}, nil
}
// IdentityForIP looks up a remote peer by its tunnel IP using the
// embedded client's status recorder. Returns the peer's WireGuard public
// key and FQDN. ok=false means the IP isn't in this client's peer
// roster — callers should treat that as "unknown peer".
func (c *Client) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
if !ip.IsValid() || c.recorder == nil {
return "", "", false
}
state, found := c.recorder.PeerStateByIP(ip.String())
if !found {
return "", "", false
}
return state.PubKey, state.FQDN, true
}
// Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock()

View File

@@ -52,10 +52,9 @@ func (m *externalChainMonitor) start() {
ctx, cancel := context.WithCancel(context.Background())
m.cancel = cancel
done := make(chan struct{})
m.done = done
m.done = make(chan struct{})
go m.run(ctx, done)
go m.run(ctx)
}
func (m *externalChainMonitor) stop() {
@@ -73,8 +72,8 @@ func (m *externalChainMonitor) stop() {
<-done
}
func (m *externalChainMonitor) run(ctx context.Context, done chan struct{}) {
defer close(done)
func (m *externalChainMonitor) run(ctx context.Context) {
defer close(m.done)
bo := &backoff.ExponentialBackOff{
InitialInterval: externalMonitorInitInterval,

View File

@@ -260,15 +260,23 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
; Create autostart registry entry based on checkbox
; Drop Run, App Paths and Uninstall entries left in the 32-bit registry view
; or HKCU by legacy installers.
DetailPrint "Cleaning legacy 32-bit / HKCU entries..."
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
SetRegView 32
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DeleteRegKey HKLM "${REG_APP_PATH}"
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
DeleteRegKey HKLM "${UNINSTALL_PATH}"
SetRegView 64
DetailPrint "Autostart enabled: $AutostartEnabled"
${If} $AutostartEnabled == "1"
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
${Else}
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DetailPrint "Autostart not enabled by user"
${EndIf}
@@ -299,11 +307,16 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
DetailPrint "Terminating Netbird UI process..."
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry
; Remove autostart entries from every view a previous installer may have used.
DetailPrint "Removing autostart registry entry if exists..."
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
SetRegView 32
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DeleteRegKey HKLM "${REG_APP_PATH}"
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
DeleteRegKey HKLM "${UNINSTALL_PATH}"
SetRegView 64
; Handle data deletion based on checkbox
DetailPrint "Checking if user requested data deletion..."

View File

@@ -339,7 +339,8 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
case entry.Pattern == ".":
return true
case entry.IsWildcard:
return strings.HasSuffix(qname, "."+entry.Pattern)
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
default:
// For non-wildcard patterns:
// If handler wants subdomain matching, allow suffix match

View File

@@ -164,54 +164,6 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
matchSubdomains: true,
shouldMatch: true,
},
{
name: "wildcard label-boundary mismatch (suffix overlap)",
handlerDomain: "*.b.test.",
queryDomain: "x.ab.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "wildcard label-boundary match",
handlerDomain: "*.b.test.",
queryDomain: "x.b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "wildcard multi-label match",
handlerDomain: "*.b.test.",
queryDomain: "x.y.b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "wildcard no match on multi-label apex",
handlerDomain: "*.b.test.",
queryDomain: "b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "wildcard no match on unrelated suffix containment",
handlerDomain: "*.example.com.",
queryDomain: "notexample.com.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "wildcard accepts pattern registered without trailing dot",
handlerDomain: "*.b.test",
queryDomain: "x.b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
}
for _, tt := range tests {
@@ -321,19 +273,6 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
expectedCalls: 1,
expectedHandler: 2, // highest priority matching handler should be called
},
{
name: "overlapping wildcard suffixes route to correct handler",
handlers: []struct {
pattern string
priority int
}{
{pattern: "*.b.test.", priority: nbdns.PriorityDNSRoute},
{pattern: "*.ab.test.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "app.ab.test.",
expectedCalls: 1,
expectedHandler: 1,
},
{
name: "root zone with specific domain",
handlers: []struct {

View File

@@ -26,19 +26,6 @@ type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
// PeerConnectivity reports whether a tunnel IP belongs to a peer the
// client knows about and whether that peer is currently connected. The
// local resolver uses this to suppress A/AAAA answers whose RDATA points
// at a disconnected peer (typical case: a synthesized private-service
// record pointing at an embedded proxy peer that just went offline).
//
// known=false means the IP isn't in the local peerstore at all — the
// record is left alone (it points at something outside our mesh, e.g.
// a non-peer upstream).
type PeerConnectivity interface {
IsConnectedByIP(ip string) (known, connected bool)
}
type Resolver struct {
mu sync.RWMutex
records map[dns.Question][]dns.RR
@@ -46,11 +33,6 @@ type Resolver struct {
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
zones map[domain.Domain]bool
resolver resolver
// peerConn, when non-nil, is consulted on every A/AAAA answer to
// drop records pointing at disconnected peers. nil disables the
// filter and preserves the legacy "return whatever is registered"
// behaviour for callers that never wire a status source.
peerConn PeerConnectivity
ctx context.Context
cancel context.CancelFunc
@@ -67,15 +49,6 @@ func NewResolver() *Resolver {
}
}
// SetPeerConnectivity wires the per-IP connectivity check used to filter
// out A/AAAA answers pointing at disconnected peers. Pass nil to disable.
// Safe to call multiple times; the latest value wins.
func (d *Resolver) SetPeerConnectivity(p PeerConnectivity) {
d.mu.Lock()
defer d.mu.Unlock()
d.peerConn = p
}
func (d *Resolver) MatchSubdomains() bool {
return true
}
@@ -122,7 +95,6 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
replyMessage.RecursionAvailable = true
result := d.lookupRecords(logger, question)
result.records = d.filterDisconnectedPeerAnswers(logger, question, result.records)
replyMessage.Authoritative = !result.hasExternalData
replyMessage.Answer = result.records
replyMessage.Rcode = d.determineRcode(question, result)
@@ -464,78 +436,6 @@ func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16,
}
}
// filterDisconnectedPeerAnswers drops A/AAAA records whose RDATA matches
// a known but disconnected peer. The synthesized private-service zones
// emit one A record per connected proxy peer in a cluster; when a peer
// goes offline, the server-side refresh removes the record from the
// next netmap, but the client may still hold the previous netmap for a
// short window. This filter is the local belt to that braces — even on
// the stale netmap, the resolver hides the offline target.
//
// Records pointing at unknown IPs (outside the local peerstore, e.g.
// non-mesh upstreams) are never dropped. Non-A/AAAA records pass
// through untouched.
//
// Escape hatch: if filtering would leave the answer empty AND at least
// one record was filtered, the original list is returned. Better to
// hand the client a record that may not respond than NXDOMAIN it
// completely when every proxy peer is offline (the upstream may still
// be reachable some other way, or the peerstore may be stale).
func (d *Resolver) filterDisconnectedPeerAnswers(logger *log.Entry, question dns.Question, records []dns.RR) []dns.RR {
if len(records) == 0 {
return records
}
d.mu.RLock()
checker := d.peerConn
d.mu.RUnlock()
if checker == nil {
return records
}
kept := make([]dns.RR, 0, len(records))
var dropped int
for _, rr := range records {
ip := extractRecordIP(rr)
if ip == "" {
kept = append(kept, rr)
continue
}
known, connected := checker.IsConnectedByIP(ip)
if known && !connected {
dropped++
continue
}
kept = append(kept, rr)
}
if dropped == 0 {
return records
}
if len(kept) == 0 {
logger.Debugf("all %d answers for %s point at disconnected peers; returning the original list", dropped, question.Name)
return records
}
logger.Tracef("dropped %d disconnected-peer answer(s) for %s, returning %d", dropped, question.Name, len(kept))
return kept
}
// extractRecordIP returns the dotted-decimal / colon-hex IP carried by
// an A or AAAA record, or "" for any other record type.
func extractRecordIP(rr dns.RR) string {
switch r := rr.(type) {
case *dns.A:
if r.A == nil {
return ""
}
return r.A.String()
case *dns.AAAA:
if r.AAAA == nil {
return ""
}
return r.AAAA.String()
}
return ""
}
// Update replaces all zones and their records
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
d.mu.Lock()

View File

@@ -30,21 +30,6 @@ func (m *mockResolver) LookupNetIP(ctx context.Context, network, host string) ([
return nil, nil
}
// mockPeerConnectivity returns canned (known, connected) results per IP.
// Used by the disconnected-peer filter tests below. IPs not in the map
// are reported as unknown so the filter leaves them alone.
type mockPeerConnectivity struct {
byIP map[string]struct{ known, connected bool }
}
func (m mockPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) {
v, ok := m.byIP[ip]
if !ok {
return false, false
}
return v.known, v.connected
}
func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.",
@@ -2667,114 +2652,3 @@ func BenchmarkIsInManagedZone_ManyZones(b *testing.B) {
resolver.isInManagedZone(qname)
}
}
// TestLocalResolver_FilterDisconnectedPeerAnswers verifies the
// connectivity-aware filtering layered on top of lookupRecords:
// when an A record's IP belongs to a known peer that's disconnected,
// the record is dropped from the answer. Records for unknown IPs pass
// through. If filtering would empty the answer entirely and at least
// one record was dropped, the original list is restored (escape hatch
// for the "all proxies offline" case).
func TestLocalResolver_FilterDisconnectedPeerAnswers(t *testing.T) {
zone := "svc.cluster.netbird."
connectedRec := nbdns.SimpleRecord{
Name: zone,
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 5,
RData: "100.64.0.10",
}
disconnectedRec := nbdns.SimpleRecord{
Name: zone,
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 5,
RData: "100.64.0.11",
}
unknownRec := nbdns.SimpleRecord{
Name: zone,
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 5,
RData: "203.0.113.5",
}
type ipState struct{ known, connected bool }
tests := []struct {
name string
records []nbdns.SimpleRecord
connByIP map[string]ipState
wantInOrder []string
}{
{
name: "drops disconnected peer, keeps connected",
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
connByIP: map[string]ipState{
"100.64.0.10": {known: true, connected: true},
"100.64.0.11": {known: true, connected: false},
},
wantInOrder: []string{"100.64.0.10"},
},
{
name: "unknown IPs pass through untouched",
records: []nbdns.SimpleRecord{unknownRec, disconnectedRec},
connByIP: map[string]ipState{
"100.64.0.11": {known: true, connected: false},
},
wantInOrder: []string{"203.0.113.5"},
},
{
name: "all disconnected falls back to original list",
records: []nbdns.SimpleRecord{disconnectedRec, connectedRec},
connByIP: map[string]ipState{
"100.64.0.10": {known: true, connected: false},
"100.64.0.11": {known: true, connected: false},
},
wantInOrder: []string{"100.64.0.11", "100.64.0.10"},
},
{
name: "no checker wired returns all records",
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
connByIP: nil,
wantInOrder: []string{"100.64.0.10", "100.64.0.11"},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
resolver := NewResolver()
if tc.connByIP != nil {
cm := mockPeerConnectivity{byIP: make(map[string]struct{ known, connected bool }, len(tc.connByIP))}
for ip, st := range tc.connByIP {
cm.byIP[ip] = struct{ known, connected bool }{st.known, st.connected}
}
resolver.SetPeerConnectivity(cm)
}
resolver.Update([]nbdns.CustomZone{{
Domain: strings.TrimSuffix(zone, "."),
Records: tc.records,
NonAuthoritative: true,
}})
var got *dns.Msg
writer := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
got = m
return nil
},
}
req := new(dns.Msg).SetQuestion(zone, dns.TypeA)
resolver.ServeDNS(writer, req)
require.NotNil(t, got, "resolver must produce a response")
require.Len(t, got.Answer, len(tc.wantInOrder),
"answer count must match expected: %v", tc.wantInOrder)
for i, want := range tc.wantInOrder {
a, ok := got.Answer[i].(*dns.A)
require.True(t, ok, "answer[%d] must be an A record", i)
assert.Equal(t, want, a.A.String(),
"answer[%d] expected %s got %s", i, want, a.A.String())
}
})
}
}

View File

@@ -301,11 +301,6 @@ func newDefaultServer(
warningDelayBase: defaultWarningDelayBase,
healthRefresh: make(chan struct{}, 1),
}
// Wire the local resolver against the peer status recorder so it can
// suppress A/AAAA answers that point at disconnected peers (typical
// case: synthesised private-service records pointing at an embedded
// proxy peer that just went offline).
defaultServer.localResolver.SetPeerConnectivity(localPeerConnectivity{statusRecorder})
// register with root zone, handler chain takes care of the routing
dnsService.RegisterMux(".", handlerChain)
@@ -1391,25 +1386,3 @@ func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error {
}
return nil
}
// localPeerConnectivity adapts *peer.Status to local.PeerConnectivity so
// the local resolver can ask "is this IP a known peer and is it
// connected?" without taking on the peer package as a dependency.
// A nil status recorder always reports known=false so the resolver
// short-circuits to the legacy "return everything" path.
type localPeerConnectivity struct {
status *peer.Status
}
// IsConnectedByIP looks the IP up in the peerstore and surfaces both
// the known and connected bits. Used by Resolver.filterDisconnectedPeerAnswers.
func (l localPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) {
if l.status == nil {
return false, false
}
state, ok := l.status.PeerStateByIP(ip)
if !ok {
return false, false
}
return true, state.ConnStatus == peer.StatusConnected
}

View File

@@ -4,8 +4,6 @@ import (
"strings"
"github.com/hashicorp/go-version"
nbversion "github.com/netbirdio/netbird/version"
)
var (
@@ -13,7 +11,7 @@ var (
)
func IsSupported(agentVersion string) bool {
if nbversion.IsDevelopmentVersion(agentVersion) {
if agentVersion == "development" {
return true
}

View File

@@ -50,7 +50,7 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
switch msg.Type {
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, flags, err := parseRouteMessage(buf[:n])
route, err := parseRouteMessage(buf[:n])
if err != nil {
log.Debugf("Network monitor: error parsing routing message: %v", err)
continue
@@ -66,10 +66,6 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
}
switch msg.Type {
case unix.RTM_ADD:
if systemops.IgnoreAddedDefaultRoute(flags) {
log.Debugf("Network monitor: ignoring added default route via %s, interface %s, flags %#x", route.Gw, intf, flags)
continue
}
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
return nil
case unix.RTM_DELETE:
@@ -82,26 +78,22 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
}
}
func parseRouteMessage(buf []byte) (*systemops.Route, int, error) {
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
if err != nil {
return nil, 0, fmt.Errorf("parse RIB: %v", err)
return nil, fmt.Errorf("parse RIB: %v", err)
}
if len(msgs) != 1 {
return nil, 0, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
}
msg, ok := msgs[0].(*route.RouteMessage)
if !ok {
return nil, 0, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
}
r, err := systemops.MsgToRoute(msg)
if err != nil {
return nil, 0, err
}
return r, msg.Flags, nil
return systemops.MsgToRoute(msg)
}
// waitReadable blocks until fd has data to read, or ctx is cancelled.

View File

@@ -185,12 +185,9 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
return s.eventsChan
}
// Status holds a state of peers, signal, management connections and relays.
// mux is an RWMutex so hot read paths (notably PeerStateByIP, called for
// every private-service request) don't contend against each other.
// Pure read methods take RLock; anything that mutates state takes Lock.
// Status holds a state of peers, signal, management connections and relays
type Status struct {
mux sync.RWMutex
mux sync.Mutex
peers map[string]State
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
signalState bool
@@ -286,8 +283,8 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string)
// GetPeer adds peer to Daemon status map
func (d *Status) GetPeer(peerPubKey string) (State, error) {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
state, ok := d.peers[peerPubKey]
if !ok {
@@ -297,8 +294,8 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
}
func (d *Status) PeerByIP(ip string) (string, bool) {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
for _, state := range d.peers {
if state.IP == ip {
@@ -308,25 +305,6 @@ func (d *Status) PeerByIP(ip string) (string, bool) {
return "", false
}
// PeerStateByIP returns the full peer State for the given tunnel IP.
// Matches against either the IPv4 (State.IP) or IPv6 (State.IPv6) tunnel
// address so dual-stack peers are reachable on either family. Returns the
// zero State and false when no peer matches or the input is empty.
func (d *Status) PeerStateByIP(ip string) (State, bool) {
if ip == "" {
return State{}, false
}
d.mux.RLock()
defer d.mux.RUnlock()
for _, state := range d.peers {
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
return state, true
}
}
return State{}, false
}
// RemovePeer removes peer from Daemon status map
func (d *Status) RemovePeer(peerPubKey string) error {
d.mux.Lock()
@@ -724,8 +702,8 @@ func (d *Status) UnsubscribePeerStateChanges(subscription *StatusChangeSubscript
// GetLocalPeerState returns the local peer state
func (d *Status) GetLocalPeerState() LocalPeerState {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
return d.localPeer.Clone()
}
@@ -931,8 +909,8 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
}
func (d *Status) GetRosenpassState() RosenpassState {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
return RosenpassState{
d.rosenpassEnabled,
d.rosenpassPermissive,
@@ -940,14 +918,14 @@ func (d *Status) GetRosenpassState() RosenpassState {
}
func (d *Status) GetLazyConnection() bool {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
return d.lazyConnectionEnabled
}
func (d *Status) GetManagementState() ManagementState {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
return ManagementState{
d.mgmAddress,
d.managementState,
@@ -973,8 +951,8 @@ func (d *Status) UpdateLatency(pubKey string, latency time.Duration) error {
// IsLoginRequired determines if a peer's login has expired.
func (d *Status) IsLoginRequired() bool {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
// if peer is connected to the management then login is not expired
if d.managementState {
@@ -989,8 +967,8 @@ func (d *Status) IsLoginRequired() bool {
}
func (d *Status) GetSignalState() SignalState {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
return SignalState{
d.signalAddress,
d.signalState,
@@ -1000,8 +978,8 @@ func (d *Status) GetSignalState() SignalState {
// GetRelayStates returns the stun/turn/permanent relay states
func (d *Status) GetRelayStates() []relay.ProbeResult {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
if d.relayMgr == nil {
return d.relayStates
}
@@ -1030,8 +1008,8 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
}
func (d *Status) ForwardingRules() []firewall.ForwardRule {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
if d.ingressGwMgr == nil {
return nil
}
@@ -1040,16 +1018,16 @@ func (d *Status) ForwardingRules() []firewall.ForwardRule {
}
func (d *Status) GetDNSStates() []NSGroupState {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
// shallow copy is good enough, as slices fields are currently not updated
return slices.Clone(d.nsGroupStates)
}
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
return maps.Clone(d.resolvedDomainsStates)
}
@@ -1065,8 +1043,8 @@ func (d *Status) GetFullStatus() FullStatus {
LazyConnectionEnabled: d.GetLazyConnection(),
}
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
fullStatus.LocalPeerState = d.localPeer
@@ -1241,8 +1219,8 @@ func (d *Status) SetWgIface(wgInterface WGIfaceStatus) {
}
func (d *Status) PeersStatus() (*configurer.Stats, error) {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
if d.wgIface == nil {
return nil, fmt.Errorf("wgInterface is nil, cannot retrieve peers status")
}

View File

@@ -63,33 +63,6 @@ func TestUpdatePeerState(t *testing.T) {
assert.Equal(t, ip, state.IP, "ip should be equal")
}
func TestStatus_PeerStateByIP(t *testing.T) {
status := NewRecorder("https://mgm")
req := require.New(t)
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", ""))
req.NoError(status.AddPeer("pk-2", "peer-2.netbird", "100.64.0.11", ""))
state, ok := status.PeerStateByIP("100.64.0.10")
req.True(ok, "known tunnel IP should resolve to a peer state")
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
req.Equal("peer-1.netbird", state.FQDN, "matching state must carry the right FQDN")
_, ok = status.PeerStateByIP("100.64.0.99")
req.False(ok, "unknown IP must report ok=false")
}
func TestStatus_PeerStateByIP_MatchesIPv6(t *testing.T) {
status := NewRecorder("https://mgm")
req := require.New(t)
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", "fd00::1"))
state, ok := status.PeerStateByIP("fd00::1")
req.True(ok, "IPv6-only match must resolve to the peer state")
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
}
func TestStatus_UpdatePeerFQDN(t *testing.T) {
key := "abc"
fqdn := "peer-a.netbird.local"

View File

@@ -1,9 +0,0 @@
//go:build dragonfly || freebsd || netbsd || openbsd
package systemops
// IgnoreAddedDefaultRoute reports whether an RTM_ADD default route with the
// given flags should be ignored by the network monitor.
func IgnoreAddedDefaultRoute(flags int) bool {
return filterRoutesByFlags(flags)
}

View File

@@ -1,21 +0,0 @@
//go:build darwin
package systemops
import "golang.org/x/sys/unix"
// IgnoreAddedDefaultRoute reports whether an RTM_ADD default route with the
// given flags should be ignored by the network monitor. Scoped routes
// (RTF_IFSCOPE) are tied to a specific interface index and cannot replace the
// unscoped default the kernel uses for general egress, so flapping ones (e.g.
// Wi-Fi calling IMS tunnels on ipsec0, Docker bridges, scoped utun defaults)
// must not trigger an engine restart.
func IgnoreAddedDefaultRoute(flags int) bool {
if filterRoutesByFlags(flags) {
return true
}
if flags&unix.RTF_IFSCOPE != 0 {
return true
}
return false
}

View File

@@ -188,9 +188,7 @@ func (d *Detector) triggerCallback(event EventType, cb func(event EventType), do
}
doneChan := make(chan struct{})
// macOS forces sleep ~30s after kIOMessageSystemWillSleep, so block long
// enough for teardown to finish while staying under that deadline.
timeout := time.NewTimer(20 * time.Second)
timeout := time.NewTimer(500 * time.Millisecond)
defer timeout.Stop()
go func() {

View File

@@ -19,6 +19,8 @@ import (
const (
latestVersion = "latest"
// this version will be ignored
developmentVersion = "development"
)
var errNoUpdateState = errors.New("no update state found")
@@ -481,7 +483,7 @@ func (m *Manager) loadAndDeleteUpdateState(ctx context.Context) (*UpdateState, e
}
func (m *Manager) shouldUpdate(updateVersion *v.Version, forceUpdate bool) bool {
if version.IsDevelopmentVersion(m.currentVersion) {
if m.currentVersion == developmentVersion {
log.Debugf("skipping auto-update, running development version")
return false
}

View File

@@ -64,6 +64,13 @@
<RegistryValue Name="InstalledByMSI" Type="integer" Value="1" KeyPath="yes" />
</RegistryKey>
</Component>
<!-- Drop the HKCU Run\Netbird value written by legacy NSIS installers. -->
<Component Id="NetbirdLegacyHKCUCleanup" Guid="*">
<RegistryValue Root="HKCU" Key="Software\NetBird GmbH\Installer"
Name="LegacyHKCUCleanup" Type="integer" Value="1" KeyPath="yes" />
<RemoveRegistryValue Root="HKCU"
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
</Component>
</StandardDirectory>
<StandardDirectory Id="CommonAppDataFolder">
@@ -76,10 +83,28 @@
</Directory>
</StandardDirectory>
<!-- Drop Run, App Paths and Uninstall entries written by legacy NSIS
installers into the 32-bit registry view (HKLM\Software\Wow6432Node). -->
<Component Id="NetbirdLegacyWow6432Cleanup" Directory="NetbirdInstallDir"
Guid="bda5d628-16bd-4086-b2c1-5099d8d51763" Bitness="always32">
<RegistryValue Root="HKLM" Key="Software\NetBird GmbH\Installer"
Name="LegacyWow6432Cleanup" Type="integer" Value="1" KeyPath="yes" />
<RemoveRegistryValue Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird-ui" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\Uninstall\Netbird" />
</Component>
<ComponentGroup Id="NetbirdFilesComponent">
<ComponentRef Id="NetbirdFiles" />
<ComponentRef Id="NetbirdAumidRegistry" />
<ComponentRef Id="NetbirdAutoStart" />
<ComponentRef Id="NetbirdLegacyHKCUCleanup" />
<ComponentRef Id="NetbirdLegacyWow6432Cleanup" />
</ComponentGroup>
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />

View File

@@ -3,14 +3,15 @@
package system
import (
"bytes"
"context"
"os"
"os/exec"
"regexp"
"runtime"
"strings"
"time"
"golang.org/x/sys/unix"
log "github.com/sirupsen/logrus"
"github.com/zcalusic/sysinfo"
@@ -28,11 +29,19 @@ func UpdateStaticInfoAsync() {
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
kernelName, kernelVersion, kernelPlatform := kernelInfo()
info := _getInfo()
for strings.Contains(info, "broken pipe") {
info = _getInfo()
time.Sleep(500 * time.Millisecond)
}
osStr := strings.ReplaceAll(info, "\n", "")
osStr = strings.ReplaceAll(osStr, "\r\n", "")
osInfo := strings.Split(osStr, " ")
osName, osVersion := readOsReleaseFile()
if osName == "" {
osName = kernelName
osName = osInfo[3]
}
systemHostname, _ := os.Hostname()
@@ -49,8 +58,8 @@ func GetInfo(ctx context.Context) *Info {
}
gio := &Info{
Kernel: kernelName,
Platform: kernelPlatform,
Kernel: osInfo[0],
Platform: osInfo[2],
OS: osName,
OSVersion: osVersion,
Hostname: extractDeviceName(ctx, systemHostname),
@@ -58,7 +67,7 @@ func GetInfo(ctx context.Context) *Info {
CPUs: runtime.NumCPU(),
NetbirdVersion: version.NetbirdVersion(),
UIVersion: extractUserAgent(ctx),
KernelVersion: kernelVersion,
KernelVersion: osInfo[1],
NetworkAddresses: addrs,
SystemSerialNumber: si.SystemSerialNumber,
SystemProductName: si.SystemProductName,
@@ -69,12 +78,18 @@ func GetInfo(ctx context.Context) *Info {
return gio
}
func kernelInfo() (string, string, string) {
var uts unix.Utsname
if err := unix.Uname(&uts); err != nil {
return "", "", ""
func _getInfo() string {
cmd := exec.Command("uname", "-srio")
cmd.Stdin = strings.NewReader("some")
var out bytes.Buffer
var stderr bytes.Buffer
cmd.Stdout = &out
cmd.Stderr = &stderr
err := cmd.Run()
if err != nil {
log.Warnf("getInfo: %s", err)
}
return unix.ByteSliceToString(uts.Sysname[:]), unix.ByteSliceToString(uts.Release[:]), unix.ByteSliceToString(uts.Machine[:])
return out.String()
}
func sysInfo() (string, string, string) {

View File

@@ -32,7 +32,6 @@ import (
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
type Controller struct {
@@ -511,7 +510,7 @@ func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
for _, peer := range peers {
// Development version is always supported
if version.IsDevelopmentVersion(peer.Meta.WtVersion) {
if peer.Meta.WtVersion == "development" {
continue
}
peerVersion := semver.Canonical("v" + peer.Meta.WtVersion)

View File

@@ -2,6 +2,7 @@ package manager
import (
"context"
"math/rand"
"sync"
"time"
@@ -11,44 +12,76 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
)
const (
// cleanupWindow is the time window to wait after nearest peer deadline to start the cleanup procedure.
// cleanupWindow is the small grace period added on top of the
// staleness horizon before a sweep fires. It absorbs minor clock
// skew between the management server and the database and avoids
// firing a sweep right at the boundary where last_seen could still
// be one tick under the threshold.
cleanupWindow = 1 * time.Minute
// initialLoadMinDelay and initialLoadMaxDelay bracket the random
// delay applied before the post-restart catch-up query runs. Spread
// across replicas this prevents a thundering herd of catch-up
// queries hitting the database simultaneously after a deploy.
initialLoadMinDelay = 8 * time.Minute
initialLoadMaxDelay = 10 * time.Minute
)
var (
timeNow = time.Now
)
type ephemeralPeer struct {
id string
accountID string
deadline time.Time
next *ephemeralPeer
// accountEntry is the per-account state held by the cleanup tracker.
// We don't track which peers are pending — the sweep query gets the
// authoritative list straight from the database every time. We only
// need to know the latest disconnect we've observed for this account
// (so we can decide when it's safe to drop the entry) and the timer
// that will fire the next sweep.
type accountEntry struct {
lastDisconnectedAt time.Time
timer *time.Timer
}
// todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it
// in worst case we will get invalid error message in this manager.
// EphemeralManager keep a list of ephemeral peers. After EphemeralLifeTime inactivity the peer will be deleted
// automatically. Inactivity means the peer disconnected from the Management server.
// EphemeralManager tracks accounts that may have ephemeral peers in
// need of cleanup and runs a per-account sweep at the appropriate
// time. State is in-memory and account-scoped: a sweep deletes any
// ephemeral peer in the account that has been disconnected for at
// least lifeTime, then either drops the account from the tracker
// (when no recent disconnects have arrived) or re-arms the timer.
type EphemeralManager struct {
store store.Store
peersManager peers.Manager
headPeer *ephemeralPeer
tailPeer *ephemeralPeer
peersLock sync.Mutex
timer *time.Timer
accountsLock sync.Mutex
accounts map[string]*accountEntry
// initialLoadTimer is the one-shot timer used to defer the
// post-restart catch-up query; held so Stop() can cancel it.
initialLoadTimer *time.Timer
// stopped is flipped by Stop() so any timer that fires after
// teardown becomes a no-op instead of touching a half-dismantled
// store.
stopped bool
lifeTime time.Duration
cleanupWindow time.Duration
// initialLoadDelay returns the wall-clock delay to wait before
// running the post-restart catch-up query. Pluggable so tests can
// fire the load immediately.
initialLoadDelay func() time.Duration
// bgCtx is the long-lived context captured at LoadInitialPeers
// time. Timer-driven sweeps use it because they fire long after
// the original gRPC handler ctx that produced an OnPeerDisconnected
// call has been cancelled.
bgCtx context.Context
// metrics is nil-safe; methods on telemetry.EphemeralPeersMetrics
// no-op when the receiver is nil so deployments without an app
// metrics provider work unchanged.
@@ -58,228 +91,265 @@ type EphemeralManager struct {
// NewEphemeralManager instantiate new EphemeralManager
func NewEphemeralManager(store store.Store, peersManager peers.Manager) *EphemeralManager {
return &EphemeralManager{
store: store,
peersManager: peersManager,
lifeTime: ephemeral.EphemeralLifeTime,
cleanupWindow: cleanupWindow,
store: store,
peersManager: peersManager,
accounts: make(map[string]*accountEntry),
lifeTime: ephemeral.EphemeralLifeTime,
cleanupWindow: cleanupWindow,
initialLoadDelay: defaultInitialLoadDelay,
}
}
// SetMetrics attaches a metrics collector. Safe to call once before
// LoadInitialPeers; later attachment is fine but earlier loads won't be
// reflected in the gauge. Pass nil to detach.
// SetMetrics attaches a metrics collector. Pass nil to detach.
func (e *EphemeralManager) SetMetrics(m *telemetry.EphemeralPeersMetrics) {
e.peersLock.Lock()
e.accountsLock.Lock()
e.metrics = m
e.peersLock.Unlock()
e.accountsLock.Unlock()
}
// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head
// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new
// head.
// LoadInitialPeers schedules the post-restart catch-up query for a
// random moment 8-10 minutes from now. Returns immediately. The
// catch-up populates the per-account tracker from the database so any
// peers that disconnected before the restart still get cleaned up.
//
// The random delay is critical: without it, every management replica
// hitting the same Postgres instance after a deploy would issue the
// catch-up query simultaneously.
func (e *EphemeralManager) LoadInitialPeers(ctx context.Context) {
e.peersLock.Lock()
defer e.peersLock.Unlock()
e.loadEphemeralPeers(ctx)
if e.headPeer != nil {
e.timer = time.AfterFunc(e.lifeTime, func() {
e.cleanup(ctx)
})
}
}
// Stop timer
func (e *EphemeralManager) Stop() {
e.peersLock.Lock()
defer e.peersLock.Unlock()
if e.timer != nil {
e.timer.Stop()
}
}
// OnPeerConnected remove the peer from the linked list of ephemeral peers. Because it has been called when the peer
// is active the manager will not delete it while it is active.
func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Peer) {
if !peer.Ephemeral {
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
if e.stopped {
return
}
log.WithContext(ctx).Tracef("remove peer from ephemeral list: %s", peer.ID)
e.bgCtx = ctx
e.peersLock.Lock()
defer e.peersLock.Unlock()
if e.removePeer(peer.ID) {
e.metrics.DecPending(1)
}
// stop the unnecessary timer
if e.headPeer == nil && e.timer != nil {
e.timer.Stop()
e.timer = nil
}
delay := e.initialLoadDelay()
log.WithContext(ctx).Infof("ephemeral peer initial load scheduled in %s", delay)
e.initialLoadTimer = time.AfterFunc(delay, func() {
e.loadInitialAccounts(e.bgCtx)
})
}
// OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer
// is inactive it will be deleted after the EphemeralLifeTime period.
// Stop cancels the deferred initial load and any per-account timers.
func (e *EphemeralManager) Stop() {
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
e.stopped = true
if e.initialLoadTimer != nil {
e.initialLoadTimer.Stop()
e.initialLoadTimer = nil
}
for _, entry := range e.accounts {
if entry.timer != nil {
entry.timer.Stop()
}
}
e.accounts = make(map[string]*accountEntry)
}
// OnPeerConnected is a no-op in the account-scoped design. The sweep
// query filters out connected peers at the database level, so we don't
// need an explicit "remove from list" signal when a peer reconnects.
// Kept on the interface to preserve the existing call sites.
func (e *EphemeralManager) OnPeerConnected(_ context.Context, _ *nbpeer.Peer) {
}
// OnPeerDisconnected registers a disconnect for the peer's account and
// arms a sweep if one isn't already scheduled. Non-ephemeral peers are
// ignored.
func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) {
if !peer.Ephemeral {
return
}
log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID)
e.peersLock.Lock()
defer e.peersLock.Unlock()
if e.isPeerOnList(peer.ID) {
return
}
e.addPeer(peer.AccountID, peer.ID, e.newDeadLine())
e.metrics.IncPending()
if e.timer == nil {
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
if delay < 0 {
delay = 0
}
e.timer = time.AfterFunc(delay, func() {
e.cleanup(ctx)
})
}
}
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthNone)
if err != nil {
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
return
}
t := e.newDeadLine()
for _, p := range peers {
e.addPeer(p.AccountID, p.ID, t)
}
e.metrics.AddPending(int64(len(peers)))
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", len(peers))
}
func (e *EphemeralManager) cleanup(ctx context.Context) {
log.Tracef("on ephemeral cleanup")
deletePeers := make(map[string]*ephemeralPeer)
e.peersLock.Lock()
now := timeNow()
for p := e.headPeer; p != nil; p = p.next {
if now.Before(p.deadline) {
break
}
deletePeers[p.id] = p
e.headPeer = p.next
if p.next == nil {
e.tailPeer = nil
}
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
if e.stopped {
return
}
if e.headPeer != nil {
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
if delay < 0 {
delay = 0
}
e.timer = time.AfterFunc(delay, func() {
e.cleanup(ctx)
entry, existed := e.accounts[peer.AccountID]
if !existed {
entry = &accountEntry{}
e.accounts[peer.AccountID] = entry
e.metrics.IncPending()
}
entry.lastDisconnectedAt = now
if entry.timer == nil {
delay := e.lifeTime + e.cleanupWindow
log.WithContext(ctx).Tracef("ephemeral: scheduling sweep for account %s in %s", peer.AccountID, delay)
accountID := peer.AccountID
entry.timer = time.AfterFunc(delay, func() {
e.sweep(e.bgCtxOrFallback(ctx), accountID)
})
} else {
e.timer = nil
}
}
// bgCtxOrFallback returns the long-lived background context captured at
// LoadInitialPeers time, falling back to the supplied ctx when the
// manager hasn't been started through LoadInitialPeers (e.g. in tests
// that drive the manager directly). Must be called with the lock held
// or before the timer is armed.
func (e *EphemeralManager) bgCtxOrFallback(ctx context.Context) context.Context {
if e.bgCtx != nil {
return e.bgCtx
}
return ctx
}
// loadInitialAccounts runs the post-restart catch-up query and seeds
// the tracker with one entry per account that has at least one
// disconnected ephemeral peer.
func (e *EphemeralManager) loadInitialAccounts(ctx context.Context) {
accounts, err := e.store.GetEphemeralAccountsLastDisconnect(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to load ephemeral accounts on startup: %v", err)
return
}
e.peersLock.Unlock()
now := timeNow()
added := 0
// Drop the gauge by the number of entries we just took off the list,
// regardless of whether the subsequent DeletePeers call succeeds. The
// list invariant is what the gauge tracks; failed delete batches are
// counted separately via CountCleanupError so we can still see them.
if len(deletePeers) > 0 {
e.metrics.CountCleanupRun()
e.metrics.DecPending(int64(len(deletePeers)))
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
if e.stopped {
return
}
peerIDsPerAccount := make(map[string][]string)
for id, p := range deletePeers {
peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id)
}
for accountID, peerIDs := range peerIDsPerAccount {
log.WithContext(ctx).Tracef("cleanup: deleting %d ephemeral peers for account %s", len(peerIDs), accountID)
err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete ephemeral peers: %s", err)
e.metrics.CountCleanupError()
for accountID, lastDisc := range accounts {
// If we already learned about this account via an
// OnPeerDisconnected that arrived during the random delay
// window, prefer the live timestamp.
if _, alreadyTracked := e.accounts[accountID]; alreadyTracked {
continue
}
e.metrics.CountPeersCleaned(int64(len(peerIDs)))
}
}
func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) {
ep := &ephemeralPeer{
id: peerID,
accountID: accountID,
deadline: deadline,
}
entry := &accountEntry{lastDisconnectedAt: lastDisc}
horizon := lastDisc.Add(e.lifeTime)
if e.headPeer == nil {
e.headPeer = ep
}
if e.tailPeer != nil {
e.tailPeer.next = ep
}
e.tailPeer = ep
}
// removePeer drops the entry from the linked list. Returns true if a
// matching entry was found and removed so callers can keep the pending
// metric gauge in sync.
func (e *EphemeralManager) removePeer(id string) bool {
if e.headPeer == nil {
return false
}
if e.headPeer.id == id {
e.headPeer = e.headPeer.next
if e.tailPeer.id == id {
e.tailPeer = nil
var delay time.Duration
if horizon.After(now) {
delay = horizon.Sub(now) + e.cleanupWindow
} else {
// Already past the staleness window — sweep right away
// (one cleanupWindow later, to keep startup load smooth
// when many accounts qualify at once).
delay = e.cleanupWindow
}
return true
idForClosure := accountID
entry.timer = time.AfterFunc(delay, func() {
e.sweep(ctx, idForClosure)
})
e.accounts[accountID] = entry
added++
}
for p := e.headPeer; p.next != nil; p = p.next {
if p.next.id == id {
// if we remove the last element from the chain then set the last-1 as tail
if e.tailPeer.id == id {
e.tailPeer = p
}
p.next = p.next.next
return true
e.metrics.AddPending(int64(added))
log.WithContext(ctx).Debugf("ephemeral: loaded %d account(s) for cleanup tracking", added)
}
// sweep runs the cleanup pass for a single account. It queries the
// database for disconnected ephemeral peers that have crossed the
// staleness window, deletes them via peers.Manager, and then decides
// whether to drop the account from the tracker or re-arm the timer.
func (e *EphemeralManager) sweep(ctx context.Context, accountID string) {
now := timeNow()
e.accountsLock.Lock()
entry, ok := e.accounts[accountID]
if !ok || e.stopped {
e.accountsLock.Unlock()
return
}
lastDisc := entry.lastDisconnectedAt
entry.timer = nil
e.accountsLock.Unlock()
threshold := now.Add(-e.lifeTime)
stalePeerIDs, err := e.store.GetStaleEphemeralPeerIDsForAccount(ctx, accountID, threshold)
if err != nil {
log.WithContext(ctx).Errorf("ephemeral: failed to query stale peers for account %s: %v", accountID, err)
e.metrics.CountCleanupError()
e.rearm(ctx, accountID, e.cleanupWindow)
return
}
if len(stalePeerIDs) > 0 {
log.WithContext(ctx).Tracef("ephemeral: deleting %d peer(s) for account %s", len(stalePeerIDs), accountID)
if err := e.peersManager.DeletePeers(ctx, accountID, stalePeerIDs, activity.SystemInitiator, true); err != nil {
log.WithContext(ctx).Errorf("ephemeral: failed to delete peers for account %s: %v", accountID, err)
e.metrics.CountCleanupError()
e.rearm(ctx, accountID, e.cleanupWindow)
return
}
e.metrics.CountCleanupRun()
e.metrics.CountPeersCleaned(int64(len(stalePeerIDs)))
}
return false
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
if e.stopped {
return
}
entry, ok = e.accounts[accountID]
if !ok {
return
}
// Drop rule: if every disconnect we've observed has now crossed
// the staleness window, the sweep we just ran saw everything that
// could possibly need cleaning. Dropping is safe — a future
// disconnect will recreate the entry. The check uses the latest
// lastDisc, which may have advanced (concurrently with the sweep
// itself) due to a new OnPeerDisconnected, in which case we
// correctly re-arm.
horizon := entry.lastDisconnectedAt.Add(e.lifeTime)
if !horizon.After(now) {
delete(e.accounts, accountID)
e.metrics.DecPending(1)
log.WithContext(ctx).Tracef("ephemeral: dropping account %s (lastDisc=%s, horizon=%s, now=%s)",
accountID, lastDisc, horizon, now)
return
}
delay := horizon.Sub(now) + e.cleanupWindow
idForClosure := accountID
entry.timer = time.AfterFunc(delay, func() {
e.sweep(ctx, idForClosure)
})
}
func (e *EphemeralManager) isPeerOnList(id string) bool {
for p := e.headPeer; p != nil; p = p.next {
if p.id == id {
return true
}
// rearm reschedules a sweep `delay` from now. Used after a recoverable
// error in the sweep path so the account doesn't get stuck.
func (e *EphemeralManager) rearm(ctx context.Context, accountID string, delay time.Duration) {
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
if e.stopped {
return
}
return false
entry, ok := e.accounts[accountID]
if !ok {
return
}
idForClosure := accountID
entry.timer = time.AfterFunc(delay, func() {
e.sweep(ctx, idForClosure)
})
}
func (e *EphemeralManager) newDeadLine() time.Time {
return timeNow().Add(e.lifeTime)
// defaultInitialLoadDelay returns a random duration in
// [initialLoadMinDelay, initialLoadMaxDelay). Process-wide
// math/rand is acceptable here — the delay is purely a smoothing
// jitter, not a security primitive.
func defaultInitialLoadDelay() time.Duration {
span := int64(initialLoadMaxDelay - initialLoadMinDelay)
if span <= 0 {
return initialLoadMinDelay
}
return initialLoadMinDelay + time.Duration(rand.Int63n(span))
}

View File

@@ -2,299 +2,544 @@ package manager
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/golang/mock/gomock"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
nbAccount "github.com/netbirdio/netbird/management/server/account"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
// MockStore is a thin in-memory stand-in that implements only the two
// methods the EphemeralManager uses. It honors the account / ephemeral
// / connected / lastSeen attributes of each peer so the cleanup logic
// can be exercised end-to-end without bringing up sqlite or Postgres.
type MockStore struct {
store.Store
mu sync.Mutex
account *types.Account
}
func (s *MockStore) GetAllEphemeralPeers(_ context.Context, _ store.LockingStrength) ([]*nbpeer.Peer, error) {
var peers []*nbpeer.Peer
for _, v := range s.account.Peers {
if v.Ephemeral {
peers = append(peers, v)
func (s *MockStore) GetStaleEphemeralPeerIDsForAccount(_ context.Context, accountID string, olderThan time.Time) ([]string, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.account == nil || s.account.Id != accountID {
return nil, nil
}
var ids []string
for _, p := range s.account.Peers {
if !p.Ephemeral {
continue
}
if p.Status == nil || p.Status.Connected {
continue
}
if p.Status.LastSeen.Before(olderThan) {
ids = append(ids, p.ID)
}
}
return peers, nil
return ids, nil
}
type MockAccountManager struct {
mu sync.Mutex
nbAccount.Manager
store *MockStore
deletePeerCalls int
bufferUpdateCalls map[string]int
wg *sync.WaitGroup
}
func (a *MockAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error {
a.mu.Lock()
defer a.mu.Unlock()
a.deletePeerCalls++
delete(a.store.account.Peers, peerID)
if a.wg != nil {
a.wg.Done()
func (s *MockStore) GetEphemeralAccountsLastDisconnect(_ context.Context) (map[string]time.Time, error) {
s.mu.Lock()
defer s.mu.Unlock()
out := map[string]time.Time{}
if s.account == nil {
return out, nil
}
return nil
}
func (a *MockAccountManager) GetDeletePeerCalls() int {
a.mu.Lock()
defer a.mu.Unlock()
return a.deletePeerCalls
}
func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
a.mu.Lock()
defer a.mu.Unlock()
if a.bufferUpdateCalls == nil {
a.bufferUpdateCalls = make(map[string]int)
var latest time.Time
hasAny := false
for _, p := range s.account.Peers {
if !p.Ephemeral || p.Status == nil || p.Status.Connected {
continue
}
if !hasAny || p.Status.LastSeen.After(latest) {
latest = p.Status.LastSeen
hasAny = true
}
}
a.bufferUpdateCalls[accountID]++
}
func (a *MockAccountManager) GetBufferUpdateCalls(accountID string) int {
a.mu.Lock()
defer a.mu.Unlock()
if a.bufferUpdateCalls == nil {
return 0
if hasAny {
out[s.account.Id] = latest
}
return a.bufferUpdateCalls[accountID]
return out, nil
}
func (a *MockAccountManager) GetStore() store.Store {
return a.store
}
func TestNewManager(t *testing.T) {
t.Cleanup(func() {
timeNow = time.Now
})
startTime := time.Now()
// withFakeClock pins timeNow to a settable value for the duration of t.
// Returns a getter and a setter so subtests can advance virtual time.
func withFakeClock(t *testing.T, start time.Time) (get func() time.Time, set func(time.Time)) {
t.Helper()
var mu sync.Mutex
now := start
timeNow = func() time.Time {
return startTime
mu.Lock()
defer mu.Unlock()
return now
}
t.Cleanup(func() { timeNow = time.Now })
store := &MockStore{}
ctrl := gomock.NewController(t)
peersManager := peers.NewMockManager(ctrl)
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
// Expect DeletePeers to be called for ephemeral peers
peersManager.EXPECT().
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
for _, peerID := range peerIDs {
delete(store.account.Peers, peerID)
}
return nil
}).
AnyTimes()
mgr := NewEphemeralManager(store, peersManager)
mgr.loadEphemeralPeers(context.Background())
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
mgr.cleanup(context.Background())
if len(store.account.Peers) != numberOfPeers {
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers))
}
return func() time.Time {
mu.Lock()
defer mu.Unlock()
return now
}, func(v time.Time) {
mu.Lock()
defer mu.Unlock()
now = v
}
}
func TestNewManagerPeerConnected(t *testing.T) {
t.Cleanup(func() {
timeNow = time.Now
})
startTime := time.Now()
timeNow = func() time.Time {
return startTime
}
store := &MockStore{}
ctrl := gomock.NewController(t)
peersManager := peers.NewMockManager(ctrl)
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
// Expect DeletePeers to be called for ephemeral peers (except the connected one)
peersManager.EXPECT().
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
for _, peerID := range peerIDs {
delete(store.account.Peers, peerID)
}
return nil
}).
AnyTimes()
mgr := NewEphemeralManager(store, peersManager)
mgr.loadEphemeralPeers(context.Background())
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
mgr.cleanup(context.Background())
expected := numberOfPeers + 1
if len(store.account.Peers) != expected {
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers))
}
// newManagerForTest builds a manager with short timers and no random
// initial-load delay so tests run instantly.
func newManagerForTest(t *testing.T, st store.Store, peersMgr peers.Manager) *EphemeralManager {
t.Helper()
mgr := NewEphemeralManager(st, peersMgr)
mgr.lifeTime = 100 * time.Millisecond
mgr.cleanupWindow = 10 * time.Millisecond
mgr.initialLoadDelay = func() time.Duration { return 0 }
t.Cleanup(mgr.Stop)
return mgr
}
func TestNewManagerPeerDisconnected(t *testing.T) {
t.Cleanup(func() {
timeNow = time.Now
})
startTime := time.Now()
timeNow = func() time.Time {
return startTime
}
// TestOnPeerDisconnected_RegistersAndSweeps drives the OnPeerDisconnected
// path with a fake clock: a single ephemeral peer disconnects, we
// advance past the staleness window, and the sweep deletes it.
func TestOnPeerDisconnected_RegistersAndSweeps(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
store := &MockStore{}
ctrl := gomock.NewController(t)
peersManager := peers.NewMockManager(ctrl)
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
// Expect DeletePeers to be called for the one disconnected peer
peersManager.EXPECT().
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
for _, peerID := range peerIDs {
delete(store.account.Peers, peerID)
}
return nil
}).
AnyTimes()
mgr := NewEphemeralManager(store, peersManager)
mgr.loadEphemeralPeers(context.Background())
for _, v := range store.account.Peers {
mgr.OnPeerConnected(context.Background(), v)
}
mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
mgr.cleanup(context.Background())
expected := numberOfPeers + numberOfEphemeralPeers - 1
if len(store.account.Peers) != expected {
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers))
}
}
func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) {
const (
ephemeralPeers = 10
testLifeTime = 1 * time.Second
testCleanupWindow = 100 * time.Millisecond
)
t.Cleanup(func() {
timeNow = time.Now
})
startTime := time.Now()
timeNow = func() time.Time {
return startTime
}
mockStore := &MockStore{}
account := newAccountWithId(context.Background(), "account", "", "", false)
mockStore.account = account
wg := &sync.WaitGroup{}
wg.Add(ephemeralPeers)
mockAM := &MockAccountManager{
store: mockStore,
wg: wg,
}
getNow, setNow := withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersManager := peers.NewMockManager(ctrl)
peersMgr := peers.NewMockManager(ctrl)
// Set up expectation that DeletePeers will be called once with all peer IDs
peersManager.EXPECT().
DeletePeers(gomock.Any(), account.Id, gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
// Simulate the actual deletion behavior
for _, peerID := range peerIDs {
err := mockAM.DeletePeer(ctx, accountID, peerID, userID)
if err != nil {
return err
}
var deletedMu sync.Mutex
var deleted []string
var deleteCalls atomic.Int32
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, accountID string, peerIDs []string, _ string, _ bool) error {
deleteCalls.Add(1)
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
}
mockAM.BufferUpdateAccountPeers(ctx, accountID, types.UpdateReason{})
mockStore.mu.Unlock()
deletedMu.Lock()
deleted = append(deleted, peerIDs...)
deletedMu.Unlock()
return nil
}).
Times(1)
}).AnyTimes()
mgr := NewEphemeralManager(mockStore, peersManager)
mgr.lifeTime = testLifeTime
mgr.cleanupWindow = testCleanupWindow
mgr := newManagerForTest(t, mockStore, peersMgr)
// Add peers and disconnect them at slightly different times (within cleanup window)
for i := range ephemeralPeers {
p := &nbpeer.Peer{ID: fmt.Sprintf("peer-%d", i), AccountID: account.Id, Ephemeral: true}
mockStore.account.Peers[p.ID] = p
// One ephemeral peer that disconnected "now".
now := getNow()
p := &nbpeer.Peer{
ID: "p1",
AccountID: "acc-1",
Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now},
}
mockStore.account.Peers[p.ID] = p
mgr.OnPeerDisconnected(context.Background(), p)
// Advance past lifeTime + cleanupWindow so the timer-driven sweep fires.
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
require.Eventually(t, func() bool { return deleteCalls.Load() >= 1 }, 2*time.Second, 5*time.Millisecond,
"sweep should fire and delete the stale peer")
deletedMu.Lock()
deletedCopy := append([]string(nil), deleted...)
deletedMu.Unlock()
require.Equal(t, []string{"p1"}, deletedCopy, "only the one ephemeral peer should be deleted")
}
// TestOnPeerDisconnected_NonEphemeralIgnored: a non-ephemeral disconnect
// must not register the account or arm any timer.
func TestOnPeerDisconnected_NonEphemeralIgnored(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
// No DeletePeers expectation — must not be called.
mgr := newManagerForTest(t, mockStore, peersMgr)
mgr.OnPeerDisconnected(context.Background(), &nbpeer.Peer{
ID: "p1",
AccountID: "acc-1",
Ephemeral: false,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: timeNow()},
})
mgr.accountsLock.Lock()
require.Empty(t, mgr.accounts, "non-ephemeral disconnect must not register an account")
mgr.accountsLock.Unlock()
}
// TestSweep_DropsAccountWhenIdle: after a sweep cleans the stale peers,
// if no more disconnects have arrived the account must be dropped from
// the in-memory tracker.
func TestSweep_DropsAccountWhenIdle(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
getNow, setNow := withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
}
mockStore.mu.Unlock()
return nil
}).AnyTimes()
mgr := newManagerForTest(t, mockStore, peersMgr)
now := getNow()
p := &nbpeer.Peer{ID: "p1", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now}}
mockStore.account.Peers[p.ID] = p
mgr.OnPeerDisconnected(context.Background(), p)
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
require.Eventually(t, func() bool {
mgr.accountsLock.Lock()
defer mgr.accountsLock.Unlock()
return len(mgr.accounts) == 0
}, 2*time.Second, 5*time.Millisecond, "account should be dropped after sweep with no new disconnects")
}
// TestSweep_ReArmsWhenNewDisconnectArrived: simulate the race where a
// fresh disconnect arrives just before the sweep fires. The sweep must
// observe the updated lastDisc and re-arm rather than drop.
func TestSweep_ReArmsWhenNewDisconnectArrived(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
getNow, setNow := withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
}
mockStore.mu.Unlock()
return nil
}).AnyTimes()
mgr := newManagerForTest(t, mockStore, peersMgr)
now := getNow()
p1 := &nbpeer.Peer{ID: "p1", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now}}
mockStore.account.Peers[p1.ID] = p1
mgr.OnPeerDisconnected(context.Background(), p1)
// Advance most of the way toward the first sweep, then introduce
// a fresh disconnect that resets lastDisc.
setNow(now.Add(mgr.lifeTime - 10*time.Millisecond))
p2 := &nbpeer.Peer{ID: "p2", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: getNow()}}
mockStore.account.Peers[p2.ID] = p2
mgr.OnPeerDisconnected(context.Background(), p2)
// Push past p1's staleness so the first sweep runs and cleans p1
// but observes p2 already on the account entry. It must re-arm.
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
require.Eventually(t, func() bool {
mockStore.mu.Lock()
defer mockStore.mu.Unlock()
_, gone := mockStore.account.Peers["p1"]
return !gone
}, 2*time.Second, 5*time.Millisecond, "p1 should be cleaned at the first sweep")
// The account should still be tracked because p2 is younger than lifeTime
// from the sweep's vantage point at this moment.
mgr.accountsLock.Lock()
_, stillTracked := mgr.accounts["acc-1"]
mgr.accountsLock.Unlock()
require.True(t, stillTracked, "account should remain tracked because p2's disconnect kept it active")
// Push past p2's staleness; second sweep cleans p2 and drops the account.
setNow(getNow().Add(mgr.lifeTime + 5*mgr.cleanupWindow))
require.Eventually(t, func() bool {
mgr.accountsLock.Lock()
defer mgr.accountsLock.Unlock()
return len(mgr.accounts) == 0
}, 2*time.Second, 5*time.Millisecond, "account should drop after the final sweep")
}
// TestSweep_BatchesPeersPerAccount: many ephemeral peers disconnect on
// the same account; a single sweep must delete them all in one
// DeletePeers call.
func TestSweep_BatchesPeersPerAccount(t *testing.T) {
const ephemeralPeers = 8
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
getNow, setNow := withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
deleteBatches := make(chan []string, 4)
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
cp := append([]string(nil), peerIDs...)
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
}
mockStore.mu.Unlock()
deleteBatches <- cp
return nil
}).Times(1)
mgr := newManagerForTest(t, mockStore, peersMgr)
now := getNow()
for i := 0; i < ephemeralPeers; i++ {
id := fmt.Sprintf("p-%d", i)
// Stagger by a fraction of cleanupWindow so they all fall on
// the same sweep tick.
when := now.Add(time.Duration(i) * time.Millisecond)
p := &nbpeer.Peer{ID: id, AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: when}}
mockStore.account.Peers[id] = p
mgr.OnPeerDisconnected(context.Background(), p)
startTime = startTime.Add(testCleanupWindow / (ephemeralPeers * 2))
}
// Advance time past the lifetime to trigger cleanup
startTime = startTime.Add(testLifeTime + testCleanupWindow)
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
// Wait for all deletions to complete
wg.Wait()
assert.Len(t, mockStore.account.Peers, 0, "all ephemeral peers should be cleaned up after the lifetime")
assert.Equal(t, 1, mockAM.GetBufferUpdateCalls(account.Id), "buffer update should be called once")
assert.Equal(t, ephemeralPeers, mockAM.GetDeletePeerCalls(), "should have deleted all peers")
select {
case batch := <-deleteBatches:
require.Len(t, batch, ephemeralPeers, "all peers should be deleted in a single batch")
case <-time.After(2 * time.Second):
t.Fatal("expected one batched DeletePeers call")
}
}
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) {
store.account = newAccountWithId(context.Background(), "my account", "", "", false)
// TestLoadInitialAccounts_SeedsFromStore exercises the post-restart
// catch-up path: pre-populate the store, point the manager at it, and
// confirm both already-stale and not-yet-stale peers get cleaned at
// their proper times.
func TestLoadInitialAccounts_SeedsFromStore(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
getNow, setNow := withFakeClock(t, time.Now())
for i := 0; i < numberOfPeers; i++ {
peerId := fmt.Sprintf("peer_%d", i)
p := &nbpeer.Peer{
ID: peerId,
Ephemeral: false,
}
store.account.Peers[p.ID] = p
now := getNow()
// p-stale: already past the staleness window when load runs.
mockStore.account.Peers["p-stale"] = &nbpeer.Peer{
ID: "p-stale", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now.Add(-time.Hour)},
}
// p-fresh: disconnected but not yet stale.
mockStore.account.Peers["p-fresh"] = &nbpeer.Peer{
ID: "p-fresh", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now},
}
for i := 0; i < numberOfEphemeralPeers; i++ {
peerId := fmt.Sprintf("ephemeral_peer_%d", i)
p := &nbpeer.Peer{
ID: peerId,
Ephemeral: true,
}
store.account.Peers[p.ID] = p
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
}
mockStore.mu.Unlock()
return nil
}).AnyTimes()
mgr := newManagerForTest(t, mockStore, peersMgr)
// Drive loadInitialAccounts directly with the fake-clock-aware now.
mgr.loadInitialAccounts(context.Background())
// First sweep should fire shortly (cleanupWindow) for the stale peer.
setNow(now.Add(5 * mgr.cleanupWindow))
require.Eventually(t, func() bool {
mockStore.mu.Lock()
defer mockStore.mu.Unlock()
_, gone := mockStore.account.Peers["p-stale"]
return !gone
}, 2*time.Second, 5*time.Millisecond, "p-stale should be deleted on the first sweep")
// p-fresh is not yet stale; advance past its window.
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
require.Eventually(t, func() bool {
mockStore.mu.Lock()
defer mockStore.mu.Unlock()
_, gone := mockStore.account.Peers["p-fresh"]
return !gone
}, 2*time.Second, 5*time.Millisecond, "p-fresh should be deleted once it crosses the staleness window")
}
// TestStop_CancelsPendingWork verifies that Stop() cancels both the
// deferred initial load and per-account sweep timers and that
// subsequent OnPeerDisconnected calls are ignored.
func TestStop_CancelsPendingWork(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
// DeletePeers must NOT be called after Stop.
mgr := NewEphemeralManager(mockStore, peersMgr)
mgr.lifeTime = 100 * time.Millisecond
mgr.cleanupWindow = 10 * time.Millisecond
// Use a long delay so the initial-load timer is still pending.
mgr.initialLoadDelay = func() time.Duration { return time.Hour }
mgr.LoadInitialPeers(context.Background())
mgr.OnPeerDisconnected(context.Background(), &nbpeer.Peer{
ID: "p1", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: timeNow()},
})
mgr.accountsLock.Lock()
require.NotNil(t, mgr.initialLoadTimer, "initial-load timer should be armed")
require.Len(t, mgr.accounts, 1, "account should be tracked after disconnect")
mgr.accountsLock.Unlock()
mgr.Stop()
mgr.accountsLock.Lock()
require.Empty(t, mgr.accounts, "Stop should clear tracked accounts")
require.True(t, mgr.stopped, "stopped flag must be set")
mgr.accountsLock.Unlock()
// Post-stop disconnect must be ignored.
mgr.OnPeerDisconnected(context.Background(), &nbpeer.Peer{
ID: "p2", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: timeNow()},
})
mgr.accountsLock.Lock()
require.Empty(t, mgr.accounts, "disconnects after Stop must be ignored")
mgr.accountsLock.Unlock()
}
// TestOnPeerConnected_IsNoop: the OnPeerConnected hook is preserved on
// the interface but does nothing in the per-account model — the sweep
// query filters connected peers at the DB level.
func TestOnPeerConnected_IsNoop(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
mgr := newManagerForTest(t, mockStore, peersMgr)
mgr.OnPeerDisconnected(context.Background(), &nbpeer.Peer{
ID: "p1", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: timeNow()},
})
mgr.accountsLock.Lock()
require.Len(t, mgr.accounts, 1, "disconnect should track the account")
mgr.accountsLock.Unlock()
mgr.OnPeerConnected(context.Background(), &nbpeer.Peer{ID: "p1", AccountID: "acc-1", Ephemeral: true})
mgr.accountsLock.Lock()
require.Len(t, mgr.accounts, 1, "OnPeerConnected must be a no-op")
mgr.accountsLock.Unlock()
}
// TestSweep_StoreErrorReArms: if the stale-peer query fails, the
// account must remain tracked and a follow-up sweep gets scheduled.
func TestSweep_StoreErrorReArms(t *testing.T) {
mockStore := &erroringStore{
MockStore: MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)},
}
getNow, setNow := withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
mgr := newManagerForTest(t, mockStore, peersMgr)
p := &nbpeer.Peer{ID: "p1", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: getNow()}}
mockStore.account.Peers[p.ID] = p
mgr.OnPeerDisconnected(context.Background(), p)
mockStore.fail.Store(true)
setNow(getNow().Add(mgr.lifeTime + 5*mgr.cleanupWindow))
// Wait until the failing sweep has run at least once.
require.Eventually(t, func() bool { return mockStore.failedCalls.Load() >= 1 },
2*time.Second, 5*time.Millisecond, "expected at least one failing sweep")
mgr.accountsLock.Lock()
_, stillTracked := mgr.accounts["acc-1"]
mgr.accountsLock.Unlock()
require.True(t, stillTracked, "account must remain tracked after a sweep error")
// Recover and ensure the rearmed sweep cleans up.
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
}
mockStore.mu.Unlock()
return nil
}).AnyTimes()
mockStore.fail.Store(false)
require.Eventually(t, func() bool {
mockStore.mu.Lock()
defer mockStore.mu.Unlock()
_, gone := mockStore.account.Peers["p1"]
return !gone
}, 2*time.Second, 5*time.Millisecond, "rearmed sweep should clean up after the store recovers")
}
// erroringStore is a MockStore that can be flipped into a failing mode
// to exercise the sweep's error-rearm path.
type erroringStore struct {
MockStore
fail atomic.Bool
failedCalls atomic.Int32
}
func (s *erroringStore) GetStaleEphemeralPeerIDsForAccount(ctx context.Context, accountID string, olderThan time.Time) ([]string, error) {
if s.fail.Load() {
s.failedCalls.Add(1)
return nil, errors.New("synthetic store error")
}
return s.MockStore.GetStaleEphemeralPeerIDsForAccount(ctx, accountID, olderThan)
}
// TestDefaultInitialLoadDelay confirms the jitter falls inside the
// documented [8m, 10m) range — sanity check for the production timer.
func TestDefaultInitialLoadDelay(t *testing.T) {
for i := 0; i < 1000; i++ {
d := defaultInitialLoadDelay()
assert.GreaterOrEqual(t, d, initialLoadMinDelay)
assert.Less(t, d, initialLoadMaxDelay)
}
}
@@ -351,3 +596,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis
}
return acc
}
// silence the import "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
// (still needed indirectly for ephemeral.EphemeralLifeTime in production paths).
var _ = ephemeral.EphemeralLifeTime

View File

@@ -5,7 +5,6 @@ package peers
import (
"context"
"fmt"
"net"
"time"
"github.com/rs/xid"
@@ -36,14 +35,6 @@ type Manager interface {
SetAccountManager(accountManager account.Manager)
GetPeerID(ctx context.Context, peerKey string) (string, error)
CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error
// GetPeerByTunnelIP looks up a peer in accountID by its WireGuard tunnel IP.
// Returns nil with an error when no match exists. No permission check;
// callers (the proxy's ValidateTunnelPeer RPC) are trusted server components.
GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error)
// GetPeerWithGroups returns the peer and the list of *types.Group it belongs
// to. Used by the proxy's auth path to authorise a request by the calling
// peer's group memberships.
GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error)
}
type managerImpl struct {
@@ -108,26 +99,6 @@ func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string,
return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
}
// GetPeerByTunnelIP delegates to the store's indexed lookup.
func (m *managerImpl) GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error) {
return m.store.GetPeerByIP(ctx, store.LockingStrengthNone, accountID, ip)
}
// GetPeerWithGroups returns the peer plus its group memberships. Any store
// error returns (nil, nil, err) so callers never receive a valid peer
// alongside a non-nil error.
func (m *managerImpl) GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error) {
p, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return nil, nil, err
}
groups, err := m.store.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return nil, nil, err
}
return p, groups, nil
}
func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {

View File

@@ -6,7 +6,6 @@ package peers
import (
context "context"
net "net"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
@@ -14,7 +13,6 @@ import (
account "github.com/netbirdio/netbird/management/server/account"
integrated_validator "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
peer "github.com/netbirdio/netbird/management/server/peer"
types "github.com/netbirdio/netbird/management/server/types"
)
// MockManager is a mock of Manager interface.
@@ -40,20 +38,6 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder {
return m.recorder
}
// CreateProxyPeer mocks base method.
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID, peerKey, cluster string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
ret0, _ := ret[0].(error)
return ret0
}
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
}
// DeletePeers mocks base method.
func (m *MockManager) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
m.ctrl.T.Helper()
@@ -113,21 +97,6 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID)
}
// GetPeerByTunnelIP mocks base method.
func (m *MockManager) GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeerByTunnelIP", ctx, accountID, ip)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPeerByTunnelIP indicates an expected call of GetPeerByTunnelIP.
func (mr *MockManagerMockRecorder) GetPeerByTunnelIP(ctx, accountID, ip interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerByTunnelIP", reflect.TypeOf((*MockManager)(nil).GetPeerByTunnelIP), ctx, accountID, ip)
}
// GetPeerID mocks base method.
func (m *MockManager) GetPeerID(ctx context.Context, peerKey string) (string, error) {
m.ctrl.T.Helper()
@@ -143,22 +112,6 @@ func (mr *MockManagerMockRecorder) GetPeerID(ctx, peerKey interface{}) *gomock.C
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerID", reflect.TypeOf((*MockManager)(nil).GetPeerID), ctx, peerKey)
}
// GetPeerWithGroups mocks base method.
func (m *MockManager) GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeerWithGroups", ctx, accountID, peerID)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].([]*types.Group)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// GetPeerWithGroups indicates an expected call of GetPeerWithGroups.
func (mr *MockManagerMockRecorder) GetPeerWithGroups(ctx, accountID, peerID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerWithGroups", reflect.TypeOf((*MockManager)(nil).GetPeerWithGroups), ctx, accountID, peerID)
}
// GetPeersByGroupIDs mocks base method.
func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
m.ctrl.T.Helper()
@@ -209,3 +162,17 @@ func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
}
// CreateProxyPeer mocks base method.
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
ret0, _ := ret[0].(error)
return ret0
}
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
}

View File

@@ -23,8 +23,6 @@ type Domain struct {
// SupportsCrowdSec is populated at query time from proxy cluster capabilities.
// Not persisted.
SupportsCrowdSec *bool `gorm:"-"`
// SupportsPrivate is populated at query time from proxy cluster capabilities. Not persisted.
SupportsPrivate *bool `gorm:"-"`
}
// EventMeta returns activity event metadata for a domain

View File

@@ -49,7 +49,6 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
SupportsCustomPorts: d.SupportsCustomPorts,
RequireSubdomain: d.RequireSubdomain,
SupportsCrowdsec: d.SupportsCrowdSec,
SupportsPrivate: d.SupportsPrivate,
}
if d.TargetCluster != "" {
resp.TargetCluster = &d.TargetCluster

View File

@@ -35,7 +35,6 @@ type proxyManager interface {
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
}
type Manager struct {
@@ -94,7 +93,6 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
d.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, cluster)
d.RequireSubdomain = m.proxyManager.ClusterRequireSubdomain(ctx, cluster)
d.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, cluster)
d.SupportsPrivate = m.proxyManager.ClusterSupportsPrivate(ctx, cluster)
ret = append(ret, d)
}
@@ -111,7 +109,6 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
if d.TargetCluster != "" {
cd.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, d.TargetCluster)
cd.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, d.TargetCluster)
cd.SupportsPrivate = m.proxyManager.ClusterSupportsPrivate(ctx, d.TargetCluster)
}
// Custom domains never require a subdomain by default since
// the account owns them and should be able to use the bare domain.

View File

@@ -10,7 +10,7 @@ import (
)
type mockProxyManager struct {
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
}
@@ -40,10 +40,6 @@ func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string)
return nil
}
func (m *mockProxyManager) ClusterSupportsPrivate(_ context.Context, _ string) *bool {
return nil
}
func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
@@ -155,3 +151,4 @@ func TestGetClusterAllowList_PublicEmpty_BYOPOnly(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, []string{"byop.example.com"}, result)
}

View File

@@ -19,7 +19,6 @@ type Manager interface {
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
CountAccountProxies(ctx context.Context, accountID string) (int64, error)

View File

@@ -17,11 +17,10 @@ type store interface {
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
GetProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
@@ -138,11 +137,6 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string
return m.store.GetClusterSupportsCrowdSec(ctx, clusterAddr)
}
// ClusterSupportsPrivate reports whether any active proxy claims the private capability (nil = unreported).
func (m Manager) ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
return m.store.GetClusterSupportsPrivate(ctx, clusterAddr)
}
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
@@ -184,3 +178,4 @@ func (m *Manager) DeleteAccountCluster(ctx context.Context, clusterAddress, acco
}
return nil
}

View File

@@ -15,16 +15,16 @@ import (
)
type mockStore struct {
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
}
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
@@ -57,7 +57,7 @@ func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context
}
return nil, nil
}
func (m *mockStore) GetProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) {
func (m *mockStore) GetActiveProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) {
return nil, nil
}
func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error {
@@ -99,9 +99,6 @@ func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *boo
func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
return nil
}
func (m *mockStore) GetClusterSupportsPrivate(_ context.Context, _ string) *bool {
return nil
}
func newTestManager(s store) *Manager {
meter := noop.NewMeterProvider().Meter("test")

View File

@@ -92,20 +92,6 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCrowdSec", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCrowdSec), ctx, clusterAddr)
}
// ClusterSupportsPrivate mocks base method.
func (m *MockManager) ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterSupportsPrivate", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// ClusterSupportsPrivate indicates an expected call of ClusterSupportsPrivate.
func (mr *MockManagerMockRecorder) ClusterSupportsPrivate(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsPrivate", reflect.TypeOf((*MockManager)(nil).ClusterSupportsPrivate), ctx, clusterAddr)
}
// Connect mocks base method.
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) {
m.ctrl.T.Helper()

View File

@@ -20,9 +20,6 @@ type Capabilities struct {
RequireSubdomain *bool
// SupportsCrowdsec indicates whether this proxy has CrowdSec configured.
SupportsCrowdsec *bool
// Private indicates whether this proxy supports inbound access via Wireguard
// tunnel and netbird-only authentication policies
Private *bool
}
// Proxy represents a reverse proxy instance
@@ -45,34 +42,10 @@ func (Proxy) TableName() string {
return "proxies"
}
// ClusterType is the source of a proxy cluster.
type ClusterType string
const (
// ClusterTypeAccount is a cluster operated by the account itself (BYOP) —
// at least one proxy row in the cluster carries a non-NULL account_id.
ClusterTypeAccount ClusterType = "account"
// ClusterTypeShared is a cluster operated by NetBird and shared across
// accounts — all proxy rows in the cluster have account_id IS NULL.
ClusterTypeShared ClusterType = "shared"
)
// Cluster represents a group of proxy nodes serving the same address.
//
// Online and ConnectedProxies derive from the same 2-min active window
// the rest of the module uses, but Cluster rows are not gated on it —
// the cluster listing surfaces offline clusters too so operators can
// see and clean them up. The 1-hour heartbeat reaper still bounds the
// table eventually.
type Cluster struct {
ID string
Address string
Type ClusterType
Online bool
ConnectedProxies int
// *bool: nil = no proxy reported the capability; the dashboard renders that as unknown.
SupportsCustomPorts *bool
RequireSubdomain *bool
SupportsCrowdSec *bool
Private *bool
SelfHosted bool
}

View File

@@ -9,7 +9,7 @@ import (
)
type Manager interface {
GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)

View File

@@ -65,20 +65,6 @@ func (mr *MockManagerMockRecorder) CreateServiceFromPeer(ctx, accountID, peerID,
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateServiceFromPeer", reflect.TypeOf((*MockManager)(nil).CreateServiceFromPeer), ctx, accountID, peerID, req)
}
// DeleteAccountCluster mocks base method.
func (m *MockManager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, accountID, userID, clusterAddress)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, accountID, userID, clusterAddress interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, accountID, userID, clusterAddress)
}
// DeleteAllServices mocks base method.
func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
m.ctrl.T.Helper()
@@ -93,6 +79,20 @@ func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID inte
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
}
// DeleteAccountCluster mocks base method.
func (m *MockManager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, accountID, userID, clusterAddress)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, accountID, userID, clusterAddress interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, accountID, userID, clusterAddress)
}
// DeleteService mocks base method.
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
m.ctrl.T.Helper()
@@ -122,6 +122,21 @@ func (mr *MockManagerMockRecorder) GetAccountServices(ctx, accountID interface{}
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockManager)(nil).GetAccountServices), ctx, accountID)
}
// GetActiveClusters mocks base method.
func (m *MockManager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveClusters", ctx, accountID, userID)
ret0, _ := ret[0].([]proxy.Cluster)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveClusters indicates an expected call of GetActiveClusters.
func (mr *MockManagerMockRecorder) GetActiveClusters(ctx, accountID, userID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx, accountID, userID)
}
// GetAllServices mocks base method.
func (m *MockManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) {
m.ctrl.T.Helper()
@@ -137,19 +152,19 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
}
// GetClusters mocks base method.
func (m *MockManager) GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
// GetServiceByDomain mocks base method.
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClusters", ctx, accountID, userID)
ret0, _ := ret[0].([]proxy.Cluster)
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetClusters indicates an expected call of GetClusters.
func (mr *MockManagerMockRecorder) GetClusters(ctx, accountID, userID interface{}) *gomock.Call {
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusters", reflect.TypeOf((*MockManager)(nil).GetClusters), ctx, accountID, userID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
}
// GetGlobalServices mocks base method.
@@ -182,21 +197,6 @@ func (mr *MockManagerMockRecorder) GetService(ctx, accountID, userID, serviceID
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetService", reflect.TypeOf((*MockManager)(nil).GetService), ctx, accountID, userID, serviceID)
}
// GetServiceByDomain mocks base method.
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
}
// GetServiceByID mocks base method.
func (m *MockManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error) {
m.ctrl.T.Helper()

View File

@@ -187,7 +187,7 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
return
}
clusters, err := h.manager.GetClusters(r.Context(), userAuth.AccountId, userAuth.UserId)
clusters, err := h.manager.GetActiveClusters(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -196,15 +196,10 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
apiClusters := make([]api.ProxyCluster, 0, len(clusters))
for _, c := range clusters {
apiClusters = append(apiClusters, api.ProxyCluster{
Id: c.ID,
Address: c.Address,
Type: api.ProxyClusterType(c.Type),
Online: c.Online,
ConnectedProxies: c.ConnectedProxies,
SupportsCustomPorts: c.SupportsCustomPorts,
RequireSubdomain: c.RequireSubdomain,
SupportsCrowdsec: c.SupportsCrowdSec,
Private: c.Private,
Id: c.ID,
Address: c.Address,
ConnectedProxies: c.ConnectedProxies,
SelfHosted: c.SelfHosted,
})
}

View File

@@ -81,8 +81,6 @@ type ClusterDeriver interface {
type CapabilityProvider interface {
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
}
type Manager struct {
@@ -114,12 +112,8 @@ func (m *Manager) StartExposeReaper(ctx context.Context) {
m.exposeReaper.StartExposeReaper(ctx)
}
// GetClusters returns every proxy cluster visible to the account
// (shared + its own BYOP), regardless of whether any proxy in the
// cluster is currently heartbeating. Each cluster is enriched with the
// capability flags reported by its active proxies so the dashboard can
// render feature support without a second round-trip.
func (m *Manager) GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -128,19 +122,7 @@ func (m *Manager) GetClusters(ctx context.Context, accountID, userID string) ([]
return nil, status.NewPermissionDeniedError()
}
clusters, err := m.store.GetProxyClusters(ctx, accountID)
if err != nil {
return nil, err
}
for i := range clusters {
clusters[i].SupportsCustomPorts = m.capabilities.ClusterSupportsCustomPorts(ctx, clusters[i].Address)
clusters[i].RequireSubdomain = m.capabilities.ClusterRequireSubdomain(ctx, clusters[i].Address)
clusters[i].SupportsCrowdSec = m.capabilities.ClusterSupportsCrowdSec(ctx, clusters[i].Address)
clusters[i].Private = m.capabilities.ClusterSupportsPrivate(ctx, clusters[i].Address)
}
return clusters, nil
return m.store.GetActiveProxyClusters(ctx, accountID)
}
// DeleteAccountCluster removes all proxy registrations for the given cluster address
@@ -210,9 +192,6 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
target.Host = resource.Domain
case service.TargetTypeSubnet:
// For subnets we do not do any lookups on the resource
case service.TargetTypeCluster:
// Cluster targets carry the upstream address on target_id; the
// proxy resolves the destination at request time.
default:
return fmt.Errorf("unknown target type: %s", target.TargetType)
}
@@ -784,10 +763,6 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
if err := validateResourceTarget(ctx, transaction, accountID, target); err != nil {
return err
}
case service.TargetTypeCluster:
if err := validateClusterTarget(target); err != nil {
return err
}
default:
return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId)
}
@@ -795,13 +770,6 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
return nil
}
func validateClusterTarget(target *service.Target) error {
if !target.Options.DirectUpstream {
return status.Errorf(status.InvalidArgument, "cluster target %s has direct upstream disabled", target.Host)
}
return nil
}
func validatePeerTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error {
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
@@ -978,14 +946,12 @@ func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID str
return fmt.Errorf("failed to get services: %w", err)
}
oidcCfg := m.proxyController.GetOIDCValidationConfig()
for _, s := range services {
err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
}
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
}
return nil

View File

@@ -1344,66 +1344,3 @@ func TestValidateSubdomainRequirement(t *testing.T) {
})
}
}
func TestValidateTargetReferences_ClusterTargetSkipsLookup(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl)
accountID := "test-account"
// No peer or resource lookups must be issued for cluster targets.
targets := []*rpservice.Target{
{
TargetId: "eu.proxy.netbird.io",
TargetType: rpservice.TargetTypeCluster,
Options: rpservice.TargetOptions{DirectUpstream: true},
},
}
require.NoError(t, validateTargetReferences(ctx, mockStore, accountID, targets), "cluster target must validate without store lookups")
}
// TestValidateTargetReferences_ClusterTargetRequiresDirectUpstream pins the
// store-side check that cluster targets must opt into the host-stack dial
// path. Without DirectUpstream the proxy would route this target through
// the embedded NetBird client and fail on every request.
func TestValidateTargetReferences_ClusterTargetRequiresDirectUpstream(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl)
accountID := "test-account"
targets := []*rpservice.Target{
{
TargetId: "eu.proxy.netbird.io",
TargetType: rpservice.TargetTypeCluster,
Host: "backend.lan",
},
}
err := validateTargetReferences(ctx, mockStore, accountID, targets)
require.Error(t, err, "cluster target without direct_upstream must be rejected")
assert.ErrorContains(t, err, "direct upstream disabled")
}
func TestReplaceHostByLookup_SkipsClusterTarget(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl)
accountID := "test-account"
mgr := &Manager{store: mockStore}
svc := &rpservice.Service{
ID: "svc-1",
AccountID: accountID,
Targets: []*rpservice.Target{
{
TargetId: "eu.proxy.netbird.io",
TargetType: rpservice.TargetTypeCluster,
Host: "127.0.0.1",
},
},
}
require.NoError(t, mgr.replaceHostByLookup(ctx, accountID, svc), "cluster target must not trigger peer/resource lookup")
assert.Equal(t, "127.0.0.1", svc.Targets[0].Host, "operator-supplied host must be preserved for cluster target")
}

View File

@@ -45,11 +45,10 @@ const (
StatusCertificateFailed Status = "certificate_failed"
StatusError Status = "error"
TargetTypePeer TargetType = "peer"
TargetTypeHost TargetType = "host"
TargetTypeDomain TargetType = "domain"
TargetTypeSubnet TargetType = "subnet"
TargetTypeCluster TargetType = "cluster"
TargetTypePeer TargetType = "peer"
TargetTypeHost TargetType = "host"
TargetTypeDomain TargetType = "domain"
TargetTypeSubnet TargetType = "subnet"
SourcePermanent = "permanent"
SourceEphemeral = "ephemeral"
@@ -61,11 +60,6 @@ type TargetOptions struct {
SessionIdleTimeout time.Duration `json:"session_idle_timeout,omitempty"`
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
// DirectUpstream bypasses the proxy's embedded NetBird client and dials
// the target via the proxy host's network stack. Useful for upstreams
// reachable without WireGuard (public APIs, LAN services, localhost
// sidecars). Default false.
DirectUpstream bool `json:"direct_upstream,omitempty"`
}
type Target struct {
@@ -73,7 +67,7 @@ type Target struct {
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
Path *string `json:"path,omitempty"`
Host string `json:"host"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Port uint16 `gorm:"index:idx_target_port" json:"port"`
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
@@ -206,10 +200,6 @@ type Service struct {
Mode string `gorm:"default:'http'"`
ListenPort uint16
PortAutoAssigned bool
// Private marks the service as NetBird-only: auth via ValidateTunnelPeer against AccessGroups instead of SSO. HTTP-only.
Private bool
// AccessGroups is the group ID allowlist for inbound peers on private services. Mutually exclusive with bearer SSO.
AccessGroups []string `json:"access_groups,omitempty" gorm:"serializer:json"`
}
// InitNewRecord generates a new unique ID and resets metadata for a newly created
@@ -309,12 +299,6 @@ func (s *Service) ToAPIResponse() *api.Service {
Mode: &mode,
ListenPort: &listenPort,
PortAutoAssigned: &s.PortAutoAssigned,
Private: &s.Private,
}
if len(s.AccessGroups) > 0 {
groups := append([]string(nil), s.AccessGroups...)
resp.AccessGroups = &groups
}
if s.ProxyCluster != "" {
@@ -324,7 +308,6 @@ func (s *Service) ToAPIResponse() *api.Service {
return resp
}
// ToProtoMapping converts the service into the wire format the proxy consumes.
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping {
pathMappings := s.buildPathMappings()
@@ -366,7 +349,6 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
RewriteRedirects: s.RewriteRedirects,
Mode: s.Mode,
ListenPort: int32(s.ListenPort), //nolint:gosec
Private: s.Private,
}
if r := restrictionsToProto(s.Restrictions); r != nil {
@@ -473,8 +455,7 @@ func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode {
}
func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 &&
opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 && !opts.DirectUpstream {
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 {
return nil
}
apiOpts := &api.ServiceTargetOptions{}
@@ -496,22 +477,17 @@ func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
if len(opts.CustomHeaders) > 0 {
apiOpts.CustomHeaders = &opts.CustomHeaders
}
if opts.DirectUpstream {
apiOpts.DirectUpstream = &opts.DirectUpstream
}
return apiOpts
}
func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 &&
len(opts.CustomHeaders) == 0 && !opts.DirectUpstream {
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && len(opts.CustomHeaders) == 0 {
return nil
}
popts := &proto.PathTargetOptions{
SkipTlsVerify: opts.SkipTLSVerify,
PathRewrite: pathRewriteToProto(opts.PathRewrite),
CustomHeaders: opts.CustomHeaders,
DirectUpstream: opts.DirectUpstream,
SkipTlsVerify: opts.SkipTLSVerify,
PathRewrite: pathRewriteToProto(opts.PathRewrite),
CustomHeaders: opts.CustomHeaders,
}
if opts.RequestTimeout != 0 {
popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
@@ -561,9 +537,6 @@ func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions,
if o.CustomHeaders != nil {
opts.CustomHeaders = *o.CustomHeaders
}
if o.DirectUpstream != nil {
opts.DirectUpstream = *o.DirectUpstream
}
return opts, nil
}
@@ -578,14 +551,6 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro
if req.ListenPort != nil {
s.ListenPort = uint16(*req.ListenPort) //nolint:gosec
}
if req.Private != nil {
s.Private = *req.Private
}
if req.AccessGroups != nil {
s.AccessGroups = append([]string(nil), *req.AccessGroups...)
} else {
s.AccessGroups = nil
}
targets, err := targetsFromAPI(accountID, req.Targets)
if err != nil {
@@ -775,9 +740,6 @@ func (s *Service) Validate() error {
if err := validateAccessRestrictions(&s.Restrictions); err != nil {
return err
}
if err := s.validatePrivateRequirements(); err != nil {
return err
}
switch s.Mode {
case ModeHTTP:
@@ -791,23 +753,6 @@ func (s *Service) Validate() error {
}
}
// validatePrivateRequirements enforces the private-service contract: HTTP mode, ≥1 access group, no bearer auth.
func (s *Service) validatePrivateRequirements() error {
if !s.Private {
return nil
}
if s.Mode != "" && s.Mode != ModeHTTP {
return fmt.Errorf("private services only support HTTP mode, got %q", s.Mode)
}
if len(s.AccessGroups) == 0 {
return errors.New("private services require at least one access group")
}
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
return errors.New("private services cannot enable bearer auth (SSO): NetBird-only access and SSO are mutually exclusive")
}
return nil
}
func (s *Service) validateHTTPMode() error {
if s.Domain == "" {
return errors.New("service domain is required")
@@ -854,21 +799,11 @@ func (s *Service) validateHTTPTargets() error {
for i, target := range s.Targets {
switch target.TargetType {
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
// Host is normally overwritten by replaceHostByLookup with the
// resolved peer IP / resource address; operator-supplied values
// are honored only when DirectUpstream is set. Validate the
// override here so misconfigured hosts fail fast at API time.
if err := validateDirectUpstreamHost(i, target); err != nil {
return err
}
// host field will be ignored
case TargetTypeSubnet:
if target.Host == "" {
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
}
case TargetTypeCluster:
if err := validateClusterTarget(i, target); err != nil {
return err
}
default:
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
}
@@ -886,67 +821,25 @@ func (s *Service) validateHTTPTargets() error {
return nil
}
// validateClusterTarget cluster targets should not have empty hosts and should have direct upstream enabled.
func validateClusterTarget(idx int, target *Target) error {
host := strings.TrimSpace(target.Host)
if host == "" {
return fmt.Errorf("target %d: has empty host", idx)
}
if !target.Options.DirectUpstream {
return fmt.Errorf("target %d: %s has direct upstream disabled", idx, target.Host)
}
return validateDirectUpstreamHost(idx, target)
}
// validateDirectUpstreamHost validates the operator-supplied Host on a
// peer/host/domain target when DirectUpstream is set. Empty Host is
// allowed — the lookup fills in the default peer IP / resource address.
// Without DirectUpstream the Host value is silently overwritten by
// replaceHostByLookup, so we don't validate it (preserves the historical
// behaviour where APIs accepted any value and dropped it). Non-empty
// Host with DirectUpstream must look like a hostname or IP and must
// not carry a port (port lives on Target.Port).
func validateDirectUpstreamHost(idx int, target *Target) error {
if !target.Options.DirectUpstream {
return nil
}
host := strings.TrimSpace(target.Host)
if host == "" {
return nil
}
if strings.ContainsAny(host, " \t/") {
return fmt.Errorf("target %d: host %q contains invalid characters", idx, host)
}
if _, _, err := net.SplitHostPort(host); err == nil {
return fmt.Errorf("target %d: host %q must not include a port (set target.port instead)", idx, host)
}
return nil
}
func (s *Service) validateL4Target(target *Target) error {
// L4 services have a single target; per-target disable is meaningless
// (use the service-level Enabled flag instead). Force it on so that
// buildPathMappings always includes the target in the proto.
target.Enabled = true
if target.Port == 0 {
return errors.New("target port is required for L4 services")
}
if target.TargetId == "" {
return errors.New("target_id is required for L4 services")
}
if target.TargetType != TargetTypeCluster && target.Port == 0 {
return errors.New("target port is required for L4 services")
}
switch target.TargetType {
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
if err := validateDirectUpstreamHost(0, target); err != nil {
return err
}
// OK
case TargetTypeSubnet:
if target.Host == "" {
return errors.New("target host is required for subnet targets")
}
case TargetTypeCluster:
// target_id carries the cluster address; the proxy resolves
// the upstream at request time.
default:
return fmt.Errorf("invalid target_type %q for L4 service", target.TargetType)
}
@@ -1281,11 +1174,6 @@ func (s *Service) Copy() *Service {
}
}
var accessGroups []string
if len(s.AccessGroups) > 0 {
accessGroups = append([]string(nil), s.AccessGroups...)
}
return &Service{
ID: s.ID,
AccountID: s.AccountID,
@@ -1307,8 +1195,6 @@ func (s *Service) Copy() *Service {
Mode: s.Mode,
ListenPort: s.ListenPort,
PortAutoAssigned: s.PortAutoAssigned,
Private: s.Private,
AccessGroups: accessGroups,
}
}

View File

@@ -12,7 +12,6 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -1117,191 +1116,3 @@ func TestValidate_HeaderAuths(t *testing.T) {
assert.Contains(t, err.Error(), "exceeds maximum length")
})
}
func TestValidate_HTTPClusterTarget(t *testing.T) {
rp := validProxy()
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
require.NoError(t, rp.Validate(), "HTTP cluster target with target_id, host, and direct_upstream must validate")
}
func TestValidate_HTTPClusterTarget_RequiresTargetId(t *testing.T) {
rp := validProxy()
rp.Targets = []*Target{{
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "empty target_id", "cluster target must reject empty target_id")
}
// TestValidate_HTTPClusterTarget_RequiresHost pins the new cluster-target
// rule that operator-supplied Host is mandatory: cluster targets dial the
// upstream via the host network stack (direct_upstream is implied), so an
// empty Host leaves the proxy with nothing to dial.
func TestValidate_HTTPClusterTarget_RequiresHost(t *testing.T) {
rp := validProxy()
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "empty host", "cluster target must reject empty host")
}
// TestValidate_HTTPClusterTarget_RequiresDirectUpstream pins the second
// half of the cluster-target rule: DirectUpstream must be true so the
// stdlib transport branch in MultiTransport is taken. Without it the
// embedded NetBird client would try to dial the cluster address through
// the WG tunnel, which is the wrong network for a cluster upstream.
func TestValidate_HTTPClusterTarget_RequiresDirectUpstream(t *testing.T) {
rp := validProxy()
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "direct upstream disabled", "cluster target must reject direct_upstream=false")
}
func TestValidate_L4ClusterTarget(t *testing.T) {
rp := validProxy()
rp.Mode = ModeTCP
rp.ListenPort = 9000
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "tcp",
Enabled: true,
}}
require.NoError(t, rp.Validate(), "L4 cluster target must validate without an explicit port")
}
func TestService_Copy_RoundtripsPrivate(t *testing.T) {
svc := validProxy()
svc.Private = true
svc.AccessGroups = []string{"grp-admins", "grp-ops"}
cp := svc.Copy()
require.NotNil(t, cp)
assert.True(t, cp.Private)
assert.Equal(t, []string{"grp-admins", "grp-ops"}, cp.AccessGroups)
cp.Private = false
assert.True(t, svc.Private)
cp.AccessGroups[0] = "grp-other"
assert.Equal(t, []string{"grp-admins", "grp-ops"}, svc.AccessGroups)
}
func TestService_APIRoundtrip_Private(t *testing.T) {
enabled := true
private := true
accessGroups := []string{"grp-admins"}
targets := []api.ServiceTarget{{
TargetId: "eu.proxy.netbird.io",
TargetType: api.ServiceTargetTargetType("cluster"),
Protocol: "http",
Port: 80,
Enabled: true,
}}
req := &api.ServiceRequest{
Name: "svc-private",
Domain: "myapp.eu.proxy.netbird.io",
Enabled: enabled,
Private: &private,
AccessGroups: &accessGroups,
Targets: &targets,
}
svc := &Service{}
require.NoError(t, svc.FromAPIRequest(req, "acc-1"))
assert.True(t, svc.Private)
assert.Equal(t, []string{"grp-admins"}, svc.AccessGroups)
resp := svc.ToAPIResponse()
require.NotNil(t, resp.Private)
assert.True(t, *resp.Private)
require.NotNil(t, resp.AccessGroups)
assert.Equal(t, []string{"grp-admins"}, *resp.AccessGroups)
}
func TestValidate_Private_RequiresAccessGroups(t *testing.T) {
rp := validProxy()
rp.Private = true
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "access group")
}
func TestValidate_Private_RejectsBearerAuth(t *testing.T) {
rp := validProxy()
rp.Private = true
rp.AccessGroups = []string{"grp-admins"}
rp.Auth.BearerAuth = &BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"grp-sso"},
}
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "mutually exclusive")
}
func TestValidate_Private_AcceptsNonClusterTargets(t *testing.T) {
rp := validProxy()
rp.Private = true
rp.AccessGroups = []string{"grp-admins"}
require.NoError(t, rp.Validate())
}
func TestValidate_Private_AcceptsClusterTargetWithAccessGroups(t *testing.T) {
rp := validProxy()
rp.Private = true
rp.AccessGroups = []string{"grp-admins"}
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
require.NoError(t, rp.Validate())
}
func TestValidate_Private_RejectsNonHTTPMode(t *testing.T) {
rp := validProxy()
rp.Private = true
rp.AccessGroups = []string{"grp-admins"}
rp.Mode = ModeTCP
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "tcp",
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "HTTP")
}

View File

@@ -20,20 +20,6 @@ type KeyPair struct {
type Claims struct {
jwt.RegisteredClaims
Method auth.Method `json:"method"`
// Email is the calling user's email address. Carried so the
// proxy can stamp identity on upstream requests (e.g.
// x-litellm-end-user-id) without an extra management
// round-trip on every cookie-bearing request.
Email string `json:"email,omitempty"`
// Groups carries the user's group IDs so the proxy can stamp them
// onto upstream requests (X-NetBird-Groups) from the cookie path
// without an extra management round-trip.
Groups []string `json:"groups,omitempty"`
// GroupNames carries the human-readable display names for the ids
// in Groups, ordered identically (positional pairing). Slice may be
// shorter than Groups for tokens minted before names were
// resolvable; the consumer falls back to ids for missing positions.
GroupNames []string `json:"group_names,omitempty"`
}
func GenerateKeyPair() (*KeyPair, error) {
@@ -48,13 +34,7 @@ func GenerateKeyPair() (*KeyPair, error) {
}, nil
}
// SignToken mints a session JWT for the given user and domain. email,
// groups, and groupNames, when non-empty, are embedded so the proxy can
// authorise and stamp identity for policy-aware middlewares without a
// management round-trip on every cookie-bearing request. groupNames
// pairs positionally with groups; pass nil when names couldn't be
// resolved.
func SignToken(privKeyB64, userID, email, domain string, method auth.Method, groups, groupNames []string, expiration time.Duration) (string, error) {
func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration time.Duration) (string, error) {
privKeyBytes, err := base64.StdEncoding.DecodeString(privKeyB64)
if err != nil {
return "", fmt.Errorf("decode private key: %w", err)
@@ -76,10 +56,7 @@ func SignToken(privKeyB64, userID, email, domain string, method auth.Method, gro
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
Method: method,
Email: email,
Groups: append([]string(nil), groups...),
GroupNames: append([]string(nil), groupNames...),
Method: method,
}
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)

View File

@@ -9,7 +9,6 @@ import (
"encoding/hex"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
@@ -137,12 +136,9 @@ type proxyConnection struct {
tokenID string
capabilities *proto.ProxyCapabilities
stream proto.ProxyService_GetMappingUpdateServer
// syncStream is set when the proxy connected via SyncMappings.
// When non-nil, the sender goroutine uses this instead of stream.
syncStream proto.ProxyService_SyncMappingsServer
sendChan chan *proto.GetMappingUpdateResponse
ctx context.Context
cancel context.CancelFunc
sendChan chan *proto.GetMappingUpdateResponse
ctx context.Context
cancel context.CancelFunc
}
func enforceAccountScope(ctx context.Context, requestAccountID string) error {
@@ -210,323 +206,145 @@ func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller
s.proxyController = proxyController
}
// proxyConnectParams holds the validated parameters extracted from either
// a GetMappingUpdateRequest or a SyncMappingsInit message.
type proxyConnectParams struct {
proxyID string
address string
capabilities *proto.ProxyCapabilities
}
// GetMappingUpdate handles the control stream with proxy clients
func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error {
params, err := s.validateProxyConnect(req.GetProxyId(), req.GetAddress(), stream.Context())
if err != nil {
return err
}
params.capabilities = req.GetCapabilities()
ctx := stream.Context()
conn, proxyRecord, err := s.registerProxyConnection(stream.Context(), params, &proxyConnection{
stream: stream,
})
if err != nil {
return err
}
if err := s.sendSnapshot(stream.Context(), conn); err != nil {
s.cleanupFailedSnapshot(stream.Context(), conn)
return fmt.Errorf("send snapshot to proxy %s: %w", params.proxyID, err)
}
errChan := make(chan error, 2)
go s.sender(conn, errChan)
return s.serveProxyConnection(conn, proxyRecord, errChan, false)
}
// SyncMappings implements the bidirectional SyncMappings RPC.
// It mirrors GetMappingUpdate but provides application-level back-pressure:
// management waits for an ack from the proxy before sending the next batch.
func (s *ProxyServiceServer) SyncMappings(stream proto.ProxyService_SyncMappingsServer) error {
init, err := recvSyncInit(stream)
if err != nil {
return err
}
params, err := s.validateProxyConnect(init.GetProxyId(), init.GetAddress(), stream.Context())
if err != nil {
return err
}
params.capabilities = init.GetCapabilities()
conn, proxyRecord, err := s.registerProxyConnection(stream.Context(), params, &proxyConnection{
syncStream: stream,
})
if err != nil {
return err
}
if err := s.sendSnapshotSync(stream.Context(), conn, stream); err != nil {
s.cleanupFailedSnapshot(stream.Context(), conn)
return fmt.Errorf("send snapshot to proxy %s: %w", params.proxyID, err)
}
errChan := make(chan error, 2)
go s.sender(conn, errChan)
go s.drainRecv(stream, errChan)
return s.serveProxyConnection(conn, proxyRecord, errChan, true)
}
// recvSyncInit receives and validates the first message on a SyncMappings stream.
func recvSyncInit(stream proto.ProxyService_SyncMappingsServer) (*proto.SyncMappingsInit, error) {
firstMsg, err := stream.Recv()
if err != nil {
return nil, status.Errorf(codes.Internal, "receive init: %v", err)
}
init := firstMsg.GetInit()
if init == nil {
return nil, status.Errorf(codes.InvalidArgument, "first message must be init")
}
return init, nil
}
// validateProxyConnect validates the proxy ID and address, and checks cluster
// address availability for account-scoped tokens.
func (s *ProxyServiceServer) validateProxyConnect(proxyID, address string, ctx context.Context) (proxyConnectParams, error) {
if proxyID == "" {
return proxyConnectParams{}, status.Errorf(codes.InvalidArgument, "proxy_id is required")
}
if !isProxyAddressValid(address) {
return proxyConnectParams{}, status.Errorf(codes.InvalidArgument, "proxy address is invalid")
}
token := GetProxyTokenFromContext(ctx)
if token != nil && token.AccountID != nil {
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, address, *token.AccountID)
if err != nil {
return proxyConnectParams{}, status.Errorf(codes.Internal, "check cluster address: %v", err)
}
if !available {
return proxyConnectParams{}, status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", address)
}
}
return proxyConnectParams{proxyID: proxyID, address: address}, nil
}
// registerProxyConnection creates a proxyConnection, registers it with the
// proxy manager and cluster, and stores it in connectedProxies. The caller
// provides a partially initialised connSeed with stream-specific fields set;
// the remaining fields are filled in here.
func (s *ProxyServiceServer) registerProxyConnection(ctx context.Context, params proxyConnectParams, connSeed *proxyConnection) (*proxyConnection, *proxy.Proxy, error) {
peerInfo := PeerIPFromContext(ctx)
log.Infof("New proxy connection from %s", peerInfo)
proxyID := req.GetProxyId()
if proxyID == "" {
return status.Errorf(codes.InvalidArgument, "proxy_id is required")
}
proxyAddress := req.GetAddress()
if !isProxyAddressValid(proxyAddress) {
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
}
var accountID *string
var tokenID string
if token := GetProxyTokenFromContext(ctx); token != nil {
if token.AccountID != nil {
accountID = token.AccountID
token := GetProxyTokenFromContext(ctx)
if token != nil && token.AccountID != nil {
accountID = token.AccountID
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID)
if err != nil {
return status.Errorf(codes.Internal, "check cluster address: %v", err)
}
if !available {
return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress)
}
}
var tokenID string
if token != nil {
tokenID = token.ID
}
sessionID := uuid.NewString()
s.supersedePriorConnection(params.proxyID, sessionID)
connCtx, cancel := context.WithCancel(ctx)
connSeed.proxyID = params.proxyID
connSeed.sessionID = sessionID
connSeed.address = params.address
connSeed.accountID = accountID
connSeed.tokenID = tokenID
connSeed.capabilities = params.capabilities
connSeed.sendChan = make(chan *proto.GetMappingUpdateResponse, 100)
connSeed.ctx = connCtx
connSeed.cancel = cancel
var caps *proxy.Capabilities
if c := params.capabilities; c != nil {
caps = &proxy.Capabilities{
SupportsCustomPorts: c.SupportsCustomPorts,
RequireSubdomain: c.RequireSubdomain,
SupportsCrowdsec: c.SupportsCrowdsec,
Private: c.Private,
}
}
proxyRecord, err := s.proxyManager.Connect(ctx, params.proxyID, sessionID, params.address, peerInfo, accountID, caps)
if err != nil {
cancel()
if accountID != nil {
return nil, nil, status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
}
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", params.proxyID, err)
return nil, nil, status.Errorf(codes.Internal, "register proxy in database: %v", err)
}
s.connectedProxies.Store(params.proxyID, connSeed)
if err := s.proxyController.RegisterProxyToCluster(ctx, params.address, params.proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", params.proxyID, err)
}
return connSeed, proxyRecord, nil
}
// supersedePriorConnection cancels any existing connection for the given proxy.
func (s *ProxyServiceServer) supersedePriorConnection(proxyID, newSessionID string) {
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
oldConn := old.(*proxyConnection)
log.WithFields(log.Fields{
"proxy_id": proxyID,
"old_session_id": oldConn.sessionID,
"new_session_id": newSessionID,
"new_session_id": sessionID,
}).Info("Superseding existing proxy connection")
oldConn.cancel()
}
}
// cleanupFailedSnapshot removes the connection from the cluster and store
// after a snapshot send failure.
func (s *ProxyServiceServer) cleanupFailedSnapshot(ctx context.Context, conn *proxyConnection) {
if s.connectedProxies.CompareAndDelete(conn.proxyID, conn) {
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil {
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", conn.proxyID, err)
connCtx, cancel := context.WithCancel(ctx)
conn := &proxyConnection{
proxyID: proxyID,
sessionID: sessionID,
address: proxyAddress,
accountID: accountID,
tokenID: tokenID,
capabilities: req.GetCapabilities(),
stream: stream,
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
ctx: connCtx,
cancel: cancel,
}
var caps *proxy.Capabilities
if c := req.GetCapabilities(); c != nil {
caps = &proxy.Capabilities{
SupportsCustomPorts: c.SupportsCustomPorts,
RequireSubdomain: c.RequireSubdomain,
SupportsCrowdsec: c.SupportsCrowdsec,
}
}
conn.cancel()
if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil {
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", conn.proxyID, err)
}
}
// drainRecv consumes and discards messages from a bidirectional stream.
// The proxy sends an ack for every incremental update; we don't need them
// after the snapshot phase. Recv errors are forwarded to errChan.
func (s *ProxyServiceServer) drainRecv(stream proto.ProxyService_SyncMappingsServer, errChan chan<- error) {
for {
if _, err := stream.Recv(); err != nil {
errChan <- err
return
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, accountID, caps)
if err != nil {
cancel()
if accountID != nil {
return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
}
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
}
}
// serveProxyConnection runs the post-snapshot lifecycle: heartbeat, sender,
// and wait for termination. When bidi is true, normal stream closure (EOF,
// canceled) is treated as a clean disconnect rather than an error.
func (s *ProxyServiceServer) serveProxyConnection(conn *proxyConnection, proxyRecord *proxy.Proxy, errChan <-chan error, bidi bool) error {
s.connectedProxies.Store(proxyID, conn)
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
if err := s.sendSnapshot(ctx, conn); err != nil {
if s.connectedProxies.CompareAndDelete(proxyID, conn) {
if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil {
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr)
}
}
cancel()
if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil {
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr)
}
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
}
errChan := make(chan error, 2)
go s.sender(conn, errChan)
log.WithFields(log.Fields{
"proxy_id": conn.proxyID,
"session_id": conn.sessionID,
"address": conn.address,
"cluster_addr": conn.address,
"account_id": conn.accountID,
"proxy_id": proxyID,
"session_id": sessionID,
"address": proxyAddress,
"cluster_addr": proxyAddress,
"account_id": accountID,
"total_proxies": len(s.GetConnectedProxies()),
}).Info("Proxy registered in cluster")
defer func() {
if !s.connectedProxies.CompareAndDelete(proxyID, conn) {
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", proxyID, sessionID)
cancel()
return
}
defer s.disconnectProxy(conn)
go s.heartbeat(conn.ctx, conn, proxyRecord)
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
}
if err := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); err != nil {
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
}
cancel()
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
}()
go s.heartbeat(connCtx, conn, proxyRecord)
select {
case err := <-errChan:
if bidi && isStreamClosed(err) {
log.Infof("Proxy %s stream closed", conn.proxyID)
return nil
}
log.Warnf("Failed to send update: %v", err)
return fmt.Errorf("send update to proxy %s: %w", conn.proxyID, err)
case <-conn.ctx.Done():
log.Infof("Proxy %s context canceled", conn.proxyID)
return conn.ctx.Err()
log.WithContext(ctx).Warnf("Failed to send update: %v", err)
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
case <-connCtx.Done():
log.WithContext(ctx).Infof("Proxy %s context canceled", proxyID)
return connCtx.Err()
}
}
// disconnectProxy removes the connection from cluster and store, unless it
// has already been superseded by a newer connection.
func (s *ProxyServiceServer) disconnectProxy(conn *proxyConnection) {
if !s.connectedProxies.CompareAndDelete(conn.proxyID, conn) {
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", conn.proxyID, conn.sessionID)
conn.cancel()
return
}
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil {
log.Warnf("Failed to unregister proxy %s from cluster: %v", conn.proxyID, err)
}
if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil {
log.Warnf("Failed to mark proxy %s as disconnected: %v", conn.proxyID, err)
}
conn.cancel()
log.Infof("Proxy %s session %s disconnected", conn.proxyID, conn.sessionID)
}
// sendSnapshotSync sends the initial snapshot with back-pressure: it sends
// one batch, then waits for the proxy to ack before sending the next.
func (s *ProxyServiceServer) sendSnapshotSync(ctx context.Context, conn *proxyConnection, stream proto.ProxyService_SyncMappingsServer) error {
if !isProxyAddressValid(conn.address) {
return fmt.Errorf("proxy address is invalid")
}
if s.snapshotBatchSize <= 0 {
return fmt.Errorf("invalid snapshot batch size: %d", s.snapshotBatchSize)
}
mappings, err := s.snapshotServiceMappings(ctx, conn)
if err != nil {
return err
}
for i := 0; i < len(mappings); i += s.snapshotBatchSize {
end := i + s.snapshotBatchSize
if end > len(mappings) {
end = len(mappings)
}
for _, m := range mappings[i:end] {
token, err := s.tokenStore.GenerateToken(m.AccountId, m.Id, s.proxyTokenTTL())
if err != nil {
return fmt.Errorf("generate auth token for service %s: %w", m.Id, err)
}
m.AuthToken = token
}
if err := stream.Send(&proto.SyncMappingsResponse{
Mapping: mappings[i:end],
InitialSyncComplete: end == len(mappings),
}); err != nil {
return fmt.Errorf("send snapshot batch: %w", err)
}
if err := waitForAck(stream); err != nil {
return err
}
}
if len(mappings) == 0 {
if err := stream.Send(&proto.SyncMappingsResponse{
InitialSyncComplete: true,
}); err != nil {
return fmt.Errorf("send snapshot completion: %w", err)
}
if err := waitForAck(stream); err != nil {
return err
}
}
return nil
}
func waitForAck(stream proto.ProxyService_SyncMappingsServer) error {
msg, err := stream.Recv()
if err != nil {
return fmt.Errorf("receive ack: %w", err)
}
if msg.GetAck() == nil {
return fmt.Errorf("expected ack, got %T", msg.GetMsg())
}
return nil
}
// heartbeat updates the proxy's last_seen timestamp every minute and
// disconnects the proxy if its access token has been revoked.
func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, p *proxy.Proxy) {
@@ -563,9 +381,6 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
if !isProxyAddressValid(conn.address) {
return fmt.Errorf("proxy address is invalid")
}
if s.snapshotBatchSize <= 0 {
return fmt.Errorf("invalid snapshot batch size: %d", s.snapshotBatchSize)
}
mappings, err := s.snapshotServiceMappings(ctx, conn)
if err != nil {
@@ -645,26 +460,12 @@ func isProxyAddressValid(addr string) bool {
return err == nil
}
// isStreamClosed returns true for errors that indicate normal stream
// termination: io.EOF, context cancellation, or gRPC Canceled.
func isStreamClosed(err error) bool {
if err == nil {
return false
}
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
return true
}
return status.Code(err) == codes.Canceled
}
// sender handles sending messages to proxy.
// When conn.syncStream is set the message is sent as SyncMappingsResponse;
// otherwise the legacy GetMappingUpdateResponse stream is used.
// sender handles sending messages to proxy
func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) {
for {
select {
case resp := <-conn.sendChan:
if err := conn.sendResponse(resp); err != nil {
if err := conn.stream.Send(resp); err != nil {
errChan <- err
return
}
@@ -674,17 +475,6 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
}
}
// sendResponse sends a mapping update on whichever stream the proxy connected with.
func (conn *proxyConnection) sendResponse(resp *proto.GetMappingUpdateResponse) error {
if conn.syncStream != nil {
return conn.syncStream.Send(&proto.SyncMappingsResponse{
Mapping: resp.Mapping,
InitialSyncComplete: resp.InitialSyncComplete,
})
}
return conn.stream.Send(resp)
}
// SendAccessLog processes access log from proxy
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
accessLog := req.GetLog()
@@ -751,15 +541,10 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
return true
}
connUpdate = &proto.GetMappingUpdateResponse{
Mapping: filtered,
InitialSyncComplete: update.InitialSyncComplete,
Mapping: filtered,
InitialSyncComplete: update.InitialSyncComplete,
}
}
// Drop mappings the proxy lacks capability for (e.g. private without SupportsPrivateService).
connUpdate = filterMappingsForProxy(conn, connUpdate)
if connUpdate == nil || len(connUpdate.Mapping) == 0 {
return true
}
resp := s.perProxyMessage(connUpdate, conn.proxyID)
if resp == nil {
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
@@ -888,20 +673,16 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
}
}
// proxyAcceptsMapping returns whether the proxy can receive this mapping.
// Private mappings require SupportsPrivateService; custom-port L4 mappings
// require SupportsCustomPorts. Remove operations always pass so proxies can
// clean up.
// proxyAcceptsMapping returns whether the proxy should receive this mapping.
// Old proxies that never reported capabilities are skipped for non-TLS L4
// mappings with a custom listen port, since they don't understand the
// protocol. Proxies that report capabilities (even SupportsCustomPorts=false)
// are new enough to handle the mapping. TLS uses SNI routing and works on
// any proxy. Delete operations are always sent so proxies can clean up.
func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool {
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
return true
}
if mapping.GetPrivate() {
caps := conn.capabilities
if caps == nil || caps.SupportsPrivateService == nil || !*caps.SupportsPrivateService {
return false
}
}
if mapping.ListenPort == 0 || mapping.Mode == "tls" {
return true
}
@@ -910,29 +691,6 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo
return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil
}
// filterMappingsForProxy drops mappings the proxy cannot safely receive
// (e.g. private mappings to a proxy without SupportsPrivateService).
// Returns the input unchanged when no filtering is needed.
func filterMappingsForProxy(conn *proxyConnection, update *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse {
if update == nil || len(update.Mapping) == 0 {
return update
}
kept := make([]*proto.ProxyMapping, 0, len(update.Mapping))
for _, m := range update.Mapping {
if !proxyAcceptsMapping(conn, m) {
continue
}
kept = append(kept, m)
}
if len(kept) == len(update.Mapping) {
return update
}
return &proto.GetMappingUpdateResponse{
Mapping: kept,
InitialSyncComplete: update.InitialSyncComplete,
}
}
// perProxyMessage returns a copy of update with a fresh one-time token for
// create/update operations. For delete operations the original mapping is
// used unchanged because proxies do not need to authenticate for removal.
@@ -994,10 +752,7 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
authenticated, userId, method := s.authenticateRequest(ctx, req, service)
// Non-OIDC schemes (PIN/Password/Header) authenticate against per-service
// secrets and have no user-level group context, so groups stay nil. Email
// is also empty — these schemes don't resolve a user record at sign time.
token, err := s.generateSessionToken(ctx, authenticated, service, userId, "", method, nil, nil)
token, err := s.generateSessionToken(ctx, authenticated, service, userId, method)
if err != nil {
return nil, err
}
@@ -1086,7 +841,7 @@ func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err err
}
}
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId, userEmail string, method proxyauth.Method, groupIDs, groupNames []string) (string, error) {
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId string, method proxyauth.Method) (string, error) {
if !authenticated || service.SessionPrivateKey == "" {
return "", nil
}
@@ -1094,11 +849,8 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
token, err := sessionkey.SignToken(
service.SessionPrivateKey,
userId,
userEmail,
service.Domain,
method,
groupIDs,
groupNames,
proxyauth.DefaultSessionExpiry,
)
if err != nil {
@@ -1109,26 +861,6 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
return token, nil
}
// pairGroupIDsAndNames splits a slice of resolved *types.Group records
// into parallel id and name slices. ids[i] and names[i] always pair to
// the same group. nil entries (orphan ids the manager couldn't resolve)
// are skipped so the consumer can rely on positional pairing.
func pairGroupIDsAndNames(groups []*types.Group) (ids, names []string) {
if len(groups) == 0 {
return nil, nil
}
ids = make([]string, 0, len(groups))
names = make([]string, 0, len(groups))
for _, g := range groups {
if g == nil {
continue
}
ids = append(ids, g.ID)
names = append(names, g.Name)
}
return ids, names
}
// SendStatusUpdate handles status updates from proxy clients.
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
@@ -1393,9 +1125,7 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
return verifier, redirectURL, nil
}
// GenerateSessionToken creates a signed session JWT for the given domain and
// user. The user's group memberships are embedded in the token so policy-aware
// middlewares on the proxy can authorise without an extra management round-trip.
// GenerateSessionToken creates a signed session JWT for the given domain and user.
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
service, err := s.getServiceByDomain(ctx, domain)
if err != nil {
@@ -1406,29 +1136,11 @@ func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, u
return "", fmt.Errorf("no session key configured for domain: %s", domain)
}
var (
email string
groupIDs []string
groupNames []string
)
if s.usersManager != nil {
user, userGroups, uerr := s.usersManager.GetUserWithGroups(ctx, userID)
if uerr != nil {
log.WithContext(ctx).Debugf("session token mint: lookup user %s: %v", userID, uerr)
} else if user != nil {
email = user.Email
groupIDs, groupNames = pairGroupIDsAndNames(userGroups)
}
}
return sessionkey.SignToken(
service.SessionPrivateKey,
userID,
email,
domain,
method,
groupIDs,
groupNames,
proxyauth.DefaultSessionExpiry,
)
}
@@ -1532,7 +1244,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil
}
userID, _, _, _, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes)
userID, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
@@ -1545,7 +1257,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil
}
user, userGroups, err := s.usersManager.GetUserWithGroups(ctx, userID)
user, err := s.usersManager.GetUser(ctx, userID)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
@@ -1579,15 +1291,12 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
"user_id": userID,
"error": err.Error(),
}).Debug("ValidateSession: access denied")
groupIDs, groupNames := pairGroupIDsAndNames(userGroups)
//nolint:nilerr
return &proto.ValidateSessionResponse{
Valid: false,
UserId: user.Id,
UserEmail: user.Email,
DeniedReason: "not_in_group",
PeerGroupIds: groupIDs,
PeerGroupNames: groupNames,
Valid: false,
UserId: user.Id,
UserEmail: user.Email,
DeniedReason: "not_in_group",
}, nil
}
@@ -1597,13 +1306,10 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
"email": user.Email,
}).Debug("ValidateSession: access granted")
groupIDs, groupNames := pairGroupIDsAndNames(userGroups)
return &proto.ValidateSessionResponse{
Valid: true,
UserId: user.Id,
UserEmail: user.Email,
PeerGroupIds: groupIDs,
PeerGroupNames: groupNames,
Valid: true,
UserId: user.Id,
UserEmail: user.Email,
}, nil
}
@@ -1636,154 +1342,3 @@ func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *
}
func ptr[T any](v T) *T { return &v }
// ValidateTunnelPeer resolves an inbound peer by its WireGuard tunnel IP and
// checks the peer's group membership against the service's access groups.
// Peers without a user (machine agents, automation workloads) are first-class
// callers; authorisation runs off peer-group memberships rather than the
// optional owning user's auto-groups. On success a session JWT is minted so
// the proxy can install a cookie and skip subsequent management round-trips.
func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
domain := req.GetDomain()
tunnelIPStr := req.GetTunnelIp()
if domain == "" || tunnelIPStr == "" {
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "missing domain or tunnel_ip",
}, nil
}
tunnelIP := net.ParseIP(tunnelIPStr)
if tunnelIP == nil {
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "invalid_tunnel_ip",
}, nil
}
service, err := s.getServiceByDomain(ctx, domain)
if err != nil {
log.WithFields(log.Fields{"domain": domain, "error": err.Error()}).Debug("ValidateTunnelPeer: service not found")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "service_not_found",
}, nil
}
// Mirror ValidateSession: account-scoped (BYOP) proxy tokens may only
// validate and mint session cookies for their own account's domains.
if err := enforceAccountScope(ctx, service.AccountID); err != nil {
return nil, err
}
peer, err := s.peersManager.GetPeerByTunnelIP(ctx, service.AccountID, tunnelIP)
if err != nil || peer == nil {
log.WithFields(log.Fields{"domain": domain, "tunnel_ip": tunnelIPStr}).Debug("ValidateTunnelPeer: peer not found")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "peer_not_found",
}, nil
}
_, peerGroups, err := s.peersManager.GetPeerWithGroups(ctx, service.AccountID, peer.ID)
if err != nil {
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: peer groups lookup failed")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "peer_not_found",
}, nil
}
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
// Resolve the principal: when the peer is linked to a user, the human
// is the principal so multiple peers owned by the same user share a
// single identity. Unlinked peers (machine agents) are their own
// principal keyed on peer.ID. displayIdentity is what upstream gateways
// tag spend with — user.Email when linked, peer.Name when not.
principalID := peer.ID
displayIdentity := peer.Name
if peer.UserID != "" {
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
principalID = user.Id
if user.Email != "" {
displayIdentity = user.Email
}
}
}
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
UserId: principalID,
UserEmail: displayIdentity,
DeniedReason: "not_in_group",
PeerGroupIds: groupIDs,
PeerGroupNames: groupNames,
}, nil
}
token, err := s.generateSessionToken(ctx, true, service, principalID, displayIdentity, proxyauth.MethodOIDC, groupIDs, groupNames)
if err != nil {
return nil, err
}
log.WithFields(log.Fields{
"domain": domain,
"tunnel_ip": tunnelIPStr,
"peer_id": peer.ID,
"principal_id": principalID,
}).Debug("ValidateTunnelPeer: access granted")
return &proto.ValidateTunnelPeerResponse{
Valid: true,
UserId: principalID,
UserEmail: displayIdentity,
SessionToken: token,
PeerGroupIds: groupIDs,
PeerGroupNames: groupNames,
}, nil
}
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
// groups. Private services authorise against AccessGroups (empty list fails
// closed — Validate() rejects that at save time but the RPC is the security
// boundary and must not trust upstream state). Bearer-auth services authorise
// against DistributionGroups when populated. Non-private non-bearer services
// are open.
func checkPeerGroupAccess(service *rpservice.Service, peerGroupIDs []string) error {
if service.Private {
if len(service.AccessGroups) == 0 {
return fmt.Errorf("private service has no access groups")
}
return matchAnyGroup(service.AccessGroups, peerGroupIDs)
}
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled && len(service.Auth.BearerAuth.DistributionGroups) > 0 {
return matchAnyGroup(service.Auth.BearerAuth.DistributionGroups, peerGroupIDs)
}
return nil
}
// matchAnyGroup returns nil when peerGroupIDs intersects allowedGroups,
// else a non-nil error.
func matchAnyGroup(allowedGroups, peerGroupIDs []string) error {
if len(allowedGroups) == 0 {
return fmt.Errorf("no allowed groups configured")
}
allowed := make(map[string]struct{}, len(allowedGroups))
for _, g := range allowedGroups {
allowed[g] = struct{}{}
}
for _, g := range peerGroupIDs {
if _, ok := allowed[g]; ok {
return nil
}
}
return fmt.Errorf("peer not in allowed groups")
}

View File

@@ -109,7 +109,7 @@ func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain s
return nil, errors.New("service not found for domain: " + domain)
}
func (m *mockReverseProxyManager) GetClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
return nil, nil
}
@@ -129,14 +129,6 @@ func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.U
return user, nil
}
func (m *mockUsersManager) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) {
user, err := m.GetUser(ctx, userID)
if err != nil {
return nil, nil, err
}
return user, nil, nil
}
func TestValidateUserGroupAccess(t *testing.T) {
tests := []struct {
name string
@@ -428,46 +420,3 @@ func TestGetAccountProxyByDomain(t *testing.T) {
})
}
}
func TestCheckPeerGroupAccess(t *testing.T) {
t.Run("private with empty AccessGroups denies", func(t *testing.T) {
svc := &service.Service{Private: true, AccessGroups: nil}
err := checkPeerGroupAccess(svc, []string{"grp-admins"})
require.Error(t, err)
assert.Contains(t, err.Error(), "no access groups")
})
t.Run("private with peer in AccessGroups allows", func(t *testing.T) {
svc := &service.Service{Private: true, AccessGroups: []string{"grp-admins", "grp-ops"}}
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-other", "grp-ops"}))
})
t.Run("private with peer outside AccessGroups denies", func(t *testing.T) {
svc := &service.Service{Private: true, AccessGroups: []string{"grp-admins"}}
assert.Error(t, checkPeerGroupAccess(svc, []string{"grp-other"}))
})
t.Run("bearer enabled with empty DistributionGroups allows", func(t *testing.T) {
svc := &service.Service{
Auth: service.AuthConfig{BearerAuth: &service.BearerAuthConfig{Enabled: true}},
}
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-anyone"}))
})
t.Run("bearer enabled gates on DistributionGroups", func(t *testing.T) {
svc := &service.Service{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"grp-allowed"},
},
},
}
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-allowed"}))
assert.Error(t, checkPeerGroupAccess(svc, []string{"grp-other"}))
})
t.Run("non-private non-bearer is open", func(t *testing.T) {
assert.NoError(t, checkPeerGroupAccess(&service.Service{}, nil))
})
}

View File

@@ -1,411 +0,0 @@
package grpc
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/shared/management/proto"
)
// syncRecordingStream is a mock ProxyService_SyncMappingsServer that records
// sent messages and returns pre-loaded ack responses from Recv.
type syncRecordingStream struct {
grpc.ServerStream
mu sync.Mutex
sent []*proto.SyncMappingsResponse
recvMsgs []*proto.SyncMappingsRequest
recvIdx int
}
func (s *syncRecordingStream) Send(m *proto.SyncMappingsResponse) error {
s.mu.Lock()
defer s.mu.Unlock()
s.sent = append(s.sent, m)
return nil
}
func (s *syncRecordingStream) Recv() (*proto.SyncMappingsRequest, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.recvIdx >= len(s.recvMsgs) {
return nil, fmt.Errorf("no more recv messages")
}
msg := s.recvMsgs[s.recvIdx]
s.recvIdx++
return msg, nil
}
func (s *syncRecordingStream) Context() context.Context { return context.Background() }
func (s *syncRecordingStream) SetHeader(metadata.MD) error { return nil }
func (s *syncRecordingStream) SendHeader(metadata.MD) error { return nil }
func (s *syncRecordingStream) SetTrailer(metadata.MD) {}
func (s *syncRecordingStream) SendMsg(any) error { return nil }
func (s *syncRecordingStream) RecvMsg(any) error { return nil }
func ackMsg() *proto.SyncMappingsRequest {
return &proto.SyncMappingsRequest{
Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}},
}
}
func TestSendSnapshotSync_BatchesWithAcks(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 3
const totalServices = 7 // 3 + 3 + 1 → 3 batches, 3 acks (one per batch, including final)
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
stream := &syncRecordingStream{
recvMsgs: []*proto.SyncMappingsRequest{ackMsg(), ackMsg(), ackMsg()},
}
conn := &proxyConnection{
proxyID: "proxy-a",
address: cluster,
syncStream: stream,
}
err := s.sendSnapshotSync(context.Background(), conn, stream)
require.NoError(t, err)
require.Len(t, stream.sent, 3, "should send ceil(7/3) = 3 batches")
assert.Len(t, stream.sent[0].Mapping, 3)
assert.False(t, stream.sent[0].InitialSyncComplete)
assert.Len(t, stream.sent[1].Mapping, 3)
assert.False(t, stream.sent[1].InitialSyncComplete)
assert.Len(t, stream.sent[2].Mapping, 1)
assert.True(t, stream.sent[2].InitialSyncComplete)
// All 3 acks consumed — including the final batch.
assert.Equal(t, 3, stream.recvIdx)
}
func TestSendSnapshotSync_SingleBatchWaitsForAck(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 100
const totalServices = 5
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
stream := &syncRecordingStream{
recvMsgs: []*proto.SyncMappingsRequest{ackMsg()},
}
conn := &proxyConnection{
proxyID: "proxy-a",
address: cluster,
syncStream: stream,
}
err := s.sendSnapshotSync(context.Background(), conn, stream)
require.NoError(t, err)
require.Len(t, stream.sent, 1)
assert.Len(t, stream.sent[0].Mapping, totalServices)
assert.True(t, stream.sent[0].InitialSyncComplete)
assert.Equal(t, 1, stream.recvIdx, "final batch ack must be consumed")
}
func TestSendSnapshotSync_EmptySnapshot(t *testing.T) {
const cluster = "cluster.example.com"
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil)
s := newSnapshotTestServer(t, 500)
s.serviceManager = mgr
stream := &syncRecordingStream{
recvMsgs: []*proto.SyncMappingsRequest{ackMsg()},
}
conn := &proxyConnection{
proxyID: "proxy-a",
address: cluster,
syncStream: stream,
}
err := s.sendSnapshotSync(context.Background(), conn, stream)
require.NoError(t, err)
require.Len(t, stream.sent, 1, "empty snapshot must still send sync-complete")
assert.Empty(t, stream.sent[0].Mapping)
assert.True(t, stream.sent[0].InitialSyncComplete)
assert.Equal(t, 1, stream.recvIdx, "empty snapshot ack must be consumed")
}
func TestSendSnapshotSync_MissingAckReturnsError(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 2
const totalServices = 4 // 2 batches → 1 ack needed, but we provide none
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
// No acks available — Recv will return error.
stream := &syncRecordingStream{}
conn := &proxyConnection{
proxyID: "proxy-a",
address: cluster,
syncStream: stream,
}
err := s.sendSnapshotSync(context.Background(), conn, stream)
require.Error(t, err)
assert.Contains(t, err.Error(), "receive ack")
// First batch should have been sent before the error.
require.Len(t, stream.sent, 1)
}
func TestSendSnapshotSync_WrongMessageInsteadOfAck(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 2
const totalServices = 4
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
// Send an init message instead of an ack.
stream := &syncRecordingStream{
recvMsgs: []*proto.SyncMappingsRequest{
{Msg: &proto.SyncMappingsRequest_Init{Init: &proto.SyncMappingsInit{ProxyId: "bad"}}},
},
}
conn := &proxyConnection{
proxyID: "proxy-a",
address: cluster,
syncStream: stream,
}
err := s.sendSnapshotSync(context.Background(), conn, stream)
require.Error(t, err)
assert.Contains(t, err.Error(), "expected ack")
}
func TestSendSnapshotSync_BackPressureOrdering(t *testing.T) {
// Verify batches are sent strictly sequentially — batch N+1 is not sent
// until the ack for batch N is received, including the final batch.
const cluster = "cluster.example.com"
const batchSize = 2
const totalServices = 6 // 3 batches, 3 acks
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
var mu sync.Mutex
var events []string
// Build a stream that logs send/recv events so we can verify ordering.
ackCh := make(chan struct{}, 3)
stream := &orderTrackingStream{
mu: &mu,
events: &events,
ackCh: ackCh,
}
conn := &proxyConnection{
proxyID: "proxy-a",
address: cluster,
syncStream: stream,
}
// Feed acks asynchronously after a short delay to simulate real proxy.
go func() {
for range 3 {
time.Sleep(10 * time.Millisecond)
ackCh <- struct{}{}
}
}()
err := s.sendSnapshotSync(context.Background(), conn, stream)
require.NoError(t, err)
mu.Lock()
defer mu.Unlock()
// Expected: send, recv-ack, send, recv-ack, send, recv-ack.
require.Len(t, events, 6)
assert.Equal(t, "send", events[0])
assert.Equal(t, "recv", events[1])
assert.Equal(t, "send", events[2])
assert.Equal(t, "recv", events[3])
assert.Equal(t, "send", events[4])
assert.Equal(t, "recv", events[5])
}
// orderTrackingStream logs "send" and "recv" events and blocks Recv until
// an ack is signaled via ackCh.
type orderTrackingStream struct {
grpc.ServerStream
mu *sync.Mutex
events *[]string
ackCh chan struct{}
}
func (s *orderTrackingStream) Send(_ *proto.SyncMappingsResponse) error {
s.mu.Lock()
*s.events = append(*s.events, "send")
s.mu.Unlock()
return nil
}
func (s *orderTrackingStream) Recv() (*proto.SyncMappingsRequest, error) {
<-s.ackCh
s.mu.Lock()
*s.events = append(*s.events, "recv")
s.mu.Unlock()
return ackMsg(), nil
}
func (s *orderTrackingStream) Context() context.Context { return context.Background() }
func (s *orderTrackingStream) SetHeader(metadata.MD) error { return nil }
func (s *orderTrackingStream) SendHeader(metadata.MD) error { return nil }
func (s *orderTrackingStream) SetTrailer(metadata.MD) {}
func (s *orderTrackingStream) SendMsg(any) error { return nil }
func (s *orderTrackingStream) RecvMsg(any) error { return nil }
func TestSendSnapshotSync_TokensGeneratedPerBatch(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 2
const totalServices = 4
const ttl = 100 * time.Millisecond
const ackDelay = 200 * time.Millisecond
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
s.tokenTTL = ttl
// Build a stream that validates tokens immediately on Send, then
// delays the ack to ensure the next batch's tokens are generated fresh.
var validateErrs []error
ackCh := make(chan struct{}, 2)
stream := &tokenValidatingSyncStream{
tokenStore: s.tokenStore,
validateErrs: &validateErrs,
ackCh: ackCh,
}
conn := &proxyConnection{
proxyID: "proxy-a",
address: cluster,
syncStream: stream,
}
go func() {
// Delay first ack so that if tokens were all generated upfront they'd expire.
time.Sleep(ackDelay)
ackCh <- struct{}{}
// Final batch ack — immediate.
ackCh <- struct{}{}
}()
err := s.sendSnapshotSync(context.Background(), conn, stream)
require.NoError(t, err)
require.Empty(t, validateErrs,
"tokens must remain valid: per-batch generation guarantees freshness")
}
type tokenValidatingSyncStream struct {
grpc.ServerStream
tokenStore *OneTimeTokenStore
validateErrs *[]error
ackCh chan struct{}
}
func (s *tokenValidatingSyncStream) Send(m *proto.SyncMappingsResponse) error {
for _, mapping := range m.Mapping {
if err := s.tokenStore.ValidateAndConsume(mapping.AuthToken, mapping.AccountId, mapping.Id); err != nil {
*s.validateErrs = append(*s.validateErrs, fmt.Errorf("svc %s: %w", mapping.Id, err))
}
}
return nil
}
func (s *tokenValidatingSyncStream) Recv() (*proto.SyncMappingsRequest, error) {
<-s.ackCh
return ackMsg(), nil
}
func (s *tokenValidatingSyncStream) Context() context.Context { return context.Background() }
func (s *tokenValidatingSyncStream) SetHeader(metadata.MD) error { return nil }
func (s *tokenValidatingSyncStream) SendHeader(metadata.MD) error { return nil }
func (s *tokenValidatingSyncStream) SetTrailer(metadata.MD) {}
func (s *tokenValidatingSyncStream) SendMsg(any) error { return nil }
func (s *tokenValidatingSyncStream) RecvMsg(any) error { return nil }
func TestConnectionSendResponse_RoutesToSyncStream(t *testing.T) {
stream := &syncRecordingStream{}
conn := &proxyConnection{
syncStream: stream,
}
resp := &proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{
{Id: "svc-1", AccountId: "acct-1", Domain: "example.com"},
},
InitialSyncComplete: true,
}
err := conn.sendResponse(resp)
require.NoError(t, err)
require.Len(t, stream.sent, 1)
assert.Len(t, stream.sent[0].Mapping, 1)
assert.Equal(t, "svc-1", stream.sent[0].Mapping[0].Id)
assert.True(t, stream.sent[0].InitialSyncComplete)
}
func TestConnectionSendResponse_RoutesToLegacyStream(t *testing.T) {
stream := &recordingStream{}
conn := &proxyConnection{
stream: stream,
}
resp := &proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{
{Id: "svc-2", AccountId: "acct-2"},
},
}
err := conn.sendResponse(resp)
require.NoError(t, err)
require.Len(t, stream.messages, 1)
assert.Equal(t, "svc-2", stream.messages[0].Mapping[0].Id)
}

View File

@@ -102,7 +102,7 @@ func generateSessionKeyPair(t *testing.T) (string, string) {
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
t.Helper()
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, nil, time.Hour)
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, time.Hour)
require.NoError(t, err)
return token
}
@@ -125,7 +125,6 @@ func TestValidateSession_UserAllowed(t *testing.T) {
assert.True(t, resp.Valid, "User should be allowed access")
assert.Equal(t, "allowedUserId", resp.UserId)
assert.Empty(t, resp.DeniedReason)
assert.Equal(t, []string{"allowedGroupId"}, resp.GetPeerGroupIds(), "PeerGroupIds must mirror the resolved user's group memberships")
}
func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
@@ -146,7 +145,6 @@ func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
assert.False(t, resp.Valid, "User not in group should be denied")
assert.Equal(t, "not_in_group", resp.DeniedReason)
assert.Equal(t, "nonGroupUserId", resp.UserId)
assert.Empty(t, resp.GetPeerGroupIds(), "PeerGroupIds must mirror the resolved user's actual (empty) memberships on denial")
}
func TestValidateSession_UserInDifferentAccount(t *testing.T) {
@@ -324,7 +322,7 @@ func (m *testValidateSessionServiceManager) GetServiceByDomain(ctx context.Conte
return m.store.GetServiceByDomain(ctx, domain)
}
func (m *testValidateSessionServiceManager) GetClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
return nil, nil
}

View File

@@ -444,7 +444,7 @@ func (m *testServiceManager) GetServiceByDomain(ctx context.Context, domain stri
return m.store.GetServiceByDomain(ctx, domain)
}
func (m *testServiceManager) GetClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
func (m *testServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
return nil, nil
}

View File

@@ -1319,7 +1319,7 @@ func Test_NetworkRouters_Update(t *testing.T) {
},
},
{
name: "Update non-existing router returns not found",
name: "Update non-existing router creates it",
networkId: "testNetworkId",
routerId: "nonExistingRouterId",
requestBody: &api.NetworkRouterRequest{
@@ -1328,7 +1328,11 @@ func Test_NetworkRouters_Update(t *testing.T) {
Metric: 100,
Enabled: true,
},
expectedStatus: http.StatusNotFound,
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, router *api.NetworkRouter) {
t.Helper()
assert.Equal(t, "nonExistingRouterId", router.Id)
},
},
{
name: "Update router with both peer and peer_groups",

View File

@@ -17,7 +17,6 @@ import (
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
nbversion "github.com/netbirdio/netbird/version"
)
@@ -54,7 +53,6 @@ type DataSource interface {
GetAllAccounts(ctx context.Context) []*types.Account
GetStoreEngine() types.Engine
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
GetProxyMetrics(ctx context.Context) (store.ProxyMetrics, error)
}
// ConnManager peer connection manager that holds state for current active connections
@@ -225,12 +223,6 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
servicesAuthPassword int
servicesAuthPin int
servicesAuthOIDC int
// Private-service signals — track adoption of NetBird-only mode
// (services backed by an embedded proxy peer + access groups).
servicesPrivate int
servicesPrivateWithGroups int
servicesPrivateAccessGroupsSum int
servicesWithDirectUpstream int
)
start := time.Now()
metricsProperties := make(properties)
@@ -388,31 +380,9 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled {
servicesAuthOIDC++
}
if service.Private {
servicesPrivate++
if len(service.AccessGroups) > 0 {
servicesPrivateWithGroups++
}
servicesPrivateAccessGroupsSum += len(service.AccessGroups)
}
for _, target := range service.Targets {
if target.Options.DirectUpstream {
servicesWithDirectUpstream++
break
}
}
}
}
// Proxy / BYOP cluster signals come from the proxies table aggregated
// across all accounts in a single store query; nil on FileStore.
proxyMetrics, err := w.dataSource.GetProxyMetrics(ctx)
if err != nil {
log.WithContext(ctx).Debugf("collect proxy metrics: %v", err)
}
minActivePeerVersion, maxActivePeerVersion := getMinMaxVersion(peerActiveVersions)
metricsProperties["uptime"] = uptime
metricsProperties["accounts"] = accounts
@@ -460,15 +430,6 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
metricsProperties["services_auth_password"] = servicesAuthPassword
metricsProperties["services_auth_pin"] = servicesAuthPin
metricsProperties["services_auth_oidc"] = servicesAuthOIDC
metricsProperties["services_private"] = servicesPrivate
metricsProperties["services_private_with_access_groups"] = servicesPrivateWithGroups
metricsProperties["services_private_access_groups_sum"] = servicesPrivateAccessGroupsSum
metricsProperties["services_with_direct_upstream"] = servicesWithDirectUpstream
metricsProperties["proxy_clusters"] = proxyMetrics.Clusters
metricsProperties["proxy_clusters_byop"] = proxyMetrics.ClustersBYOP
metricsProperties["proxy_clusters_private"] = proxyMetrics.ClustersPrivate
metricsProperties["proxies"] = proxyMetrics.Proxies
metricsProperties["proxies_connected"] = proxyMetrics.ProxiesConnected
metricsProperties["custom_domains"] = customDomains
metricsProperties["custom_domains_validated"] = customDomainsValidated

View File

@@ -12,7 +12,6 @@ import (
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
@@ -124,7 +123,7 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
Enabled: true,
Targets: []*rpservice.Target{
{TargetType: "peer"},
{TargetType: "host", Options: rpservice.TargetOptions{DirectUpstream: true}},
{TargetType: "host"},
},
Auth: rpservice.AuthConfig{
PasswordAuth: &rpservice.PasswordAuthConfig{Enabled: true},
@@ -142,16 +141,6 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
},
Meta: rpservice.Meta{Status: string(rpservice.StatusPending)},
},
{
ID: "svc3-private",
Enabled: true,
Private: true,
AccessGroups: []string{"grp-eng", "grp-ops"},
Targets: []*rpservice.Target{
{TargetType: "cluster", Options: rpservice.TargetOptions{DirectUpstream: true}},
},
Meta: rpservice.Meta{Status: string(rpservice.StatusActive)},
},
},
},
{
@@ -265,18 +254,6 @@ func (mockDatasource) GetCustomDomainsCounts(_ context.Context) (int64, int64, e
return 3, 2, nil
}
// GetProxyMetrics returns canned proxy/cluster counts so the
// generateProperties test can assert the BYOP signals end-to-end.
func (mockDatasource) GetProxyMetrics(_ context.Context) (store.ProxyMetrics, error) {
return store.ProxyMetrics{
Clusters: 3,
ClustersBYOP: 1,
ClustersPrivate: 1,
Proxies: 4,
ProxiesConnected: 2,
}, nil
}
// TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties
func TestGenerateProperties(t *testing.T) {
ds := mockDatasource{}
@@ -416,17 +393,17 @@ func TestGenerateProperties(t *testing.T) {
t.Errorf("expected 3 embedded_idp_count, got %v", properties["embedded_idp_count"])
}
if properties["services"] != 3 {
t.Errorf("expected 3 services, got %v", properties["services"])
if properties["services"] != 2 {
t.Errorf("expected 2 services, got %v", properties["services"])
}
if properties["services_enabled"] != 2 {
t.Errorf("expected 2 services_enabled, got %v", properties["services_enabled"])
if properties["services_enabled"] != 1 {
t.Errorf("expected 1 services_enabled, got %v", properties["services_enabled"])
}
if properties["services_targets"] != 4 {
t.Errorf("expected 4 services_targets, got %v", properties["services_targets"])
if properties["services_targets"] != 3 {
t.Errorf("expected 3 services_targets, got %v", properties["services_targets"])
}
if properties["services_status_active"] != 2 {
t.Errorf("expected 2 services_status_active, got %v", properties["services_status_active"])
if properties["services_status_active"] != 1 {
t.Errorf("expected 1 services_status_active, got %v", properties["services_status_active"])
}
if properties["services_status_pending"] != 1 {
t.Errorf("expected 1 services_status_pending, got %v", properties["services_status_pending"])
@@ -443,9 +420,6 @@ func TestGenerateProperties(t *testing.T) {
if properties["services_target_type_domain"] != 1 {
t.Errorf("expected 1 services_target_type_domain, got %v", properties["services_target_type_domain"])
}
if properties["services_target_type_cluster"] != 1 {
t.Errorf("expected 1 services_target_type_cluster, got %v", properties["services_target_type_cluster"])
}
if properties["services_auth_password"] != 1 {
t.Errorf("expected 1 services_auth_password, got %v", properties["services_auth_password"])
}
@@ -455,33 +429,6 @@ func TestGenerateProperties(t *testing.T) {
if properties["services_auth_pin"] != 0 {
t.Errorf("expected 0 services_auth_pin, got %v", properties["services_auth_pin"])
}
if properties["services_private"] != 1 {
t.Errorf("expected 1 services_private, got %v", properties["services_private"])
}
if properties["services_private_with_access_groups"] != 1 {
t.Errorf("expected 1 services_private_with_access_groups, got %v", properties["services_private_with_access_groups"])
}
if properties["services_private_access_groups_sum"] != 2 {
t.Errorf("expected 2 services_private_access_groups_sum, got %v", properties["services_private_access_groups_sum"])
}
if properties["services_with_direct_upstream"] != 2 {
t.Errorf("expected 2 services_with_direct_upstream, got %v", properties["services_with_direct_upstream"])
}
if properties["proxy_clusters"] != int64(3) {
t.Errorf("expected 3 proxy_clusters, got %v", properties["proxy_clusters"])
}
if properties["proxy_clusters_byop"] != int64(1) {
t.Errorf("expected 1 proxy_clusters_byop, got %v", properties["proxy_clusters_byop"])
}
if properties["proxy_clusters_private"] != int64(1) {
t.Errorf("expected 1 proxy_clusters_private, got %v", properties["proxy_clusters_private"])
}
if properties["proxies"] != int64(4) {
t.Errorf("expected 4 proxies, got %v", properties["proxies"])
}
if properties["proxies_connected"] != int64(2) {
t.Errorf("expected 2 proxies_connected, got %v", properties["proxies_connected"])
}
if properties["custom_domains"] != int64(3) {
t.Errorf("expected 3 custom_domains, got %v", properties["custom_domains"])
}

View File

@@ -198,11 +198,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
require.NoError(t, err, "Failed to insert account")
account.PeersG = []nbpeer.Peer{
{
AccountID: "1234",
Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}},
Status: &nbpeer.PeerStatus{LastSeen: time.Now()},
},
{AccountID: "1234", Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
}
err = db.Save(account).Error

View File

@@ -34,11 +34,8 @@ func Test_GetAllNetworksReturnsNetworks(t *testing.T) {
networks, err := manager.GetAllNetworks(ctx, accountID, userID)
require.NoError(t, err)
ids := make([]string, 0, len(networks))
for _, n := range networks {
ids = append(ids, n.ID)
}
require.ElementsMatch(t, []string{"testNetworkId", "secondNetworkId"}, ids)
require.Len(t, networks, 1)
require.Equal(t, "testNetworkId", networks[0].ID)
}
func Test_GetAllNetworksReturnsPermissionDenied(t *testing.T) {

View File

@@ -102,7 +102,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
router.ID = xid.New().String()
err = transaction.CreateNetworkRouter(ctx, router)
err = transaction.SaveNetworkRouter(ctx, router)
if err != nil {
return fmt.Errorf("failed to create network router: %w", err)
}
@@ -162,20 +162,11 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
return fmt.Errorf("failed to get network: %w", err)
}
existing, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, router.AccountID, router.ID)
if err != nil {
return fmt.Errorf("failed to get network router: %w", err)
}
if existing.AccountID != router.AccountID {
return status.NewNetworkRouterNotFoundError(router.ID)
}
if existing.NetworkID != router.NetworkID {
if network.ID != router.NetworkID {
return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
}
err = transaction.UpdateNetworkRouter(ctx, router)
err = transaction.SaveNetworkRouter(ctx, router)
if err != nil {
return fmt.Errorf("failed to update network router: %w", err)
}

View File

@@ -195,7 +195,6 @@ func Test_UpdateRouterSuccessfully(t *testing.T) {
if err != nil {
require.NoError(t, err)
}
router.ID = "testRouterId"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil {
@@ -211,102 +210,6 @@ func Test_UpdateRouterSuccessfully(t *testing.T) {
require.Equal(t, router.Metric, updatedRouter.Metric)
}
func Test_UpdateRouterRejectsCrossAccountID(t *testing.T) {
ctx := context.Background()
userID := "testAdminId"
// Admin of testAccountId tries to update a router that belongs to otherAccountId
// by passing the other account's router ID through the URL.
router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 1, true)
if err != nil {
require.NoError(t, err)
}
router.ID = "otherRouterId"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(s)
am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am)
updatedRouter, err := manager.UpdateRouter(ctx, userID, router)
require.Error(t, err)
require.Nil(t, updatedRouter)
// The other account's router must be untouched.
stored, err := s.GetNetworkRouterByID(ctx, store.LockingStrengthNone, "otherAccountId", "otherRouterId")
require.NoError(t, err)
require.Equal(t, "otherAccountId", stored.AccountID)
require.Equal(t, "otherNetworkId", stored.NetworkID)
require.Equal(t, "otherPeer", stored.Peer)
require.Equal(t, 1, stored.Metric)
}
func Test_CreateRouterRejectsCrossAccountID(t *testing.T) {
ctx := context.Background()
userID := "testAdminId"
// Admin of testAccountId tries to create a router in otherAccountId's network.
// The permission check is on router.AccountID (their own), but the network
// lookup must fail because (testAccountId, otherNetworkId) does not exist.
router, err := types.NewNetworkRouter("testAccountId", "otherNetworkId", "testPeerId", []string{}, false, 1, true)
if err != nil {
require.NoError(t, err)
}
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(s)
am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am)
createdRouter, err := manager.CreateRouter(ctx, userID, router)
require.Error(t, err)
require.Nil(t, createdRouter)
// No router should have been created in either account's scope under otherNetworkId.
routersInOther, err := s.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, "otherAccountId", "otherNetworkId")
require.NoError(t, err)
require.Len(t, routersInOther, 1)
require.Equal(t, "otherRouterId", routersInOther[0].ID)
}
func Test_UpdateRouterRejectsNetworkMismatch(t *testing.T) {
ctx := context.Background()
userID := "testAdminId"
// The router exists in testNetworkId, but the caller submits secondNetworkId
// (a different network in the same account). The update must be refused.
router, err := types.NewNetworkRouter("testAccountId", "secondNetworkId", "testPeerId", []string{}, false, 1, true)
if err != nil {
require.NoError(t, err)
}
router.ID = "testRouterId"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(s)
am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am)
updatedRouter, err := manager.UpdateRouter(ctx, userID, router)
require.Error(t, err)
require.Nil(t, updatedRouter)
stored, err := s.GetNetworkRouterByID(ctx, store.LockingStrengthNone, "testAccountId", "testRouterId")
require.NoError(t, err)
require.Equal(t, "testNetworkId", stored.NetworkID)
}
func Test_UpdateRouterFailsWithPermissionDenied(t *testing.T) {
ctx := context.Background()
userID := "testUserId"

View File

@@ -30,7 +30,6 @@ import (
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/version"
)
const remoteJobsMinVer = "0.64.0"
@@ -126,18 +125,6 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
}
}
// An embedded proxy peer flipping to connected is the trigger for
// SynthesizePrivateServiceZones to emit DNS A records pointing at its
// tunnel IP. Without an account-wide netmap recompute, user peers keep
// the stale synth (or no synth at all on first connect) until some
// other change pokes the controller. Fire OnPeersUpdated so the
// buffered recompute fans the new state out to every peer.
if peer.ProxyMeta.Embedded {
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
log.WithContext(ctx).Warnf("notify network map controller of embedded proxy %s connect: %v", peer.ID, err)
}
}
return nil
}
@@ -173,17 +160,6 @@ func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerP
return nil
}
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusApplied)
// Symmetric with MarkPeerConnected: when an embedded proxy peer goes
// offline, drive an account-wide netmap recompute so the synthesized
// DNS records that pointed at it are pulled. Without this the records
// linger client-side at TTL until something else triggers a refresh.
if peer.ProxyMeta.Embedded {
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
log.WithContext(ctx).Warnf("notify network map controller of embedded proxy %s disconnect: %v", peer.ID, err)
}
}
return nil
}
@@ -373,7 +349,7 @@ func (am *DefaultAccountManager) CreatePeerJob(ctx context.Context, accountID, p
}
meetMinVer, err := posture.MeetsMinVersion(remoteJobsMinVer, p.Meta.WtVersion)
if !version.IsDevelopmentVersion(p.Meta.WtVersion) && (!meetMinVer || err != nil) {
if !strings.Contains(p.Meta.WtVersion, "dev") && (!meetMinVer || err != nil) {
return status.Errorf(status.PreconditionFailed, "peer version %s does not meet the minimum required version %s for remote jobs", p.Meta.WtVersion, remoteJobsMinVer)
}

View File

@@ -86,7 +86,7 @@ type PeerStatus struct { //nolint:revive
// active session". Integer nanoseconds are used so equality is
// precision-safe across drivers, and so the predicates compose to a
// single bigint comparison.
SessionStartedAt int64 `gorm:"not null;default:0"`
SessionStartedAt int64
// Connected indicates whether peer is connected to the management service or not
Connected bool
// LoginExpired

View File

@@ -2218,9 +2218,6 @@ func Test_IsUniqueConstraintError(t *testing.T) {
ID: "test-peer-id",
AccountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
DNSLabel: "test-peer-dns-label",
Status: &nbpeer.PeerStatus{
LastSeen: time.Now(),
},
}
for _, tt := range tests {

View File

@@ -274,9 +274,3 @@ func (s *FileStore) SetFieldEncrypt(_ *crypt.FieldEncrypt) {
func (s *FileStore) GetCustomDomainsCounts(_ context.Context) (int64, int64, error) {
return 0, 0, nil
}
// GetProxyMetrics is a no-op for FileStore — proxy/cluster state isn't
// persisted in the JSON file format.
func (s *FileStore) GetProxyMetrics(_ context.Context) (ProxyMetrics, error) {
return ProxyMetrics{}, nil
}

View File

@@ -1090,38 +1090,6 @@ func (s *SqlStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, er
return total, validated, nil
}
// GetProxyMetrics aggregates per-cluster + per-proxy counts for the
// self-hosted telemetry payload. Single round-trip via conditional
// aggregations so a large proxies table doesn't fan out into multiple
// queries.
func (s *SqlStore) GetProxyMetrics(ctx context.Context) (ProxyMetrics, error) {
var m ProxyMetrics
activeCutoff := time.Now().Add(-proxyActiveThreshold)
// COUNT(DISTINCT ... CASE WHEN ...) is portable across sqlite/postgres
// (MySQL too) and keeps the round-trip to one. proxy.StatusConnected
// is the same string the cluster-capability queries use; the active
// window matches the cluster-capability semantics (only proxies
// heartbeating within ~2 * heartbeat interval count as connected).
row := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Select(
"COUNT(DISTINCT cluster_address) AS clusters, "+
"COUNT(DISTINCT CASE WHEN account_id IS NOT NULL THEN cluster_address END) AS clusters_byop, "+
"COUNT(DISTINCT CASE WHEN private = ? THEN cluster_address END) AS clusters_private, "+
"COUNT(*) AS proxies, "+
"COUNT(CASE WHEN status = ? AND last_seen > ? THEN 1 END) AS proxies_connected",
true,
proxy.StatusConnected,
activeCutoff,
).
Row()
if err := row.Scan(&m.Clusters, &m.ClustersBYOP, &m.ClustersPrivate, &m.Proxies, &m.ProxiesConnected); err != nil {
return ProxyMetrics{}, fmt.Errorf("scan proxy metrics: %w", err)
}
return m, nil
}
func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) {
var accounts []types.Account
result := s.db.Find(&accounts)
@@ -2210,8 +2178,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth,
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
pass_host_header, rewrite_redirects, session_private_key, session_public_key,
mode, listen_port, port_auto_assigned, source, source_peer, terminated,
private, access_groups
mode, listen_port, port_auto_assigned, source, source_peer, terminated
FROM services WHERE account_id = $1`
const targetsQuery = `SELECT id, account_id, service_id, path, host, port, protocol,
@@ -2226,11 +2193,10 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*rpservice.Service, error) {
var s rpservice.Service
var auth []byte
var accessGroups []byte
var createdAt, certIssuedAt sql.NullTime
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
var mode, source, sourcePeer sql.NullString
var terminated, portAutoAssigned, private sql.NullBool
var terminated, portAutoAssigned sql.NullBool
var listenPort sql.NullInt64
err := row.Scan(
&s.ID,
@@ -2253,8 +2219,6 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
&source,
&sourcePeer,
&terminated,
&private,
&accessGroups,
)
if err != nil {
return nil, err
@@ -2266,16 +2230,6 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
}
}
if len(accessGroups) > 0 {
if err := json.Unmarshal(accessGroups, &s.AccessGroups); err != nil {
return nil, fmt.Errorf("unmarshal access_groups: %w", err)
}
}
if private.Valid {
s.Private = private.Bool
}
s.Meta = rpservice.Meta{}
if createdAt.Valid {
s.Meta.CreatedAt = createdAt.Time
@@ -3509,6 +3463,49 @@ func (s *SqlStore) GetAllEphemeralPeers(ctx context.Context, lockStrength Lockin
return allEphemeralPeers, nil
}
// GetStaleEphemeralPeerIDsForAccount returns IDs of disconnected
// ephemeral peers in the given account whose last_seen is strictly
// older than olderThan.
func (s *SqlStore) GetStaleEphemeralPeerIDsForAccount(ctx context.Context, accountID string, olderThan time.Time) ([]string, error) {
var ids []string
err := s.db.WithContext(ctx).
Model(&nbpeer.Peer{}).
Where("account_id = ? AND ephemeral = ? AND peer_status_connected = ? AND peer_status_last_seen < ?",
accountID, true, false, olderThan).
Pluck("id", &ids).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to query stale ephemeral peers for account %s: %v", accountID, err)
return nil, status.Errorf(status.Internal, "query stale ephemeral peers")
}
return ids, nil
}
// GetEphemeralAccountsLastDisconnect returns the latest peer_status_last_seen
// per account across disconnected ephemeral peers. Returns one entry per
// account that has at least one such peer.
func (s *SqlStore) GetEphemeralAccountsLastDisconnect(ctx context.Context) (map[string]time.Time, error) {
type row struct {
AccountID string
LastSeen time.Time
}
var rows []row
err := s.db.WithContext(ctx).
Model(&nbpeer.Peer{}).
Select("account_id, MAX(peer_status_last_seen) AS last_seen").
Where("ephemeral = ? AND peer_status_connected = ?", true, false).
Group("account_id").
Scan(&rows).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to load ephemeral-account last disconnect map: %v", err)
return nil, status.Errorf(status.Internal, "load ephemeral accounts")
}
out := make(map[string]time.Time, len(rows))
for _, r := range rows {
out[r.AccountID] = r.LastSeen
}
return out, nil
}
// DeletePeer removes a peer from the store.
func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID string) error {
result := s.db.Delete(&nbpeer.Peer{}, accountAndIDQueryCondition, accountID, peerID)
@@ -4361,27 +4358,11 @@ func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength Lockin
return netRouter, nil
}
func (s *SqlStore) CreateNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error {
if err := s.db.Create(router).Error; err != nil {
log.WithContext(ctx).Errorf("failed to create network router in store: %v", err)
return status.Errorf(status.Internal, "failed to create network router in store")
}
return nil
}
func (s *SqlStore) UpdateNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error {
result := s.db.
Select("*").
Where(accountAndIDQueryCondition, router.AccountID, router.ID).
Updates(router)
func (s *SqlStore) SaveNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error {
result := s.db.Save(router)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to update network router in store: %v", result.Error)
return status.Errorf(status.Internal, "failed to update network router in store")
}
if result.RowsAffected == 0 {
return status.NewNetworkRouterNotFoundError(router.ID)
log.WithContext(ctx).Errorf("failed to save network router to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to save network router to store")
}
return nil
@@ -5798,67 +5779,19 @@ func (s *SqlStore) DeleteAccountCluster(ctx context.Context, clusterAddress, acc
return nil
}
// GetProxyClusters returns every cluster the account can see (shared
// plus its own BYOP), regardless of whether any proxy in the cluster
// is currently heartbeating. Online and ConnectedProxies are derived
// from the 2-min active window so the dashboard can render offline
// clusters distinctly; the 1-hour heartbeat reaper still removes rows
// that go quiet for too long.
//
// AccountOwned is determined by whether any proxy row in the group
// carries a non-NULL account_id; the caller maps that to Cluster.Type.
// Capability flags are NOT filled here — the handler enriches them via
// the per-cluster capability lookups.
func (s *SqlStore) GetProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) {
activeCutoff := time.Now().Add(-proxyActiveThreshold)
func (s *SqlStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) {
var clusters []proxy.Cluster
type clusterRow struct {
ID string
Address string
ConnectedProxies int
Online bool
AccountOwned bool
}
var rows []clusterRow
result := s.db.Model(&proxy.Proxy{}).
Select(
"MIN(id) AS id, "+
"cluster_address AS address, "+
// COUNT(CASE WHEN ... THEN 1 END) counts only non-NULL — i.e. only
// rows that satisfy the predicate — so it works portably across
// sqlite/postgres/mysql without dialect-specific FILTER syntax.
"COUNT(CASE WHEN status = ? AND last_seen > ? THEN 1 END) AS connected_proxies, "+
// MAX(CASE …) > 0 expresses BOOL_OR in a way Postgres tolerates
// (Postgres can't MAX a boolean column).
"MAX(CASE WHEN status = ? AND last_seen > ? THEN 1 ELSE 0 END) > 0 AS online, "+
"MAX(CASE WHEN account_id IS NOT NULL THEN 1 ELSE 0 END) > 0 AS account_owned",
proxy.StatusConnected, activeCutoff,
proxy.StatusConnected, activeCutoff,
).
Where("account_id IS NULL OR account_id = ?", accountID).
Select("MIN(id) as id, cluster_address as address, COUNT(*) as connected_proxies, COUNT(account_id) > 0 as self_hosted").
Where("status = ? AND last_seen > ? AND (account_id IS NULL OR account_id = ?)",
proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold), accountID).
Group("cluster_address").
Scan(&rows)
Scan(&clusters)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get proxy clusters: %v", result.Error)
return nil, status.Errorf(status.Internal, "get proxy clusters")
}
clusters := make([]proxy.Cluster, 0, len(rows))
for _, r := range rows {
c := proxy.Cluster{
ID: r.ID,
Address: r.Address,
Online: r.Online,
ConnectedProxies: r.ConnectedProxies,
}
if r.AccountOwned {
c.Type = proxy.ClusterTypeAccount
} else {
c.Type = proxy.ClusterTypeShared
}
clusters = append(clusters, c)
log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", result.Error)
return nil, status.Errorf(status.Internal, "get active proxy clusters")
}
return clusters, nil
@@ -5872,7 +5805,6 @@ var validCapabilityColumns = map[string]struct{}{
"supports_custom_ports": {},
"require_subdomain": {},
"supports_crowdsec": {},
"private": {},
}
// GetClusterSupportsCustomPorts returns whether any active proxy in the cluster
@@ -5887,12 +5819,6 @@ func (s *SqlStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr s
return s.getClusterCapability(ctx, clusterAddr, "require_subdomain")
}
// GetClusterSupportsPrivate reports whether any active proxy in the cluster
// has the private capability (nil = unreported).
func (s *SqlStore) GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
return s.getClusterCapability(ctx, clusterAddr, "private")
}
// GetClusterSupportsCrowdSec returns whether all active proxies in the cluster
// have CrowdSec configured. Returns nil when no proxy reported the capability.
// Unlike other capabilities that use ANY-true (for rolling upgrades), CrowdSec

View File

@@ -1,109 +0,0 @@
package store
import (
"context"
"os"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
)
// TestSqlStore_GetProxyClusters_DerivesOnlineAndType guards the
// account-visible cluster list against silent regressions in two
// dimensions:
//
// 1. Online derivation: a cluster with one stale and one fresh proxy
// is online and counts only the fresh proxy; a cluster whose
// proxies all heartbeated outside the 2-min window appears offline
// with connected_proxies = 0 (rather than disappearing, which is
// what the old query did).
// 2. Type derivation: a cluster scoped to the calling account is
// reported as `account`; a cluster with account_id IS NULL is
// reported as `shared`. Clusters scoped to other accounts must not
// leak into the result.
//
// Capability flags are intentionally not asserted here — they're filled
// by the manager (handler) layer from the per-cluster capability
// lookups, not by the store query.
func TestSqlStore_GetProxyClusters_DerivesOnlineAndType(t *testing.T) {
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
ctx := context.Background()
accountID := "acct-clusters"
require.NoError(t, store.SaveAccount(ctx, newAccountWithId(ctx, accountID, "user-1", "")))
otherAccountID := "acct-other"
require.NoError(t, store.SaveAccount(ctx, newAccountWithId(ctx, otherAccountID, "user-2", "")))
acctID := accountID
otherID := otherAccountID
fresh := time.Now().Add(-30 * time.Second)
stale := time.Now().Add(-30 * time.Minute)
mustSave := func(id, cluster string, accID *string, status string, lastSeen time.Time) {
require.NoError(t, store.SaveProxy(ctx, &rpproxy.Proxy{
ID: id,
SessionID: id + "-sess",
ClusterAddress: cluster,
IPAddress: "10.0.0.1",
AccountID: accID,
LastSeen: lastSeen,
Status: status,
}))
}
// shared-mixed: one fresh + one stale proxy → online, connected=1
mustSave("p-shared-fresh", "shared-mixed.netbird.io", nil, rpproxy.StatusConnected, fresh)
mustSave("p-shared-stale", "shared-mixed.netbird.io", nil, rpproxy.StatusConnected, stale)
// shared-offline: only stale proxies → offline, connected=0,
// but row must still appear (this is the new semantic — old
// query would have dropped it entirely).
mustSave("p-shared-off", "shared-offline.netbird.io", nil, rpproxy.StatusConnected, stale)
// account-online: BYOP cluster owned by acctID, fresh
mustSave("p-acct-fresh", "byop.acct.example", &acctID, rpproxy.StatusConnected, fresh)
// other-account: must not surface for acctID
mustSave("p-other", "byop.other.example", &otherID, rpproxy.StatusConnected, fresh)
clusters, err := store.GetProxyClusters(ctx, accountID)
require.NoError(t, err)
byAddr := map[string]rpproxy.Cluster{}
for _, c := range clusters {
byAddr[c.Address] = c
}
assert.NotContains(t, byAddr, "byop.other.example",
"another account's BYOP cluster must not leak into this account's listing")
require.Contains(t, byAddr, "shared-mixed.netbird.io")
mixed := byAddr["shared-mixed.netbird.io"]
assert.Equal(t, rpproxy.ClusterTypeShared, mixed.Type, "shared cluster (account_id IS NULL) must be reported as Type=shared")
assert.True(t, mixed.Online, "cluster with a fresh proxy must be online")
assert.Equal(t, 1, mixed.ConnectedProxies, "connected_proxies must count only fresh proxies; the stale one should not bump the count")
require.Contains(t, byAddr, "shared-offline.netbird.io",
"offline clusters must still appear so the dashboard can render them — the old GetActiveProxyClusters would have dropped this row, which is the regression this test guards against")
offline := byAddr["shared-offline.netbird.io"]
assert.Equal(t, rpproxy.ClusterTypeShared, offline.Type)
assert.False(t, offline.Online, "no fresh heartbeat → offline")
assert.Equal(t, 0, offline.ConnectedProxies, "no fresh proxies → connected_proxies=0")
require.Contains(t, byAddr, "byop.acct.example")
acct := byAddr["byop.acct.example"]
assert.Equal(t, rpproxy.ClusterTypeAccount, acct.Type, "BYOP cluster owned by the account must be reported as Type=account")
assert.True(t, acct.Online)
assert.Equal(t, 1, acct.ConnectedProxies)
})
}

View File

@@ -1,46 +0,0 @@
package store
import (
"context"
"os"
"runtime"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
)
func TestSqlStore_GetAccount_PrivateServiceRoundtrip(t *testing.T) {
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
ctx := context.Background()
account := newAccountWithId(ctx, "account_private_svc", "testuser", "")
require.NoError(t, store.SaveAccount(ctx, account))
svc := &rpservice.Service{
ID: "svc-private",
AccountID: account.Id,
Name: "private-svc",
Domain: "private.example",
ProxyCluster: "cluster.example",
Enabled: true,
Mode: rpservice.ModeHTTP,
Private: true,
AccessGroups: []string{"grp-admins", "grp-ops"},
}
require.NoError(t, store.CreateService(ctx, svc))
loaded, err := store.GetAccount(ctx, account.Id)
require.NoError(t, err)
require.Len(t, loaded.Services, 1)
got := loaded.Services[0]
assert.True(t, got.Private)
assert.Equal(t, []string{"grp-admins", "grp-ops"}, got.AccessGroups)
})
}

View File

@@ -2399,7 +2399,7 @@ func TestSqlStore_GetNetworkRouterByID(t *testing.T) {
}
}
func TestSqlStore_CreateNetworkRouter(t *testing.T) {
func TestSqlStore_SaveNetworkRouter(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@@ -2410,7 +2410,7 @@ func TestSqlStore_CreateNetworkRouter(t *testing.T) {
netRouter, err := routerTypes.NewNetworkRouter(accountID, networkID, "", []string{"net-router-grp"}, true, 0, true)
require.NoError(t, err)
err = store.CreateNetworkRouter(context.Background(), netRouter)
err = store.SaveNetworkRouter(context.Background(), netRouter)
require.NoError(t, err)
savedNetRouter, err := store.GetNetworkRouterByID(context.Background(), LockingStrengthNone, accountID, netRouter.ID)
@@ -2418,39 +2418,6 @@ func TestSqlStore_CreateNetworkRouter(t *testing.T) {
require.Equal(t, netRouter, savedNetRouter)
}
func TestSqlStore_UpdateNetworkRouter(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
networkID := "ct286bi7qv930dsrrug0"
routerID := "ctc20ji7qv9ck2sebc80"
netRouter := &routerTypes.NetworkRouter{
ID: routerID,
AccountID: accountID,
NetworkID: networkID,
Peer: "",
PeerGroups: []string{"net-router-grp"},
Masquerade: true,
Metric: 42,
Enabled: true,
}
err = store.UpdateNetworkRouter(context.Background(), netRouter)
require.NoError(t, err)
savedNetRouter, err := store.GetNetworkRouterByID(context.Background(), LockingStrengthNone, accountID, routerID)
require.NoError(t, err)
require.Equal(t, netRouter, savedNetRouter)
// Updating a router under a different account must not match any row.
netRouter.AccountID = "non-existent-account"
err = store.UpdateNetworkRouter(context.Background(), netRouter)
require.Error(t, err)
}
func TestSqlStore_DeleteNetworkRouter(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)

View File

@@ -165,6 +165,15 @@ type Store interface {
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
// GetStaleEphemeralPeerIDsForAccount returns the IDs of disconnected
// ephemeral peers whose last_seen is strictly older than olderThan,
// scoped to a single account. Used by the per-account cleanup sweep.
GetStaleEphemeralPeerIDsForAccount(ctx context.Context, accountID string, olderThan time.Time) ([]string, error)
// GetEphemeralAccountsLastDisconnect returns, for every account that
// has at least one disconnected ephemeral peer, the most recent
// last_seen across that account's disconnected ephemeral peers. Used
// to reconstruct the per-account cleanup tracker after a restart.
GetEphemeralAccountsLastDisconnect(ctx context.Context) (map[string]time.Time, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error
// MarkPeerConnectedIfNewerSession sets the peer to connected with the
@@ -228,8 +237,7 @@ type Store interface {
GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error)
GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error)
GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error)
CreateNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error
UpdateNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error
SaveNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error
DeleteNetworkRouter(ctx context.Context, accountID, routerID string) error
GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*resourceTypes.NetworkResource, error)
@@ -308,11 +316,10 @@ type Store interface {
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
GetProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
@@ -321,38 +328,9 @@ type Store interface {
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
// GetProxyMetrics returns aggregated proxy / cluster counts for the
// self-hosted metrics worker. Self-hosted only — file-based stores
// return a zero-valued struct.
GetProxyMetrics(ctx context.Context) (ProxyMetrics, error)
GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error)
}
// ProxyMetrics aggregates self-hosted proxy + cluster usage signals
// surfaced to the telemetry payload. Each field is best-effort: when a
// store cannot answer (e.g. FileStore) all fields are zero.
type ProxyMetrics struct {
// Clusters counts distinct cluster_address values across the proxies
// table — every cluster the management server has heard from, online or not.
Clusters int64
// ClustersBYOP counts distinct cluster_address values that are owned
// by an account (account_id IS NOT NULL). These are bring-your-own-proxy
// installations as opposed to NetBird-operated shared clusters.
ClustersBYOP int64
// ClustersPrivate counts distinct cluster_address values where at
// least one proxy reported the private capability (embedded
// `netbird proxy` running inside a client).
ClustersPrivate int64
// Proxies is the total number of proxy rows currently persisted.
Proxies int64
// ProxiesConnected is the subset of proxies whose status is
// "connected" AND last_seen falls within the active heartbeat window
// (~2 * heartbeat interval). Proxies the controller hasn't pruned
// yet but that are visibly stale don't count.
ProxiesConnected int64
}
const (
postgresDsnEnv = "NB_STORE_ENGINE_POSTGRES_DSN"
postgresDsnEnvLegacy = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
@@ -502,9 +480,6 @@ func getMigrationsPreAuto(ctx context.Context) []migrationFunc {
func(db *gorm.DB) error {
return migration.MigrateNewField[types.User](ctx, db, "email", "")
},
func(db *gorm.DB) error {
return migration.MigrateNewField[nbpeer.Peer](ctx, db, "peer_status_session_started_at", int64(0))
},
func(db *gorm.DB) error {
return migration.RemoveDuplicatePeerKeys(ctx, db)
},

View File

@@ -310,20 +310,6 @@ func (mr *MockStoreMockRecorder) CreateGroups(ctx, accountID, groups interface{}
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateGroups", reflect.TypeOf((*MockStore)(nil).CreateGroups), ctx, accountID, groups)
}
// CreateNetworkRouter mocks base method.
func (m *MockStore) CreateNetworkRouter(ctx context.Context, router *types0.NetworkRouter) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateNetworkRouter", ctx, router)
ret0, _ := ret[0].(error)
return ret0
}
// CreateNetworkRouter indicates an expected call of CreateNetworkRouter.
func (mr *MockStoreMockRecorder) CreateNetworkRouter(ctx, router interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateNetworkRouter", reflect.TypeOf((*MockStore)(nil).CreateNetworkRouter), ctx, router)
}
// CreatePeerJob mocks base method.
func (m *MockStore) CreatePeerJob(ctx context.Context, job *types2.Job) error {
m.ctrl.T.Helper()
@@ -394,20 +380,6 @@ func (mr *MockStoreMockRecorder) DeleteAccount(ctx, account interface{}) *gomock
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccount", reflect.TypeOf((*MockStore)(nil).DeleteAccount), ctx, account)
}
// DeleteAccountCluster mocks base method.
func (m *MockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
func (mr *MockStoreMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockStore)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID)
}
// DeleteCustomDomain mocks base method.
func (m *MockStore) DeleteCustomDomain(ctx context.Context, accountID, domainID string) error {
m.ctrl.T.Helper()
@@ -605,6 +577,20 @@ func (mr *MockStoreMockRecorder) DeletePostureChecks(ctx, accountID, postureChec
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePostureChecks", reflect.TypeOf((*MockStore)(nil).DeletePostureChecks), ctx, accountID, postureChecksID)
}
// DeleteAccountCluster mocks base method.
func (m *MockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
func (mr *MockStoreMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockStore)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID)
}
// DeleteRoute mocks base method.
func (m *MockStore) DeleteRoute(ctx context.Context, accountID, routeID string) error {
m.ctrl.T.Helper()
@@ -745,20 +731,6 @@ func (mr *MockStoreMockRecorder) DeleteZoneDNSRecords(ctx, accountID, zoneID int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteZoneDNSRecords", reflect.TypeOf((*MockStore)(nil).DeleteZoneDNSRecords), ctx, accountID, zoneID)
}
// DisconnectProxy mocks base method.
func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID, sessionID)
ret0, _ := ret[0].(error)
return ret0
}
// DisconnectProxy indicates an expected call of DisconnectProxy.
func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID, sessionID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID, sessionID)
}
// EphemeralServiceExists mocks base method.
func (m *MockStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) {
m.ctrl.T.Helper()
@@ -1360,6 +1332,21 @@ func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddressesForAccount(ctx, a
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddressesForAccount", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddressesForAccount), ctx, accountID)
}
// GetActiveProxyClusters mocks base method.
func (m *MockStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx, accountID)
ret0, _ := ret[0].([]proxy.Cluster)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveProxyClusters indicates an expected call of GetActiveProxyClusters.
func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx, accountID)
}
// GetAllAccounts mocks base method.
func (m *MockStore) GetAllAccounts(ctx context.Context) []*types2.Account {
m.ctrl.T.Helper()
@@ -1389,6 +1376,36 @@ func (mr *MockStoreMockRecorder) GetAllEphemeralPeers(ctx, lockStrength interfac
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllEphemeralPeers", reflect.TypeOf((*MockStore)(nil).GetAllEphemeralPeers), ctx, lockStrength)
}
// GetStaleEphemeralPeerIDsForAccount mocks base method.
func (m *MockStore) GetStaleEphemeralPeerIDsForAccount(ctx context.Context, accountID string, olderThan time.Time) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetStaleEphemeralPeerIDsForAccount", ctx, accountID, olderThan)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetStaleEphemeralPeerIDsForAccount indicates an expected call of GetStaleEphemeralPeerIDsForAccount.
func (mr *MockStoreMockRecorder) GetStaleEphemeralPeerIDsForAccount(ctx, accountID, olderThan interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStaleEphemeralPeerIDsForAccount", reflect.TypeOf((*MockStore)(nil).GetStaleEphemeralPeerIDsForAccount), ctx, accountID, olderThan)
}
// GetEphemeralAccountsLastDisconnect mocks base method.
func (m *MockStore) GetEphemeralAccountsLastDisconnect(ctx context.Context) (map[string]time.Time, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetEphemeralAccountsLastDisconnect", ctx)
ret0, _ := ret[0].(map[string]time.Time)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetEphemeralAccountsLastDisconnect indicates an expected call of GetEphemeralAccountsLastDisconnect.
func (mr *MockStoreMockRecorder) GetEphemeralAccountsLastDisconnect(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEphemeralAccountsLastDisconnect", reflect.TypeOf((*MockStore)(nil).GetEphemeralAccountsLastDisconnect), ctx)
}
// GetAllProxyAccessTokens mocks base method.
func (m *MockStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types2.ProxyAccessToken, error) {
m.ctrl.T.Helper()
@@ -1461,20 +1478,6 @@ func (mr *MockStoreMockRecorder) GetClusterSupportsCustomPorts(ctx, clusterAddr
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCustomPorts", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCustomPorts), ctx, clusterAddr)
}
// GetClusterSupportsPrivate mocks base method.
func (m *MockStore) GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClusterSupportsPrivate", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// GetClusterSupportsPrivate indicates an expected call of GetClusterSupportsPrivate.
func (mr *MockStoreMockRecorder) GetClusterSupportsPrivate(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsPrivate", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsPrivate), ctx, clusterAddr)
}
// GetCustomDomain mocks base method.
func (m *MockStore) GetCustomDomain(ctx context.Context, accountID, domainID string) (*domain.Domain, error) {
m.ctrl.T.Helper()
@@ -2075,36 +2078,6 @@ func (mr *MockStoreMockRecorder) GetProxyByAccountID(ctx, accountID interface{})
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyByAccountID), ctx, accountID)
}
// GetProxyClusters mocks base method.
func (m *MockStore) GetProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProxyClusters", ctx, accountID)
ret0, _ := ret[0].([]proxy.Cluster)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetProxyClusters indicates an expected call of GetProxyClusters.
func (mr *MockStoreMockRecorder) GetProxyClusters(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyClusters", reflect.TypeOf((*MockStore)(nil).GetProxyClusters), ctx, accountID)
}
// GetProxyMetrics mocks base method.
func (m *MockStore) GetProxyMetrics(ctx context.Context) (ProxyMetrics, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProxyMetrics", ctx)
ret0, _ := ret[0].(ProxyMetrics)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetProxyMetrics indicates an expected call of GetProxyMetrics.
func (mr *MockStoreMockRecorder) GetProxyMetrics(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyMetrics", reflect.TypeOf((*MockStore)(nil).GetProxyMetrics), ctx)
}
// GetResourceGroups mocks base method.
func (m *MockStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types2.Group, error) {
m.ctrl.T.Helper()
@@ -2655,36 +2628,6 @@ func (mr *MockStoreMockRecorder) MarkPATUsed(ctx, patID interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPATUsed", reflect.TypeOf((*MockStore)(nil).MarkPATUsed), ctx, patID)
}
// MarkPeerConnectedIfNewerSession mocks base method.
func (m *MockStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkPeerConnectedIfNewerSession", ctx, accountID, peerID, newSessionStartedAt)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// MarkPeerConnectedIfNewerSession indicates an expected call of MarkPeerConnectedIfNewerSession.
func (mr *MockStoreMockRecorder) MarkPeerConnectedIfNewerSession(ctx, accountID, peerID, newSessionStartedAt interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnectedIfNewerSession", reflect.TypeOf((*MockStore)(nil).MarkPeerConnectedIfNewerSession), ctx, accountID, peerID, newSessionStartedAt)
}
// MarkPeerDisconnectedIfSameSession mocks base method.
func (m *MockStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkPeerDisconnectedIfSameSession", ctx, accountID, peerID, sessionStartedAt)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// MarkPeerDisconnectedIfSameSession indicates an expected call of MarkPeerDisconnectedIfSameSession.
func (mr *MockStoreMockRecorder) MarkPeerDisconnectedIfSameSession(ctx, accountID, peerID, sessionStartedAt interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnectedIfSameSession", reflect.TypeOf((*MockStore)(nil).MarkPeerDisconnectedIfSameSession), ctx, accountID, peerID, sessionStartedAt)
}
// MarkPendingJobsAsFailed mocks base method.
func (m *MockStore) MarkPendingJobsAsFailed(ctx context.Context, accountID, peerID, jobID, reason string) error {
m.ctrl.T.Helper()
@@ -2895,6 +2838,20 @@ func (mr *MockStoreMockRecorder) SaveNetworkResource(ctx, resource interface{})
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveNetworkResource", reflect.TypeOf((*MockStore)(nil).SaveNetworkResource), ctx, resource)
}
// SaveNetworkRouter mocks base method.
func (m *MockStore) SaveNetworkRouter(ctx context.Context, router *types0.NetworkRouter) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveNetworkRouter", ctx, router)
ret0, _ := ret[0].(error)
return ret0
}
// SaveNetworkRouter indicates an expected call of SaveNetworkRouter.
func (mr *MockStoreMockRecorder) SaveNetworkRouter(ctx, router interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveNetworkRouter", reflect.TypeOf((*MockStore)(nil).SaveNetworkRouter), ctx, router)
}
// SavePAT mocks base method.
func (m *MockStore) SavePAT(ctx context.Context, pat *types2.PersonalAccessToken) error {
m.ctrl.T.Helper()
@@ -2951,6 +2908,36 @@ func (mr *MockStoreMockRecorder) SavePeerStatus(ctx, accountID, peerID, status i
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePeerStatus", reflect.TypeOf((*MockStore)(nil).SavePeerStatus), ctx, accountID, peerID, status)
}
// MarkPeerConnectedIfNewerSession mocks base method.
func (m *MockStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkPeerConnectedIfNewerSession", ctx, accountID, peerID, newSessionStartedAt)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// MarkPeerConnectedIfNewerSession indicates an expected call of MarkPeerConnectedIfNewerSession.
func (mr *MockStoreMockRecorder) MarkPeerConnectedIfNewerSession(ctx, accountID, peerID, newSessionStartedAt interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnectedIfNewerSession", reflect.TypeOf((*MockStore)(nil).MarkPeerConnectedIfNewerSession), ctx, accountID, peerID, newSessionStartedAt)
}
// MarkPeerDisconnectedIfSameSession mocks base method.
func (m *MockStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkPeerDisconnectedIfSameSession", ctx, accountID, peerID, sessionStartedAt)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// MarkPeerDisconnectedIfSameSession indicates an expected call of MarkPeerDisconnectedIfSameSession.
func (mr *MockStoreMockRecorder) MarkPeerDisconnectedIfSameSession(ctx, accountID, peerID, sessionStartedAt interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnectedIfSameSession", reflect.TypeOf((*MockStore)(nil).MarkPeerDisconnectedIfSameSession), ctx, accountID, peerID, sessionStartedAt)
}
// SavePolicy mocks base method.
func (m *MockStore) SavePolicy(ctx context.Context, policy *types2.Policy) error {
m.ctrl.T.Helper()
@@ -2993,6 +2980,20 @@ func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy)
}
// DisconnectProxy mocks base method.
func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID, sessionID)
ret0, _ := ret[0].(error)
return ret0
}
// DisconnectProxy indicates an expected call of DisconnectProxy.
func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID, sessionID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID, sessionID)
}
// SaveProxyAccessToken mocks base method.
func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error {
m.ctrl.T.Helper()
@@ -3202,20 +3203,6 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{}
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroups", reflect.TypeOf((*MockStore)(nil).UpdateGroups), ctx, accountID, groups)
}
// UpdateNetworkRouter mocks base method.
func (m *MockStore) UpdateNetworkRouter(ctx context.Context, router *types0.NetworkRouter) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateNetworkRouter", ctx, router)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateNetworkRouter indicates an expected call of UpdateNetworkRouter.
func (mr *MockStoreMockRecorder) UpdateNetworkRouter(ctx, router interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateNetworkRouter", reflect.TypeOf((*MockStore)(nil).UpdateNetworkRouter), ctx, router)
}
// UpdateProxyHeartbeat mocks base method.
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
m.ctrl.T.Helper()

View File

@@ -7,9 +7,9 @@ import (
)
// EphemeralPeersMetrics tracks the ephemeral peer cleanup pipeline: how
// many peers are currently scheduled for deletion, how many tick runs
// the cleaner has performed, how many peers it has removed, and how
// many delete batches failed.
// many accounts are currently being tracked for cleanup, how many sweep
// runs deleted at least one peer, how many peers have been removed, and
// how many delete batches failed.
type EphemeralPeersMetrics struct {
ctx context.Context
@@ -21,16 +21,16 @@ type EphemeralPeersMetrics struct {
// NewEphemeralPeersMetrics constructs the ephemeral cleanup counters.
func NewEphemeralPeersMetrics(ctx context.Context, meter metric.Meter) (*EphemeralPeersMetrics, error) {
pending, err := meter.Int64UpDownCounter("management.ephemeral.peers.pending",
pending, err := meter.Int64UpDownCounter("management.ephemeral.accounts.tracked",
metric.WithUnit("1"),
metric.WithDescription("Number of ephemeral peers currently waiting to be cleaned up"))
metric.WithDescription("Number of accounts currently tracked for ephemeral peer cleanup"))
if err != nil {
return nil, err
}
cleanupRuns, err := meter.Int64Counter("management.ephemeral.cleanup.runs.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of ephemeral cleanup ticks that processed at least one peer"))
metric.WithDescription("Number of ephemeral cleanup sweeps that deleted at least one peer"))
if err != nil {
return nil, err
}
@@ -61,7 +61,8 @@ func NewEphemeralPeersMetrics(ctx context.Context, meter metric.Meter) (*Ephemer
// All methods are nil-receiver safe so callers that haven't wired metrics
// (tests, self-hosted with metrics off) can invoke them unconditionally.
// IncPending bumps the pending gauge when a peer is added to the cleanup list.
// IncPending bumps the tracked-accounts gauge when a new account
// becomes eligible for ephemeral cleanup tracking.
func (m *EphemeralPeersMetrics) IncPending() {
if m == nil {
return
@@ -69,8 +70,8 @@ func (m *EphemeralPeersMetrics) IncPending() {
m.pending.Add(m.ctx, 1)
}
// AddPending bumps the pending gauge by n — used at startup when the
// initial set of ephemeral peers is loaded from the store.
// AddPending bumps the tracked-accounts gauge by n — used at startup
// when the catch-up query seeds the tracker.
func (m *EphemeralPeersMetrics) AddPending(n int64) {
if m == nil || n <= 0 {
return
@@ -78,9 +79,8 @@ func (m *EphemeralPeersMetrics) AddPending(n int64) {
m.pending.Add(m.ctx, n)
}
// DecPending decreases the pending gauge — used both when a peer reconnects
// before its deadline (removed from the list) and when a cleanup tick
// actually deletes it.
// DecPending decreases the tracked-accounts gauge when an account is
// dropped from the tracker (no more disconnects to chase).
func (m *EphemeralPeersMetrics) DecPending(n int64) {
if m == nil || n <= 0 {
return

View File

@@ -9,13 +9,9 @@ INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM
CREATE TABLE `networks` (`id` text,`account_id` text,`name` text,`description` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_networks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
INSERT INTO networks VALUES('testNetworkId','testAccountId','some-name','some-description');
INSERT INTO networks VALUES('secondNetworkId','testAccountId','second-name','second-description');
CREATE TABLE `network_routers` (`id` text,`network_id` text,`account_id` text,`peer` text,`peer_groups` text,`masquerade` numeric,`metric` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_network_routers` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
INSERT INTO network_routers VALUES('testRouterId','testNetworkId','testAccountId','','["csquuo4jcko732k1ag00"]',0,9999);
INSERT INTO accounts VALUES('otherAccountId','','2024-10-02 16:01:38.000000000+00:00','other.com','private',1,'otherNetworkIdentifier','{"IP":"100.65.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO networks VALUES('otherNetworkId','otherAccountId','other-net','other-description');
INSERT INTO network_routers VALUES('otherRouterId','otherNetworkId','otherAccountId','otherPeer',NULL,0,1);
CREATE TABLE `network_resources` (`id` text,`network_id` text,`account_id` text,`name` text,`description` text,`type` text,`address` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_network_resources` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
INSERT INTO network_resources VALUES('testResourceId','testNetworkId','testAccountId','some-name','some-description','host','3.3.3.3/32');

View File

@@ -29,13 +29,10 @@ import (
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/version"
)
const (
defaultTTL = 300
// privateServiceDNSRecordTTL is short so proxy-peer changes propagate quickly to clients.
privateServiceDNSRecordTTL = 5
defaultTTL = 300
DefaultPeerLoginExpiration = 24 * time.Hour
DefaultPeerInactivityExpiration = 10 * time.Minute
@@ -257,117 +254,6 @@ func getUniqueHostLabel(name string, peerLabels LookupMap) string {
return ""
}
// SynthesizePrivateServiceZones returns in-memory CustomZones with A records pointing each enabled private service the peer can reach at the cluster's proxy-peer IPs. One zone per cluster (multiple services share); records gated by AccessGroups.
func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZone {
peer, ok := a.Peers[peerID]
if !ok || peer == nil {
return nil
}
if len(a.Services) == 0 {
return nil
}
proxyPeersByCluster := a.GetProxyPeers()
if len(proxyPeersByCluster) == 0 {
return nil
}
peerGroups := a.GetPeerGroups(peerID)
zonesByCluster := map[string]*nbdns.CustomZone{}
for _, svc := range a.Services {
if svc == nil || !svc.Enabled || !svc.Private {
continue
}
if len(svc.AccessGroups) == 0 {
continue
}
if !peerInDistributionGroups(peerGroups, svc.AccessGroups) {
continue
}
proxyPeers := proxyPeersByCluster[svc.ProxyCluster]
if len(proxyPeers) == 0 {
continue
}
zone, exists := zonesByCluster[svc.ProxyCluster]
if !exists {
// NonAuthoritative makes this a match-only zone: queries for
// names without an explicit record fall through to the
// upstream resolver instead of returning NXDOMAIN. Without
// it, adding a single private service would black-hole every
// other name under the cluster apex.
zone = &nbdns.CustomZone{
Domain: dns.Fqdn(svc.ProxyCluster),
Records: []nbdns.SimpleRecord{},
NonAuthoritative: true,
}
zonesByCluster[svc.ProxyCluster] = zone
}
emitted := 0
skippedDisconnected := 0
for _, p := range proxyPeers {
if p == nil || !p.IP.IsValid() {
continue
}
// Only emit a record when the proxy peer is actually
// connected. A disconnected proxy peer's tunnel IP won't
// answer; pointing DNS at it would produce a black hole
// for as long as the record is cached client-side.
if p.Status == nil || !p.Status.Connected {
skippedDisconnected++
continue
}
zone.Records = append(zone.Records, nbdns.SimpleRecord{
Name: dns.Fqdn(svc.Domain),
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: privateServiceDNSRecordTTL,
RData: p.IP.String(),
})
emitted++
}
// Disagreement with the firewall path is the typical
// "domain doesn't reach client but firewall rules do"
// symptom: the synth service is otherwise fine, only the
// proxy peer's persisted Connected flag is wrong (most
// likely the connection reaper marked it disconnected even
// though the gRPC stream is alive).
if emitted == 0 && skippedDisconnected > 0 {
log.Debugf("private-zone synth: svc %s domain=%s cluster=%s emitted_zero proxy_peers=%d all_disconnected=%d (firewall would still fire)",
svc.ID, svc.Domain, svc.ProxyCluster, len(proxyPeers), skippedDisconnected)
}
}
out := make([]nbdns.CustomZone, 0, len(zonesByCluster))
for _, zone := range zonesByCluster {
if len(zone.Records) == 0 {
continue
}
out = append(out, *zone)
}
if len(out) == 0 && len(a.Services) > 0 {
// Targeted diagnostic for the "firewall yes, DNS no" divergence —
// fires only when services exist but synth returns zero zones,
// so accounts without private services produce no noise.
log.Debugf("private-zone synth: peer %s account %s returned 0 zones from %d candidate service(s)",
peerID, a.Id, len(a.Services))
}
return out
}
// peerInDistributionGroups reports whether any of the peer's groups
// matches the service's bearer-auth distribution_groups.
func peerInDistributionGroups(peerGroups LookupMap, distributionGroups []string) bool {
for _, gid := range distributionGroups {
if _, ok := peerGroups[gid]; ok {
return true
}
}
return false
}
func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone {
var merr *multierror.Error
@@ -1612,53 +1498,6 @@ func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *servi
a.injectTargetProxyPolicies(ctx, service, target, proxyPeers)
}
a.injectPrivateServicePolicies(service, proxyPeers)
}
// injectPrivateServicePolicies synthesises an in-memory ACL: AccessGroups → cluster proxy peers on TCP 80/443.
func (a *Account) injectPrivateServicePolicies(svc *service.Service, proxyPeers []*nbpeer.Peer) {
if !svc.Private {
return
}
if len(svc.AccessGroups) == 0 {
return
}
if len(proxyPeers) == 0 {
return
}
for _, proxyPeer := range proxyPeers {
a.Policies = append(a.Policies, a.createPrivateServicePolicy(svc, proxyPeer))
}
}
func (a *Account) createPrivateServicePolicy(svc *service.Service, proxyPeer *nbpeer.Peer) *Policy {
policyID := fmt.Sprintf("private-access-%s-%s", svc.ID, proxyPeer.ID)
sources := append([]string(nil), svc.AccessGroups...)
return &Policy{
ID: policyID,
Name: fmt.Sprintf("Private Access to %s", svc.Name),
Enabled: true,
Rules: []*PolicyRule{
{
ID: policyID,
PolicyID: policyID,
Name: fmt.Sprintf("Allow access groups to reach %s", svc.Name),
Enabled: true,
Sources: sources,
DestinationResource: Resource{
ID: proxyPeer.ID,
Type: ResourceTypePeer,
},
Bidirectional: false,
Protocol: PolicyRuleProtocolTCP,
Action: PolicyTrafficActionAccept,
PortRanges: []RulePortRange{
{Start: 80, End: 80},
{Start: 443, End: 443},
},
},
},
}
}
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *service.Service, target *service.Target, proxyPeers []*nbpeer.Peer) {
@@ -1805,7 +1644,7 @@ func shouldCheckRulesForNativeSSH(supportsNative bool, rule *PolicyRule, peer *n
// peerSupportedFirewallFeatures checks if the peer version supports port ranges.
func peerSupportedFirewallFeatures(peerVer string) supportedFeatures {
if version.IsDevelopmentVersion(peerVer) {
if strings.Contains(peerVer, "dev") {
return supportedFeatures{true, true}
}

View File

@@ -119,7 +119,6 @@ func (a *Account) GetPeerNetworkMapComponents(
peerGroups := a.GetPeerGroups(peerID)
components.AccountZones = filterPeerAppliedZones(ctx, accountZones, peerGroups)
components.AccountZones = append(components.AccountZones, a.SynthesizePrivateServiceZones(peerID)...)
for _, nsGroup := range a.NameServerGroups {
if nsGroup.Enabled {

View File

@@ -1,85 +0,0 @@
package types
import (
"context"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
func TestPrivateService_NetworkMap_UserPeer_AndProxyPeer(t *testing.T) {
account := privateZoneTestAccount(t)
account.Peers["user-peer"].Meta.WtVersion = "0.50.0"
account.Peers["proxy-peer"].Meta.WtVersion = "0.50.0"
ctx := context.Background()
account.InjectProxyPolicies(ctx)
validated := map[string]struct{}{
"user-peer": {},
"proxy-peer": {},
}
t.Run("user-peer update", func(t *testing.T) {
nm := account.GetPeerNetworkMapFromComponents(ctx, "user-peer", nbdns.CustomZone{}, nil, validated, nil, nil, nil, nil)
require.NotNil(t, nm)
zone, ok := findCustomZone(nm.DNSConfig.CustomZones, "eu.proxy.netbird.io")
require.True(t, ok)
require.Len(t, zone.Records, 1)
assert.Equal(t, "myapp.eu.proxy.netbird.io.", zone.Records[0].Name)
assert.Equal(t, int(dns.TypeA), zone.Records[0].Type)
assert.Equal(t, "100.64.0.99", zone.Records[0].RData)
assert.Contains(t, netmapPeerIDs(nm.Peers), "proxy-peer")
assertPrivateServiceFirewallRules(t, nm.FirewallRules, "100.64.0.99", FirewallRuleDirectionOUT)
})
t.Run("proxy-peer update", func(t *testing.T) {
nm := account.GetPeerNetworkMapFromComponents(ctx, "proxy-peer", nbdns.CustomZone{}, nil, validated, nil, nil, nil, nil)
require.NotNil(t, nm)
assert.Contains(t, netmapPeerIDs(nm.Peers), "user-peer")
assertPrivateServiceFirewallRules(t, nm.FirewallRules, "100.64.0.10", FirewallRuleDirectionIN)
})
}
func netmapPeerIDs(peers []*nbpeer.Peer) []string {
ids := make([]string, 0, len(peers))
for _, p := range peers {
if p == nil {
continue
}
ids = append(ids, p.ID)
}
return ids
}
func assertPrivateServiceFirewallRules(t *testing.T, rules []*FirewallRule, peerIP string, direction int) {
t.Helper()
wantPorts := map[uint16]bool{80: false, 443: false}
for _, r := range rules {
if r == nil || r.PeerIP != peerIP || r.Direction != direction {
continue
}
if r.Protocol != string(PolicyRuleProtocolTCP) || r.Action != string(PolicyTrafficActionAccept) {
continue
}
switch {
case r.PortRange.Start == r.PortRange.End && r.PortRange.Start != 0:
wantPorts[r.PortRange.Start] = true
case r.Port == "80":
wantPorts[80] = true
case r.Port == "443":
wantPorts[443] = true
}
}
for port, found := range wantPorts {
assert.Truef(t, found, "missing TCP accept rule on port %d for peer %s direction %d", port, peerIP, direction)
}
}

View File

@@ -1,256 +0,0 @@
package types
import (
"context"
"net"
"net/netip"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
func privateZoneTestAccount(t *testing.T) *Account {
t.Helper()
return &Account{
Id: "acct-1",
Settings: &Settings{},
Network: &Network{
Identifier: "net-1",
Net: net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32)},
},
Peers: map[string]*nbpeer.Peer{
"user-peer": {
ID: "user-peer",
AccountID: "acct-1",
Key: "user-peer-key",
IP: netip.MustParseAddr("100.64.0.10"),
Status: &nbpeer.PeerStatus{Connected: true},
},
"proxy-peer": {
ID: "proxy-peer",
AccountID: "acct-1",
Key: "proxy-peer-key",
IP: netip.MustParseAddr("100.64.0.99"),
Status: &nbpeer.PeerStatus{Connected: true},
ProxyMeta: nbpeer.ProxyMeta{
Embedded: true,
Cluster: "eu.proxy.netbird.io",
},
},
},
Groups: map[string]*Group{
"grp-admins": {
ID: "grp-admins",
Name: "admins",
Peers: []string{"user-peer"},
},
},
Services: []*service.Service{
{
ID: "svc-1",
AccountID: "acct-1",
Name: "myapp",
Domain: "myapp.eu.proxy.netbird.io",
ProxyCluster: "eu.proxy.netbird.io",
Enabled: true,
Private: true,
Mode: service.ModeHTTP,
AccessGroups: []string{"grp-admins"},
},
},
}
}
func TestSynthesizePrivateServiceZones_PeerInGroup_GetsRecord(t *testing.T) {
account := privateZoneTestAccount(t)
zones := account.SynthesizePrivateServiceZones("user-peer")
require.Len(t, zones, 1, "one cluster should produce one zone")
zone := zones[0]
assert.Equal(t, "eu.proxy.netbird.io.", zone.Domain, "zone apex must be the cluster FQDN")
assert.True(t, zone.NonAuthoritative, "synth zone must be match-only so unrelated sibling names fall through to the upstream resolver")
require.Len(t, zone.Records, 1, "one private service yields one A record")
rec := zone.Records[0]
assert.Equal(t, "myapp.eu.proxy.netbird.io.", rec.Name, "record name is the service FQDN")
assert.Equal(t, int(dns.TypeA), rec.Type, "record type must be A")
assert.Equal(t, "100.64.0.99", rec.RData, "record points at the embedded proxy peer's tunnel IP")
assert.Equal(t, privateServiceDNSRecordTTL, rec.TTL, "TTL must match the synth-records constant")
assert.Equal(t, nbdns.DefaultClass, rec.Class, "record class must be the package default")
}
func TestSynthesizePrivateServiceZones_PeerNotInGroup_NoRecord(t *testing.T) {
account := privateZoneTestAccount(t)
account.Groups["grp-admins"].Peers = nil
zones := account.SynthesizePrivateServiceZones("user-peer")
assert.Empty(t, zones, "peer outside distribution_groups must not see private-service records")
}
func TestSynthesizePrivateServiceZones_NotPrivate_NoRecord(t *testing.T) {
account := privateZoneTestAccount(t)
account.Services[0].Private = false
zones := account.SynthesizePrivateServiceZones("user-peer")
assert.Empty(t, zones, "non-private service must not produce DNS records")
}
func TestSynthesizePrivateServiceZones_NoAccessGroups_NoRecord(t *testing.T) {
account := privateZoneTestAccount(t)
account.Services[0].AccessGroups = nil
zones := account.SynthesizePrivateServiceZones("user-peer")
assert.Empty(t, zones, "private service without bearer auth must not produce DNS records")
}
func TestSynthesizePrivateServiceZones_NoProxyPeers_NoRecord(t *testing.T) {
account := privateZoneTestAccount(t)
delete(account.Peers, "proxy-peer")
zones := account.SynthesizePrivateServiceZones("user-peer")
assert.Empty(t, zones, "no embedded proxy peer in cluster means no record to emit")
}
func TestSynthesizePrivateServiceZones_DisabledService_NoRecord(t *testing.T) {
account := privateZoneTestAccount(t)
account.Services[0].Enabled = false
zones := account.SynthesizePrivateServiceZones("user-peer")
assert.Empty(t, zones, "disabled service must not produce DNS records")
}
func TestSynthesizePrivateServiceZones_DisconnectedProxyPeer_NoRecord(t *testing.T) {
account := privateZoneTestAccount(t)
account.Peers["proxy-peer"].Status = &nbpeer.PeerStatus{Connected: false}
zones := account.SynthesizePrivateServiceZones("user-peer")
assert.Empty(t, zones, "disconnected proxy peer must not produce a DNS record (would be a black hole)")
}
func TestSynthesizePrivateServiceZones_PartiallyDisconnectedProxyPeers_OnlyConnectedSurface(t *testing.T) {
account := privateZoneTestAccount(t)
account.Peers["proxy-peer-2"] = &nbpeer.Peer{
ID: "proxy-peer-2",
AccountID: "acct-1",
Key: "proxy-peer-2-key",
IP: netip.MustParseAddr("100.64.0.100"),
Status: &nbpeer.PeerStatus{Connected: false},
ProxyMeta: nbpeer.ProxyMeta{Embedded: true, Cluster: "eu.proxy.netbird.io"},
}
zones := account.SynthesizePrivateServiceZones("user-peer")
require.Len(t, zones, 1)
require.Len(t, zones[0].Records, 1, "only the connected proxy peer must surface")
assert.Equal(t, "100.64.0.99", zones[0].Records[0].RData)
}
func TestSynthesizePrivateServiceZones_MultipleProxyPeers_RoundRobin(t *testing.T) {
account := privateZoneTestAccount(t)
account.Peers["proxy-peer-2"] = &nbpeer.Peer{
ID: "proxy-peer-2",
AccountID: "acct-1",
Key: "proxy-peer-2-key",
IP: netip.MustParseAddr("100.64.0.100"),
Status: &nbpeer.PeerStatus{Connected: true},
ProxyMeta: nbpeer.ProxyMeta{Embedded: true, Cluster: "eu.proxy.netbird.io"},
}
zones := account.SynthesizePrivateServiceZones("user-peer")
require.Len(t, zones, 1, "still one cluster yields one zone")
require.Len(t, zones[0].Records, 2, "two proxy peers must produce two A records on the same name")
rdata := []string{zones[0].Records[0].RData, zones[0].Records[1].RData}
assert.ElementsMatch(t, []string{"100.64.0.99", "100.64.0.100"}, rdata, "both proxy peer IPs must surface")
}
// findCustomZone returns the CustomZone whose Domain equals the FQDN
// of want, or a zero value when not found. Tests use it to assert
// that the synth zone reaches dnsUpdate.CustomZones end-to-end.
func findCustomZone(zones []nbdns.CustomZone, want string) (nbdns.CustomZone, bool) {
wantFqdn := dns.Fqdn(want)
for _, z := range zones {
if z.Domain == wantFqdn {
return z, true
}
}
return nbdns.CustomZone{}, false
}
// TestPrivateZone_GetPeerNetworkMapFromComponents_ShipsSynthZone
// covers the components-based builder path. The components builder
// appends SynthesizePrivateServiceZones to AccountZones; the
// CalculateNetworkMapFromComponents step then merges AccountZones
// into dnsUpdate.CustomZones.
func TestPrivateZone_GetPeerNetworkMapFromComponents_ShipsSynthZone(t *testing.T) {
account := privateZoneTestAccount(t)
ctx := context.Background()
validated := map[string]struct{}{
"user-peer": {},
"proxy-peer": {},
}
nm := account.GetPeerNetworkMapFromComponents(ctx, "user-peer", nbdns.CustomZone{}, nil, validated, nil, nil, nil, nil)
require.NotNil(t, nm, "network map must be produced for an in-account peer")
zone, ok := findCustomZone(nm.DNSConfig.CustomZones, "eu.proxy.netbird.io")
require.True(t, ok, "shipped CustomZones must include the synth zone for the cluster")
require.Len(t, zone.Records, 1, "exactly one record per private service per connected proxy peer")
rec := zone.Records[0]
assert.Equal(t, "myapp.eu.proxy.netbird.io.", rec.Name, "record name is the service FQDN")
assert.Equal(t, "100.64.0.99", rec.RData, "record points at the embedded proxy peer's tunnel IP")
}
// TestPrivateZone_GetPeerNetworkMap_PeerOutsideGroups_OmitsSynthZone
// confirms the negative case the user encountered: a peer whose
// groups don't overlap the policy's distribution_groups gets a
// network map with no synth zone (and the wildcard / peer zones still
// flow through). This is the test mirror of the runtime confusion
// where the user looked at a non-distribution-group peer and assumed
// the synth path was broken.
func TestPrivateZone_GetPeerNetworkMap_PeerOutsideGroups_OmitsSynthZone(t *testing.T) {
account := privateZoneTestAccount(t)
account.Peers["outsider"] = &nbpeer.Peer{
ID: "outsider",
AccountID: "acct-1",
Key: "outsider-key",
IP: netip.MustParseAddr("100.64.0.20"),
Status: &nbpeer.PeerStatus{Connected: true},
}
ctx := context.Background()
validated := map[string]struct{}{
"user-peer": {},
"proxy-peer": {},
"outsider": {},
}
nm := account.GetPeerNetworkMapFromComponents(ctx, "outsider", nbdns.CustomZone{}, nil, validated, nil, nil, nil, nil)
require.NotNil(t, nm)
_, ok := findCustomZone(nm.DNSConfig.CustomZones, "eu.proxy.netbird.io")
assert.False(t, ok, "peer outside the distribution_groups must not see the synth zone")
}
func TestSynthesizePrivateServiceZones_TwoServicesSameCluster_OneZone(t *testing.T) {
account := privateZoneTestAccount(t)
account.Services = append(account.Services, &service.Service{
ID: "svc-2",
AccountID: "acct-1",
Name: "anotherapp",
Domain: "anotherapp.eu.proxy.netbird.io",
ProxyCluster: "eu.proxy.netbird.io",
Enabled: true,
Private: true,
Mode: service.ModeHTTP,
AccessGroups: []string{"grp-admins"},
})
zones := account.SynthesizePrivateServiceZones("user-peer")
require.Len(t, zones, 1, "two services on the same cluster must collapse into one zone")
require.Len(t, zones[0].Records, 2, "two services yield two A records")
names := []string{zones[0].Records[0].Name, zones[0].Records[1].Name}
assert.ElementsMatch(t, []string{"myapp.eu.proxy.netbird.io.", "anotherapp.eu.proxy.netbird.io."}, names, "both service domains must surface")
}

View File

@@ -3,7 +3,6 @@ package types
import (
"context"
"fmt"
"net"
"net/netip"
"testing"
@@ -12,7 +11,6 @@ import (
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -84,9 +82,9 @@ func setupTestAccount() *Account {
},
Groups: map[string]*Group{
"groupAll": {
ID: "groupAll",
Name: "All",
Peers: []string{"peer1", "peer2", "peer3", "peer11", "peer12", "peer21", "peer31", "peer32", "peer41", "peer51", "peer61"},
ID: "groupAll",
Name: "All",
Peers: []string{"peer1", "peer2", "peer3", "peer11", "peer12", "peer21", "peer31", "peer32", "peer41", "peer51", "peer61"},
Issued: GroupIssuedAPI,
},
"group1": {
@@ -646,7 +644,41 @@ func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
expectedPorts: []string{"20-25", "10-100", "22022"},
},
{
name: "development version supports all features",
name: "dev suffix version supports all features",
peer: &nbpeer.Peer{
ID: "peer1",
SSHEnabled: true,
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.50.0-dev",
Flags: nbpeer.Flags{ServerSSHAllowed: true},
},
},
rule: &PolicyRule{
Protocol: PolicyRuleProtocolTCP,
Ports: []string{"22"},
},
base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"},
expectedPorts: []string{"22", "22022"},
},
{
name: "dev suffix version supports all features",
peer: &nbpeer.Peer{
ID: "peer1",
SSHEnabled: true,
Meta: nbpeer.PeerSystemMeta{
WtVersion: "dev",
Flags: nbpeer.Flags{ServerSSHAllowed: true},
},
},
rule: &PolicyRule{
Protocol: PolicyRuleProtocolTCP,
Ports: []string{"22"},
},
base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"},
expectedPorts: []string{"22", "22022"},
},
{
name: "development suffix version supports all features",
peer: &nbpeer.Peer{
ID: "peer1",
SSHEnabled: true,
@@ -1551,203 +1583,3 @@ func Test_filterPeerAppliedZones(t *testing.T) {
})
}
}
func TestInjectPrivateServicePolicies_ProxyPeerGetsInboundRule(t *testing.T) {
ctx := context.Background()
userPeerIP := netip.MustParseAddr("100.64.0.10")
proxyPeerIP := netip.MustParseAddr("100.64.0.99")
account := &Account{
Id: "acct-1",
Network: &Network{
Identifier: "net-1",
Net: net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32)},
},
Peers: map[string]*nbpeer.Peer{
"user-peer": {
ID: "user-peer",
AccountID: "acct-1",
Key: "user-peer-key",
IP: userPeerIP,
},
"proxy-peer": {
ID: "proxy-peer",
AccountID: "acct-1",
Key: "proxy-peer-key",
IP: proxyPeerIP,
ProxyMeta: nbpeer.ProxyMeta{
Embedded: true,
Cluster: "eu.proxy.netbird.io",
},
},
},
Groups: map[string]*Group{
"grp-admins": {
ID: "grp-admins",
Name: "admins",
Peers: []string{"user-peer"},
},
},
Services: []*service.Service{
{
ID: "svc-1",
AccountID: "acct-1",
Name: "myapp",
Domain: "myapp.eu.proxy.netbird.io",
ProxyCluster: "eu.proxy.netbird.io",
Enabled: true,
Private: true,
Mode: service.ModeHTTP,
AccessGroups: []string{"grp-admins"},
Targets: []*service.Target{
{
TargetId: "eu.proxy.netbird.io",
TargetType: service.TargetTypeCluster,
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Enabled: true,
},
},
},
},
}
account.InjectProxyPolicies(ctx)
var found *Policy
for _, p := range account.Policies {
if p != nil && p.ID == "private-access-svc-1-proxy-peer" {
found = p
break
}
}
require.NotNil(t, found, "expected synthesised private-access policy in account.Policies")
require.Len(t, found.Rules, 1, "policy should have exactly one rule")
rule := found.Rules[0]
assert.Equal(t, []string{"grp-admins"}, rule.Sources, "sources should be group IDs verbatim")
assert.Equal(t, "proxy-peer", rule.DestinationResource.ID, "destination resource should be the proxy peer ID")
assert.Equal(t, ResourceTypePeer, rule.DestinationResource.Type, "destination resource type should be peer")
validatedPeersMap := map[string]struct{}{
"user-peer": {},
"proxy-peer": {},
}
proxyPeer := account.Peers["proxy-peer"]
aclPeers, firewallRules, _, _ := account.GetPeerConnectionResources(ctx, proxyPeer, validatedPeersMap, nil)
var sawUserAsAclPeer bool
for _, p := range aclPeers {
if p.ID == "user-peer" {
sawUserAsAclPeer = true
break
}
}
assert.True(t, sawUserAsAclPeer, "proxy peer should see the user peer as an ACL peer")
var inboundRules []*FirewallRule
for _, r := range firewallRules {
if r.Direction == FirewallRuleDirectionIN && r.PeerIP == userPeerIP.String() {
inboundRules = append(inboundRules, r)
}
}
assert.NotEmpty(t, inboundRules, "proxy peer should have inbound firewall rules from the user peer")
}
func TestInjectPrivateServicePolicies_NotPrivate_NoPolicy(t *testing.T) {
ctx := context.Background()
account := privateServiceTestAccount(t)
account.Services[0].Private = false
account.InjectProxyPolicies(ctx)
assert.False(t, hasPrivateAccessPolicy(account, "svc-1"), "non-private service must not synthesise an access policy")
}
func TestInjectPrivateServicePolicies_EmptyAccessGroups_NoPolicy(t *testing.T) {
ctx := context.Background()
account := privateServiceTestAccount(t)
account.Services[0].AccessGroups = nil
account.InjectProxyPolicies(ctx)
assert.False(t, hasPrivateAccessPolicy(account, "svc-1"), "private service with no access groups must not synthesise a policy")
}
func TestInjectPrivateServicePolicies_NoProxyPeers_NoPolicy(t *testing.T) {
ctx := context.Background()
account := privateServiceTestAccount(t)
delete(account.Peers, "proxy-peer")
account.InjectProxyPolicies(ctx)
assert.False(t, hasPrivateAccessPolicy(account, "svc-1"), "policy must not synthesise when the cluster has no proxy peers")
}
func privateServiceTestAccount(t *testing.T) *Account {
t.Helper()
return &Account{
Id: "acct-1",
Network: &Network{
Identifier: "net-1",
Net: net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32)},
},
Peers: map[string]*nbpeer.Peer{
"user-peer": {
ID: "user-peer",
AccountID: "acct-1",
Key: "user-peer-key",
IP: netip.MustParseAddr("100.64.0.10"),
},
"proxy-peer": {
ID: "proxy-peer",
AccountID: "acct-1",
Key: "proxy-peer-key",
IP: netip.MustParseAddr("100.64.0.99"),
ProxyMeta: nbpeer.ProxyMeta{
Embedded: true,
Cluster: "eu.proxy.netbird.io",
},
},
},
Groups: map[string]*Group{
"grp-admins": {
ID: "grp-admins",
Name: "admins",
Peers: []string{"user-peer"},
},
},
Services: []*service.Service{
{
ID: "svc-1",
AccountID: "acct-1",
Name: "myapp",
Domain: "myapp.eu.proxy.netbird.io",
ProxyCluster: "eu.proxy.netbird.io",
Enabled: true,
Private: true,
Mode: service.ModeHTTP,
AccessGroups: []string{"grp-admins"},
Targets: []*service.Target{
{
TargetId: "eu.proxy.netbird.io",
TargetType: service.TargetTypeCluster,
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Enabled: true,
},
},
},
},
}
}
func hasPrivateAccessPolicy(account *Account, serviceID string) bool {
prefix := "private-access-" + serviceID + "-"
for _, p := range account.Policies {
if p != nil && len(p.ID) > len(prefix) && p.ID[:len(prefix)] == prefix {
return true
}
}
return false
}

View File

@@ -10,7 +10,6 @@ import (
type Manager interface {
GetUser(ctx context.Context, userID string) (*types.User, error)
GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error)
}
type managerImpl struct {
@@ -30,31 +29,6 @@ func (m *managerImpl) GetUser(ctx context.Context, userID string) (*types.User,
return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
}
// GetUserWithGroups returns the user and the *types.Group records for the user's AutoGroups, in the same order as
// AutoGroups. Group ids that don't resolve to a stored group are skipped from the returned slice (the parallel id list is
// derivable from the returned User). Wraps two store calls today; can be optimised to a single JOIN later if needed.
// Any store error returns (nil, nil, err) so callers never receive a valid user alongside a non-nil error.
func (m *managerImpl) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) {
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
return nil, nil, err
}
if len(user.AutoGroups) == 0 {
return user, nil, nil
}
groupsMap, err := m.store.GetGroupsByIDs(ctx, store.LockingStrengthNone, user.AccountID, user.AutoGroups)
if err != nil {
return nil, nil, err
}
groups := make([]*types.Group, 0, len(user.AutoGroups))
for _, id := range user.AutoGroups {
if g, ok := groupsMap[id]; ok && g != nil {
groups = append(groups, g)
}
}
return user, groups, nil
}
func NewManagerMock() Manager {
return &managerMock{}
}
@@ -73,11 +47,3 @@ func (m *managerMock) GetUser(ctx context.Context, userID string) (*types.User,
return nil, errors.New("user not found")
}
}
func (m *managerMock) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) {
user, err := m.GetUser(ctx, userID)
if err != nil {
return nil, nil, err
}
return user, nil, nil
}

View File

@@ -45,14 +45,10 @@ func ResolveProto(forwardedProto string, conn *tls.ConnectionState) string {
}
}
// ValidateSessionJWT validates a session JWT and returns the user ID, the
// user's email (when carried), the authentication method, any embedded
// group memberships, and the parallel group display names. email,
// groups, and groupNames may be empty for tokens minted before those
// claims were introduced. groupNames pairs positionally with groups.
func ValidateSessionJWT(tokenString, domain string, publicKey ed25519.PublicKey) (userID, email, method string, groups, groupNames []string, err error) {
// ValidateSessionJWT validates a session JWT and returns the user ID and method.
func ValidateSessionJWT(tokenString, domain string, publicKey ed25519.PublicKey) (userID, method string, err error) {
if publicKey == nil {
return "", "", "", nil, nil, fmt.Errorf("no public key configured for domain")
return "", "", fmt.Errorf("no public key configured for domain")
}
token, err := jwt.Parse(tokenString, func(t *jwt.Token) (interface{}, error) {
@@ -62,46 +58,20 @@ func ValidateSessionJWT(tokenString, domain string, publicKey ed25519.PublicKey)
return publicKey, nil
}, jwt.WithAudience(domain), jwt.WithIssuer(SessionJWTIssuer))
if err != nil {
return "", "", "", nil, nil, fmt.Errorf("parse token: %w", err)
return "", "", fmt.Errorf("parse token: %w", err)
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
return "", "", "", nil, nil, fmt.Errorf("invalid token claims")
return "", "", fmt.Errorf("invalid token claims")
}
sub, _ := claims.GetSubject()
if sub == "" {
return "", "", "", nil, nil, fmt.Errorf("missing subject claim")
return "", "", fmt.Errorf("missing subject claim")
}
methodClaim, _ := claims["method"].(string)
emailClaim, _ := claims["email"].(string)
groups = extractGroupsClaim(claims["groups"])
groupNames = extractGroupsClaim(claims["group_names"])
return sub, emailClaim, methodClaim, groups, groupNames, nil
}
// extractGroupsClaim decodes the "groups" claim into a string slice. The JWT
// library decodes JSON arrays as []interface{}, so we coerce element-wise
// and skip non-string entries silently.
func extractGroupsClaim(claim interface{}) []string {
raw, ok := claim.([]interface{})
if !ok {
return nil
}
if len(raw) == 0 {
return nil
}
groups := make([]string, 0, len(raw))
for _, v := range raw {
if s, ok := v.(string); ok && s != "" {
groups = append(groups, s)
}
}
if len(groups) == 0 {
return nil
}
return groups
return sub, methodClaim, nil
}

View File

@@ -63,7 +63,6 @@ var (
preSharedKey string
supportsCustomPorts bool
requireSubdomain bool
private bool
geoDataDir string
crowdsecAPIURL string
crowdsecAPIKey string
@@ -106,8 +105,6 @@ func init() {
rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers")
rootCmd.Flags().BoolVar(&supportsCustomPorts, "supports-custom-ports", envBoolOrDefault("NB_PROXY_SUPPORTS_CUSTOM_PORTS", true), "Whether the proxy can bind arbitrary ports for UDP/TCP passthrough")
rootCmd.Flags().BoolVar(&requireSubdomain, "require-subdomain", envBoolOrDefault("NB_PROXY_REQUIRE_SUBDOMAIN", false), "Require a subdomain label in front of the cluster domain")
rootCmd.Flags().BoolVar(&private, "private", envBoolOrDefault("NB_PROXY_PRIVATE", false), "Enable private services accessible with NetBird-Only authentication mode.")
_ = rootCmd.Flags().MarkHidden("private")
rootCmd.Flags().DurationVar(&maxDialTimeout, "max-dial-timeout", envDurationOrDefault("NB_PROXY_MAX_DIAL_TIMEOUT", 0), "Cap per-service backend dial timeout (0 = no cap)")
rootCmd.Flags().DurationVar(&maxSessionIdleTimeout, "max-session-idle-timeout", envDurationOrDefault("NB_PROXY_MAX_SESSION_IDLE_TIMEOUT", 0), "Cap per-service session idle timeout (0 = no cap)")
rootCmd.Flags().StringVar(&geoDataDir, "geo-data-dir", envStringOrDefault("NB_PROXY_GEO_DATA_DIR", "/var/lib/netbird/geolocation"), "Directory for the GeoLite2 MMDB file (auto-downloaded if missing)")
@@ -164,8 +161,7 @@ func runServer(cmd *cobra.Command, args []string) error {
return fmt.Errorf("invalid --trusted-proxies: %w", err)
}
srv := proxy.New(proxy.Config{
ListenAddr: addr,
srv := proxy.Server{
Logger: logger,
Version: Version,
ManagementAddress: mgmtAddr,
@@ -182,7 +178,7 @@ func runServer(cmd *cobra.Command, args []string) error {
ACMEChallengeType: acmeChallengeType,
DebugEndpointEnabled: debugEndpoint,
DebugEndpointAddress: debugEndpointAddr,
HealthAddr: healthAddr,
HealthAddress: healthAddr,
ForwardedProto: forwardedProto,
TrustedProxies: parsedTrustedProxies,
CertLockMethod: nbacme.CertLockMethod(certLockMethod),
@@ -192,13 +188,12 @@ func runServer(cmd *cobra.Command, args []string) error {
PreSharedKey: preSharedKey,
SupportsCustomPorts: supportsCustomPorts,
RequireSubdomain: requireSubdomain,
Private: private,
MaxDialTimeout: maxDialTimeout,
MaxSessionIdleTimeout: maxSessionIdleTimeout,
GeoDataDir: geoDataDir,
CrowdSecAPIURL: crowdsecAPIURL,
CrowdSecAPIKey: crowdsecAPIKey,
})
}
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
defer stop()

View File

@@ -4,7 +4,6 @@ import (
"context"
"io"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
@@ -60,7 +59,7 @@ func TestHandleMappingStream_SyncCompleteFlag(t *testing.T) {
}
syncDone := false
err := s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{})
err := s.handleMappingStream(context.Background(), stream, &syncDone)
assert.NoError(t, err)
assert.True(t, syncDone, "initial sync should be marked done when flag is set")
}
@@ -80,7 +79,7 @@ func TestHandleMappingStream_NoSyncFlagDoesNotMarkDone(t *testing.T) {
}
syncDone := false
err := s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{})
err := s.handleMappingStream(context.Background(), stream, &syncDone)
assert.NoError(t, err)
assert.False(t, syncDone, "initial sync should not be marked done without flag")
}
@@ -98,7 +97,7 @@ func TestHandleMappingStream_NilHealthChecker(t *testing.T) {
}
syncDone := false
err := s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{})
err := s.handleMappingStream(context.Background(), stream, &syncDone)
assert.NoError(t, err)
assert.True(t, syncDone, "sync done flag should be set even without health checker")
}

View File

@@ -1,547 +0,0 @@
package proxy
import (
"context"
"crypto/tls"
"errors"
"fmt"
stdlog "log"
"net"
"net/http"
"net/netip"
"strconv"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/embed"
"github.com/netbirdio/netbird/proxy/internal/auth"
"github.com/netbirdio/netbird/proxy/internal/debug"
nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp"
"github.com/netbirdio/netbird/proxy/internal/types"
)
// httpInboundReadHeaderTimeout matches the host-listener read header timeout
// so per-account http.Servers don't leak idle connections.
const httpInboundReadHeaderTimeout = 30 * time.Second
// httpInboundIdleTimeout caps idle keep-alive on per-account inbound HTTP
// servers; matches the host listener.
const httpInboundIdleTimeout = 90 * time.Second
// inboundShutdownTimeout caps how long a per-account http.Server gets to
// drain in-flight requests during teardown.
const inboundShutdownTimeout = 5 * time.Second
// privateInboundPortHTTPS is the WG-side TLS port. Each account's
// embedded netstack binds independently, so a fixed port is fine.
const privateInboundPortHTTPS = 443
// privateInboundPortHTTP is the WG-side plain-HTTP port.
const privateInboundPortHTTP = 80
// inboundManager wires per-account inbound listeners into the proxy
// pipeline when --private-inbound is enabled. When disabled the manager
// is nil and every method on *Server that touches it short-circuits.
type inboundManager struct {
logger *log.Logger
handler http.Handler
tlsConfig *tls.Config
// muxLock guards entries and pendingRoutes.
muxLock sync.Mutex
entries map[types.AccountID]*inboundEntry
pendingRoutes map[types.AccountID][]pendingInboundRoute
}
// inboundEntry owns the listeners, router and HTTP servers for a single
// account's embedded netstack.
type inboundEntry struct {
router *nbtcp.Router
tlsListener net.Listener
plainListener net.Listener
httpsServer *http.Server
httpServer *http.Server
cancel context.CancelFunc
wg sync.WaitGroup
}
// pendingInboundRoute holds a route that arrived before the account's
// listener finished starting.
type pendingInboundRoute struct {
host nbtcp.SNIHost
route nbtcp.Route
}
// newInboundManager constructs a manager bound to the proxy's HTTP
// handler chain and TLS config.
func newInboundManager(logger *log.Logger, handler http.Handler, tlsConfig *tls.Config) *inboundManager {
return &inboundManager{
logger: logger,
handler: handler,
tlsConfig: tlsConfig,
entries: make(map[types.AccountID]*inboundEntry),
pendingRoutes: make(map[types.AccountID][]pendingInboundRoute),
}
}
// onClientReady is registered with NetBird.SetClientLifecycle so the
// listener pair comes up exactly when the embedded client reports ready.
// The returned value is opaque to the roundtrip package; it is handed
// back verbatim to onClientStop on teardown.
func (m *inboundManager) onClientReady(ctx context.Context, accountID types.AccountID, client *embed.Client) any {
if m == nil {
return nil
}
entry, err := m.bringUp(ctx, accountID, client)
if err != nil {
m.logger.WithField("account_id", accountID).WithError(err).Warn("failed to start per-account inbound listener; continuing without inbound")
return nil
}
m.flushPending(accountID, entry)
m.logger.WithFields(log.Fields{
"account_id": accountID,
"https": entry.tlsListener.Addr().String(),
"http": entry.plainListener.Addr().String(),
}).Info("per-account inbound listeners up")
return entry
}
// onClientStop tears down a per-account listener bundle. State is the
// opaque value previously returned by onClientReady.
func (m *inboundManager) onClientStop(accountID types.AccountID, state any) {
if m == nil {
return
}
entry, ok := state.(*inboundEntry)
if !ok || entry == nil {
return
}
m.tearDown(accountID, entry)
}
// bringUp opens both listeners on the account's netstack, builds the
// router, and starts the parallel HTTP servers.
func (m *inboundManager) bringUp(ctx context.Context, accountID types.AccountID, client *embed.Client) (*inboundEntry, error) {
tlsListener, err := client.ListenTCP(fmt.Sprintf(":%d", privateInboundPortHTTPS))
if err != nil {
return nil, fmt.Errorf("listen tls on netstack: %w", err)
}
plainListener, err := client.ListenTCP(fmt.Sprintf(":%d", privateInboundPortHTTP))
if err != nil {
_ = tlsListener.Close()
return nil, fmt.Errorf("listen plain on netstack: %w", err)
}
router := nbtcp.NewRouter(m.logger, accountDialResolver(accountID, client), tlsListener.Addr(), nbtcp.WithPlainHTTP(plainListener.Addr()))
scopedHandler := withTunnelLookup(m.handler, accountTunnelLookup(client))
// markOverlayOrigin stamps every connection accepted by an inbound
// listener with a context value middlewares can read to skip
// geo/CrowdSec checks (the source address is always inside the
// NetBird CGNAT range and won't match either dataset).
markOverlayOrigin := func(ctx context.Context, _ net.Conn) context.Context {
return types.WithOverlayOrigin(ctx)
}
httpsServer := &http.Server{
Handler: scopedHandler,
TLSConfig: m.tlsConfig,
ReadHeaderTimeout: httpInboundReadHeaderTimeout,
IdleTimeout: httpInboundIdleTimeout,
ErrorLog: newInboundErrorLog(m.logger, "https", accountID),
ConnContext: markOverlayOrigin,
}
httpServer := &http.Server{
Handler: scopedHandler,
ReadHeaderTimeout: httpInboundReadHeaderTimeout,
IdleTimeout: httpInboundIdleTimeout,
ErrorLog: newInboundErrorLog(m.logger, "http", accountID),
ConnContext: markOverlayOrigin,
}
runCtx, cancel := context.WithCancel(ctx)
entry := &inboundEntry{
router: router,
tlsListener: tlsListener,
plainListener: plainListener,
httpsServer: httpsServer,
httpServer: httpServer,
cancel: cancel,
}
entry.wg.Add(1)
go func() {
defer entry.wg.Done()
if err := router.Serve(runCtx, tlsListener); err != nil {
m.logger.WithField("account_id", accountID).Debugf("per-account router stopped: %v", err)
}
}()
entry.wg.Add(1)
go func() {
defer entry.wg.Done()
if err := httpsServer.ServeTLS(router.HTTPListener(), "", ""); err != nil && !errors.Is(err, http.ErrServerClosed) {
m.logger.WithField("account_id", accountID).Debugf("per-account https server stopped: %v", err)
}
}()
entry.wg.Add(1)
go func() {
defer entry.wg.Done()
if err := httpServer.Serve(router.HTTPListenerPlain()); err != nil && !errors.Is(err, http.ErrServerClosed) {
m.logger.WithField("account_id", accountID).Debugf("per-account http server stopped: %v", err)
}
}()
entry.wg.Add(1)
go func() {
defer entry.wg.Done()
feedRouterFromListener(runCtx, plainListener, router, m.logger, accountID)
}()
m.muxLock.Lock()
m.entries[accountID] = entry
m.muxLock.Unlock()
return entry, nil
}
// tearDown shuts every goroutine down and closes the netstack listeners.
func (m *inboundManager) tearDown(accountID types.AccountID, entry *inboundEntry) {
m.muxLock.Lock()
if m.entries[accountID] == entry {
delete(m.entries, accountID)
delete(m.pendingRoutes, accountID)
}
m.muxLock.Unlock()
entry.cancel()
shutdownCtx, cancel := context.WithTimeout(context.Background(), inboundShutdownTimeout)
defer cancel()
if err := entry.httpsServer.Shutdown(shutdownCtx); err != nil {
m.logger.Debugf("per-account https shutdown: %v", err)
}
if err := entry.httpServer.Shutdown(shutdownCtx); err != nil {
m.logger.Debugf("per-account http shutdown: %v", err)
}
if err := entry.tlsListener.Close(); err != nil {
m.logger.Debugf("close per-account tls listener: %v", err)
}
if err := entry.plainListener.Close(); err != nil {
m.logger.Debugf("close per-account plain listener: %v", err)
}
entry.wg.Wait()
}
// AddRoute records an SNI/host route on the account's per-account router.
// Routes registered before the listener is up are queued and replayed
// once startup completes.
func (m *inboundManager) AddRoute(accountID types.AccountID, host nbtcp.SNIHost, route nbtcp.Route) {
if m == nil {
return
}
m.muxLock.Lock()
entry, ok := m.entries[accountID]
if !ok {
m.queuePendingLocked(accountID, host, route)
m.muxLock.Unlock()
return
}
router := entry.router
m.muxLock.Unlock()
router.AddRoute(host, route)
}
// RemoveRoute drops a previously registered route. Safe to call when the
// listener is not yet up; queued copies are pruned in that case.
func (m *inboundManager) RemoveRoute(accountID types.AccountID, host nbtcp.SNIHost, svcID types.ServiceID) {
if m == nil {
return
}
m.muxLock.Lock()
m.dropPendingLocked(accountID, host, svcID)
entry, ok := m.entries[accountID]
if !ok {
m.muxLock.Unlock()
return
}
router := entry.router
m.muxLock.Unlock()
router.RemoveRoute(host, svcID)
}
// queuePendingLocked stores or upserts a pending route. Caller holds muxLock.
func (m *inboundManager) queuePendingLocked(accountID types.AccountID, host nbtcp.SNIHost, route nbtcp.Route) {
queued := m.pendingRoutes[accountID]
for i, pr := range queued {
if pr.host == host && pr.route.ServiceID == route.ServiceID {
queued[i] = pendingInboundRoute{host: host, route: route}
m.pendingRoutes[accountID] = queued
return
}
}
m.pendingRoutes[accountID] = append(queued, pendingInboundRoute{host: host, route: route})
}
// dropPendingLocked removes any queued route matching host/svcID.
// Caller holds muxLock.
func (m *inboundManager) dropPendingLocked(accountID types.AccountID, host nbtcp.SNIHost, svcID types.ServiceID) {
queued, ok := m.pendingRoutes[accountID]
if !ok {
return
}
filtered := queued[:0]
for _, pr := range queued {
if pr.host == host && pr.route.ServiceID == svcID {
continue
}
filtered = append(filtered, pr)
}
if len(filtered) == 0 {
delete(m.pendingRoutes, accountID)
return
}
m.pendingRoutes[accountID] = filtered
}
// flushPending applies all queued routes to a freshly-up router.
func (m *inboundManager) flushPending(accountID types.AccountID, entry *inboundEntry) {
m.muxLock.Lock()
queued := m.pendingRoutes[accountID]
delete(m.pendingRoutes, accountID)
m.muxLock.Unlock()
for _, pr := range queued {
entry.router.AddRoute(pr.host, pr.route)
}
}
// HasInbound reports whether the manager has a live listener for the account.
// Used by tests.
func (m *inboundManager) HasInbound(accountID types.AccountID) bool {
if m == nil {
return false
}
m.muxLock.Lock()
defer m.muxLock.Unlock()
_, ok := m.entries[accountID]
return ok
}
// PendingRouteCount reports the number of queued routes for the account.
// Used by tests.
func (m *inboundManager) PendingRouteCount(accountID types.AccountID) int {
if m == nil {
return 0
}
m.muxLock.Lock()
defer m.muxLock.Unlock()
return len(m.pendingRoutes[accountID])
}
// InboundListenerInfo describes the bound addresses of a single
// per-account inbound listener. Both addresses live on the embedded
// netstack of the account's WireGuard client and share the same tunnel IP.
type InboundListenerInfo struct {
TunnelIP string
HTTPSPort uint16
HTTPPort uint16
}
// ListenerInfo returns the inbound listener addresses for the given
// account, or ok=false when the account has no live listener. Used by
// the status-update RPC and the debug HTTP handler to surface inbound
// reachability to operators.
func (m *inboundManager) ListenerInfo(accountID types.AccountID) (InboundListenerInfo, bool) {
if m == nil {
return InboundListenerInfo{}, false
}
m.muxLock.Lock()
defer m.muxLock.Unlock()
entry, ok := m.entries[accountID]
if !ok || entry == nil {
return InboundListenerInfo{}, false
}
return listenerInfoFromEntry(entry), true
}
// Snapshot returns the inbound listener state for every account that has
// a live listener at call time. Empty when --private-inbound is off or
// no accounts have come up yet.
func (m *inboundManager) Snapshot() map[types.AccountID]InboundListenerInfo {
if m == nil {
return nil
}
m.muxLock.Lock()
defer m.muxLock.Unlock()
if len(m.entries) == 0 {
return nil
}
out := make(map[types.AccountID]InboundListenerInfo, len(m.entries))
for id, entry := range m.entries {
if entry == nil {
continue
}
out[id] = listenerInfoFromEntry(entry)
}
return out
}
// listenerInfoFromEntry extracts the tunnel IP and ports from a live
// per-account entry. Both listeners are bound on the same netstack so
// their host components match; we still pull the TLS host as the
// authoritative source.
func listenerInfoFromEntry(entry *inboundEntry) InboundListenerInfo {
info := InboundListenerInfo{HTTPSPort: privateInboundPortHTTPS, HTTPPort: privateInboundPortHTTP}
if entry.tlsListener != nil {
host, port := splitHostPort(entry.tlsListener.Addr())
info.TunnelIP = host
if port != 0 {
info.HTTPSPort = port
}
}
if entry.plainListener != nil {
host, port := splitHostPort(entry.plainListener.Addr())
if info.TunnelIP == "" {
info.TunnelIP = host
}
if port != 0 {
info.HTTPPort = port
}
}
return info
}
// splitHostPort extracts host and port from a net.Addr, returning the
// zero values when the address is missing or malformed.
func splitHostPort(addr net.Addr) (string, uint16) {
if addr == nil {
return "", 0
}
host, portStr, err := net.SplitHostPort(addr.String())
if err != nil {
return "", 0
}
if portStr == "" {
return host, 0
}
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return host, 0
}
return host, uint16(port)
}
// feedRouterFromListener accepts on the plain-HTTP netstack listener and
// hands every connection to the account's router. The router peeks the
// first byte and dispatches to the plain-HTTP channel for non-TLS
// streams or the TLS channel for ClientHellos that arrive on :80.
func feedRouterFromListener(ctx context.Context, ln net.Listener, router *nbtcp.Router, logger *log.Logger, accountID types.AccountID) {
go func() {
<-ctx.Done()
_ = ln.Close()
}()
for {
conn, err := ln.Accept()
if err != nil {
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
return
}
logger.WithField("account_id", accountID).Debugf("plain inbound accept: %v", err)
continue
}
router.HandleConn(ctx, conn)
}
}
// accountDialResolver returns a DialResolver bound to a single account's
// embedded client. The router only ever serves traffic for that account
// so the supplied accountID is ignored at dial time.
func accountDialResolver(_ types.AccountID, client *embed.Client) nbtcp.DialResolver {
return func(_ types.AccountID) (types.DialContextFunc, error) {
return client.DialContext, nil
}
}
// accountTunnelLookup returns a TunnelLookupFunc backed by the embedded
// client's peerstore for a single account. Phase 3 uses the result to
// short-circuit ValidateTunnelPeer when the source IP is not in the
// account's roster and to seed the cached identity for known peers.
func accountTunnelLookup(client *embed.Client) auth.TunnelLookupFunc {
if client == nil {
return nil
}
return func(ip netip.Addr) (auth.PeerIdentity, bool) {
pubKey, fqdn, ok := client.IdentityForIP(ip)
if !ok {
return auth.PeerIdentity{}, false
}
return auth.PeerIdentity{
PubKey: pubKey,
TunnelIP: ip,
FQDN: fqdn,
}, true
}
}
// withTunnelLookup returns an http.Handler that attaches the per-account
// peerstore lookup to every request's context before delegating to next.
// Calling on the host-level listener is a no-op because that path never
// installs this wrapper, so the existing behaviour stays byte-for-byte
// identical when --private-inbound is off or the request didn't arrive
// on a per-account listener.
func withTunnelLookup(next http.Handler, lookup auth.TunnelLookupFunc) http.Handler {
if lookup == nil {
return next
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := auth.WithTunnelLookup(r.Context(), lookup)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// inboundDebugAdapter adapts *inboundManager to the debug.InboundProvider
// interface so the debug HTTP handler can render per-account inbound
// listener state without importing the proxy package.
type inboundDebugAdapter struct {
mgr *inboundManager
}
// InboundListeners returns a snapshot of the live per-account inbound
// listeners formatted for the debug surface.
func (a inboundDebugAdapter) InboundListeners() map[types.AccountID]debug.InboundListenerInfo {
if a.mgr == nil {
return nil
}
snap := a.mgr.Snapshot()
if len(snap) == 0 {
return nil
}
out := make(map[types.AccountID]debug.InboundListenerInfo, len(snap))
for id, info := range snap {
out[id] = debug.InboundListenerInfo{
TunnelIP: info.TunnelIP,
HTTPSPort: info.HTTPSPort,
HTTPPort: info.HTTPPort,
}
}
return out
}
// newInboundErrorLog routes a per-account http.Server's stdlib error
// stream through logrus at warn level.
func newInboundErrorLog(logger *log.Logger, scheme string, accountID types.AccountID) *stdlog.Logger {
return stdlog.New(logger.WithFields(log.Fields{
"inbound-http": scheme,
"account_id": accountID,
}).WriterLevel(log.WarnLevel), "", 0)
}

View File

@@ -1,502 +0,0 @@
package proxy
import (
"bufio"
"context"
"crypto/tls"
"net"
"net/http"
"net/http/httptest"
"net/netip"
"sync"
"sync/atomic"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/proxy/internal/auth"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// bufioReader wraps the connection in a buffered reader so http.ReadResponse
// can parse the response line + headers off the wire.
func bufioReader(conn net.Conn) *bufio.Reader {
return bufio.NewReader(conn)
}
// quietLogger returns a logger that emits nothing — keeps test output tidy.
func quietLogger() *log.Logger {
logger := log.New()
logger.SetLevel(log.PanicLevel)
return logger
}
func TestInboundManager_RouteScopedToAccount(t *testing.T) {
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
accountA := types.AccountID("acct-a")
accountB := types.AccountID("acct-b")
mgr.AddRoute(accountA, "shared.example", nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountA, ServiceID: "svc-a", Domain: "shared.example"})
mgr.AddRoute(accountB, "other.example", nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountB, ServiceID: "svc-b", Domain: "other.example"})
require.Equal(t, 1, mgr.PendingRouteCount(accountA), "account A should have one queued route")
require.Equal(t, 1, mgr.PendingRouteCount(accountB), "account B should have one queued route")
mgr.RemoveRoute(accountA, "shared.example", "svc-a")
mgr.RemoveRoute(accountB, "other.example", "svc-b")
assert.Equal(t, 0, mgr.PendingRouteCount(accountA), "queue should drain on remove")
assert.Equal(t, 0, mgr.PendingRouteCount(accountB), "queue should drain on remove")
}
func TestInboundManager_PendingThenFlush(t *testing.T) {
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
accountID := types.AccountID("acct-1")
host := nbtcp.SNIHost("example.test")
route := nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountID, ServiceID: "svc-1", Domain: "example.test"}
mgr.AddRoute(accountID, host, route)
require.Equal(t, 1, mgr.PendingRouteCount(accountID), "pending count before listener is up")
// Simulate listener up by registering a fake entry, then flushing.
router := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"})
entry := &inboundEntry{router: router}
mgr.muxLock.Lock()
mgr.entries[accountID] = entry
mgr.muxLock.Unlock()
mgr.flushPending(accountID, entry)
assert.Equal(t, 0, mgr.PendingRouteCount(accountID), "queue should be empty after flush")
}
// fakeAddr is a stub net.Addr for tests that don't actually bind sockets.
type fakeAddr struct {
addr string
}
func (a *fakeAddr) Network() string { return "tcp" }
func (a *fakeAddr) String() string { return a.addr }
// fakeMgmtClient implements roundtrip.managementClient for tests.
type fakeMgmtClient struct{}
func (fakeMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
return &proto.CreateProxyPeerResponse{Success: true}, nil
}
// TestServer_PrivateInbound_NotEnabled_NoManager confirms that with
// --private off the inbound manager is nil and the standalone proxy
// keeps its zero-overhead default path.
func TestServer_PrivateInbound_NotEnabled_NoManager(t *testing.T) {
s := &Server{Logger: quietLogger(), Private: false}
s.initPrivateInbound(http.NotFoundHandler(), nil)
assert.Nil(t, s.inbound, "manager should remain nil when --private is off")
}
// TestServer_PrivateInbound_Enabled_WiresLifecycle confirms that
// --private alone wires the manager into the NetBird transport, so
// AddPeer / RemovePeer drive the lifecycle.
func TestServer_PrivateInbound_Enabled_WiresLifecycle(t *testing.T) {
s := &Server{Logger: quietLogger(), Private: true}
// Construct a NetBird transport. We can't actually start the embedded
// client here (that needs a real management server), but we can
// confirm that the lifecycle callbacks are registered.
s.netbird = roundtrip.NewNetBird("test", "test", roundtrip.ClientConfig{
MgmtAddr: "http://invalid.test",
}, quietLogger(), nil, fakeMgmtClient{})
s.initPrivateInbound(http.NotFoundHandler(), &tls.Config{}) //nolint:gosec
require.NotNil(t, s.inbound, "manager should be set when --private is on")
assert.NotNil(t, s.inbound.handler, "handler should be set on manager")
assert.NotNil(t, s.inbound.tlsConfig, "tls config should be set on manager")
}
// TestInboundManager_AddRouteAfterReady_RegistersDirectly verifies that
// when the listener is already up, AddRoute writes straight to the
// router without queueing.
func TestInboundManager_AddRouteAfterReady_RegistersDirectly(t *testing.T) {
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
accountID := types.AccountID("acct-1")
router := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"})
mgr.muxLock.Lock()
mgr.entries[accountID] = &inboundEntry{router: router}
mgr.muxLock.Unlock()
host := nbtcp.SNIHost("ready.example")
mgr.AddRoute(accountID, host, nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountID, ServiceID: "svc-ready", Domain: string(host)})
assert.Equal(t, 0, mgr.PendingRouteCount(accountID), "no pending entries when listener is up")
}
// TestPrivateCapability_DerivedFromPrivateOnly tests that the capability
// bit reported upstream tracks --private exclusively. The previous
// --private-inbound flag has been folded into --private.
func TestPrivateCapability_DerivedFromPrivateOnly(t *testing.T) {
tests := []struct {
name string
private bool
expected bool
}{
{"off", false, false},
{"on", true, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Server{Private: tt.private}
assert.Equal(t, tt.expected, s.Private, "private capability bit should match --private")
})
}
}
// TestInboundManager_RouteScopedToAccountB_DoesNotMatchA verifies that a
// service registered for account B is invisible to a router serving
// account A. We exercise the path through real per-account routers.
func TestInboundManager_RouteScopedToAccountB_DoesNotMatchA(t *testing.T) {
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
accountA := types.AccountID("acct-a")
accountB := types.AccountID("acct-b")
routerA := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"})
routerB := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"})
mgr.muxLock.Lock()
mgr.entries[accountA] = &inboundEntry{router: routerA}
mgr.entries[accountB] = &inboundEntry{router: routerB}
mgr.muxLock.Unlock()
host := nbtcp.SNIHost("shared.example")
mgr.AddRoute(accountB, host, nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountB, ServiceID: "svc-b", Domain: string(host)})
// Account A's router should have no routes; account B's should have one.
// We check via IsEmpty — true means no routes and no fallback.
assert.True(t, routerA.IsEmpty(), "account A router must not see account B's mappings")
assert.False(t, routerB.IsEmpty(), "account B router should hold its own mapping")
}
// TestInboundEntry_ShutdownIdempotent ensures that tearDown can run twice
// without panicking — callers may invoke it from RemovePeer + StopAll.
func TestInboundEntry_ShutdownIdempotent(t *testing.T) {
t.Skip("teardown requires real netstack listeners; covered by integration tests")
}
// TestRouter_PlainHTTP_ForwardedProtoIsHTTP exercises the full per-account
// router pipeline against a loopback listener (proxy of a netstack
// listener for test purposes): a plain HTTP request lands on the plain
// http.Server and the inner handler observes a nil r.TLS, which is what
// auth.ResolveProto translates to "http" in the real pipeline.
func TestRouter_PlainHTTP_ForwardedProtoIsHTTP(t *testing.T) {
logger := quietLogger()
var captured atomic.Value
captured.Store("")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil {
captured.Store("http")
} else {
captured.Store("https")
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
})
hostListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "loopback listener bind must succeed")
defer hostListener.Close()
router := nbtcp.NewRouter(logger, nil, hostListener.Addr(), nbtcp.WithPlainHTTP(hostListener.Addr()))
httpServer := &http.Server{Handler: handler, ReadHeaderTimeout: time.Second}
defer func() { _ = httpServer.Close() }()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = httpServer.Serve(router.HTTPListenerPlain()) }()
go func() { _ = router.Serve(ctx, hostListener) }()
conn, err := net.DialTimeout("tcp", hostListener.Addr().String(), 2*time.Second)
require.NoError(t, err, "plain HTTP dial must succeed")
defer conn.Close()
_, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\nConnection: close\r\n\r\n"))
require.NoError(t, err, "write must succeed")
resp, err := http.ReadResponse(bufioReader(conn), nil)
require.NoError(t, err, "must read response")
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "http", captured.Load(), "ForwardedProto must be http on plain path")
}
// TestWithTunnelLookup_AttachesLookupToContext verifies that requests
// flowing through the per-account handler wrapper carry the peerstore
// lookup function. Phase 3's local-first deny path depends on this.
func TestWithTunnelLookup_AttachesLookupToContext(t *testing.T) {
expected := auth.PeerIdentity{TunnelIP: netip.MustParseAddr("100.64.0.10"), FQDN: "peer.netbird"}
lookup := auth.TunnelLookupFunc(func(_ netip.Addr) (auth.PeerIdentity, bool) {
return expected, true
})
var observed auth.TunnelLookupFunc
inner := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
observed = auth.TunnelLookupFromContext(r.Context())
})
handler := withTunnelLookup(inner, lookup)
r := httptest.NewRequest(http.MethodGet, "https://svc.example/", nil)
handler.ServeHTTP(httptest.NewRecorder(), r)
require.NotNil(t, observed, "wrapper must inject the lookup into the request context")
got, ok := observed(netip.MustParseAddr("100.64.0.10"))
assert.True(t, ok, "lookup must round-trip through context")
assert.Equal(t, expected.FQDN, got.FQDN, "lookup must return the same identity it was constructed with")
}
// TestWithTunnelLookup_NilLookupIsNoop confirms the wrapper is a pure
// pass-through when no lookup is provided. Required for the host-level
// listener path to keep its byte-for-byte previous behaviour.
func TestWithTunnelLookup_NilLookupIsNoop(t *testing.T) {
var called bool
inner := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
called = true
assert.Nil(t, auth.TunnelLookupFromContext(r.Context()), "host-level path must not see a lookup function")
})
handler := withTunnelLookup(inner, nil)
r := httptest.NewRequest(http.MethodGet, "https://svc.example/", nil)
handler.ServeHTTP(httptest.NewRecorder(), r)
assert.True(t, called, "wrapper without lookup must still invoke next")
}
// fakeListener satisfies net.Listener for snapshot tests without binding
// a real socket on the netstack.
type fakeListener struct {
addr net.Addr
}
func (f *fakeListener) Accept() (net.Conn, error) { return nil, net.ErrClosed }
func (f *fakeListener) Close() error { return nil }
func (f *fakeListener) Addr() net.Addr { return f.addr }
// TestInboundManager_ListenerInfo confirms ListenerInfo and Snapshot
// surface the bound tunnel-IP and ports for live entries.
func TestInboundManager_ListenerInfo(t *testing.T) {
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
accountID := types.AccountID("acct-info")
tlsAddr := &net.TCPAddr{IP: net.ParseIP("100.64.0.5"), Port: privateInboundPortHTTPS}
plainAddr := &net.TCPAddr{IP: net.ParseIP("100.64.0.5"), Port: privateInboundPortHTTP}
mgr.muxLock.Lock()
mgr.entries[accountID] = &inboundEntry{
tlsListener: &fakeListener{addr: tlsAddr},
plainListener: &fakeListener{addr: plainAddr},
}
mgr.muxLock.Unlock()
info, ok := mgr.ListenerInfo(accountID)
require.True(t, ok, "ListenerInfo must report ok for live entry")
assert.Equal(t, "100.64.0.5", info.TunnelIP, "tunnel IP must come from listener address")
assert.Equal(t, uint16(privateInboundPortHTTPS), info.HTTPSPort, "TLS port must match bound port")
assert.Equal(t, uint16(privateInboundPortHTTP), info.HTTPPort, "HTTP port must match bound port")
snap := mgr.Snapshot()
require.Len(t, snap, 1, "snapshot must contain exactly one entry")
assert.Equal(t, info, snap[accountID], "snapshot entry must equal direct lookup")
_, ok = mgr.ListenerInfo(types.AccountID("missing"))
assert.False(t, ok, "ListenerInfo must report ok=false for unknown accounts")
}
// TestInboundManager_NilManagerSafe ensures the observability accessors
// are safe to call when --private-inbound is off (nil manager).
func TestInboundManager_NilManagerSafe(t *testing.T) {
var mgr *inboundManager
_, ok := mgr.ListenerInfo("anything")
assert.False(t, ok, "nil manager must return ok=false")
assert.Nil(t, mgr.Snapshot(), "nil manager must return nil snapshot")
}
// TestInboundManager_ConcurrentAddRemove pounds AddRoute / RemoveRoute
// from multiple goroutines to expose any locking gaps.
func TestInboundManager_ConcurrentAddRemove(t *testing.T) {
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
accountID := types.AccountID("acct-1")
const workers = 32
const iterations = 50
var wg sync.WaitGroup
wg.Add(workers)
for i := 0; i < workers; i++ {
go func(idx int) {
defer wg.Done()
host := nbtcp.SNIHost("example.test")
svc := types.ServiceID("svc")
route := nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountID, ServiceID: svc, Domain: "example.test"}
for j := 0; j < iterations; j++ {
mgr.AddRoute(accountID, host, route)
mgr.RemoveRoute(accountID, host, svc)
}
}(i)
}
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(10 * time.Second):
t.Fatal("concurrent add/remove timed out")
}
}
// TestFeedRouterFromListener_DeliversConnectionToHandler validates the
// per-account inbound chain end-to-end with a loopback listener
// substituted for the embedded netstack: a TCP connection arriving at
// the plain listener flows through feedRouterFromListener, the router's
// peek-and-dispatch, the wrapped HTTP server, and reaches the user
// handler. If the embedded netstack is delivering connections at all,
// this is the path they take. Failures localise to wiring bugs in the
// proxy, not the netstack.
func TestFeedRouterFromListener_DeliversConnectionToHandler(t *testing.T) {
logger := quietLogger()
hits := make(chan string, 1)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hits <- r.Host
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("served"))
})
plainLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "plain loopback bind must succeed")
t.Cleanup(func() { _ = plainLn.Close() })
router := nbtcp.NewRouter(logger, nil, &fakeAddr{addr: "127.0.0.1:0"}, nbtcp.WithPlainHTTP(plainLn.Addr()))
httpServer := &http.Server{Handler: handler, ReadHeaderTimeout: time.Second}
t.Cleanup(func() { _ = httpServer.Close() })
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
go func() { _ = httpServer.Serve(router.HTTPListenerPlain()) }()
go feedRouterFromListener(ctx, plainLn, router, logger, types.AccountID("acct-1"))
conn, err := net.DialTimeout("tcp", plainLn.Addr().String(), 2*time.Second)
require.NoError(t, err, "must connect to the plain listener")
t.Cleanup(func() { _ = conn.Close() })
_, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: app.example\r\nConnection: close\r\n\r\n"))
require.NoError(t, err, "request write must succeed")
resp, err := http.ReadResponse(bufioReader(conn), nil)
require.NoError(t, err, "must read response from server")
t.Cleanup(func() { _ = resp.Body.Close() })
assert.Equal(t, http.StatusOK, resp.StatusCode, "handler must be reached")
select {
case host := <-hits:
assert.Equal(t, "app.example", host, "handler must observe the request Host")
case <-time.After(2 * time.Second):
t.Fatal("handler was not invoked — connection did not flow through router → http server")
}
}
// TestFeedRouterFromListener_DispatchesTLSToTLSChannel verifies that a
// TLS ClientHello arriving on the plain listener is detected by the
// router peek and re-dispatched to the TLS channel — the cross-channel
// fallback the inbound stack relies on for HTTPS-on-:80 testing.
func TestFeedRouterFromListener_DispatchesTLSToTLSChannel(t *testing.T) {
logger := quietLogger()
hits := make(chan string, 1)
tlsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hits <- r.Host
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("served-tls"))
})
plainLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "plain loopback bind must succeed")
t.Cleanup(func() { _ = plainLn.Close() })
tlsLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "tls loopback bind must succeed")
t.Cleanup(func() { _ = tlsLn.Close() })
router := nbtcp.NewRouter(logger, nil, tlsLn.Addr(), nbtcp.WithPlainHTTP(plainLn.Addr()))
tlsConfig := selfSignedTLSConfig(t)
httpsServer := &http.Server{
Handler: tlsHandler,
TLSConfig: tlsConfig,
ReadHeaderTimeout: time.Second,
}
t.Cleanup(func() { _ = httpsServer.Close() })
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
go func() { _ = httpsServer.ServeTLS(router.HTTPListener(), "", "") }()
go feedRouterFromListener(ctx, plainLn, router, logger, types.AccountID("acct-tls"))
tlsConn, err := tls.Dial("tcp", plainLn.Addr().String(), &tls.Config{InsecureSkipVerify: true}) //nolint:gosec
require.NoError(t, err, "TLS dial against the plain listener must succeed (cross-channel)")
t.Cleanup(func() { _ = tlsConn.Close() })
req, err := http.NewRequest(http.MethodGet, "https://app.example/", nil)
require.NoError(t, err)
require.NoError(t, req.Write(tlsConn), "TLS request write must succeed")
resp, err := http.ReadResponse(bufioReader(tlsConn), req)
require.NoError(t, err, "must read TLS response")
t.Cleanup(func() { _ = resp.Body.Close() })
assert.Equal(t, http.StatusOK, resp.StatusCode, "TLS handler must be reached")
select {
case host := <-hits:
assert.Equal(t, "app.example", host, "TLS handler must observe the request Host")
case <-time.After(2 * time.Second):
t.Fatal("TLS handler was not invoked — peek/dispatch path is broken")
}
}
func selfSignedTLSConfig(t *testing.T) *tls.Config {
t.Helper()
cert, err := tls.X509KeyPair(testCertPEM, testKeyPEM)
require.NoError(t, err, "load static self-signed cert")
return &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12} //nolint:gosec
}
// testCertPEM / testKeyPEM are a minimal RSA self-signed cert for
// 127.0.0.1 — only used by tests that need a working TLS handshake.
var testCertPEM = []byte(`-----BEGIN CERTIFICATE-----
MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw
DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow
EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d
7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B
5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr
BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1
NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l
Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc
6MF9+Yw1Yy0t
-----END CERTIFICATE-----`)
var testKeyPEM = []byte(`-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
-----END EC PRIVATE KEY-----`)

View File

@@ -1,47 +0,0 @@
package auth
import (
"context"
"net/netip"
)
// PeerIdentity describes the locally-known facts about a peer reachable on
// the proxy's per-account WireGuard listener. Phase 3 fills PubKey, TunnelIP
// and FQDN from the embedded client's peerstore. UserID, Email and Groups
// stay zero in V1 — full identity still travels through ValidateTunnelPeer.
// Phase V2 will populate them once RemotePeerConfig carries user identity.
type PeerIdentity struct {
PubKey string
TunnelIP netip.Addr
FQDN string
// V2 fields (zero in V1).
UserID string
Email string
Groups []string
}
// TunnelLookupFunc resolves a tunnel IP to a peer identity using locally
// available peerstore data. ok=false means the IP is not in the calling
// account's roster.
type TunnelLookupFunc func(ip netip.Addr) (PeerIdentity, bool)
type tunnelLookupContextKey struct{}
// WithTunnelLookup attaches a per-account peerstore lookup function to
// the request context. The auth middleware calls this lookup before
// hitting management's ValidateTunnelPeer to short-circuit unknown IPs
// and to skip the RPC for already-cached identities.
func WithTunnelLookup(ctx context.Context, lookup TunnelLookupFunc) context.Context {
if lookup == nil {
return ctx
}
return context.WithValue(ctx, tunnelLookupContextKey{}, lookup)
}
// TunnelLookupFromContext returns the peerstore lookup attached to ctx,
// or nil when the request did not arrive on a per-account listener.
func TunnelLookupFromContext(ctx context.Context) TunnelLookupFunc {
v, _ := ctx.Value(tunnelLookupContextKey{}).(TunnelLookupFunc)
return v
}

View File

@@ -36,7 +36,6 @@ type authenticator interface {
// SessionValidator validates session tokens and checks user access permissions.
type SessionValidator interface {
ValidateSession(ctx context.Context, in *proto.ValidateSessionRequest, opts ...grpc.CallOption) (*proto.ValidateSessionResponse, error)
ValidateTunnelPeer(ctx context.Context, in *proto.ValidateTunnelPeerRequest, opts ...grpc.CallOption) (*proto.ValidateTunnelPeerResponse, error)
}
// Scheme defines an authentication mechanism for a domain.
@@ -57,21 +56,12 @@ type DomainConfig struct {
AccountID types.AccountID
ServiceID types.ServiceID
IPRestrictions *restrict.Filter
// Private routes the domain through ValidateTunnelPeer; failure → 403.
Private bool
}
type validationResult struct {
UserID string
UserEmail string
Valid bool
DeniedReason string
Groups []string
// GroupNames carries the human-readable display names for Groups,
// ordered identically (positional pairing). May be shorter than
// Groups for tokens minted before names were embedded; the consumer
// falls back to ids for missing positions.
GroupNames []string
}
// Middleware applies per-domain authentication and IP restriction checks.
@@ -81,7 +71,6 @@ type Middleware struct {
logger *log.Logger
sessionValidator SessionValidator
geo restrict.GeoResolver
tunnelCache *tunnelValidationCache
}
// NewMiddleware creates a new authentication middleware. The sessionValidator is
@@ -95,7 +84,6 @@ func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator, geo re
logger: logger,
sessionValidator: sessionValidator,
geo: geo,
tunnelCache: newTunnelValidationCache(),
}
}
@@ -123,15 +111,6 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
return
}
// Private services bypass operator schemes and gate on tunnel peer.
if config.Private {
if mw.forwardWithTunnelPeer(w, r, host, config, next) {
return
}
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
// Domains with no authentication schemes pass through after IP checks.
if len(config.Schemes) == 0 {
next.ServeHTTP(w, r)
@@ -150,54 +129,10 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
return
}
if mw.forwardWithTunnelPeer(w, r, host, config, next) {
return
}
if mw.blockOIDCOnPlainHTTP(w, r, config) {
return
}
mw.authenticateWithSchemes(w, r, host, config)
})
}
// requestIsPlainHTTP reports whether the request arrived without TLS.
// Used to gate cookie-on-plain warnings and the OIDC plain-HTTP block.
func requestIsPlainHTTP(r *http.Request) bool {
return r.TLS == nil
}
// hasOIDCScheme reports whether any of the configured schemes requires
// TLS to round-trip safely with an external IdP.
func hasOIDCScheme(schemes []Scheme) bool {
for _, s := range schemes {
if s.Type() == auth.MethodOIDC {
return true
}
}
return false
}
// blockOIDCOnPlainHTTP fails fast when an OIDC-configured domain is hit
// over plain HTTP. Most IdPs reject http:// redirect URIs, so surfacing
// the misconfiguration here yields a clearer error than the IdP's
// "invalid redirect_uri" round-trip.
func (mw *Middleware) blockOIDCOnPlainHTTP(w http.ResponseWriter, r *http.Request, config DomainConfig) bool {
if !requestIsPlainHTTP(r) {
return false
}
if !hasOIDCScheme(config.Schemes) {
return false
}
mw.logger.WithFields(log.Fields{
"host": r.Host,
"remote": r.RemoteAddr,
}).Warn("OIDC scheme reached on plain HTTP path; rejecting with 400 — use port 443")
http.Error(w, "OIDC requires TLS — use port 443", http.StatusBadRequest)
return true
}
func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) {
mw.domainsMux.RLock()
defer mw.domainsMux.RUnlock()
@@ -227,17 +162,7 @@ func (mw *Middleware) checkIPRestrictions(w http.ResponseWriter, r *http.Request
return false
}
var verdict restrict.Verdict
if types.IsOverlayOrigin(r.Context()) {
// Geo/CrowdSec checks don't apply over the WireGuard overlay:
// the source address is always inside the NetBird CGNAT range,
// which is never in a GeoIP database or a CrowdSec decision
// list. Enforcing them here would either no-op (best case) or
// fail-closed when the geo database is missing.
verdict = config.IPRestrictions.CheckCIDR(clientIP)
} else {
verdict = config.IPRestrictions.Check(clientIP, mw.geo)
}
verdict := config.IPRestrictions.Check(clientIP, mw.geo)
if verdict == restrict.Allow {
return true
}
@@ -321,111 +246,18 @@ func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Re
if err != nil {
return false
}
userID, email, method, groups, groupNames, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey)
userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey)
if err != nil {
return false
}
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetUserID(userID)
cd.SetUserEmail(email)
cd.SetUserGroups(groups)
cd.SetUserGroupNames(groupNames)
cd.SetAuthMethod(method)
}
next.ServeHTTP(w, r)
return true
}
// forwardWithTunnelPeer is the OIDC fast-path for requests originating on the
// netbird mesh. When the source IP belongs to a private/CGNAT range the proxy
// asks management to resolve it to a peer/user and to gate by the service's
// distribution_groups. On success the proxy installs the freshly minted JWT
// as a session cookie, sets UserID + Method=oidc on the captured data, and
// forwards directly — operators see the same access-log shape as if the user
// had completed an OIDC redirect. Any failure (private-range mismatch,
// management unreachable, peer unknown, user not in group) returns false so
// the caller falls back to the existing OIDC scheme dispatch.
//
// Phase 3 adds a local-first short-circuit: when the request arrived on a
// per-account inbound listener the context carries a peerstore lookup
// (TunnelLookupFromContext). If the lookup says the IP isn't in the account's
// roster the proxy denies fast without calling management. If the lookup
// confirms a known peer the RPC still runs for the user-identity tail
// (UserID + group access), but its result is cached for tunnelCacheTTL so
// repeat requests skip management entirely.
func (mw *Middleware) forwardWithTunnelPeer(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
if mw.sessionValidator == nil {
return false
}
clientIP := mw.resolveClientIP(r)
if !clientIP.IsValid() {
return false
}
if !isTunnelSourceIP(clientIP) {
return false
}
if lookup := TunnelLookupFromContext(r.Context()); lookup != nil {
if _, ok := lookup(clientIP); !ok {
mw.logger.WithFields(log.Fields{
"host": host,
"remote": clientIP,
}).Debug("local peerstore: tunnel IP not in account roster; denying without RPC")
return false
}
}
resp, _, err := mw.tunnelCache.fetch(r.Context(), tunnelCacheKey{
accountID: config.AccountID,
tunnelIP: clientIP,
domain: host,
}, mw.validateTunnelPeer)
if err != nil {
mw.logger.WithError(err).Debug("ValidateTunnelPeer failed; falling back to OIDC")
return false
}
if !resp.GetValid() || resp.GetSessionToken() == "" {
return false
}
setSessionCookie(w, resp.GetSessionToken(), config.SessionExpiration)
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(resp.GetUserId())
cd.SetUserEmail(resp.GetUserEmail())
cd.SetUserGroups(resp.GetPeerGroupIds())
cd.SetUserGroupNames(resp.GetPeerGroupNames())
cd.SetAuthMethod(auth.MethodOIDC.String())
}
next.ServeHTTP(w, r)
return true
}
// validateTunnelPeer adapts the SessionValidator interface to the cache's
// validateTunnelPeerFn signature.
func (mw *Middleware) validateTunnelPeer(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
return mw.sessionValidator.ValidateTunnelPeer(ctx, req)
}
// cgnatPrefix covers RFC 6598 100.64.0.0/10, the CGNAT block NetBird
// allocates tunnel addresses from by default. IsPrivate() doesn't include
// it, so we check it explicitly.
var cgnatPrefix = netip.MustParsePrefix("100.64.0.0/10")
// isTunnelSourceIP reports whether ip falls within an address range typical
// of NetBird tunnels: RFC1918 private space, IPv6 ULA, or CGNAT 100.64/10
// (NetBird's default range). Loopback and link-local are excluded — the
// fast-path is meant for peer-to-peer mesh traffic, not localhost.
func isTunnelSourceIP(ip netip.Addr) bool {
if !ip.IsValid() || ip.IsLoopback() || ip.IsLinkLocalUnicast() {
return false
}
if ip.IsPrivate() {
return true
}
return cgnatPrefix.Contains(ip)
}
// forwardWithHeaderAuth checks for a Header auth scheme. If the header validates,
// the request is forwarded directly (no redirect), which is important for API clients.
func (mw *Middleware) forwardWithHeaderAuth(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
@@ -454,7 +286,7 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, auth.MethodHeader)
if err != nil {
setHeaderCapturedData(r.Context(), "", "", nil, nil)
setHeaderCapturedData(r.Context(), "")
status := http.StatusBadRequest
msg := "invalid session token"
if errors.Is(err, errValidationUnavailable) {
@@ -466,7 +298,7 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
}
if !result.Valid {
setHeaderCapturedData(r.Context(), result.UserID, result.UserEmail, result.Groups, result.GroupNames)
setHeaderCapturedData(r.Context(), result.UserID)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return true
}
@@ -474,9 +306,6 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
setSessionCookie(w, token, config.SessionExpiration)
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetUserID(result.UserID)
cd.SetUserEmail(result.UserEmail)
cd.SetUserGroups(result.Groups)
cd.SetUserGroupNames(result.GroupNames)
cd.SetAuthMethod(auth.MethodHeader.String())
}
@@ -486,7 +315,7 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
func (mw *Middleware) handleHeaderAuthError(w http.ResponseWriter, r *http.Request, err error) bool {
if errors.Is(err, ErrHeaderAuthFailed) {
setHeaderCapturedData(r.Context(), "", "", nil, nil)
setHeaderCapturedData(r.Context(), "")
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return true
}
@@ -498,7 +327,7 @@ func (mw *Middleware) handleHeaderAuthError(w http.ResponseWriter, r *http.Reque
return true
}
func setHeaderCapturedData(ctx context.Context, userID, userEmail string, groups, groupNames []string) {
func setHeaderCapturedData(ctx context.Context, userID string) {
cd := proxy.CapturedDataFromContext(ctx)
if cd == nil {
return
@@ -506,9 +335,6 @@ func setHeaderCapturedData(ctx context.Context, userID, userEmail string, groups
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(auth.MethodHeader.String())
cd.SetUserID(userID)
cd.SetUserEmail(userEmail)
cd.SetUserGroups(groups)
cd.SetUserGroupNames(groupNames)
}
// authenticateWithSchemes tries each configured auth scheme in order.
@@ -579,9 +405,6 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetUserEmail(result.UserEmail)
cd.SetUserGroups(result.Groups)
cd.SetUserGroupNames(result.GroupNames)
cd.SetAuthMethod(scheme.Type().String())
requestID = cd.GetRequestID()
}
@@ -596,9 +419,6 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetUserEmail(result.UserEmail)
cd.SetUserGroups(result.Groups)
cd.SetUserGroupNames(result.GroupNames)
cd.SetAuthMethod(scheme.Type().String())
}
redirectURL := stripSessionTokenParam(r.URL)
@@ -634,9 +454,12 @@ func wasCredentialSubmitted(r *http.Request, method auth.Method) bool {
return false
}
// AddDomain registers authentication schemes for the given domain. With schemes a valid session public key is required.
// private=true forces ValidateTunnelPeer enforcement (403 on failure) regardless of the schemes list.
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID, ipRestrictions *restrict.Filter, private bool) error {
// AddDomain registers authentication schemes for the given domain.
// If schemes are provided, a valid session public key is required to sign/verify
// session JWTs. Returns an error if the key is missing or invalid.
// Callers must not serve the domain if this returns an error, to avoid
// exposing an unauthenticated service.
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID, ipRestrictions *restrict.Filter) error {
if len(schemes) == 0 {
mw.domainsMux.Lock()
defer mw.domainsMux.Unlock()
@@ -644,7 +467,6 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st
AccountID: accountID,
ServiceID: serviceID,
IPRestrictions: ipRestrictions,
Private: private,
}
return nil
}
@@ -666,7 +488,6 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st
AccountID: accountID,
ServiceID: serviceID,
IPRestrictions: ipRestrictions,
Private: private,
}
return nil
}
@@ -697,25 +518,18 @@ func (mw *Middleware) validateSessionToken(ctx context.Context, host, token stri
}).Debug("Session validation denied")
return &validationResult{
UserID: resp.UserId,
UserEmail: resp.GetUserEmail(),
Valid: false,
DeniedReason: resp.DeniedReason,
}, nil
}
return &validationResult{
UserID: resp.UserId,
UserEmail: resp.GetUserEmail(),
Valid: true,
Groups: resp.GetPeerGroupIds(),
GroupNames: resp.GetPeerGroupNames(),
}, nil
return &validationResult{UserID: resp.UserId, Valid: true}, nil
}
userID, email, _, groups, groupNames, err := auth.ValidateSessionJWT(token, host, publicKey)
userID, _, err := auth.ValidateSessionJWT(token, host, publicKey)
if err != nil {
return nil, err
}
return &validationResult{UserID: userID, UserEmail: email, Valid: true, Groups: groups, GroupNames: groupNames}, nil
return &validationResult{UserID: userID, Valid: true}, nil
}
// stripSessionTokenParam returns the request URI with the session_token query

View File

@@ -4,7 +4,6 @@ import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/tls"
"encoding/base64"
"errors"
"net/http"
@@ -24,7 +23,6 @@ import (
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -64,7 +62,7 @@ func TestAddDomain_ValidKey(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)
require.NoError(t, err)
mw.domainsMux.RLock()
@@ -81,7 +79,7 @@ func TestAddDomain_EmptyKey(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "", nil, false)
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "", nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid session public key size")
@@ -95,7 +93,7 @@ func TestAddDomain_InvalidBase64(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "", nil, false)
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "", nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "decode session public key")
@@ -110,7 +108,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort"))
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "", nil, false)
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "", nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid session public key size")
@@ -123,7 +121,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil, false)
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil)
require.NoError(t, err, "domains with no auth schemes should not require a key")
mw.domainsMux.RLock()
@@ -139,8 +137,8 @@ func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "", nil))
mw.domainsMux.RLock()
config := mw.domains["example.com"]
@@ -156,7 +154,7 @@ func TestRemoveDomain(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
mw.RemoveDomain("example.com")
@@ -180,7 +178,7 @@ func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil, false))
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil))
handler := mw.Protect(newPassthroughHandler())
@@ -197,7 +195,7 @@ func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -218,7 +216,7 @@ func TestProtect_HostWithPortIsMatched(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -239,9 +237,9 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
require.NoError(t, err)
capturedData := proxy.NewCapturedData("")
@@ -264,48 +262,15 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
assert.Equal(t, "authenticated", rec.Body.String())
}
// TestProtect_SessionCookieGroupsPropagate verifies the cookie path lifts the
// JWT's groups claim into CapturedData so policy-aware middlewares can
// authorise without an extra management round-trip.
func TestProtect_SessionCookieGroupsPropagate(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
groups := []string{"engineering", "sre"}
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, groups, nil, time.Hour)
require.NoError(t, err)
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cd := proxy.CapturedDataFromContext(r.Context())
require.NotNil(t, cd, "captured data must be present in request context")
assert.Equal(t, "test-user", cd.GetUserID())
assert.Equal(t, groups, cd.GetUserGroups(), "JWT groups claim must propagate to CapturedData")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "request with valid groups-bearing cookie must succeed")
assert.Equal(t, groups, capturedData.GetUserGroups(), "CapturedData groups must be retained after handler completes")
}
func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
// Sign a token that expired 1 second ago.
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, -time.Second)
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, -time.Second)
require.NoError(t, err)
var backendCalled bool
@@ -328,10 +293,10 @@ func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
// Token signed for a different domain audience.
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "other.com", auth.MethodPIN, nil, nil, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "other.com", auth.MethodPIN, time.Hour)
require.NoError(t, err)
var backendCalled bool
@@ -355,10 +320,10 @@ func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
kp2 := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
// Token signed with a different private key.
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour)
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
require.NoError(t, err)
var backendCalled bool
@@ -380,7 +345,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour)
require.NoError(t, err)
scheme := &stubScheme{
@@ -392,7 +357,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -445,7 +410,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
handler := mw.Protect(newPassthroughHandler())
@@ -462,7 +427,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "", "example.com", auth.MethodPassword, nil, nil, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour)
require.NoError(t, err)
// First scheme (PIN) always fails, second scheme (password) succeeds.
@@ -481,7 +446,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
return "", "password", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "", nil))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -511,7 +476,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
return "invalid-jwt-token", "", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
handler := mw.Protect(newPassthroughHandler())
@@ -535,7 +500,7 @@ func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
key := base64.StdEncoding.EncodeToString(randomBytes)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "", nil, false)
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "", nil)
require.NoError(t, err, "any 32-byte key should be accepted at registration time")
}
@@ -544,10 +509,10 @@ func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
// Attempt to overwrite with an invalid key.
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "", nil, false)
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "", nil)
require.Error(t, err)
// The original valid config should still be intact.
@@ -571,7 +536,7 @@ func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
@@ -598,7 +563,7 @@ func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
return "", "password", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
@@ -625,7 +590,7 @@ func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
@@ -713,7 +678,7 @@ func TestCheckIPRestrictions_UnparseableAddress(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}), false)
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}))
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -749,7 +714,7 @@ func TestCheckIPRestrictions_UsesCapturedDataClientIP(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"203.0.113.0/24"}}), false)
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"203.0.113.0/24"}}))
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -790,7 +755,7 @@ func TestCheckIPRestrictions_NilGeoWithCountryRules(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter(restrict.FilterConfig{AllowedCountries: []string{"US"}}), false)
restrict.ParseFilter(restrict.FilterConfig{AllowedCountries: []string{"US"}}))
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -805,69 +770,6 @@ func TestCheckIPRestrictions_NilGeoWithCountryRules(t *testing.T) {
assert.Equal(t, http.StatusForbidden, rr.Code, "country restrictions with nil geo must deny")
}
// TestCheckIPRestrictions_OverlayOriginSkipsCountryRules covers the
// inbound (WG) listener path: requests stamped with WithOverlayOrigin
// must skip country lookups, even when no geo database is configured.
// Without this short-circuit the inbound flow would fail-closed for
// every overlay request whenever country rules are configured.
func TestCheckIPRestrictions_OverlayOriginSkipsCountryRules(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter(restrict.FilterConfig{
AllowedCIDRs: []string{"100.64.0.0/10"},
AllowedCountries: []string{"US"},
}), false)
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.RemoteAddr = "100.64.5.6:5000"
req.Host = "example.com"
req = req.WithContext(types.WithOverlayOrigin(req.Context()))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code,
"overlay-origin requests must not be denied by country rules they would fail without geo data")
// Sanity check: the same filter without the overlay flag denies (no geo,
// country allowlist active → DenyGeoUnavailable).
req2 := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req2.RemoteAddr = "100.64.5.6:5000"
req2.Host = "example.com"
rr2 := httptest.NewRecorder()
handler.ServeHTTP(rr2, req2)
assert.Equal(t, http.StatusForbidden, rr2.Code,
"WAN-origin requests must still hit the full Check path and be denied without geo data")
}
// TestCheckIPRestrictions_OverlayOriginRespectsCIDR confirms CIDR
// rules still apply on the overlay path so operators retain a way to
// scope private services to specific peer subnets.
func TestCheckIPRestrictions_OverlayOriginRespectsCIDR(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"100.64.0.0/16"}}), false)
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.RemoteAddr = "100.65.5.6:5000" // outside 100.64.0.0/16
req.Host = "example.com"
req = req.WithContext(types.WithOverlayOrigin(req.Context()))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusForbidden, rr.Code,
"CIDR rules must still apply on the overlay path")
}
func TestProtect_OIDCOnlyRedirectsDirectly(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
@@ -879,12 +781,11 @@ func TestProtect_OIDCOnlyRedirectsDirectly(t *testing.T) {
return "", oidcURL, nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
req.TLS = &tls.ConnectionState{}
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
@@ -908,12 +809,11 @@ func TestProtect_OIDCWithOtherMethodShowsLoginPage(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{oidcScheme, pinScheme}, kp.PublicKey, time.Hour, "", "", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{oidcScheme, pinScheme}, kp.PublicKey, time.Hour, "", "", nil))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
req.TLS = &tls.ConnectionState{}
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
@@ -934,7 +834,7 @@ func (m *mockAuthenticator) Authenticate(ctx context.Context, in *proto.Authenti
// returns a signed session token when the expected header value is provided.
func newHeaderSchemeWithToken(t *testing.T, kp *sessionkey.KeyPair, headerName, expectedValue string) Header {
t.Helper()
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "", "example.com", auth.MethodHeader, nil, nil, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour)
require.NoError(t, err)
mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
@@ -952,7 +852,7 @@ func TestProtect_HeaderAuth_ForwardsOnSuccess(t *testing.T) {
kp := generateTestKeyPair(t)
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
var backendCalled bool
capturedData := proxy.NewCapturedData("")
@@ -995,7 +895,7 @@ func TestProtect_HeaderAuth_MissingHeaderFallsThrough(t *testing.T) {
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
// Also add a PIN scheme so we can verify fallthrough behavior.
pinScheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr, pinScheme}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr, pinScheme}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
handler := mw.Protect(newPassthroughHandler())
@@ -1015,7 +915,7 @@ func TestProtect_HeaderAuth_WrongValueReturns401(t *testing.T) {
return &proto.AuthenticateResponse{Success: false}, nil
}}
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
@@ -1038,7 +938,7 @@ func TestProtect_HeaderAuth_InfraErrorReturns502(t *testing.T) {
return nil, errors.New("gRPC unavailable")
}}
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
handler := mw.Protect(newPassthroughHandler())
@@ -1055,7 +955,7 @@ func TestProtect_HeaderAuth_SubsequentRequestUsesSessionCookie(t *testing.T) {
kp := generateTestKeyPair(t)
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
@@ -1106,7 +1006,7 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
ha := req.GetHeaderAuth()
if ha != nil && accepted[ha.GetHeaderValue()] {
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "", "example.com", auth.MethodHeader, nil, nil, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour)
require.NoError(t, err)
return &proto.AuthenticateResponse{Success: true, SessionToken: token}, nil
}
@@ -1115,7 +1015,7 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
// Single Header scheme (as if one entry existed), but the mock checks both values.
hdr := NewHeader(mock, "svc1", "acc1", "Authorization")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
var backendCalled bool
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -1159,71 +1059,3 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
assert.False(t, backendCalled, "unknown token should be rejected")
})
}
// TestProtect_OIDCOnPlainHTTP_BlockedWith400 verifies that when an OIDC
// scheme is configured and the request arrived without TLS, the middleware
// short-circuits with a 400 instead of dispatching to the IdP redirect.
func TestProtect_OIDCOnPlainHTTP_BlockedWith400(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{
method: auth.MethodOIDC,
authFn: func(_ *http.Request) (string, string, error) {
return "", "https://idp.example.com/authorize", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code, "OIDC over plain HTTP should be rejected")
assert.Contains(t, rec.Body.String(), "OIDC requires TLS", "response body should explain the rejection")
}
// TestProtect_OIDCOverTLS_NotBlocked confirms the same configuration works
// over TLS — the block only fires on plain HTTP.
func TestProtect_OIDCOverTLS_NotBlocked(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{
method: auth.MethodOIDC,
authFn: func(_ *http.Request) (string, string, error) {
return "", "https://idp.example.com/authorize", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
req.TLS = &tls.ConnectionState{}
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusFound, rec.Code, "OIDC over TLS should redirect to IdP")
}
// TestProtect_NonOIDCSchemes_PlainHTTP_NotBlocked confirms that the OIDC
// block only fires when an OIDC scheme is configured. PIN-only domains
// pass through normally on plain HTTP.
func TestProtect_NonOIDCSchemes_PlainHTTP_NotBlocked(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnauthorized, rec.Code, "PIN-only domain should serve the login page on plain HTTP")
}

View File

@@ -1,171 +0,0 @@
package auth
import (
"context"
"net/netip"
"sync"
"time"
"golang.org/x/sync/singleflight"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// tunnelCacheTTL caps how long a positive ValidateTunnelPeer result is
// reused before re-fetching from management. 5 minutes balances freshness
// against management load on busy mesh networks.
const tunnelCacheTTL = 300 * time.Second
// tunnelCachePerAccount caps the number of cached identities per account.
// Bounded eviction avoids memory growth in pathological cases (huge peer
// roster, brief request bursts) while staying generous for normal use.
const tunnelCachePerAccount = 1024
// tunnelCacheKey identifies a cached entry by tunnel IP and originating
// account. Domain is part of the value, not the key, because the
// management response is per (account, IP) — domain only gates whether a
// re-fetch is needed if the operator is accessing a different service.
type tunnelCacheKey struct {
accountID types.AccountID
tunnelIP netip.Addr
domain string
}
// tunnelCacheEntry stores a positive validation response with the time it
// was minted. Entries past tunnelCacheTTL are treated as misses.
type tunnelCacheEntry struct {
resp *proto.ValidateTunnelPeerResponse
cachedAt time.Time
}
// tunnelValidationCache memoizes ValidateTunnelPeer responses keyed by
// (accountID, tunnelIP, domain). Only successful, valid responses are
// cached — denials skip the cache so policy changes apply immediately.
// Single-flight de-duplicates concurrent fetches for the same key so a
// burst of cold requests collapses into a single RPC.
type tunnelValidationCache struct {
mu sync.Mutex
entries map[types.AccountID]*accountBucket
flight singleflight.Group
ttl time.Duration
maxSize int
now func() time.Time
}
// accountBucket holds the cached entries for a single account, with a
// FIFO eviction queue used when the bucket exceeds maxSize.
type accountBucket struct {
items map[tunnelCacheKey]tunnelCacheEntry
order []tunnelCacheKey
}
// newTunnelValidationCache constructs a cache with default TTL and bounds.
func newTunnelValidationCache() *tunnelValidationCache {
return &tunnelValidationCache{
entries: make(map[types.AccountID]*accountBucket),
ttl: tunnelCacheTTL,
maxSize: tunnelCachePerAccount,
now: time.Now,
}
}
// get returns a cached response for the key, or nil when missing or
// expired. Expired entries are evicted lazily on read.
func (c *tunnelValidationCache) get(key tunnelCacheKey) *proto.ValidateTunnelPeerResponse {
c.mu.Lock()
defer c.mu.Unlock()
bucket, ok := c.entries[key.accountID]
if !ok {
return nil
}
entry, ok := bucket.items[key]
if !ok {
return nil
}
if c.now().Sub(entry.cachedAt) > c.ttl {
delete(bucket.items, key)
bucket.order = removeKey(bucket.order, key)
return nil
}
return entry.resp
}
// put records a positive response under the key. Evicts the oldest entry
// in the account's bucket when the bound is exceeded.
func (c *tunnelValidationCache) put(key tunnelCacheKey, resp *proto.ValidateTunnelPeerResponse) {
c.mu.Lock()
defer c.mu.Unlock()
bucket, ok := c.entries[key.accountID]
if !ok {
bucket = &accountBucket{items: make(map[tunnelCacheKey]tunnelCacheEntry)}
c.entries[key.accountID] = bucket
}
if _, exists := bucket.items[key]; !exists {
bucket.order = append(bucket.order, key)
}
bucket.items[key] = tunnelCacheEntry{resp: resp, cachedAt: c.now()}
for len(bucket.order) > c.maxSize {
oldest := bucket.order[0]
bucket.order = bucket.order[1:]
delete(bucket.items, oldest)
}
}
// removeKey drops the first occurrence of needle from order. The cache
// uses small slices so a linear scan is cheaper than a map+slice combo.
func removeKey(order []tunnelCacheKey, needle tunnelCacheKey) []tunnelCacheKey {
for i, k := range order {
if k == needle {
return append(order[:i], order[i+1:]...)
}
}
return order
}
// flightKey turns a cache key into a single-flight string. AccountID and
// IP isolation by themselves are insufficient because different domains
// for the same peer/account may have different group access.
func flightKey(key tunnelCacheKey) string {
return string(key.accountID) + "|" + key.tunnelIP.String() + "|" + key.domain
}
// validateTunnelPeerFn is the RPC entry point the cache wraps. It matches
// the SessionValidator.ValidateTunnelPeer signature without exposing the
// gRPC option variadic, since callers don't need it on the cache hot path.
type validateTunnelPeerFn func(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error)
// fetch returns a cached response when present, otherwise calls validate
// under single-flight and caches the result. Denied responses pass
// through but are not cached so policy changes apply immediately.
func (c *tunnelValidationCache) fetch(ctx context.Context, key tunnelCacheKey, validate validateTunnelPeerFn) (*proto.ValidateTunnelPeerResponse, bool, error) {
if resp := c.get(key); resp != nil {
return resp, true, nil
}
flight := flightKey(key)
res, err, _ := c.flight.Do(flight, func() (any, error) {
if cached := c.get(key); cached != nil {
return cached, nil
}
resp, err := validate(ctx, &proto.ValidateTunnelPeerRequest{
TunnelIp: key.tunnelIP.String(),
Domain: key.domain,
})
if err != nil {
return nil, err
}
if resp.GetValid() && resp.GetSessionToken() != "" {
c.put(key, resp)
}
return resp, nil
})
if err != nil {
return nil, false, err
}
resp, _ := res.(*proto.ValidateTunnelPeerResponse)
return resp, false, nil
}

View File

@@ -1,171 +0,0 @@
package auth
import (
"context"
"net/netip"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
func newTestKey(account types.AccountID, ip string, domain string) tunnelCacheKey {
return tunnelCacheKey{
accountID: account,
tunnelIP: netip.MustParseAddr(ip),
domain: domain,
}
}
func TestTunnelCache_HitSkipsRPC(t *testing.T) {
cache := newTunnelValidationCache()
key := newTestKey("acct-1", "100.64.0.10", "svc.example")
var calls int32
validate := func(_ context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
atomic.AddInt32(&calls, 1)
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}, nil
}
resp, fromCache, err := cache.fetch(context.Background(), key, validate)
require.NoError(t, err)
require.NotNil(t, resp, "first fetch returns RPC response")
assert.False(t, fromCache, "first fetch must not be cached")
resp2, fromCache2, err := cache.fetch(context.Background(), key, validate)
require.NoError(t, err)
require.NotNil(t, resp2, "second fetch returns cached response")
assert.True(t, fromCache2, "second fetch must be served from cache")
assert.Equal(t, "user-1", resp2.GetUserId(), "cached response should preserve user identity")
assert.Equal(t, int32(1), atomic.LoadInt32(&calls), "validate should run exactly once with one cache hit")
}
func TestTunnelCache_ExpiredEntryRefetches(t *testing.T) {
cache := newTunnelValidationCache()
clock := time.Now()
cache.now = func() time.Time { return clock }
key := newTestKey("acct-1", "100.64.0.10", "svc.example")
var calls int32
validate := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
atomic.AddInt32(&calls, 1)
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok"}, nil
}
_, _, err := cache.fetch(context.Background(), key, validate)
require.NoError(t, err)
assert.Equal(t, int32(1), atomic.LoadInt32(&calls), "first fetch issues one RPC")
clock = clock.Add(tunnelCacheTTL + time.Second)
_, fromCache, err := cache.fetch(context.Background(), key, validate)
require.NoError(t, err)
assert.False(t, fromCache, "expired entry must miss the cache")
assert.Equal(t, int32(2), atomic.LoadInt32(&calls), "expired entry forces a re-fetch")
}
func TestTunnelCache_DeniedResponseNotCached(t *testing.T) {
cache := newTunnelValidationCache()
key := newTestKey("acct-1", "100.64.0.10", "svc.example")
var calls int32
validate := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
atomic.AddInt32(&calls, 1)
return &proto.ValidateTunnelPeerResponse{Valid: false, DeniedReason: "not_in_group"}, nil
}
for i := 0; i < 3; i++ {
_, _, err := cache.fetch(context.Background(), key, validate)
require.NoError(t, err, "fetch must not error on denied response")
}
assert.Equal(t, int32(3), atomic.LoadInt32(&calls), "denied responses bypass the cache so policy changes apply immediately")
}
func TestTunnelCache_ConcurrentColdHitsCoalesce(t *testing.T) {
cache := newTunnelValidationCache()
key := newTestKey("acct-1", "100.64.0.10", "svc.example")
gate := make(chan struct{})
var calls int32
validate := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
atomic.AddInt32(&calls, 1)
<-gate
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok"}, nil
}
const workers = 16
var wg sync.WaitGroup
wg.Add(workers)
results := make([]bool, workers)
for i := 0; i < workers; i++ {
go func(idx int) {
defer wg.Done()
resp, _, err := cache.fetch(context.Background(), key, validate)
results[idx] = err == nil && resp.GetValid()
}(i)
}
time.Sleep(20 * time.Millisecond)
close(gate)
wg.Wait()
for i, ok := range results {
assert.Truef(t, ok, "worker %d should observe a successful response", i)
}
assert.Equal(t, int32(1), atomic.LoadInt32(&calls), "single-flight must collapse concurrent cold fetches into one RPC")
}
func TestTunnelCache_PerAccountIsolation(t *testing.T) {
cache := newTunnelValidationCache()
keyA := newTestKey("acct-a", "100.64.0.10", "svc.example")
keyB := newTestKey("acct-b", "100.64.0.10", "svc.example")
var callsA, callsB int32
validateA := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
atomic.AddInt32(&callsA, 1)
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok-a", UserId: "user-a"}, nil
}
validateB := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
atomic.AddInt32(&callsB, 1)
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok-b", UserId: "user-b"}, nil
}
respA, _, err := cache.fetch(context.Background(), keyA, validateA)
require.NoError(t, err)
respB, _, err := cache.fetch(context.Background(), keyB, validateB)
require.NoError(t, err)
assert.Equal(t, "user-a", respA.GetUserId(), "account A response should belong to user-a")
assert.Equal(t, "user-b", respB.GetUserId(), "account B response must not be served from account A's cache")
assert.Equal(t, int32(1), atomic.LoadInt32(&callsA), "validateA called exactly once")
assert.Equal(t, int32(1), atomic.LoadInt32(&callsB), "validateB called exactly once")
}
func TestTunnelCache_BoundedSizeEvictsOldest(t *testing.T) {
cache := newTunnelValidationCache()
cache.maxSize = 2
validate := func(_ context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok-" + req.GetTunnelIp()}, nil
}
keys := []tunnelCacheKey{
newTestKey("acct-1", "100.64.0.10", "svc"),
newTestKey("acct-1", "100.64.0.11", "svc"),
newTestKey("acct-1", "100.64.0.12", "svc"),
}
for _, k := range keys {
_, _, err := cache.fetch(context.Background(), k, validate)
require.NoError(t, err)
}
assert.Nil(t, cache.get(keys[0]), "oldest key should be evicted past maxSize")
assert.NotNil(t, cache.get(keys[1]), "second-newest must remain cached")
assert.NotNil(t, cache.get(keys[2]), "newest must remain cached")
}

View File

@@ -1,325 +0,0 @@
package auth
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"net/netip"
"sync/atomic"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/shared/management/proto"
)
// stubSessionValidator records ValidateTunnelPeer calls and returns the
// pre-canned response. Counts let tests assert RPC traffic.
type stubSessionValidator struct {
respFn func(req *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse
respErr error
tunnelCalls atomic.Int32
}
func (s *stubSessionValidator) ValidateSession(_ context.Context, _ *proto.ValidateSessionRequest, _ ...grpc.CallOption) (*proto.ValidateSessionResponse, error) {
return &proto.ValidateSessionResponse{Valid: false}, nil
}
func (s *stubSessionValidator) ValidateTunnelPeer(_ context.Context, in *proto.ValidateTunnelPeerRequest, _ ...grpc.CallOption) (*proto.ValidateTunnelPeerResponse, error) {
s.tunnelCalls.Add(1)
if s.respErr != nil {
return nil, s.respErr
}
if s.respFn != nil {
return s.respFn(in), nil
}
return &proto.ValidateTunnelPeerResponse{Valid: false}, nil
}
func newTunnelMiddleware(t *testing.T, validator SessionValidator) *Middleware {
t.Helper()
mw := NewMiddleware(log.New(), validator, nil)
require.NoError(t, mw.AddDomain("svc.example", nil, "", 0, "acct-1", "svc-1", nil, false))
return mw
}
func newTunnelRequest(remoteAddr string) (*httptest.ResponseRecorder, *http.Request) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "https://svc.example/", nil)
r.Host = "svc.example"
r.RemoteAddr = remoteAddr
return w, r
}
// TestForwardWithTunnelPeer_LocalLookupUnknownIPDeniesFast verifies the
// short-circuit: a tunnel IP not in the account's roster never reaches
// management's ValidateTunnelPeer.
func TestForwardWithTunnelPeer_LocalLookupUnknownIPDeniesFast(t *testing.T) {
validator := &stubSessionValidator{}
mw := newTunnelMiddleware(t, validator)
lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) {
return PeerIdentity{}, false
})
w, r := newTunnelRequest("100.64.0.99:55555")
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
called := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true })
config, _ := mw.getDomainConfig("svc.example")
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
assert.False(t, handled, "unknown peer must fall through, not forward")
assert.False(t, called, "next handler must not run for unknown peer")
assert.Equal(t, int32(0), validator.tunnelCalls.Load(), "ValidateTunnelPeer must be skipped on local-lookup miss")
}
// TestForwardWithTunnelPeer_GroupsPropagateToCapturedData verifies the proxy
// surfaces the calling peer's group memberships from ValidateTunnelPeerResponse
// onto CapturedData so policy-aware middlewares can authorise without an
// extra management round-trip.
func TestForwardWithTunnelPeer_GroupsPropagateToCapturedData(t *testing.T) {
groups := []string{"engineering", "sre"}
validator := &stubSessionValidator{
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
return &proto.ValidateTunnelPeerResponse{
Valid: true,
SessionToken: "tok",
UserId: "user-1",
PeerGroupIds: groups,
}
},
}
mw := newTunnelMiddleware(t, validator)
w, r := newTunnelRequest("100.64.0.10:55555")
cd := proxy.NewCapturedData("")
r = r.WithContext(proxy.WithCapturedData(r.Context(), cd))
called := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true })
config, _ := mw.getDomainConfig("svc.example")
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
require.True(t, handled, "valid tunnel-peer response must forward")
require.True(t, called, "next handler must run")
assert.Equal(t, "user-1", cd.GetUserID(), "user id must propagate from tunnel-peer response")
assert.Equal(t, groups, cd.GetUserGroups(), "peer group IDs must propagate from tunnel-peer response")
}
// TestForwardWithTunnelPeer_LocalLookupKnownPeerStillRPCs verifies that a
// known tunnel IP still triggers ValidateTunnelPeer for the user-identity
// tail (UserID + group access). Phase 3 only short-circuits the deny path.
func TestForwardWithTunnelPeer_LocalLookupKnownPeerStillRPCs(t *testing.T) {
validator := &stubSessionValidator{
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}
},
}
mw := newTunnelMiddleware(t, validator)
knownIP := netip.MustParseAddr("100.64.0.10")
lookup := TunnelLookupFunc(func(ip netip.Addr) (PeerIdentity, bool) {
if ip == knownIP {
return PeerIdentity{PubKey: "pk", TunnelIP: ip, FQDN: "peer.netbird.cloud"}, true
}
return PeerIdentity{}, false
})
w, r := newTunnelRequest(knownIP.String() + ":55555")
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
called := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true })
config, _ := mw.getDomainConfig("svc.example")
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
assert.True(t, handled, "known peer with valid RPC response must forward")
assert.True(t, called, "next handler must run on success")
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "RPC must run for the user-identity tail when local lookup confirms the peer")
}
// TestForwardWithTunnelPeer_NoLookupKeepsLegacyPath ensures the existing
// behaviour stays intact on the host-level listener (no lookup attached).
func TestForwardWithTunnelPeer_NoLookupKeepsLegacyPath(t *testing.T) {
validator := &stubSessionValidator{
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}
},
}
mw := newTunnelMiddleware(t, validator)
w, r := newTunnelRequest("100.64.0.10:55555")
called := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true })
config, _ := mw.getDomainConfig("svc.example")
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
assert.True(t, handled, "host-level path forwards on positive RPC result")
assert.True(t, called, "next handler runs on host-level success")
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "host-level path always RPCs (Phase 3 unchanged)")
}
// TestForwardWithTunnelPeer_RPCErrorFallsThrough validates that an RPC
// failure still falls through to the next scheme (no false positive).
func TestForwardWithTunnelPeer_RPCErrorFallsThrough(t *testing.T) {
validator := &stubSessionValidator{respErr: errors.New("management down")}
mw := newTunnelMiddleware(t, validator)
knownIP := netip.MustParseAddr("100.64.0.10")
lookup := TunnelLookupFunc(func(ip netip.Addr) (PeerIdentity, bool) {
return PeerIdentity{TunnelIP: ip}, true
})
w, r := newTunnelRequest(knownIP.String() + ":55555")
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
config, _ := mw.getDomainConfig("svc.example")
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
assert.False(t, handled, "RPC error must let the caller try other schemes")
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "RPC was attempted exactly once")
}
// TestForwardWithTunnelPeer_CacheReusesPositiveResponse confirms the
// (account, IP, domain) cache prevents repeated RPCs for the same peer.
func TestForwardWithTunnelPeer_CacheReusesPositiveResponse(t *testing.T) {
validator := &stubSessionValidator{
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}
},
}
mw := newTunnelMiddleware(t, validator)
for i := 0; i < 4; i++ {
w, r := newTunnelRequest("100.64.0.10:55555")
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
config, _ := mw.getDomainConfig("svc.example")
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
require.True(t, handled, "iteration %d should forward", i)
}
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "subsequent forwards must hit the cache, not management")
}
// TestForwardWithTunnelPeer_RoutesAccountIDIntoCacheKey ensures cache keys
// honour account scoping — same tunnel IP on different accounts must not
// collide.
func TestForwardWithTunnelPeer_RoutesAccountIDIntoCacheKey(t *testing.T) {
validator := &stubSessionValidator{
respFn: func(req *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user"}
},
}
mw := NewMiddleware(log.New(), validator, nil)
require.NoError(t, mw.AddDomain("svc-a.example", nil, "", 0, "acct-a", "svc-a", nil, false))
require.NoError(t, mw.AddDomain("svc-b.example", nil, "", 0, "acct-b", "svc-b", nil, false))
for _, host := range []string{"svc-a.example", "svc-b.example"} {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "https://"+host+"/", nil)
r.Host = host
r.RemoteAddr = "100.64.0.10:55555"
config, _ := mw.getDomainConfig(host)
handled := mw.forwardWithTunnelPeer(w, r, host, config, http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
require.True(t, handled, "host %s should forward", host)
}
assert.Equal(t, int32(2), validator.tunnelCalls.Load(), "cache must not collide across accounts even when tunnel IPs match")
}
// TestForwardWithTunnelPeer_LocalLookupShortCircuitDoesNotPopulateCache
// guarantees that the deny-fast path leaves the cache untouched, so a
// subsequent request from the same IP after the peerstore catches up
// goes through the normal RPC flow.
func TestForwardWithTunnelPeer_LocalLookupShortCircuitDoesNotPopulateCache(t *testing.T) {
validator := &stubSessionValidator{
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok"}
},
}
mw := newTunnelMiddleware(t, validator)
knownIP := netip.MustParseAddr("100.64.0.10")
known := false
lookup := TunnelLookupFunc(func(ip netip.Addr) (PeerIdentity, bool) {
if known && ip == knownIP {
return PeerIdentity{TunnelIP: ip}, true
}
return PeerIdentity{}, false
})
doRequest := func() bool {
w, r := newTunnelRequest(knownIP.String() + ":55555")
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
config, _ := mw.getDomainConfig("svc.example")
return mw.forwardWithTunnelPeer(w, r, "svc.example", config, http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
}
require.False(t, doRequest(), "first request must short-circuit")
require.Equal(t, int32(0), validator.tunnelCalls.Load(), "short-circuit must not populate the cache")
known = true
require.True(t, doRequest(), "second request with peer in roster must forward via RPC")
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "RPC runs once after peerstore catches up")
}
func TestPrivateService_FailsClosedOnTunnelPeerFailure(t *testing.T) {
mw := NewMiddleware(log.New(), nil, nil)
require.NoError(t, mw.AddDomain("private.svc", nil, "", 0, "acct-1", "svc-1", nil, true))
called := false
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "https://private.svc/", nil)
req.Host = "private.svc"
req.RemoteAddr = "100.64.0.10:55555"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, http.StatusForbidden, w.Code)
assert.False(t, called)
}
func TestPrivateService_ForwardsOnTunnelPeerSuccess(t *testing.T) {
validator := &stubSessionValidator{
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
return &proto.ValidateTunnelPeerResponse{
Valid: true,
SessionToken: "tok",
UserId: "user-1",
}
},
}
mw := NewMiddleware(log.New(), validator, nil)
require.NoError(t, mw.AddDomain("private.svc", nil, "", 0, "acct-1", "svc-1", nil, true))
called := false
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "https://private.svc/", nil)
req.Host = "private.svc"
req.RemoteAddr = "100.64.0.10:55555"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.True(t, called)
}

View File

@@ -11,6 +11,7 @@ import (
"net/url"
"strings"
"time"
)
// StatusFilters contains filter options for status queries.
@@ -159,49 +160,6 @@ func (c *Client) printClients(data map[string]any) {
for _, item := range clients {
c.printClientRow(item)
}
c.printInboundListeners(clients)
}
func (c *Client) printInboundListeners(clients []any) {
type row struct {
accountID string
tunnelIP string
httpsPort int
httpPort int
}
var rows []row
for _, item := range clients {
client, ok := item.(map[string]any)
if !ok {
continue
}
inbound, ok := client["inbound_listener"].(map[string]any)
if !ok {
continue
}
tunnelIP, _ := inbound["tunnel_ip"].(string)
httpsPort, _ := inbound["https_port"].(float64)
httpPort, _ := inbound["http_port"].(float64)
accountID, _ := client["account_id"].(string)
rows = append(rows, row{
accountID: accountID,
tunnelIP: tunnelIP,
httpsPort: int(httpsPort),
httpPort: int(httpPort),
})
}
if len(rows) == 0 {
return
}
_, _ = fmt.Fprintln(c.out)
_, _ = fmt.Fprintln(c.out, "Inbound listeners (per-account):")
_, _ = fmt.Fprintf(c.out, " %-38s %-20s %-7s %s\n", "ACCOUNT ID", "TUNNEL IP", "HTTPS", "HTTP")
_, _ = fmt.Fprintln(c.out, " "+strings.Repeat("-", 78))
for _, r := range rows {
_, _ = fmt.Fprintf(c.out, " %-38s %-20s %-7d %d\n", r.accountID, r.tunnelIP, r.httpsPort, r.httpPort)
}
}
func (c *Client) printClientRow(item any) {
@@ -261,14 +219,7 @@ func (c *Client) ClientStatus(ctx context.Context, accountID string, filters Sta
}
func (c *Client) printClientStatus(data map[string]any) {
_, _ = fmt.Fprintf(c.out, "Account: %v\n", data["account_id"])
if inbound, ok := data["inbound_listener"].(map[string]any); ok {
tunnelIP, _ := inbound["tunnel_ip"].(string)
httpsPort, _ := inbound["https_port"].(float64)
httpPort, _ := inbound["http_port"].(float64)
_, _ = fmt.Fprintf(c.out, "Inbound listener: %s (https=%d, http=%d)\n", tunnelIP, int(httpsPort), int(httpPort))
}
_, _ = fmt.Fprintln(c.out)
_, _ = fmt.Fprintf(c.out, "Account: %v\n\n", data["account_id"])
if status, ok := data["status"].(string); ok {
_, _ = fmt.Fprint(c.out, status)
}

View File

@@ -61,23 +61,6 @@ type clientProvider interface {
ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo
}
// InboundListenerInfo describes a per-account inbound listener as
// surfaced through the debug HTTP handler. Mirrors the proto sub-message
// emitted with SendStatusUpdate so dashboards and CLI tooling see the
// same shape.
type InboundListenerInfo struct {
TunnelIP string `json:"tunnel_ip"`
HTTPSPort uint16 `json:"https_port"`
HTTPPort uint16 `json:"http_port"`
}
// InboundProvider exposes per-account inbound listener state. Optional;
// when nil the debug endpoint omits the inbound section entirely so the
// existing JSON shape stays additive.
type InboundProvider interface {
InboundListeners() map[types.AccountID]InboundListenerInfo
}
// healthChecker provides health probe state.
type healthChecker interface {
ReadinessProbe() bool
@@ -97,7 +80,6 @@ type Handler struct {
provider clientProvider
health healthChecker
certStatus certStatus
inbound InboundProvider
logger *log.Logger
startTime time.Time
templates *template.Template
@@ -126,13 +108,6 @@ func (h *Handler) SetCertStatus(cs certStatus) {
h.certStatus = cs
}
// SetInboundProvider wires per-account inbound listener observability.
// Pass nil (or skip the call) to keep the inbound section out of debug
// responses on proxies that don't run --private-inbound.
func (h *Handler) SetInboundProvider(p InboundProvider) {
h.inbound = p
}
func (h *Handler) loadTemplates() error {
tmpl, err := template.ParseFS(templateFS, "templates/*.html")
if err != nil {
@@ -348,35 +323,23 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
sortedIDs := sortedAccountIDs(clients)
if wantJSON {
var inboundAll map[types.AccountID]InboundListenerInfo
if h.inbound != nil {
inboundAll = h.inbound.InboundListeners()
}
clientsJSON := make([]map[string]interface{}, 0, len(clients))
for _, id := range sortedIDs {
info := clients[id]
row := map[string]interface{}{
clientsJSON = append(clientsJSON, map[string]interface{}{
"account_id": info.AccountID,
"service_count": info.ServiceCount,
"service_keys": info.ServiceKeys,
"has_client": info.HasClient,
"created_at": info.CreatedAt,
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
}
if inb, ok := inboundAll[id]; ok {
row["inbound_listener"] = inb
}
clientsJSON = append(clientsJSON, row)
})
}
resp := map[string]interface{}{
h.writeJSON(w, map[string]interface{}{
"uptime": time.Since(h.startTime).Round(time.Second).String(),
"client_count": len(clients),
"clients": clientsJSON,
}
if len(inboundAll) > 0 {
resp["inbound_listener_count"] = len(inboundAll)
}
h.writeJSON(w, resp)
})
return
}
@@ -458,14 +421,10 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc
})
if wantJSON {
resp := map[string]interface{}{
h.writeJSON(w, map[string]interface{}{
"account_id": accountID,
"status": overview.FullDetailSummary(),
}
if info, ok := h.inboundInfoFor(accountID); ok {
resp["inbound_listener"] = info
}
h.writeJSON(w, resp)
})
return
}
@@ -478,18 +437,6 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc
h.renderTemplate(w, "clientDetail", data)
}
// inboundInfoFor returns the inbound listener info for an account, or
// ok=false when no inbound provider is wired or the account has no live
// listener.
func (h *Handler) inboundInfoFor(accountID types.AccountID) (InboundListenerInfo, bool) {
if h.inbound == nil {
return InboundListenerInfo{}, false
}
all := h.inbound.InboundListeners()
info, ok := all[accountID]
return info, ok
}
func (h *Handler) handleClientSyncResponse(w http.ResponseWriter, _ *http.Request, accountID types.AccountID, wantJSON bool) {
client, ok := h.provider.GetClient(accountID)
if !ok {

View File

@@ -25,11 +25,6 @@ type Metrics struct {
backendDuration metric.Int64Histogram
certificateIssueDuration metric.Int64Histogram
// Management sync metrics.
snapshotSyncDuration metric.Int64Histogram
snapshotBatchDuration metric.Int64Histogram
addPeerDuration metric.Int64Histogram
// L4 service-level metrics.
l4Services metric.Int64UpDownCounter
@@ -59,9 +54,6 @@ func New(ctx context.Context, meter metric.Meter) (*Metrics, error) {
if err := m.initHTTPMetrics(meter); err != nil {
return nil, err
}
if err := m.initSyncMetrics(meter); err != nil {
return nil, err
}
if err := m.initL4Metrics(meter); err != nil {
return nil, err
}
@@ -134,59 +126,6 @@ func (m *Metrics) initHTTPMetrics(meter metric.Meter) error {
return err
}
func (m *Metrics) initSyncMetrics(meter metric.Meter) error {
var err error
m.snapshotSyncDuration, err = meter.Int64Histogram(
"proxy.sync.snapshot.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("Duration from management connect until the initial snapshot sync is complete"),
metric.WithExplicitBucketBoundaries(100, 250, 500, 1000, 2500, 5000, 10000, 30000, 60000, 120000, 300000),
)
if err != nil {
return err
}
m.snapshotBatchDuration, err = meter.Int64Histogram(
"proxy.sync.batch.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("Duration to process a single mapping batch during initial snapshot sync"),
metric.WithExplicitBucketBoundaries(100, 250, 500, 1000, 2500, 5000, 10000, 30000, 60000, 120000, 300000),
)
if err != nil {
return err
}
m.addPeerDuration, err = meter.Int64Histogram(
"proxy.peer.add.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("Duration to add a peer for an account (keygen + gRPC CreateProxyPeer + embed.New)"),
metric.WithExplicitBucketBoundaries(10, 25, 50, 100, 250, 500, 1000, 2500, 5000, 10000),
)
return err
}
// RecordSnapshotSyncDuration records the total time from connect to sync-complete.
func (m *Metrics) RecordSnapshotSyncDuration(d time.Duration) {
m.snapshotSyncDuration.Record(m.ctx, d.Milliseconds())
}
// RecordSnapshotBatchDuration records the time to process one mapping batch during initial sync.
func (m *Metrics) RecordSnapshotBatchDuration(d time.Duration) {
m.snapshotBatchDuration.Record(m.ctx, d.Milliseconds())
}
// RecordAddPeerDuration records the time to create a new peer for an account.
func (m *Metrics) RecordAddPeerDuration(d time.Duration, err error) {
result := "success"
if err != nil {
result = "error"
}
m.addPeerDuration.Record(m.ctx, d.Milliseconds(), metric.WithAttributes(
attribute.String("result", result),
))
}
func (m *Metrics) initL4Metrics(meter metric.Meter) error {
var err error

View File

@@ -52,15 +52,8 @@ type CapturedData struct {
origin ResponseOrigin
clientIP netip.Addr
userID string
userEmail string
userGroups []string
// userGroupNames pairs positionally with userGroups; populated from
// the JWT's group_names claim or from ValidateSession/Tunnel
// responses. Slice may be shorter than userGroups for tokens minted
// before names were resolvable.
userGroupNames []string
authMethod string
metadata map[string]string
authMethod string
metadata map[string]string
}
// NewCapturedData creates a CapturedData with the given request ID.
@@ -145,81 +138,6 @@ func (c *CapturedData) GetUserID() string {
return c.userID
}
// SetUserEmail records the authenticated user's email address. Used by
// policy-aware middlewares to stamp identity onto upstream requests
// (e.g. x-litellm-end-user-id) without a management round-trip.
func (c *CapturedData) SetUserEmail(email string) {
c.mu.Lock()
defer c.mu.Unlock()
c.userEmail = email
}
// GetUserEmail returns the authenticated user's email address. Returns
// the empty string when the auth path didn't carry an email (e.g.
// non-OIDC schemes or legacy JWTs minted before the email claim).
func (c *CapturedData) GetUserEmail() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.userEmail
}
// SetUserGroups records the authenticated user's group memberships so
// downstream policy-aware middlewares can authorise the request without
// an additional management round-trip. The auth middleware populates this
// from ValidateSessionResponse / ValidateTunnelPeerResponse and from the
// session JWT's groups claim on cookie-bearing requests.
func (c *CapturedData) SetUserGroups(groups []string) {
c.mu.Lock()
defer c.mu.Unlock()
if len(groups) == 0 {
c.userGroups = nil
return
}
c.userGroups = append(c.userGroups[:0], groups...)
}
// GetUserGroups returns a copy of the authenticated user's group
// memberships.
func (c *CapturedData) GetUserGroups() []string {
c.mu.RLock()
defer c.mu.RUnlock()
if len(c.userGroups) == 0 {
return nil
}
out := make([]string, len(c.userGroups))
copy(out, c.userGroups)
return out
}
// SetUserGroupNames records the human-readable display names for the
// user's groups, ordered identically to UserGroups (positional
// pairing). Stamped onto upstream requests as X-NetBird-Groups so
// downstream services can read names rather than opaque ids.
func (c *CapturedData) SetUserGroupNames(names []string) {
c.mu.Lock()
defer c.mu.Unlock()
if len(names) == 0 {
c.userGroupNames = nil
return
}
c.userGroupNames = append(c.userGroupNames[:0], names...)
}
// GetUserGroupNames returns a copy of the authenticated user's group
// display names. Position i pairs with UserGroups[i]. May be shorter
// than UserGroups for tokens minted before names were resolvable; the
// consumer should fall back to ids for missing positions.
func (c *CapturedData) GetUserGroupNames() []string {
c.mu.RLock()
defer c.mu.RUnlock()
if len(c.userGroupNames) == 0 {
return nil
}
out := make([]string, len(c.userGroupNames))
copy(out, c.userGroupNames)
return out
}
// SetAuthMethod sets the authentication method used.
func (c *CapturedData) SetAuthMethod(method string) {
c.mu.Lock()

View File

@@ -86,9 +86,6 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if pt.RequestTimeout > 0 {
ctx = types.WithDialTimeout(ctx, pt.RequestTimeout)
}
if pt.DirectUpstream {
ctx = roundtrip.WithDirectUpstream(ctx)
}
rewriteMatchedPath := result.matchedPath
if pt.PathRewrite == PathRewritePreserve {
@@ -145,8 +142,6 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost
r.Out.Header.Set(k, v)
}
stampNetBirdIdentity(r)
clientIP := extractHostIP(r.In.RemoteAddr)
if isTrustedAddr(clientIP, p.trustedProxies) {
@@ -431,70 +426,3 @@ func opErrorContains(err error, substr string) bool {
}
return false
}
const (
// headerNetBirdUser carries the authenticated user's display identity
// (email when the peer is attached to a user, else peer name) onto
// upstream requests. Stripped from inbound requests before stamping
// so a client can't spoof identity by setting the header themselves.
headerNetBirdUser = "X-NetBird-User"
// headerNetBirdGroups carries the user's group display names as a
// comma-separated list. Falls back to group IDs at positions where a
// name wasn't available at session-mint time. Labels containing a
// comma or any non-printable byte are dropped at stamp time so the
// list is unambiguously splittable by consumers.
headerNetBirdGroups = "X-NetBird-Groups"
)
// isHeaderValueSafe reports whether v is a valid RFC 7230 field-value:
// VCHAR (0x21-0x7E), SP (0x20), or HTAB (0x09). Empty values are
// rejected; the caller decides whether to omit the header entirely.
func isHeaderValueSafe(v string) bool {
if v == "" {
return false
}
for i := 0; i < len(v); i++ {
c := v[i]
if c == '\t' || (c >= 0x20 && c <= 0x7E) {
continue
}
return false
}
return true
}
// stampNetBirdIdentity injects authenticated identity onto outbound
// requests as X-NetBird-User and X-NetBird-Groups. Always strips any
// client-sent values first (anti-spoof). Skips when the request didn't
// carry CapturedData (early-path errors, internal endpoints).
func stampNetBirdIdentity(r *httputil.ProxyRequest) {
r.Out.Header.Del(headerNetBirdUser)
r.Out.Header.Del(headerNetBirdGroups)
cd := CapturedDataFromContext(r.In.Context())
if cd == nil {
return
}
if email := cd.GetUserEmail(); isHeaderValueSafe(email) {
r.Out.Header.Set(headerNetBirdUser, email)
}
groupIDs := cd.GetUserGroups()
if len(groupIDs) == 0 {
return
}
groupNames := cd.GetUserGroupNames()
labels := make([]string, 0, len(groupIDs))
for i, id := range groupIDs {
label := id
if i < len(groupNames) && groupNames[i] != "" {
label = groupNames[i]
}
if !isHeaderValueSafe(label) || strings.ContainsRune(label, ',') {
continue
}
labels = append(labels, label)
}
if len(labels) > 0 {
r.Out.Header.Set(headerNetBirdGroups, strings.Join(labels, ","))
}
}

View File

@@ -1067,245 +1067,3 @@ func TestClassifyProxyError(t *testing.T) {
})
}
}
func TestStampNetBirdIdentity_NoCapturedData_StripsOnly(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io")
pr.In.Header.Set(headerNetBirdGroups, "admin")
pr.Out.Header = pr.In.Header.Clone()
rewrite(pr)
assert.Empty(t, pr.Out.Header.Get(headerNetBirdUser),
"client-supplied X-NetBird-User must be stripped when no captured identity is present")
assert.Empty(t, pr.Out.Header.Get(headerNetBirdGroups),
"client-supplied X-NetBird-Groups must be stripped when no captured identity is present")
}
func TestStampNetBirdIdentity_StampsFromCapturedData(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io")
pr.Out.Header = pr.In.Header.Clone()
cd := NewCapturedData("req-1")
cd.SetUserEmail("alice@netbird.io")
cd.SetUserGroups([]string{"grp-eng", "grp-ops"})
cd.SetUserGroupNames([]string{"engineering", "operations"})
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
rewrite(pr)
assert.Equal(t, "alice@netbird.io", pr.Out.Header.Get(headerNetBirdUser),
"captured email must overwrite any spoofed value")
assert.Equal(t, "engineering,operations", pr.Out.Header.Get(headerNetBirdGroups),
"group display names must be CSV-joined in positional order")
}
// TestStampNetBirdIdentity_GroupsOnlyWhenEmailEmpty covers the
// tunnel-peer-without-user case (machine agents, unattached proxy peers).
// The proxy must still stamp the peer's groups so downstream services can
// authorise, but X-NetBird-User stays unset — only its inbound stripping
// must happen.
func TestStampNetBirdIdentity_GroupsOnlyWhenEmailEmpty(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io")
pr.Out.Header = pr.In.Header.Clone()
cd := NewCapturedData("req-1")
cd.SetUserGroups([]string{"grp-machines"})
cd.SetUserGroupNames([]string{"machines"})
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
rewrite(pr)
assert.Empty(t, pr.Out.Header.Get(headerNetBirdUser),
"X-NetBird-User must remain unset when CapturedData carries no email")
assert.Equal(t, "machines", pr.Out.Header.Get(headerNetBirdGroups),
"groups must still be stamped for peers without a user identity")
}
// TestStampNetBirdIdentity_EmailOnlyWhenGroupsEmpty covers the symmetric
// case: identity-resolved user without resolved group memberships.
func TestStampNetBirdIdentity_EmailOnlyWhenGroupsEmpty(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set(headerNetBirdGroups, "spoofed-admin")
pr.Out.Header = pr.In.Header.Clone()
cd := NewCapturedData("req-1")
cd.SetUserEmail("carol@netbird.io")
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
rewrite(pr)
assert.Equal(t, "carol@netbird.io", pr.Out.Header.Get(headerNetBirdUser),
"email must be stamped even when no groups are captured")
assert.Empty(t, pr.Out.Header.Get(headerNetBirdGroups),
"X-NetBird-Groups must remain unset when CapturedData carries no groups")
}
func TestStampNetBirdIdentity_FallsBackToGroupIDsWhenNameMissing(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
cd := NewCapturedData("req-1")
cd.SetUserEmail("bob@netbird.io")
cd.SetUserGroups([]string{"grp-a", "grp-b", "grp-c"})
// "grp-b" gets an explicit empty-string display name (not just a
// shorter slice). Both gap shapes must fall back to the id.
cd.SetUserGroupNames([]string{"alpha", "", ""})
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
rewrite(pr)
assert.Equal(t, "alpha,grp-b,grp-c", pr.Out.Header.Get(headerNetBirdGroups),
"empty-string and out-of-range name slots must both fall back to the group id")
}
// TestStampNetBirdIdentity_DropsLabelsWithComma covers the
// comma-separator constraint: a group display name that itself contains
// a comma is dropped from the header (rather than corrupting the list),
// and the remaining labels are stamped.
func TestStampNetBirdIdentity_DropsLabelsWithComma(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
cd := NewCapturedData("req-1")
cd.SetUserEmail("alice@netbird.io")
cd.SetUserGroups([]string{"grp-a", "grp-b", "grp-c"})
cd.SetUserGroupNames([]string{"engineering", "EU, EMEA", "operations"})
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
rewrite(pr)
assert.Equal(t, "engineering,operations", pr.Out.Header.Get(headerNetBirdGroups),
"group label with embedded comma must be dropped, remaining labels stamped")
}
// TestStampNetBirdIdentity_RejectsControlCharsInEmail covers the
// header-injection defence: an email value containing CR/LF/control
// chars is omitted entirely (not partially stamped) so the upstream
// request stays well-formed and no header injection is possible.
func TestStampNetBirdIdentity_RejectsControlCharsInEmail(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io")
pr.Out.Header = pr.In.Header.Clone()
cd := NewCapturedData("req-1")
cd.SetUserEmail("alice@netbird.io\r\nX-Admin: yes")
cd.SetUserGroups([]string{"grp-a"})
cd.SetUserGroupNames([]string{"engineering"})
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
rewrite(pr)
assert.Empty(t, pr.Out.Header.Get(headerNetBirdUser),
"email with CR/LF must be dropped, not partially stamped")
assert.Equal(t, "engineering", pr.Out.Header.Get(headerNetBirdGroups),
"groups remain stampable even when email is invalid")
}
// TestStampNetBirdIdentity_RejectsControlCharsInGroup covers the
// per-label defence: a group name with a control char is silently
// dropped, the rest are stamped.
func TestStampNetBirdIdentity_RejectsControlCharsInGroup(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
cd := NewCapturedData("req-1")
cd.SetUserEmail("alice@netbird.io")
cd.SetUserGroups([]string{"grp-a", "grp-b"})
cd.SetUserGroupNames([]string{"engineering\r\nsneaky", "operations"})
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
rewrite(pr)
assert.Equal(t, "operations", pr.Out.Header.Get(headerNetBirdGroups),
"group label with control char must be dropped, valid ones kept")
}
// TestStampNetBirdIdentity_OmitsGroupsHeaderWhenAllInvalid covers the
// edge case where every group label is rejected: the header must not be
// set at all (rather than set to an empty string).
func TestStampNetBirdIdentity_OmitsGroupsHeaderWhenAllInvalid(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set(headerNetBirdGroups, "spoofed-admin")
pr.Out.Header = pr.In.Header.Clone()
cd := NewCapturedData("req-1")
cd.SetUserEmail("alice@netbird.io")
cd.SetUserGroups([]string{"grp-a", "grp-b"})
cd.SetUserGroupNames([]string{"with,comma", "with\nbreak"})
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
rewrite(pr)
_, present := pr.Out.Header[http.CanonicalHeaderKey(headerNetBirdGroups)]
assert.False(t, present,
"X-NetBird-Groups must not be set when every group label is rejected")
}
// TestStampNetBirdIdentity_CapturedDataPresentButEmpty covers requests
// that carry CapturedData with no identity fields populated (e.g. the
// auth middleware ran but the request didn't authenticate). Both
// headers must be cleared and neither stamped.
func TestStampNetBirdIdentity_CapturedDataPresentButEmpty(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io")
pr.In.Header.Set(headerNetBirdGroups, "spoofed-admin")
pr.Out.Header = pr.In.Header.Clone()
cd := NewCapturedData("req-1")
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
rewrite(pr)
assert.Empty(t, pr.Out.Header.Get(headerNetBirdUser),
"X-NetBird-User must be stripped when CapturedData has no email")
assert.Empty(t, pr.Out.Header.Get(headerNetBirdGroups),
"X-NetBird-Groups must be stripped when CapturedData has no groups")
}

View File

@@ -28,10 +28,6 @@ type PathTarget struct {
RequestTimeout time.Duration
PathRewrite PathRewriteMode
CustomHeaders map[string]string
// DirectUpstream selects the stdlib HTTP transport (host network stack)
// over the embedded NetBird WireGuard client when forwarding requests
// to this target. Default false → embedded client (existing behaviour).
DirectUpstream bool
}
// Mapping describes how a domain is routed by the HTTP reverse proxy.

View File

@@ -191,18 +191,6 @@ func (f *Filter) IsObserveOnly(v Verdict) bool {
return v.IsCrowdSec() && f.CrowdSecMode == CrowdSecObserve
}
// CheckCIDR runs only the CIDR allow/block evaluation. Use this when
// country and CrowdSec checks don't apply — e.g. requests arriving
// from the WireGuard overlay, whose source addresses live in the
// CGNAT range and have no meaningful geolocation or IP-reputation
// data.
func (f *Filter) CheckCIDR(addr netip.Addr) Verdict {
if f == nil {
return Allow
}
return f.checkCIDR(addr.Unmap())
}
// Check evaluates whether addr is permitted. CIDR rules are evaluated
// first because they are O(n) prefix comparisons. Country rules run
// only when CIDR checks pass and require a geo lookup. CrowdSec checks

View File

@@ -514,34 +514,6 @@ func TestFilter_CrowdSec_Observe_NilChecker(t *testing.T) {
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.2.3.4"), nil))
}
func TestFilter_CheckCIDR_AllowsWithoutCountryOrCrowdSec(t *testing.T) {
cs := &mockCrowdSec{ready: true, decisions: map[string]*CrowdSecDecision{
"100.64.5.6": {Type: DecisionBan},
}}
f := ParseFilter(FilterConfig{
AllowedCIDRs: []string{"100.64.0.0/10"},
AllowedCountries: []string{"US"},
CrowdSec: cs,
CrowdSecMode: CrowdSecEnforce,
})
// CheckCIDR skips country + CrowdSec evaluation: an address inside
// the allowed CIDR passes even when it would be denied by CrowdSec
// or by the country allowlist (CGNAT addresses have no geo data).
assert.Equal(t, Allow, f.CheckCIDR(netip.MustParseAddr("100.64.5.6")),
"CheckCIDR must not run CrowdSec lookups on overlay traffic")
// CIDR denials still fire.
assert.Equal(t, DenyCIDR, f.CheckCIDR(netip.MustParseAddr("198.51.100.1")),
"CheckCIDR must still reject addresses outside the allow list")
}
func TestFilter_CheckCIDR_NilFilter(t *testing.T) {
var f *Filter
assert.Equal(t, Allow, f.CheckCIDR(netip.MustParseAddr("100.64.5.6")),
"CheckCIDR on a nil filter must allow")
}
func TestFilter_HasRestrictions_CrowdSec(t *testing.T) {
cs := &mockCrowdSec{ready: true}
f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecEnforce})

View File

@@ -1,112 +0,0 @@
package roundtrip
import (
"crypto/tls"
"errors"
"net"
"net/http"
"time"
log "github.com/sirupsen/logrus"
)
// MultiTransport dispatches each request to either the embedded NetBird
// http.RoundTripper or a stdlib http.Transport based on a per-request
// context flag set by the reverse-proxy rewrite step. When the flag is
// absent (the default for every existing target), requests follow the
// embedded NetBird path — current behaviour, preserved.
//
// The stdlib branch is used when a target was configured with
// direct_upstream=true. It dials via the host's network stack, which is
// what private (`netbird proxy`) deployments and centralised proxies
// fronting host-reachable upstreams (public APIs, LAN services,
// localhost sidecars) want.
//
// An embedded roundtripper is required. To run direct-only (no WG
// branch at all), construct the MultiTransport via NewDirectOnly.
type MultiTransport struct {
embedded http.RoundTripper
direct *http.Transport
insecure *http.Transport
}
// errNoEmbeddedTransport is returned when a request reaches the
// embedded branch on a MultiTransport that wasn't given one. Surfaces
// the misconfiguration to the caller instead of silently routing to
// the direct branch, which would bypass the WG tunnel.
var errNoEmbeddedTransport = errors.New("multitransport: embedded roundtripper not configured")
// NewMultiTransport wires both branches. embedded is the existing NetBird
// roundtripper and must not be nil — pass to NewDirectOnly for a
// MultiTransport that only ever uses the direct branch. The direct
// branches honour the same NB_PROXY_* tuning env vars as the embedded
// transport (see loadTransportConfig) plus a dial-timeout wrapper that
// respects types.WithDialTimeout.
func NewMultiTransport(embedded http.RoundTripper, logger *log.Logger) *MultiTransport {
if logger == nil {
logger = log.StandardLogger()
}
cfg := loadTransportConfig(logger)
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
direct := &http.Transport{
DialContext: dialWithTimeout(dialer.DialContext),
ForceAttemptHTTP2: true,
MaxIdleConns: cfg.maxIdleConns,
MaxIdleConnsPerHost: cfg.maxIdleConnsPerHost,
MaxConnsPerHost: cfg.maxConnsPerHost,
IdleConnTimeout: cfg.idleConnTimeout,
TLSHandshakeTimeout: cfg.tlsHandshakeTimeout,
ExpectContinueTimeout: cfg.expectContinueTimeout,
ResponseHeaderTimeout: cfg.responseHeaderTimeout,
WriteBufferSize: cfg.writeBufferSize,
ReadBufferSize: cfg.readBufferSize,
DisableCompression: cfg.disableCompression,
}
insecure := direct.Clone()
insecure.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec // matches the embedded NetBird transport's per-target opt-in
return &MultiTransport{
embedded: embedded,
direct: direct,
insecure: insecure,
}
}
// NewDirectOnly returns a MultiTransport with no embedded branch.
// Every request goes through the direct branch regardless of the
// per-request flag, so the embedded path can never be reached
// silently — wiring code that needs WG must use NewMultiTransport.
func NewDirectOnly(logger *log.Logger) *MultiTransport {
return NewMultiTransport(noEmbeddedRoundTripper{}, logger)
}
// noEmbeddedRoundTripper is the sentinel embedded transport for
// direct-only MultiTransports. RoundTrip is never called in practice
// because the direct branch matches every request, but if anything
// ever did reach this path it would fail loudly instead of falling
// back to direct.
type noEmbeddedRoundTripper struct{}
func (noEmbeddedRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
return nil, errNoEmbeddedTransport
}
// RoundTrip dispatches by reading the direct-upstream flag from the request
// context. When set, the request is forwarded via the stdlib transport,
// honouring the existing per-request skip-TLS-verify flag. Otherwise it
// goes through the embedded NetBird roundtripper.
func (m *MultiTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if DirectUpstreamFromContext(req.Context()) {
if skipTLSVerifyFromContext(req.Context()) {
return m.insecure.RoundTrip(req)
}
return m.direct.RoundTrip(req)
}
if m.embedded == nil {
return nil, errNoEmbeddedTransport
}
return m.embedded.RoundTrip(req)
}

View File

@@ -1,134 +0,0 @@
package roundtrip
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// stubRoundTripper records whether RoundTrip was called and returns a
// canned response so tests can assert the dispatch decision without
// running a real network.
type stubRoundTripper struct {
called bool
body string
}
func (s *stubRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) {
s.called = true
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(s.body)),
Header: http.Header{},
}, nil
}
func TestMultiTransport_DispatchesByContextFlag(t *testing.T) {
embedded := &stubRoundTripper{body: "embedded"}
mt := NewMultiTransport(embedded, nil)
t.Run("default routes to embedded", func(t *testing.T) {
embedded.called = false
req := httptest.NewRequest(http.MethodGet, "http://example.invalid", nil)
resp, err := mt.RoundTrip(req)
require.NoError(t, err, "embedded path must not error on stubbed transport")
require.NotNil(t, resp)
_ = resp.Body.Close()
assert.True(t, embedded.called, "request without WithDirectUpstream must hit the embedded transport")
})
t.Run("WithDirectUpstream skips embedded", func(t *testing.T) {
embedded.called = false
// Hit a server we control to verify the stdlib transport is used.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "direct")
}))
defer srv.Close()
req, err := http.NewRequestWithContext(WithDirectUpstream(context.Background()), http.MethodGet, srv.URL, nil)
require.NoError(t, err)
resp, err := mt.RoundTrip(req)
require.NoError(t, err, "direct path must dial via stdlib transport")
body, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
require.NoError(t, err)
assert.Equal(t, "direct", string(body), "stdlib transport must reach the test server")
assert.False(t, embedded.called, "WithDirectUpstream must bypass the embedded transport")
})
}
// TestMultiTransport_AppliesEnvOverridesToDirect verifies that the
// NB_PROXY_* env vars consumed by loadTransportConfig flow into the
// direct branches (previously they only applied to the embedded
// roundtripper, so direct-upstream traffic ignored operator tuning).
func TestMultiTransport_AppliesEnvOverridesToDirect(t *testing.T) {
t.Setenv(EnvMaxIdleConns, "42")
t.Setenv(EnvIdleConnTimeout, "11s")
t.Setenv(EnvTLSHandshakeTimeout, "7s")
mt := NewMultiTransport(&stubRoundTripper{body: "embedded"}, nil)
assert.Equal(t, 42, mt.direct.MaxIdleConns,
"NB_PROXY_MAX_IDLE_CONNS must propagate to the direct transport")
assert.Equal(t, 11*time.Second, mt.direct.IdleConnTimeout,
"NB_PROXY_IDLE_CONN_TIMEOUT must propagate to the direct transport")
assert.Equal(t, 7*time.Second, mt.direct.TLSHandshakeTimeout,
"NB_PROXY_TLS_HANDSHAKE_TIMEOUT must propagate to the direct transport")
assert.Equal(t, 42, mt.insecure.MaxIdleConns,
"env tuning must also apply to the insecure-skip-verify direct transport")
}
// TestMultiTransport_NilEmbeddedErrorsWhenWGPathRequested guards
// against the previous silent fallback: a MultiTransport constructed
// without an embedded transport must reject requests that don't
// explicitly opt into the direct branch, rather than routing them
// over the host stack and bypassing WireGuard.
func TestMultiTransport_NilEmbeddedErrorsWhenWGPathRequested(t *testing.T) {
mt := NewMultiTransport(nil, nil)
req := httptest.NewRequest(http.MethodGet, "http://example.invalid", nil)
resp, err := mt.RoundTrip(req)
if resp != nil {
_ = resp.Body.Close()
}
require.Error(t, err, "nil embedded must surface as an explicit error, not a silent direct dispatch")
assert.Nil(t, resp)
assert.ErrorIs(t, err, errNoEmbeddedTransport,
"the error must be the sentinel so callers can distinguish misconfiguration from network failures")
}
// TestMultiTransport_DirectOnlyServesDirectBranch verifies NewDirectOnly
// constructs a MultiTransport whose direct branch handles requests with
// the direct-upstream flag set, and surfaces the explicit sentinel
// when the embedded path is reached.
func TestMultiTransport_DirectOnlyServesDirectBranch(t *testing.T) {
mt := NewDirectOnly(nil)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "ok")
}))
defer srv.Close()
req, err := http.NewRequestWithContext(WithDirectUpstream(context.Background()), http.MethodGet, srv.URL, nil)
require.NoError(t, err)
resp, err := mt.RoundTrip(req)
require.NoError(t, err, "direct-only must serve requests that opt into the direct branch")
_ = resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
wgReq := httptest.NewRequest(http.MethodGet, "http://example.invalid", nil)
resp, err = mt.RoundTrip(wgReq)
if resp != nil {
_ = resp.Body.Close()
}
require.Error(t, err, "direct-only must refuse requests that didn't opt into the direct branch")
assert.Nil(t, resp)
assert.ErrorIs(t, err, errNoEmbeddedTransport)
}

View File

@@ -7,7 +7,6 @@ import (
"fmt"
"net"
"net/http"
"net/netip"
"sync"
"time"
@@ -77,11 +76,6 @@ type clientEntry struct {
services map[ServiceKey]serviceInfo
createdAt time.Time
started bool
// inbound is opaque per-account state owned by the NetBird parent's
// ReadyHandler. The roundtrip package never inspects this value; it
// only stores it so RemovePeer / StopAll can hand it back to the
// matching StopHandler. Nil when no inbound integration is active.
inbound any
// Per-backend in-flight limiting keyed by target host:port.
// TODO: clean up stale entries when backend targets change.
inflightMu sync.Mutex
@@ -89,19 +83,6 @@ type clientEntry struct {
maxInflight int
}
// IdentityForIP resolves a tunnel IP to the peer identity locally known by
// this account's embedded client. Returns (pubKey, fqdn) on success.
// ok=false means the IP is not in the account's roster — callers can use
// that as a fast deny without round-tripping management. The returned
// strings carry only what the embedded peerstore exposes; user identity
// (UserID / Email / Groups) still flows through ValidateTunnelPeer.
func (e *clientEntry) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
if e == nil || e.client == nil || !ip.IsValid() {
return "", "", false
}
return e.client.IdentityForIP(ip)
}
// acquireInflight attempts to acquire an in-flight slot for the given backend.
// It returns a release function that must always be called, and true on success.
func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bool) {
@@ -131,12 +112,6 @@ type ClientConfig struct {
MgmtAddr string
WGPort uint16
PreSharedKey string
// BlockInbound mirrors embed.Options.BlockInbound. Set to true on the
// standalone proxy where the embedded client never accepts inbound;
// set to false on the private/embedded proxy so the engine creates
// the ACL manager and applies management's per-policy firewall rules
// (which is what gates per-account inbound listeners on the netstack).
BlockInbound bool
}
type statusNotifier interface {
@@ -162,19 +137,6 @@ type NetBird struct {
clients map[types.AccountID]*clientEntry
initLogOnce sync.Once
statusNotifier statusNotifier
// readyHandler runs after the embedded client for an account reports
// Ready. The opaque return value is stored on clientEntry and handed
// back to stopHandler when the entry is torn down. Nil disables the
// hook entirely (default for the standalone proxy).
readyHandler func(ctx context.Context, accountID types.AccountID, client *embed.Client) any
// stopHandler runs when an account's last service is removed (or the
// transport is shutting down). Receives whatever readyHandler returned.
stopHandler func(accountID types.AccountID, state any)
// OnAddPeer, when set, is called after AddPeer completes for a new account
// (i.e. when a new client was actually created, not when an existing one
// was reused). The duration covers keygen + gRPC CreateProxyPeer + embed.New.
OnAddPeer func(d time.Duration, err error)
}
// ClientDebugInfo contains debug information about a client.
@@ -222,11 +184,7 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
return nil
}
createStart := time.Now()
entry, err := n.createClientEntry(ctx, accountID, key, authToken, si)
if n.OnAddPeer != nil {
n.OnAddPeer(time.Since(createStart), err)
}
if err != nil {
n.clientsMux.Unlock()
return err
@@ -306,15 +264,9 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
ManagementURL: n.clientCfg.MgmtAddr,
PrivateKey: privateKey.String(),
LogLevel: log.WarnLevel.String(),
BlockInbound: n.clientCfg.BlockInbound,
// The embedded proxy peer must never be a stepping stone into
// the proxy host's LAN: it only exists to reach NetBird mesh
// targets or, when direct_upstream is set, the host network
// stack via the MultiTransport's direct branch (which bypasses
// the engine routing entirely).
BlockLANAccess: true,
WireguardPort: &wgPort,
PreSharedKey: n.clientCfg.PreSharedKey,
BlockInbound: true,
WireguardPort: &wgPort,
PreSharedKey: n.clientCfg.PreSharedKey,
})
if err != nil {
return nil, fmt.Errorf("create netbird client: %w", err)
@@ -379,25 +331,8 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
toNotify = append(toNotify, serviceNotification{key: key, serviceID: info.serviceID})
}
}
readyHandler := n.readyHandler
n.clientsMux.Unlock()
if readyHandler != nil {
state := readyHandler(ctx, accountID, client)
n.clientsMux.Lock()
if e, ok := n.clients[accountID]; ok {
e.inbound = state
} else if state != nil && n.stopHandler != nil {
// Account was removed while readyHandler ran; tear down the
// resources it just brought up.
stop := n.stopHandler
n.clientsMux.Unlock()
stop(accountID, state)
n.clientsMux.Lock()
}
n.clientsMux.Unlock()
}
if n.statusNotifier == nil {
return
}
@@ -443,15 +378,11 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
stopClient := len(entry.services) == 0
var client *embed.Client
var transport, insecureTransport *http.Transport
var inbound any
var stopHandler func(types.AccountID, any)
if stopClient {
n.logger.WithField("account_id", accountID).Info("stopping client, no more services")
client = entry.client
transport = entry.transport
insecureTransport = entry.insecureTransport
inbound = entry.inbound
stopHandler = n.stopHandler
delete(n.clients, accountID)
} else {
n.logger.WithFields(log.Fields{
@@ -465,9 +396,6 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
n.notifyDisconnect(ctx, accountID, key, si.serviceID)
if stopClient {
if inbound != nil && stopHandler != nil {
stopHandler(accountID, inbound)
}
transport.CloseIdleConnections()
insecureTransport.CloseIdleConnections()
if err := client.Stop(ctx); err != nil {
@@ -554,12 +482,8 @@ func (n *NetBird) StopAll(ctx context.Context) error {
n.clientsMux.Lock()
defer n.clientsMux.Unlock()
stopHandler := n.stopHandler
var merr *multierror.Error
for accountID, entry := range n.clients {
if entry.inbound != nil && stopHandler != nil {
stopHandler(accountID, entry.inbound)
}
entry.transport.CloseIdleConnections()
entry.insecureTransport.CloseIdleConnections()
if err := entry.client.Stop(ctx); err != nil {
@@ -612,19 +536,6 @@ func (n *NetBird) GetClient(accountID types.AccountID) (*embed.Client, bool) {
return entry.client, true
}
// IdentityForIP resolves a tunnel IP to a peer identity local to the given
// account. Delegates to clientEntry.IdentityForIP. Returns ok=false when
// the account has no client or the IP is not in its peerstore.
func (n *NetBird) IdentityForIP(accountID types.AccountID, ip netip.Addr) (pubKey, fqdn string, ok bool) {
n.clientsMux.RLock()
entry, exists := n.clients[accountID]
n.clientsMux.RUnlock()
if !exists {
return "", "", false
}
return entry.IdentityForIP(ip)
}
// ListClientsForDebug returns information about all clients for debug purposes.
func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo {
n.clientsMux.RLock()
@@ -680,18 +591,6 @@ func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.L
}
}
// SetClientLifecycle registers callbacks that run when an embedded
// client becomes ready and when its entry is torn down. The opaque value
// returned by ready is stored on the entry and handed back to stop on
// cleanup. Must be called before AddPeer. A nil pair leaves the
// outbound-only behaviour intact.
func (n *NetBird) SetClientLifecycle(ready func(ctx context.Context, accountID types.AccountID, client *embed.Client) any, stop func(accountID types.AccountID, state any)) {
n.clientsMux.Lock()
defer n.clientsMux.Unlock()
n.readyHandler = ready
n.stopHandler = stop
}
// dialWithTimeout wraps a DialContext function so that any dial timeout
// stored in the context (via types.WithDialTimeout) is applied only to
// the connection establishment phase, not the full request lifetime.
@@ -734,22 +633,3 @@ func skipTLSVerifyFromContext(ctx context.Context) bool {
v, _ := ctx.Value(skipTLSVerifyContextKey{}).(bool)
return v
}
// directUpstreamContextKey signals that the request should bypass the embedded
// NetBird WireGuard client and dial via the host's network stack instead.
// Set by the reverse-proxy rewrite step when the matched target carries
// PathTarget.DirectUpstream; consumed by MultiTransport.
type directUpstreamContextKey struct{}
// WithDirectUpstream marks the context so MultiTransport routes the request
// through its stdlib transport instead of the embedded NetBird roundtripper.
func WithDirectUpstream(ctx context.Context) context.Context {
return context.WithValue(ctx, directUpstreamContextKey{}, true)
}
// DirectUpstreamFromContext reports whether the context has been marked to
// bypass the embedded NetBird client.
func DirectUpstreamFromContext(ctx context.Context) bool {
v, _ := ctx.Value(directUpstreamContextKey{}).(bool)
return v
}

View File

@@ -3,7 +3,6 @@ package roundtrip
import (
"context"
"net/http"
"net/netip"
"sync"
"testing"
@@ -306,36 +305,6 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
assert.True(t, calls[0].connected)
}
// TestNetBird_IdentityForIP_UnknownAccountReturnsFalse confirms that the
// public lookup short-circuits when no client has been registered for
// the queried account. The auth middleware uses ok=false as a fast deny.
func TestNetBird_IdentityForIP_UnknownAccountReturnsFalse(t *testing.T) {
nb := mockNetBird()
_, _, ok := nb.IdentityForIP("acct-missing", netip.MustParseAddr("100.64.0.10"))
assert.False(t, ok, "unknown account must yield ok=false")
}
// TestClientEntry_IdentityForIP_NilClientGuard ensures the receiver
// methods stay safe when called on partially-initialized state, which
// can happen briefly during AddPeer setup or test fixtures.
func TestClientEntry_IdentityForIP_NilClientGuard(t *testing.T) {
var e *clientEntry
_, _, ok := e.IdentityForIP(netip.MustParseAddr("100.64.0.10"))
assert.False(t, ok, "nil clientEntry must yield ok=false")
e = &clientEntry{}
_, _, ok = e.IdentityForIP(netip.MustParseAddr("100.64.0.10"))
assert.False(t, ok, "clientEntry with nil embed.Client must yield ok=false")
}
// TestClientEntry_IdentityForIP_InvalidIPReturnsFalse covers the input
// guard so callers don't have to repeat the check.
func TestClientEntry_IdentityForIP_InvalidIPReturnsFalse(t *testing.T) {
e := &clientEntry{}
_, _, ok := e.IdentityForIP(netip.Addr{})
assert.False(t, ok, "invalid IP must yield ok=false")
}
func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
notifier := &mockStatusNotifier{}
nb := NewNetBird("test-proxy", "invalid.test", ClientConfig{

View File

@@ -36,7 +36,7 @@ func BenchmarkPeekClientHello_TLS(b *testing.B) {
for b.Loop() {
r := bytes.NewReader(hello)
conn := &readerConn{Reader: r}
sni, wrapped, _, err := PeekClientHello(conn)
sni, wrapped, err := PeekClientHello(conn)
if err != nil {
b.Fatal(err)
}
@@ -59,7 +59,7 @@ func BenchmarkPeekClientHello_NonTLS(b *testing.B) {
for b.Loop() {
r := bytes.NewReader(httpReq)
conn := &readerConn{Reader: r}
_, wrapped, _, err := PeekClientHello(conn)
_, wrapped, err := PeekClientHello(conn)
if err != nil {
b.Fatal(err)
}

Some files were not shown because too many files have changed in this diff Show More