mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-21 16:19:56 +00:00
Compare commits
1 Commits
refactor/l
...
debug-logs
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6240dcd96a |
1
.github/pull_request_template.md
vendored
1
.github/pull_request_template.md
vendored
@@ -12,7 +12,6 @@
|
||||
- [ ] Is a feature enhancement
|
||||
- [ ] It is a refactor
|
||||
- [ ] Created tests that fail without the change (if possible)
|
||||
- [ ] This change does **not** modify the public API, gRPC protocols, functionality behavior, CLI / service flags, or introduce a new feature — **OR** I have discussed it with the NetBird team beforehand (link the issue / Slack thread in the description). See [CONTRIBUTING.md](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTING.md#discuss-changes-with-the-netbird-team-first).
|
||||
|
||||
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).
|
||||
|
||||
|
||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
golangci:
|
||||
strategy:
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -33,4 +33,3 @@ infrastructure_files/setup-*.env
|
||||
vendor/
|
||||
/netbird
|
||||
client/netbird-electron/
|
||||
management/server/types/testdata/
|
||||
|
||||
@@ -8,14 +8,13 @@ There are many ways that you can contribute:
|
||||
- Sharing use cases in slack or Reddit
|
||||
- Bug fix or feature enhancement
|
||||
|
||||
If you haven't already, join our slack workspace [here](https://docs.netbird.io/slack-url), we would love to discuss topics that need community contribution and enhancements to existing features.
|
||||
If you haven't already, join our slack workspace [here](https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A), we would love to discuss topics that need community contribution and enhancements to existing features.
|
||||
|
||||
## Contents
|
||||
|
||||
- [Contributing to NetBird](#contributing-to-netbird)
|
||||
- [Contents](#contents)
|
||||
- [Code of conduct](#code-of-conduct)
|
||||
- [Discuss changes with the NetBird team first](#discuss-changes-with-the-netbird-team-first)
|
||||
- [Directory structure](#directory-structure)
|
||||
- [Development setup](#development-setup)
|
||||
- [Requirements](#requirements)
|
||||
@@ -34,14 +33,6 @@ Conduct which can be found in the file [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md).
|
||||
By participating, you are expected to uphold this code. Please report
|
||||
unacceptable behavior to community@netbird.io.
|
||||
|
||||
## Discuss changes with the NetBird team first
|
||||
|
||||
Changes to the **public API**, **gRPC protocols**, **functionality behavior**, **CLI / service flags**, or **new features** should be discussed with the NetBird team before you start the work. These surfaces are part of NetBird's contract with operators, self-hosters, and downstream integrators, and changes to them have compatibility, security, and release-planning implications that benefit from an early conversation.
|
||||
|
||||
Open an issue or reach out on [Slack](https://docs.netbird.io/slack-url) to talk through what you have in mind. We'll help shape the change, flag any constraints we know about, and confirm the direction so the PR review can focus on implementation rather than design.
|
||||
|
||||
Typical bug fixes, internal refactors, documentation updates, and tests do not need pre-discussion — open the PR directly.
|
||||
|
||||
## Directory structure
|
||||
|
||||
The NetBird project monorepo is organized to maintain most of its individual dependencies code within their directories, except for a few auxiliary or shared packages.
|
||||
|
||||
153
README.md
153
README.md
@@ -1,134 +1,147 @@
|
||||
|
||||
<div align="center">
|
||||
<p align="center">
|
||||
<img width="234" src="docs/media/logo-full.png" alt="NetBird logo"/>
|
||||
</p>
|
||||
<p align="center">
|
||||
<a href="https://sonarcloud.io/dashboard?id=netbirdio_netbird">
|
||||
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" alt="SonarCloud alert status"/>
|
||||
</a>
|
||||
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" alt="BSD-3 License"/>
|
||||
</a>
|
||||
<br/>
|
||||
<br/>
|
||||
<p align="center">
|
||||
<img width="234" src="docs/media/logo-full.png"/>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://img.shields.io/badge/license-BSD--3-blue)">
|
||||
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" />
|
||||
</a>
|
||||
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://docs.netbird.io/slack-url">
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack" alt="NetBird Slack"/>
|
||||
</a>
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||
</a>
|
||||
<a href="https://forum.netbird.io">
|
||||
<img src="https://img.shields.io/badge/community%20forum-@netbird-red.svg?logo=discourse" alt="Community forum"/>
|
||||
</a>
|
||||
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://gurubase.io/g/netbird">
|
||||
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF" alt="Gurubase: Ask NetBird Guru"/>
|
||||
</a>
|
||||
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
<p align="center">
|
||||
<strong>
|
||||
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
|
||||
</strong>
|
||||
<strong>
|
||||
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
<strong>
|
||||
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
|
||||
</strong>
|
||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
|
||||
<br/>
|
||||
|
||||
</strong>
|
||||
<br>
|
||||
<strong>
|
||||
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
|
||||
</strong>
|
||||
<br>
|
||||
<br>
|
||||
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
|
||||
New: NetBird terraform provider
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<br>
|
||||
|
||||
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
|
||||
|
||||
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
||||
|
||||
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
||||
|
||||
### Open Source Network Security in a Single Platform
|
||||
|
||||
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||
|
||||
### Self-host NetBird (video)
|
||||
|
||||
### Self-Host NetBird (Video)
|
||||
[](https://youtu.be/bZAgpT6nzaQ)
|
||||
|
||||
### Key features
|
||||
|
||||
| Connectivity | Management | Security | Automation | Platforms |
|
||||
|---|---|---|---|---|
|
||||
| ✓ [Kernel WireGuard](https://docs.netbird.io/about-netbird/why-wireguard-with-netbird) | ✓ [Admin Web UI](https://github.com/netbirdio/dashboard) | ✓ [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) | ✓ [Public API](https://docs.netbird.io/api) | ✓ [Linux](https://docs.netbird.io/get-started/install/linux) |
|
||||
| ✓ [Peer-to-peer connections](https://docs.netbird.io/about-netbird/how-netbird-works) | ✓ Auto peer discovery and configuration | ✓ [Access control: groups & rules](https://docs.netbird.io/how-to/manage-network-access) | ✓ [Setup keys for bulk provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) | ✓ [macOS](https://docs.netbird.io/get-started/install/macos) |
|
||||
| ✓ Connection relay fallback | ✓ [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) | ✓ [Activity logging](https://docs.netbird.io/how-to/audit-events-logging) | ✓ [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) | ✓ [Windows](https://docs.netbird.io/get-started/install/windows) |
|
||||
| ✓ [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) | ✓ [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) | ✓ [Traffic events](https://docs.netbird.io/manage/activity/traffic-events-logging) | ✓ [IdP groups sync with JWT](https://docs.netbird.io/manage/team/idp-sync) | ✓ [Android](https://docs.netbird.io/get-started/install/android) |
|
||||
| ✓ [Domain-based DNS routes](https://docs.netbird.io/manage/dns/dns-aliases-for-routed-networks) | ✓ [Custom DNS zones](https://docs.netbird.io/manage/dns/custom-zones) | ✓ [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) | ✓ [Terraform provider](https://registry.terraform.io/providers/netbirdio/netbird/latest) | ✓ [Android TV](https://docs.netbird.io/get-started/install/android-tv) |
|
||||
| ✓ [Exit nodes](https://docs.netbird.io/manage/network-routes/use-cases/exit-nodes) | ✓ [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) | ✓ Peer-to-peer encryption | ✓ [Ansible collection](https://github.com/netbirdio/ansible-netbird) | ✓ [iOS](https://docs.netbird.io/get-started/install/ios) |
|
||||
| ✓ [IPv6 dual-stack overlay](https://docs.netbird.io/manage/settings/ipv6) | ✓ [Multi-account profile switching](https://docs.netbird.io/client/profiles) | ✓ [SSH with central access policies](https://docs.netbird.io/manage/peers/ssh) | | ✓ [Apple TV](https://docs.netbird.io/get-started/install/tvos) |
|
||||
| ✓ [Browser SSH & RDP](https://docs.netbird.io/manage/peers/browser-client) | | ✓ [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) | | ✓ FreeBSD |
|
||||
| ✓ [Reverse proxy with auto-TLS](https://docs.netbird.io/manage/reverse-proxy) | | ✓ [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication) | | ✓ [pfSense](https://docs.netbird.io/get-started/install/pfsense) |
|
||||
| | | | | ✓ [OPNsense](https://docs.netbird.io/get-started/install/opnsense) |
|
||||
| | | | | ✓ [MikroTik RouterOS](https://docs.netbird.io/use-cases/homelab/client-on-mikrotik-router) |
|
||||
| | | | | ✓ OpenWRT |
|
||||
| | | | | ✓ [Synology](https://docs.netbird.io/get-started/install/synology) |
|
||||
| | | | | ✓ [TrueNAS](https://docs.netbird.io/get-started/install/truenas) |
|
||||
| | | | | ✓ [Proxmox](https://docs.netbird.io/get-started/install/proxmox-ve) |
|
||||
| | | | | ✓ [Raspberry Pi](https://docs.netbird.io/get-started/install/raspberrypi) |
|
||||
| | | | | ✓ [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) |
|
||||
| | | | | ✓ [Container](https://docs.netbird.io/get-started/install/docker) |
|
||||
| Connectivity | Management | Security | Automation| Platforms |
|
||||
|----|----|----|----|----|
|
||||
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
|
||||
| <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
|
||||
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
|
||||
| <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
|
||||
| <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
|
||||
||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
|
||||
||| <ul><li>- \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ui></li> || <ul><li>- \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)</ui></li> |
|
||||
||||| <ul><li>- \[x] Docker</ui></li> |
|
||||
|
||||
### Quickstart with NetBird Cloud
|
||||
|
||||
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install).
|
||||
- Follow the steps to sign up with Google, Microsoft, GitHub or your email address.
|
||||
- Check the NetBird [admin UI](https://app.netbird.io/).
|
||||
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
|
||||
- Follow the steps to sign-up with Google, Microsoft, GitHub or your email address.
|
||||
- Check NetBird [admin UI](https://app.netbird.io/).
|
||||
- Add more machines.
|
||||
|
||||
### Quickstart with self-hosted NetBird
|
||||
|
||||
This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM. Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IdPs.
|
||||
> This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM.
|
||||
Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IDPs.
|
||||
|
||||
**Infrastructure requirements:**
|
||||
- A Linux VM with at least **1 CPU** and **2 GB** of memory.
|
||||
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port **3478**.
|
||||
- A **public domain** name pointing to the VM.
|
||||
- A Linux VM with at least **1CPU** and **2GB** of memory.
|
||||
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port: **3478**.
|
||||
- **Public domain** name pointing to the VM.
|
||||
|
||||
**Software requirements:**
|
||||
- Docker with the Compose plugin (Compose v2 or higher). See the [Docker installation guide](https://docs.docker.com/engine/install/).
|
||||
- Docker installed on the VM with the docker-compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
|
||||
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
|
||||
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
|
||||
- [curl](https://curl.se/) installed.
|
||||
Usually available in the official repositories and can be installed with `sudo apt install curl` or `sudo yum install curl`
|
||||
|
||||
**Steps**
|
||||
- Download and run the installation script:
|
||||
```bash
|
||||
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started.sh | bash
|
||||
```
|
||||
- Once finished, you can manage the resources via `docker-compose`
|
||||
|
||||
### A bit on NetBird internals
|
||||
- Every machine in the network runs the [NetBird agent](client/), which manages WireGuard.
|
||||
- Every agent connects to the [Management Service](management/), which holds network state, manages peer IPs, and distributes updates to agents.
|
||||
- Agents use ICE (via [pion/ice](https://github.com/pion/ice)) to discover connection candidates for peer-to-peer connections.
|
||||
- Candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
|
||||
- Agents negotiate a connection through the [Signal Service](signal/), exchanging end-to-end encrypted messages with candidates.
|
||||
- When NAT traversal fails (e.g. mobile carrier-grade NAT) and a direct p2p connection isn't possible, the system falls back to a [Relay Service](relay/) and a secure WireGuard tunnel is established through it.
|
||||
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
|
||||
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
|
||||
- NetBird agent uses WebRTC ICE implemented in [pion/ice library](https://github.com/pion/ice) to discover connection candidates when establishing a peer-to-peer connection between machines.
|
||||
- Connection candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
|
||||
- Agents negotiate a connection through [Signal Service](signal/) passing p2p encrypted messages with candidates.
|
||||
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and a p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
|
||||
|
||||
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
|
||||
|
||||
<p float="left" align="middle">
|
||||
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700" alt="NetBird high-level architecture diagram"/>
|
||||
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700"/>
|
||||
</p>
|
||||
|
||||
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
||||
|
||||
### Community projects
|
||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) - terminal UI for managing NetBird peers, routes, and settings
|
||||
- [caddy-netbird](https://github.com/lixmal/caddy-netbird) - Caddy plugin that embeds a NetBird client for proxying HTTP and TCP/UDP traffic through NetBird networks
|
||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
|
||||
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) — terminal UI for managing NetBird peers, routes, and settings
|
||||
|
||||
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
||||
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
||||
|
||||
### Support acknowledgement
|
||||
|
||||
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by the Federal Ministry of Education and Research of the Federal Republic of Germany. Together with the [CISPA Helmholtz Center for Information Security](https://cispa.de/en), NetBird brings security best practices and simplicity to private networking.
|
||||
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by The Federal Ministry of Education and Research of The Federal Republic of Germany. Together with [CISPA Helmholtz Center for Information Security](https://cispa.de/en) NetBird brings the security best practices and simplicity to private networking.
|
||||
|
||||

|
||||
|
||||
### Acknowledgements
|
||||
We build on open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE](https://github.com/pion/ice), and [Rosenpass](https://rosenpass.eu). We greatly appreciate the work these projects are doing, and we'd love it if you could support them too (e.g., by starring or contributing).
|
||||
### Testimonials
|
||||
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g., by giving a star or a contribution).
|
||||
|
||||
### Legal
|
||||
This repository is licensed under the BSD-3-Clause license, which applies to all parts of the repository except for the directories management/, signal/ and relay/.
|
||||
This repository is licensed under BSD-3-Clause license that applies to all parts of the repository except for the directories management/, signal/ and relay/.
|
||||
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
|
||||
|
||||
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
|
||||
|
||||
@@ -301,11 +301,10 @@ func (c *Client) PeersList() *PeerInfoArray {
|
||||
peerInfos := make([]PeerInfo, len(fullStatus.Peers))
|
||||
for n, p := range fullStatus.Peers {
|
||||
pi := PeerInfo{
|
||||
IP: p.IP,
|
||||
IPv6: p.IPv6,
|
||||
FQDN: p.FQDN,
|
||||
ConnStatus: int(p.ConnStatus),
|
||||
Routes: PeerRoutes{routes: maps.Keys(p.GetRoutes())},
|
||||
p.IP,
|
||||
p.FQDN,
|
||||
int(p.ConnStatus),
|
||||
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
|
||||
}
|
||||
peerInfos[n] = pi
|
||||
}
|
||||
@@ -337,84 +336,43 @@ func (c *Client) Networks() *NetworkArray {
|
||||
return nil
|
||||
}
|
||||
|
||||
routesMap := routeManager.GetClientRoutesWithNetID()
|
||||
v6Merged := route.V6ExitMergeSet(routesMap)
|
||||
resolvedDomains := c.recorder.GetResolvedDomainsStates()
|
||||
|
||||
networkArray := &NetworkArray{
|
||||
items: make([]Network, 0),
|
||||
}
|
||||
|
||||
for id, routes := range routesMap {
|
||||
resolvedDomains := c.recorder.GetResolvedDomainsStates()
|
||||
|
||||
for id, routes := range routeManager.GetClientRoutesWithNetID() {
|
||||
if len(routes) == 0 {
|
||||
continue
|
||||
}
|
||||
if _, skip := v6Merged[id]; skip {
|
||||
continue
|
||||
|
||||
r := routes[0]
|
||||
domains := c.getNetworkDomainsFromRoute(r, resolvedDomains)
|
||||
netStr := r.Network.String()
|
||||
|
||||
if r.IsDynamic() {
|
||||
netStr = r.Domains.SafeString()
|
||||
}
|
||||
|
||||
network := c.buildNetwork(id, routes, routeSelector.IsSelected(id), resolvedDomains, v6Merged)
|
||||
if network == nil {
|
||||
routePeer, err := c.recorder.GetPeer(routes[0].Peer)
|
||||
if err != nil {
|
||||
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
|
||||
continue
|
||||
}
|
||||
networkArray.Add(*network)
|
||||
network := Network{
|
||||
Name: string(id),
|
||||
Network: netStr,
|
||||
Peer: routePeer.FQDN,
|
||||
Status: routePeer.ConnStatus.String(),
|
||||
IsSelected: routeSelector.IsSelected(id),
|
||||
Domains: domains,
|
||||
}
|
||||
networkArray.Add(network)
|
||||
}
|
||||
return networkArray
|
||||
}
|
||||
|
||||
func (c *Client) buildNetwork(id route.NetID, routes []*route.Route, selected bool, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo, v6Merged map[route.NetID]struct{}) *Network {
|
||||
r := routes[0]
|
||||
netStr := r.Network.String()
|
||||
if r.IsDynamic() {
|
||||
netStr = r.Domains.SafeString()
|
||||
}
|
||||
|
||||
routePeer, err := c.findBestRoutePeer(routes)
|
||||
if err != nil {
|
||||
log.Errorf("could not get peer info for route %s: %v", id, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
network := &Network{
|
||||
Name: string(id),
|
||||
Network: netStr,
|
||||
Peer: routePeer.FQDN,
|
||||
Status: routePeer.ConnStatus.String(),
|
||||
IsSelected: selected,
|
||||
Domains: c.getNetworkDomainsFromRoute(r, resolvedDomains),
|
||||
}
|
||||
|
||||
if route.IsV4DefaultRoute(r.Network) && route.HasV6ExitPair(id, v6Merged) {
|
||||
network.Network = "0.0.0.0/0, ::/0"
|
||||
}
|
||||
|
||||
return network
|
||||
}
|
||||
|
||||
// findBestRoutePeer returns the peer actively routing traffic for the given
|
||||
// HA route group. Falls back to the first connected peer, then the first peer.
|
||||
func (c *Client) findBestRoutePeer(routes []*route.Route) (peer.State, error) {
|
||||
netStr := routes[0].Network.String()
|
||||
|
||||
fullStatus := c.recorder.GetFullStatus()
|
||||
for _, p := range fullStatus.Peers {
|
||||
if _, ok := p.GetRoutes()[netStr]; ok {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
|
||||
for _, r := range routes {
|
||||
p, err := c.recorder.GetPeer(r.Peer)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if p.ConnStatus == peer.StatusConnected {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
return c.recorder.GetPeer(routes[0].Peer)
|
||||
}
|
||||
|
||||
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
||||
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
||||
dnsServer, err := dns.GetServerDns()
|
||||
|
||||
@@ -14,7 +14,6 @@ const (
|
||||
// PeerInfo describe information about the peers. It designed for the UI usage
|
||||
type PeerInfo struct {
|
||||
IP string
|
||||
IPv6 string
|
||||
FQDN string
|
||||
ConnStatus int
|
||||
Routes PeerRoutes
|
||||
|
||||
@@ -307,24 +307,6 @@ func (p *Preferences) SetBlockInbound(block bool) {
|
||||
p.configInput.BlockInbound = &block
|
||||
}
|
||||
|
||||
// GetDisableIPv6 reads disable IPv6 setting from config file
|
||||
func (p *Preferences) GetDisableIPv6() (bool, error) {
|
||||
if p.configInput.DisableIPv6 != nil {
|
||||
return *p.configInput.DisableIPv6, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.DisableIPv6, err
|
||||
}
|
||||
|
||||
// SetDisableIPv6 stores the given value and waits for commit
|
||||
func (p *Preferences) SetDisableIPv6(disable bool) {
|
||||
p.configInput.DisableIPv6 = &disable
|
||||
}
|
||||
|
||||
// Commit writes out the changes to the config file
|
||||
func (p *Preferences) Commit() error {
|
||||
_, err := profilemanager.UpdateOrCreateConfig(p.configInput)
|
||||
|
||||
@@ -18,12 +18,9 @@ func executeRouteToggle(id string, manager routemanager.Manager,
|
||||
netID := route.NetID(id)
|
||||
routes := []route.NetID{netID}
|
||||
|
||||
routesMap := manager.GetClientRoutesWithNetID()
|
||||
routes = route.ExpandV6ExitPairs(routes, routesMap)
|
||||
log.Debugf("%s with id: %s", operationName, id)
|
||||
|
||||
log.Debugf("%s with ids: %v", operationName, routes)
|
||||
|
||||
if err := routeOperation(routes, maps.Keys(routesMap)); err != nil {
|
||||
if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil {
|
||||
log.Debugf("error when %s: %s", operationName, err)
|
||||
return fmt.Errorf("error %s: %w", operationName, err)
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"net/url"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -27,9 +26,8 @@ type Anonymizer struct {
|
||||
}
|
||||
|
||||
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
||||
// 198.51.100.0 (RFC 5737 TEST-NET-2), 2001:db8:ffff:: (RFC 3849 documentation, last /48)
|
||||
// The old start 100:: (discard, RFC 6666) is now used for fake IPs on Android.
|
||||
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.MustParseAddr("2001:db8:ffff::")
|
||||
// 198.51.100.0, 100::
|
||||
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
|
||||
}
|
||||
|
||||
func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
|
||||
@@ -50,7 +48,7 @@ func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
|
||||
ip.IsLinkLocalUnicast() ||
|
||||
ip.IsLinkLocalMulticast() ||
|
||||
ip.IsInterfaceLocalMulticast() ||
|
||||
(ip.Is4() && ip.IsPrivate()) ||
|
||||
ip.IsPrivate() ||
|
||||
ip.IsUnspecified() ||
|
||||
ip.IsMulticast() ||
|
||||
isWellKnown(ip) ||
|
||||
@@ -98,11 +96,6 @@ func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
|
||||
}
|
||||
|
||||
func (a *Anonymizer) AnonymizeIPString(ip string) string {
|
||||
// Handle CIDR notation (e.g. "2001:db8::/32")
|
||||
if prefix, err := netip.ParsePrefix(ip); err == nil {
|
||||
return a.AnonymizeIP(prefix.Addr()).String() + "/" + strconv.Itoa(prefix.Bits())
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(ip)
|
||||
if err != nil {
|
||||
return ip
|
||||
@@ -157,7 +150,7 @@ func (a *Anonymizer) AnonymizeURI(uri string) string {
|
||||
if u.Opaque != "" {
|
||||
host, port, err := net.SplitHostPort(u.Opaque)
|
||||
if err == nil {
|
||||
anonymizedHost = net.JoinHostPort(a.AnonymizeDomain(host), port)
|
||||
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
|
||||
} else {
|
||||
anonymizedHost = a.AnonymizeDomain(u.Opaque)
|
||||
}
|
||||
@@ -165,7 +158,7 @@ func (a *Anonymizer) AnonymizeURI(uri string) string {
|
||||
} else if u.Host != "" {
|
||||
host, port, err := net.SplitHostPort(u.Host)
|
||||
if err == nil {
|
||||
anonymizedHost = net.JoinHostPort(a.AnonymizeDomain(host), port)
|
||||
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
|
||||
} else {
|
||||
anonymizedHost = a.AnonymizeDomain(u.Host)
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
func TestAnonymizeIP(t *testing.T) {
|
||||
startIPv4 := netip.MustParseAddr("198.51.100.0")
|
||||
startIPv6 := netip.MustParseAddr("2001:db8:ffff::")
|
||||
startIPv6 := netip.MustParseAddr("100::")
|
||||
anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6)
|
||||
|
||||
tests := []struct {
|
||||
@@ -26,9 +26,9 @@ func TestAnonymizeIP(t *testing.T) {
|
||||
{"Second Public IPv4", "4.3.2.1", "198.51.100.1"},
|
||||
{"Repeated IPv4", "1.2.3.4", "198.51.100.0"},
|
||||
{"Private IPv4", "192.168.1.1", "192.168.1.1"},
|
||||
{"First Public IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"},
|
||||
{"Second Public IPv6", "a::b", "2001:db8:ffff::1"},
|
||||
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"},
|
||||
{"First Public IPv6", "2607:f8b0:4005:805::200e", "100::"},
|
||||
{"Second Public IPv6", "a::b", "100::1"},
|
||||
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "100::"},
|
||||
{"Private IPv6", "fe80::1", "fe80::1"},
|
||||
{"In Range IPv4", "198.51.100.2", "198.51.100.2"},
|
||||
}
|
||||
@@ -274,27 +274,17 @@ func TestAnonymizeString_IPAddresses(t *testing.T) {
|
||||
{
|
||||
name: "IPv6 Address",
|
||||
input: "Access attempted from 2001:db8::ff00:42",
|
||||
expect: "Access attempted from 2001:db8:ffff::",
|
||||
expect: "Access attempted from 100::",
|
||||
},
|
||||
{
|
||||
name: "IPv6 Address with Port",
|
||||
input: "Access attempted from [2001:db8::ff00:42]:8080",
|
||||
expect: "Access attempted from [2001:db8:ffff::]:8080",
|
||||
expect: "Access attempted from [100::]:8080",
|
||||
},
|
||||
{
|
||||
name: "Both IPv4 and IPv6",
|
||||
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
|
||||
expect: "IPv4: 198.51.100.1 and IPv6: 2001:db8:ffff::1",
|
||||
},
|
||||
{
|
||||
name: "STUN URI with IPv6",
|
||||
input: "Connecting to stun:[2001:db8::ff00:42]:3478",
|
||||
expect: "Connecting to stun:[2001:db8:ffff::]:3478",
|
||||
},
|
||||
{
|
||||
name: "HTTPS URI with IPv6",
|
||||
input: "Visit https://[2001:db8::ff00:42]:443/path",
|
||||
expect: "Visit https://[2001:db8:ffff::]:443/path",
|
||||
expect: "IPv4: 198.51.100.1 and IPv6: 100::1",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -143,7 +143,7 @@ func init() {
|
||||
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets WireGuard PreSharedKey property. If set, then only peers that have the same key can communicate.")
|
||||
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
|
||||
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
|
||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", profilemanager.DefaultConfigPath, "Overrides the default profile file location")
|
||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Overrides the default profile file location")
|
||||
|
||||
rootCmd.AddCommand(upCmd)
|
||||
rootCmd.AddCommand(downCmd)
|
||||
|
||||
@@ -523,7 +523,7 @@ func parseHostnameAndCommand(args []string) error {
|
||||
}
|
||||
|
||||
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
|
||||
target := net.JoinHostPort(strings.Trim(addr, "[]"), strconv.Itoa(port))
|
||||
target := fmt.Sprintf("%s:%d", addr, port)
|
||||
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
|
||||
KnownHostsFile: knownHostsFile,
|
||||
IdentityFile: identityFile,
|
||||
@@ -787,10 +787,10 @@ func isUnixSocket(path string) bool {
|
||||
return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./")
|
||||
}
|
||||
|
||||
// normalizeLocalHost converts "*" to "" for binding to all interfaces (dual-stack).
|
||||
// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces.
|
||||
func normalizeLocalHost(host string) string {
|
||||
if host == "*" {
|
||||
return ""
|
||||
return "0.0.0.0"
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
@@ -527,10 +527,10 @@ func TestParsePortForward(t *testing.T) {
|
||||
{
|
||||
name: "wildcard bind all interfaces",
|
||||
spec: "*:8080:localhost:80",
|
||||
expectedLocal: ":8080",
|
||||
expectedLocal: "0.0.0.0:8080",
|
||||
expectedRemote: "localhost:80",
|
||||
expectError: false,
|
||||
description: "Wildcard * should bind to all interfaces (dual-stack)",
|
||||
description: "Wildcard * should bind to all interfaces (0.0.0.0)",
|
||||
},
|
||||
{
|
||||
name: "wildcard for port only",
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
var (
|
||||
detailFlag bool
|
||||
ipv4Flag bool
|
||||
ipv6Flag bool
|
||||
jsonFlag bool
|
||||
yamlFlag bool
|
||||
ipsFilter []string
|
||||
@@ -43,16 +42,15 @@ func init() {
|
||||
ipsFilterMap = make(map[string]struct{})
|
||||
prefixNamesFilterMap = make(map[string]struct{})
|
||||
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information in human-readable format")
|
||||
statusCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display detailed status information in json format")
|
||||
statusCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display detailed status information in yaml format")
|
||||
statusCmd.PersistentFlags().BoolVarP(&ipv4Flag, "ipv4", "4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
|
||||
statusCmd.PersistentFlags().BoolVarP(&ipv6Flag, "ipv6", "6", false, "display only NetBird IPv6 of this peer")
|
||||
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4", "ipv6")
|
||||
statusCmd.PersistentFlags().StringSliceVarP(&ipsFilter, "filter-by-ips", "I", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
|
||||
statusCmd.PersistentFlags().StringSliceVarP(&prefixNamesFilter, "filter-by-names", "N", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||
statusCmd.PersistentFlags().StringVarP(&statusFilter, "filter-by-status", "S", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||
statusCmd.PersistentFlags().StringVarP(&connectionTypeFilter, "filter-by-connection-type", "T", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
||||
statusCmd.PersistentFlags().StringVarP(&checkFlag, "check", "C", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
|
||||
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
|
||||
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
|
||||
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
|
||||
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
||||
statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
|
||||
}
|
||||
|
||||
func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -103,14 +101,6 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ipv6Flag {
|
||||
ipv6 := resp.GetFullStatus().GetLocalPeerState().GetIpv6()
|
||||
if ipv6 != "" {
|
||||
cmd.Print(parseInterfaceIP(ipv6))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
|
||||
@@ -8,7 +8,6 @@ const (
|
||||
disableFirewallFlag = "disable-firewall"
|
||||
blockLANAccessFlag = "block-lan-access"
|
||||
blockInboundFlag = "block-inbound"
|
||||
disableIPv6Flag = "disable-ipv6"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -18,7 +17,6 @@ var (
|
||||
disableFirewall bool
|
||||
blockLANAccess bool
|
||||
blockInbound bool
|
||||
disableIPv6 bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -41,7 +39,4 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
|
||||
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
|
||||
"This overrides any policies received from the management service.")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&disableIPv6, disableIPv6Flag, false,
|
||||
"Disable IPv6 overlay. If enabled, the client won't request or use an IPv6 overlay address.")
|
||||
}
|
||||
|
||||
@@ -435,10 +435,6 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
||||
req.BlockInbound = &blockInbound
|
||||
}
|
||||
|
||||
if cmd.Flag(disableIPv6Flag).Changed {
|
||||
req.DisableIpv6 = &disableIPv6
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
req.LazyConnectionEnabled = &lazyConnEnabled
|
||||
}
|
||||
@@ -556,10 +552,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
ic.BlockInbound = &blockInbound
|
||||
}
|
||||
|
||||
if cmd.Flag(disableIPv6Flag).Changed {
|
||||
ic.DisableIPv6 = &disableIPv6
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
ic.LazyConnectionEnabled = &lazyConnEnabled
|
||||
}
|
||||
@@ -674,10 +666,6 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
loginRequest.BlockInbound = &blockInbound
|
||||
}
|
||||
|
||||
if cmd.Flag(disableIPv6Flag).Changed {
|
||||
loginRequest.DisableIpv6 = &disableIPv6
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
|
||||
}
|
||||
|
||||
@@ -80,8 +80,6 @@ type Options struct {
|
||||
StatePath string
|
||||
// DisableClientRoutes disables the client routes
|
||||
DisableClientRoutes bool
|
||||
// DisableIPv6 disables IPv6 overlay addressing
|
||||
DisableIPv6 bool
|
||||
// BlockInbound blocks all inbound connections from peers
|
||||
BlockInbound bool
|
||||
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
|
||||
@@ -173,7 +171,6 @@ func New(opts Options) (*Client, error) {
|
||||
PreSharedKey: &opts.PreSharedKey,
|
||||
DisableServerRoutes: &t,
|
||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||
DisableIPv6: &opts.DisableIPv6,
|
||||
BlockInbound: &opts.BlockInbound,
|
||||
WireguardPort: opts.WireguardPort,
|
||||
MTU: opts.MTU,
|
||||
@@ -336,7 +333,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("split host port: %w", err)
|
||||
}
|
||||
listenAddr := net.JoinHostPort(addr.String(), port)
|
||||
listenAddr := fmt.Sprintf("%s:%s", addr, port)
|
||||
|
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
|
||||
if err != nil {
|
||||
@@ -357,7 +354,7 @@ func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("split host port: %w", err)
|
||||
}
|
||||
listenAddr := net.JoinHostPort(addr.String(), port)
|
||||
listenAddr := fmt.Sprintf("%s:%s", addr, port)
|
||||
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
|
||||
if err != nil {
|
||||
|
||||
@@ -40,7 +40,6 @@ type aclManager struct {
|
||||
entries aclEntries
|
||||
optionalEntries map[string][]entry
|
||||
ipsetStore *ipsetStore
|
||||
v6 bool
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
}
|
||||
@@ -52,7 +51,6 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*acl
|
||||
entries: make(map[string][][]string),
|
||||
optionalEntries: make(map[string][]entry),
|
||||
ipsetStore: newIpsetStore(),
|
||||
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -87,11 +85,7 @@ func (m *aclManager) AddPeerFiltering(
|
||||
chain := chainNameInputRules
|
||||
|
||||
ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
|
||||
if m.v6 && ipsetName != "" {
|
||||
ipsetName += "-v6"
|
||||
}
|
||||
proto := protoForFamily(protocol, m.v6)
|
||||
specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName)
|
||||
specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName)
|
||||
|
||||
mangleSpecs := slices.Clone(specs)
|
||||
mangleSpecs = append(mangleSpecs,
|
||||
@@ -115,7 +109,6 @@ func (m *aclManager) AddPeerFiltering(
|
||||
ip: ip.String(),
|
||||
chain: chain,
|
||||
specs: specs,
|
||||
v6: m.v6,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
@@ -168,7 +161,6 @@ func (m *aclManager) AddPeerFiltering(
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
chain: chain,
|
||||
v6: m.v6,
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
@@ -421,13 +413,8 @@ func (m *aclManager) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
if m.v6 {
|
||||
currentState.ACLEntries6 = m.entries
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore
|
||||
} else {
|
||||
currentState.ACLEntries = m.entries
|
||||
currentState.ACLIPsetStore = m.ipsetStore
|
||||
}
|
||||
currentState.ACLEntries = m.entries
|
||||
currentState.ACLIPsetStore = m.ipsetStore
|
||||
|
||||
if err := m.stateManager.UpdateState(currentState); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
@@ -435,22 +422,13 @@ func (m *aclManager) updateState() {
|
||||
}
|
||||
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
|
||||
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
|
||||
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
|
||||
if v6 && protocol == firewall.ProtocolICMP {
|
||||
return "ipv6-icmp"
|
||||
}
|
||||
return string(protocol)
|
||||
}
|
||||
|
||||
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
||||
// don't use IP matching if IP is 0.0.0.0
|
||||
matchByIP := !ip.IsUnspecified()
|
||||
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--match-set", ipsetName, "src")
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
||||
} else {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
}
|
||||
@@ -496,9 +474,6 @@ func (m *aclManager) createIPSet(name string) error {
|
||||
opts := ipset.CreateOptions{
|
||||
Replace: true,
|
||||
}
|
||||
if m.v6 {
|
||||
opts.Family = ipset.FamilyIPV6
|
||||
}
|
||||
|
||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||
|
||||
@@ -18,10 +18,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
type resetter interface {
|
||||
Reset() error
|
||||
}
|
||||
|
||||
// Manager of iptables firewall
|
||||
type Manager struct {
|
||||
mutex sync.Mutex
|
||||
@@ -32,11 +28,6 @@ type Manager struct {
|
||||
aclMgr *aclManager
|
||||
router *router
|
||||
rawSupported bool
|
||||
|
||||
// IPv6 counterparts, nil when no v6 overlay
|
||||
ipv6Client *iptables.IPTables
|
||||
aclMgr6 *aclManager
|
||||
router6 *router
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
@@ -67,43 +58,9 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||
}
|
||||
|
||||
if wgIface.Address().HasIPv6() {
|
||||
if err := m.createIPv6Components(wgIface, mtu); err != nil {
|
||||
return nil, fmt.Errorf("create IPv6 firewall: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error {
|
||||
ip6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init ip6tables: %w", err)
|
||||
}
|
||||
m.ipv6Client = ip6Client
|
||||
|
||||
m.router6, err = newRouter(ip6Client, wgIface, mtu)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 router: %w", err)
|
||||
}
|
||||
|
||||
// Share the same IP forwarding state with the v4 router, since
|
||||
// EnableIPForwarding controls both v4 and v6 sysctls.
|
||||
m.router6.ipFwdState = m.router.ipFwdState
|
||||
|
||||
m.aclMgr6, err = newAclManager(ip6Client, wgIface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 acl manager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) hasIPv6() bool {
|
||||
return m.ipv6Client != nil
|
||||
}
|
||||
|
||||
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
state := &ShutdownState{
|
||||
InterfaceState: &InterfaceState{
|
||||
@@ -117,8 +74,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
|
||||
if err := m.initChains(stateManager); err != nil {
|
||||
return err
|
||||
if err := m.router.init(stateManager); err != nil {
|
||||
return fmt.Errorf("router init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.aclMgr.init(stateManager); err != nil {
|
||||
// TODO: cleanup router
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.initNoTrackChain(); err != nil {
|
||||
@@ -141,41 +103,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// initChains initializes router and ACL chains for both address families,
|
||||
// rolling back on failure.
|
||||
func (m *Manager) initChains(stateManager *statemanager.Manager) error {
|
||||
type initStep struct {
|
||||
name string
|
||||
init func(*statemanager.Manager) error
|
||||
mgr resetter
|
||||
}
|
||||
|
||||
steps := []initStep{
|
||||
{"router", m.router.init, m.router},
|
||||
{"acl manager", m.aclMgr.init, m.aclMgr},
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
steps = append(steps,
|
||||
initStep{"v6 router", m.router6.init, m.router6},
|
||||
initStep{"v6 acl manager", m.aclMgr6.init, m.aclMgr6},
|
||||
)
|
||||
}
|
||||
|
||||
var initialized []initStep
|
||||
for _, s := range steps {
|
||||
if err := s.init(stateManager); err != nil {
|
||||
for i := len(initialized) - 1; i >= 0; i-- {
|
||||
if rerr := initialized[i].mgr.Reset(); rerr != nil {
|
||||
log.Warnf("rollback %s: %v", initialized[i].name, rerr)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("%s init: %w", s.name, err)
|
||||
}
|
||||
initialized = append(initialized, s)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPeerFiltering adds a rule to the firewall
|
||||
//
|
||||
// Comment will be ignored because some system this feature is not supported
|
||||
@@ -191,13 +118,7 @@ func (m *Manager) AddPeerFiltering(
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if ip.To4() != nil {
|
||||
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
@@ -211,48 +132,25 @@ func (m *Manager) AddRouteFiltering(
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if isIPv6RouteRule(sources, destination) {
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||
if destination.IsPrefix() {
|
||||
return destination.Prefix.Addr().Is6()
|
||||
}
|
||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && isIPv6IptRule(rule) {
|
||||
return m.aclMgr6.DeletePeerRule(rule)
|
||||
}
|
||||
return m.aclMgr.DeletePeerRule(rule)
|
||||
}
|
||||
|
||||
func isIPv6IptRule(rule firewall.Rule) bool {
|
||||
r, ok := rule.(*Rule)
|
||||
return ok && r.v6
|
||||
}
|
||||
|
||||
// DeleteRouteRule deletes a routing rule.
|
||||
// Route rules are keyed by content hash. Check v4 first, try v6 if not found.
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
|
||||
return m.router6.DeleteRouteRule(rule)
|
||||
}
|
||||
return m.router.DeleteRouteRule(rule)
|
||||
}
|
||||
|
||||
@@ -268,65 +166,18 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddNatRule(pair)
|
||||
}
|
||||
|
||||
if err := m.router.AddNatRule(pair); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Dynamic routes need NAT in both tables since resolved IPs can be
|
||||
// either v4 or v6. This covers both DomainSet (modern) and the legacy
|
||||
// wildcard 0.0.0.0/0 destination where the client resolves DNS.
|
||||
if m.hasIPv6() && pair.Dynamic {
|
||||
v6Pair := firewall.ToV6NatPair(pair)
|
||||
if err := m.router6.AddNatRule(v6Pair); err != nil {
|
||||
return fmt.Errorf("add v6 NAT rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return m.router.AddNatRule(pair)
|
||||
}
|
||||
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
|
||||
if !m.hasIPv6() {
|
||||
return nil
|
||||
}
|
||||
return m.router6.RemoveNatRule(pair)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := m.router.RemoveNatRule(pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
|
||||
}
|
||||
|
||||
if m.hasIPv6() && pair.Dynamic {
|
||||
v6Pair := firewall.ToV6NatPair(pair)
|
||||
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
return m.router.RemoveNatRule(pair)
|
||||
}
|
||||
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
|
||||
return err
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
return firewall.SetLegacyManagement(m.router6, isLegacy)
|
||||
}
|
||||
return nil
|
||||
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
@@ -340,15 +191,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
if err := m.aclMgr6.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err))
|
||||
}
|
||||
if err := m.router6.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.aclMgr.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
||||
}
|
||||
@@ -376,21 +218,24 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
// This is called when USPFilter wraps the native firewall, adding blanket accept
|
||||
// rules so that packet filtering is handled in userspace instead of by netfilter.
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
var merr *multierror.Error
|
||||
if _, err := m.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("allow netbird v4 interface traffic: %w", err))
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
if _, err := m.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("allow netbird v6 interface traffic: %w", err))
|
||||
}
|
||||
_, err := m.AddPeerFiltering(
|
||||
nil,
|
||||
net.IP{0, 0, 0, 0},
|
||||
firewall.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
firewall.ActionAccept,
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||
}
|
||||
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
@@ -420,12 +265,6 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if rule.TranslatedAddress.Is6() {
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddDNATRule(rule)
|
||||
}
|
||||
return m.router.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
@@ -434,9 +273,6 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) {
|
||||
return m.router6.DeleteDNATRule(rule)
|
||||
}
|
||||
return m.router.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
@@ -445,82 +281,39 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var v4Prefixes, v6Prefixes []netip.Prefix
|
||||
for _, p := range prefixes {
|
||||
if p.Addr().Is6() {
|
||||
v6Prefixes = append(v6Prefixes, p)
|
||||
} else {
|
||||
v4Prefixes = append(v4Prefixes, p)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
||||
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
|
||||
return fmt.Errorf("update v6 set: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return m.router.UpdateSet(set, prefixes)
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if localAddr.Is6() {
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if localAddr.Is6() {
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if localAddr.Is6() {
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if localAddr.Is6() {
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
@@ -54,10 +54,8 @@ const (
|
||||
snatSuffix = "_snat"
|
||||
fwdSuffix = "_fwd"
|
||||
|
||||
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
|
||||
ipv4TCPHeaderSize = 40
|
||||
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
|
||||
ipv6TCPHeaderSize = 60
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
)
|
||||
|
||||
type ruleInfo struct {
|
||||
@@ -88,7 +86,6 @@ type router struct {
|
||||
wgIface iFaceMapper
|
||||
legacyManagement bool
|
||||
mtu uint16
|
||||
v6 bool
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
@@ -100,7 +97,6 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1
|
||||
rules: make(map[string][]string),
|
||||
wgIface: wgIface,
|
||||
mtu: mtu,
|
||||
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||
}
|
||||
|
||||
@@ -190,11 +186,6 @@ func (r *router) AddRouteFiltering(
|
||||
return ruleKey, nil
|
||||
}
|
||||
|
||||
func (r *router) hasRule(id string) bool {
|
||||
_, ok := r.rules[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
ruleKey := rule.ID()
|
||||
|
||||
@@ -401,13 +392,9 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
|
||||
// Remove jump rules from built-in chains before deleting custom chains,
|
||||
// otherwise the chain deletion fails with "device or resource busy".
|
||||
if ok, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput); err != nil {
|
||||
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
|
||||
} else if ok {
|
||||
jumpRule := []string{"-j", chainNATOutput}
|
||||
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil {
|
||||
log.Debugf("clean OUTPUT jump rule: %v", err)
|
||||
}
|
||||
jumpRule := []string{"-j", chainNATOutput}
|
||||
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil {
|
||||
log.Debugf("clean OUTPUT jump rule: %v", err)
|
||||
}
|
||||
|
||||
for _, chainInfo := range []struct {
|
||||
@@ -447,12 +434,6 @@ func (r *router) createContainers() error {
|
||||
{chainRTRDR, tableNat},
|
||||
{chainRTMSSCLAMP, tableMangle},
|
||||
} {
|
||||
// Fallback: clear chains that survived an unclean shutdown.
|
||||
if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok {
|
||||
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||
log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err)
|
||||
}
|
||||
}
|
||||
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||
}
|
||||
@@ -559,12 +540,9 @@ func (r *router) addPostroutingRules() error {
|
||||
}
|
||||
|
||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||
// TODO: Add IPv6 support
|
||||
func (r *router) addMSSClampingRules() error {
|
||||
overhead := uint16(ipv4TCPHeaderSize)
|
||||
if r.v6 {
|
||||
overhead = ipv6TCPHeaderSize
|
||||
}
|
||||
mss := r.mtu - overhead
|
||||
mss := r.mtu - ipTCPHeaderMinSize
|
||||
|
||||
// Add jump rule from FORWARD chain in mangle table to our custom chain
|
||||
jumpRule := []string{
|
||||
@@ -749,13 +727,8 @@ func (r *router) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
if r.v6 {
|
||||
currentState.RouteRules6 = r.rules
|
||||
currentState.RouteIPsetCounter6 = r.ipsetCounter
|
||||
} else {
|
||||
currentState.RouteRules = r.rules
|
||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||
}
|
||||
currentState.RouteRules = r.rules
|
||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||
|
||||
if err := r.stateManager.UpdateState(currentState); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
@@ -883,7 +856,7 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||
}
|
||||
|
||||
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
|
||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDOUT, fwdRule...); err != nil {
|
||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
|
||||
}
|
||||
delete(r.rules, ruleKey+fwdSuffix)
|
||||
@@ -910,7 +883,7 @@ func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []net
|
||||
rule = append(rule, destExp...)
|
||||
|
||||
if params.Proto != firewall.ProtocolALL {
|
||||
rule = append(rule, "-p", strings.ToLower(protoForFamily(params.Proto, r.v6)))
|
||||
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
|
||||
rule = append(rule, applyPort("--sport", params.SPort)...)
|
||||
rule = append(rule, applyPort("--dport", params.DPort)...)
|
||||
}
|
||||
@@ -927,12 +900,11 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []
|
||||
}
|
||||
|
||||
if network.IsSet() {
|
||||
name := r.ipsetName(network.Set.HashedName())
|
||||
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
|
||||
if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
|
||||
return []string{"-m", "set", matchSet, name, direction}, nil
|
||||
return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil
|
||||
}
|
||||
if network.IsPrefix() {
|
||||
return []string{flag, network.Prefix.String()}, nil
|
||||
@@ -943,23 +915,27 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []
|
||||
}
|
||||
|
||||
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
name := r.ipsetName(set.HashedName())
|
||||
var merr *multierror.Error
|
||||
for _, prefix := range prefixes {
|
||||
if err := r.addPrefixToIPSet(name, prefix); err != nil {
|
||||
// TODO: Implement IPv6 support
|
||||
if prefix.Addr().Is6() {
|
||||
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||
continue
|
||||
}
|
||||
if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
|
||||
}
|
||||
}
|
||||
if merr == nil {
|
||||
log.Debugf("updated set %s with prefixes %v", name, prefixes)
|
||||
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
|
||||
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
@@ -967,12 +943,12 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
||||
|
||||
dnatRule := []string{
|
||||
"-i", r.wgIface.Name(),
|
||||
"-p", strings.ToLower(protoForFamily(protocol, r.v6)),
|
||||
"--dport", strconv.Itoa(int(originalPort)),
|
||||
"-p", strings.ToLower(string(protocol)),
|
||||
"--dport", strconv.Itoa(int(sourcePort)),
|
||||
"-d", localAddr.String(),
|
||||
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||
"-j", "DNAT",
|
||||
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
|
||||
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
|
||||
}
|
||||
|
||||
ruleInfo := ruleInfo{
|
||||
@@ -991,8 +967,8 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
|
||||
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if dnatRule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
|
||||
@@ -1037,8 +1013,8 @@ func (r *router) ensureNATOutputChain() error {
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
|
||||
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
@@ -1049,11 +1025,11 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
|
||||
}
|
||||
|
||||
dnatRule := []string{
|
||||
"-p", strings.ToLower(protoForFamily(protocol, localAddr.Is6())),
|
||||
"--dport", strconv.Itoa(int(originalPort)),
|
||||
"-p", strings.ToLower(string(protocol)),
|
||||
"--dport", strconv.Itoa(int(sourcePort)),
|
||||
"-d", localAddr.String(),
|
||||
"-j", "DNAT",
|
||||
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
|
||||
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||
@@ -1066,8 +1042,8 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
|
||||
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if dnatRule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||
@@ -1100,22 +1076,10 @@ func applyPort(flag string, port *firewall.Port) []string {
|
||||
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
||||
}
|
||||
|
||||
// ipsetName returns the ipset name, suffixed with "-v6" for the v6 router
|
||||
// to avoid collisions since ipsets are global in the kernel.
|
||||
func (r *router) ipsetName(name string) string {
|
||||
if r.v6 {
|
||||
return name + "-v6"
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func (r *router) createIPSet(name string) error {
|
||||
opts := ipset.CreateOptions{
|
||||
Replace: true,
|
||||
}
|
||||
if r.v6 {
|
||||
opts.Family = ipset.FamilyIPV6
|
||||
}
|
||||
|
||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||
|
||||
@@ -9,7 +9,6 @@ type Rule struct {
|
||||
mangleSpecs []string
|
||||
ip string
|
||||
chain string
|
||||
v6 bool
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
|
||||
@@ -4,8 +4,6 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
@@ -34,12 +32,6 @@ type ShutdownState struct {
|
||||
|
||||
ACLEntries aclEntries `json:"acl_entries,omitempty"`
|
||||
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
|
||||
|
||||
// IPv6 counterparts
|
||||
RouteRules6 routeRules `json:"route_rules_v6,omitempty"`
|
||||
RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"`
|
||||
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
|
||||
ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
@@ -70,28 +62,6 @@ func (s *ShutdownState) Cleanup() error {
|
||||
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
|
||||
}
|
||||
|
||||
// Clean up v6 state even if the current run has no IPv6.
|
||||
// The previous run may have left ip6tables rules behind.
|
||||
if !ipt.hasIPv6() {
|
||||
if err := ipt.createIPv6Components(s.InterfaceState, mtu); err != nil {
|
||||
log.Warnf("failed to create v6 components for cleanup: %v", err)
|
||||
}
|
||||
}
|
||||
if ipt.hasIPv6() {
|
||||
if s.RouteRules6 != nil {
|
||||
ipt.router6.rules = s.RouteRules6
|
||||
}
|
||||
if s.RouteIPsetCounter6 != nil {
|
||||
ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
|
||||
}
|
||||
if s.ACLEntries6 != nil {
|
||||
ipt.aclMgr6.entries = s.ACLEntries6
|
||||
}
|
||||
if s.ACLIPsetStore6 != nil {
|
||||
ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6
|
||||
}
|
||||
}
|
||||
|
||||
if err := ipt.Close(nil); err != nil {
|
||||
return fmt.Errorf("reset iptables manager: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -12,10 +11,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// ErrIPv6NotInitialized is returned when an IPv6 address is passed to a firewall
|
||||
// method but the IPv6 firewall components were not initialized.
|
||||
var ErrIPv6NotInitialized = errors.New("IPv6 firewall not initialized")
|
||||
|
||||
const (
|
||||
ForwardingFormatPrefix = "netbird-fwd-"
|
||||
ForwardingFormat = "netbird-fwd-%s-%t"
|
||||
@@ -169,16 +164,18 @@ type Manager interface {
|
||||
UpdateSet(hash Set, prefixes []netip.Prefix) error
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
|
||||
AddInboundDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error
|
||||
AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error
|
||||
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
AddOutputDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error
|
||||
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
|
||||
AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error
|
||||
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
|
||||
RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
|
||||
// This prevents conntrack from interfering with WireGuard proxy communication.
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -12,10 +10,6 @@ type RouterPair struct {
|
||||
Destination Network
|
||||
Masquerade bool
|
||||
Inverse bool
|
||||
// Dynamic indicates the route is domain-based. NAT rules for dynamic
|
||||
// routes are duplicated to the v6 table so that resolved AAAA records
|
||||
// are masqueraded correctly.
|
||||
Dynamic bool
|
||||
}
|
||||
|
||||
func GetInversePair(pair RouterPair) RouterPair {
|
||||
@@ -26,17 +20,5 @@ func GetInversePair(pair RouterPair) RouterPair {
|
||||
Destination: pair.Source,
|
||||
Masquerade: pair.Masquerade,
|
||||
Inverse: true,
|
||||
Dynamic: pair.Dynamic,
|
||||
}
|
||||
}
|
||||
|
||||
// ToV6NatPair creates a v6 counterpart of a v4 NAT pair with `::/0` source
|
||||
// and, for prefix destinations, `::/0` destination.
|
||||
func ToV6NatPair(pair RouterPair) RouterPair {
|
||||
v6 := pair
|
||||
v6.Source = Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||
if v6.Destination.IsPrefix() {
|
||||
v6.Destination = Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||
}
|
||||
return v6
|
||||
}
|
||||
|
||||
@@ -33,12 +33,15 @@ const (
|
||||
|
||||
const flushError = "flush: %w"
|
||||
|
||||
var (
|
||||
anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
)
|
||||
|
||||
type AclManager struct {
|
||||
rConn *nftables.Conn
|
||||
sConn *nftables.Conn
|
||||
wgIface iFaceMapper
|
||||
routingFwChainName string
|
||||
af addrFamily
|
||||
|
||||
workTable *nftables.Table
|
||||
chainInputRules *nftables.Chain
|
||||
@@ -64,7 +67,6 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
|
||||
wgIface: wgIface,
|
||||
workTable: table,
|
||||
routingFwChainName: routingFwChainName,
|
||||
af: familyForAddr(table.Family == nftables.TableFamilyIPv4),
|
||||
|
||||
ipsetStore: newIpsetStore(),
|
||||
rules: make(map[string]*Rule),
|
||||
@@ -143,7 +145,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
}
|
||||
|
||||
if _, ok := ips[r.ip.String()]; ok {
|
||||
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}})
|
||||
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}})
|
||||
if err != nil {
|
||||
log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err)
|
||||
}
|
||||
@@ -252,11 +254,11 @@ func (m *AclManager) addIOFiltering(
|
||||
expressions = append(expressions, &expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: m.af.protoOffset,
|
||||
Offset: uint32(9),
|
||||
Len: uint32(1),
|
||||
})
|
||||
|
||||
protoData, err := m.af.protoNum(proto)
|
||||
protoData, err := protoToInt(proto)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %v", err)
|
||||
}
|
||||
@@ -268,16 +270,19 @@ func (m *AclManager) addIOFiltering(
|
||||
})
|
||||
}
|
||||
|
||||
rawIP := ipToBytes(ip, m.af)
|
||||
rawIP := ip.To4()
|
||||
// check if rawIP contains zeroed IPv4 0.0.0.0 value
|
||||
// in that case not add IP match expression into the rule definition
|
||||
if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) {
|
||||
if !bytes.HasPrefix(anyIP, rawIP) {
|
||||
// source address position
|
||||
addrOffset := uint32(12)
|
||||
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: m.af.srcAddrOffset,
|
||||
Len: m.af.addrLen,
|
||||
Offset: addrOffset,
|
||||
Len: 4,
|
||||
},
|
||||
)
|
||||
// add individual IP for match if no ipset defined
|
||||
@@ -582,7 +587,7 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr
|
||||
|
||||
func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) {
|
||||
ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName)
|
||||
rawIP := ipToBytes(ip, m.af)
|
||||
rawIP := ip.To4()
|
||||
if err != nil {
|
||||
if ipset, err = m.createSet(m.workTable, ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("get set name: %v", err)
|
||||
@@ -614,7 +619,7 @@ func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Se
|
||||
Name: name,
|
||||
Table: table,
|
||||
Dynamic: true,
|
||||
KeyType: m.af.setKeyType,
|
||||
KeyType: nftables.TypeIPAddr,
|
||||
}
|
||||
|
||||
if err := m.rConn.AddSet(ipset, nil); err != nil {
|
||||
@@ -702,12 +707,15 @@ func ifname(n string) []byte {
|
||||
return b
|
||||
}
|
||||
|
||||
|
||||
// ipToBytes converts net.IP to the correct byte length for the address family.
|
||||
func ipToBytes(ip net.IP, af addrFamily) []byte {
|
||||
if af.addrLen == 4 {
|
||||
return ip.To4()
|
||||
func protoToInt(protocol firewall.Protocol) (uint8, error) {
|
||||
switch protocol {
|
||||
case firewall.ProtocolTCP:
|
||||
return unix.IPPROTO_TCP, nil
|
||||
case firewall.ProtocolUDP:
|
||||
return unix.IPPROTO_UDP, nil
|
||||
case firewall.ProtocolICMP:
|
||||
return unix.IPPROTO_ICMP, nil
|
||||
}
|
||||
return ip.To16()
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
|
||||
}
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
var (
|
||||
// afIPv4 defines IPv4 header layout and nftables types.
|
||||
afIPv4 = addrFamily{
|
||||
protoOffset: 9,
|
||||
srcAddrOffset: 12,
|
||||
dstAddrOffset: 16,
|
||||
addrLen: net.IPv4len,
|
||||
totalBits: 8 * net.IPv4len,
|
||||
setKeyType: nftables.TypeIPAddr,
|
||||
tableFamily: nftables.TableFamilyIPv4,
|
||||
icmpProto: unix.IPPROTO_ICMP,
|
||||
}
|
||||
// afIPv6 defines IPv6 header layout and nftables types.
|
||||
afIPv6 = addrFamily{
|
||||
protoOffset: 6,
|
||||
srcAddrOffset: 8,
|
||||
dstAddrOffset: 24,
|
||||
addrLen: net.IPv6len,
|
||||
totalBits: 8 * net.IPv6len,
|
||||
setKeyType: nftables.TypeIP6Addr,
|
||||
tableFamily: nftables.TableFamilyIPv6,
|
||||
icmpProto: unix.IPPROTO_ICMPV6,
|
||||
}
|
||||
)
|
||||
|
||||
// addrFamily holds protocol-specific constants for nftables expression building.
|
||||
type addrFamily struct {
|
||||
// protoOffset is the IP header offset for the protocol/next-header field (9 for v4, 6 for v6)
|
||||
protoOffset uint32
|
||||
// srcAddrOffset is the IP header offset for the source address (12 for v4, 8 for v6)
|
||||
srcAddrOffset uint32
|
||||
// dstAddrOffset is the IP header offset for the destination address (16 for v4, 24 for v6)
|
||||
dstAddrOffset uint32
|
||||
// addrLen is the byte length of addresses (4 for v4, 16 for v6)
|
||||
addrLen uint32
|
||||
// totalBits is the address size in bits (32 for v4, 128 for v6)
|
||||
totalBits int
|
||||
// setKeyType is the nftables set data type for addresses
|
||||
setKeyType nftables.SetDatatype
|
||||
// tableFamily is the nftables table family
|
||||
tableFamily nftables.TableFamily
|
||||
// icmpProto is the ICMP protocol number for this family (1 for v4, 58 for v6)
|
||||
icmpProto uint8
|
||||
}
|
||||
|
||||
// familyForAddr returns the address family for the given IP.
|
||||
func familyForAddr(is4 bool) addrFamily {
|
||||
if is4 {
|
||||
return afIPv4
|
||||
}
|
||||
return afIPv6
|
||||
}
|
||||
|
||||
// protoNum converts a firewall protocol to the IP protocol number,
|
||||
// using the correct ICMP variant for the address family.
|
||||
func (af addrFamily) protoNum(protocol firewall.Protocol) (uint8, error) {
|
||||
switch protocol {
|
||||
case firewall.ProtocolTCP:
|
||||
return unix.IPPROTO_TCP, nil
|
||||
case firewall.ProtocolUDP:
|
||||
return unix.IPPROTO_UDP, nil
|
||||
case firewall.ProtocolICMP:
|
||||
return af.icmpProto, nil
|
||||
case firewall.ProtocolALL:
|
||||
return 0, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
|
||||
}
|
||||
}
|
||||
@@ -1,76 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestExternalChainMonitorRootIntegration verifies that adding a new chain
|
||||
// in an external (non-netbird) filter table triggers the reconciler.
|
||||
// Requires CAP_NET_ADMIN; skip otherwise.
|
||||
func TestExternalChainMonitorRootIntegration(t *testing.T) {
|
||||
if os.Geteuid() != 0 {
|
||||
t.Skip("root required")
|
||||
}
|
||||
|
||||
calls := make(chan struct{}, 8)
|
||||
var count atomic.Int32
|
||||
rec := &countingReconciler{calls: calls, count: &count}
|
||||
|
||||
m := newExternalChainMonitor(rec)
|
||||
m.start()
|
||||
t.Cleanup(m.stop)
|
||||
|
||||
// Give the netlink subscription a moment to register.
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
conn := &nftables.Conn{}
|
||||
table := conn.AddTable(&nftables.Table{
|
||||
Name: "nbmon_integration_test",
|
||||
Family: nftables.TableFamilyINet,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
cleanup := &nftables.Conn{}
|
||||
cleanup.DelTable(table)
|
||||
_ = cleanup.Flush()
|
||||
})
|
||||
|
||||
chain := conn.AddChain(&nftables.Chain{
|
||||
Name: "filter_INPUT",
|
||||
Table: table,
|
||||
Hooknum: nftables.ChainHookInput,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
_ = chain
|
||||
require.NoError(t, conn.Flush(), "create external test chain")
|
||||
|
||||
select {
|
||||
case <-calls:
|
||||
// success
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatalf("reconcile was not invoked after creating an external chain")
|
||||
}
|
||||
require.GreaterOrEqual(t, count.Load(), int32(1))
|
||||
}
|
||||
|
||||
type countingReconciler struct {
|
||||
calls chan struct{}
|
||||
count *atomic.Int32
|
||||
}
|
||||
|
||||
func (c *countingReconciler) reconcileExternalChains() error {
|
||||
c.count.Add(1)
|
||||
select {
|
||||
case c.calls <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,200 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/google/nftables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
externalMonitorReconcileDelay = 500 * time.Millisecond
|
||||
externalMonitorInitInterval = 5 * time.Second
|
||||
externalMonitorMaxInterval = 5 * time.Minute
|
||||
externalMonitorRandomization = 0.5
|
||||
)
|
||||
|
||||
// externalChainReconciler re-applies passthrough accept rules to external
|
||||
// nftables chains. Implementations must be safe to call from the monitor
|
||||
// goroutine; the Manager locks its mutex internally.
|
||||
type externalChainReconciler interface {
|
||||
reconcileExternalChains() error
|
||||
}
|
||||
|
||||
// externalChainMonitor watches nftables netlink events and triggers a
|
||||
// reconcile when a new table or chain appears (e.g. after
|
||||
// `firewall-cmd --reload`). Netlink errors trigger exponential-backoff
|
||||
// reconnect.
|
||||
type externalChainMonitor struct {
|
||||
reconciler externalChainReconciler
|
||||
|
||||
mu sync.Mutex
|
||||
cancel context.CancelFunc
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func newExternalChainMonitor(r externalChainReconciler) *externalChainMonitor {
|
||||
return &externalChainMonitor{reconciler: r}
|
||||
}
|
||||
|
||||
func (m *externalChainMonitor) start() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.cancel != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m.cancel = cancel
|
||||
done := make(chan struct{})
|
||||
m.done = done
|
||||
|
||||
go m.run(ctx, done)
|
||||
}
|
||||
|
||||
func (m *externalChainMonitor) stop() {
|
||||
m.mu.Lock()
|
||||
cancel := m.cancel
|
||||
done := m.done
|
||||
m.cancel = nil
|
||||
m.done = nil
|
||||
m.mu.Unlock()
|
||||
|
||||
if cancel == nil {
|
||||
return
|
||||
}
|
||||
cancel()
|
||||
<-done
|
||||
}
|
||||
|
||||
func (m *externalChainMonitor) run(ctx context.Context, done chan struct{}) {
|
||||
defer close(done)
|
||||
|
||||
bo := &backoff.ExponentialBackOff{
|
||||
InitialInterval: externalMonitorInitInterval,
|
||||
RandomizationFactor: externalMonitorRandomization,
|
||||
Multiplier: backoff.DefaultMultiplier,
|
||||
MaxInterval: externalMonitorMaxInterval,
|
||||
MaxElapsedTime: 0,
|
||||
Clock: backoff.SystemClock,
|
||||
}
|
||||
bo.Reset()
|
||||
|
||||
for ctx.Err() == nil {
|
||||
err := m.watch(ctx)
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
delay := bo.NextBackOff()
|
||||
log.Warnf("external chain monitor: %v, reconnecting in %s", err, delay)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(delay):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *externalChainMonitor) watch(ctx context.Context) error {
|
||||
events, closeMon, err := m.subscribe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer closeMon()
|
||||
|
||||
debounce := time.NewTimer(time.Hour)
|
||||
if !debounce.Stop() {
|
||||
<-debounce.C
|
||||
}
|
||||
defer debounce.Stop()
|
||||
|
||||
pending := false
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-debounce.C:
|
||||
pending = false
|
||||
m.reconcile()
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
return errors.New("monitor channel closed")
|
||||
}
|
||||
if ev.Error != nil {
|
||||
return fmt.Errorf("monitor event: %w", ev.Error)
|
||||
}
|
||||
if !isRelevantMonitorEvent(ev) {
|
||||
continue
|
||||
}
|
||||
resetDebounce(debounce, pending)
|
||||
pending = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *externalChainMonitor) subscribe() (chan *nftables.MonitorEvent, func(), error) {
|
||||
conn := &nftables.Conn{}
|
||||
mon := nftables.NewMonitor(
|
||||
nftables.WithMonitorAction(nftables.MonitorActionNew),
|
||||
nftables.WithMonitorObject(nftables.MonitorObjectChains|nftables.MonitorObjectTables),
|
||||
)
|
||||
events, err := conn.AddMonitor(mon)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("add netlink monitor: %w", err)
|
||||
}
|
||||
return events, func() { _ = mon.Close() }, nil
|
||||
}
|
||||
|
||||
// resetDebounce reschedules a pending debounce timer without leaking a stale
|
||||
// fire on its channel. pending must reflect whether the timer is armed.
|
||||
func resetDebounce(t *time.Timer, pending bool) {
|
||||
if pending && !t.Stop() {
|
||||
select {
|
||||
case <-t.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
t.Reset(externalMonitorReconcileDelay)
|
||||
}
|
||||
|
||||
func (m *externalChainMonitor) reconcile() {
|
||||
if err := m.reconciler.reconcileExternalChains(); err != nil {
|
||||
log.Warnf("reconcile external chain rules: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// isRelevantMonitorEvent returns true for table/chain creation events on
|
||||
// families we care about. The reconciler filters to actual external filter
|
||||
// chains.
|
||||
func isRelevantMonitorEvent(ev *nftables.MonitorEvent) bool {
|
||||
switch ev.Type {
|
||||
case nftables.MonitorEventTypeNewChain:
|
||||
chain, ok := ev.Data.(*nftables.Chain)
|
||||
if !ok || chain == nil || chain.Table == nil {
|
||||
return false
|
||||
}
|
||||
return isMonitoredFamily(chain.Table.Family)
|
||||
case nftables.MonitorEventTypeNewTable:
|
||||
table, ok := ev.Data.(*nftables.Table)
|
||||
if !ok || table == nil {
|
||||
return false
|
||||
}
|
||||
return isMonitoredFamily(table.Family)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isMonitoredFamily(family nftables.TableFamily) bool {
|
||||
switch family {
|
||||
case nftables.TableFamilyIPv4, nftables.TableFamilyIPv6, nftables.TableFamilyINet:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,137 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsMonitoredFamily(t *testing.T) {
|
||||
tests := []struct {
|
||||
family nftables.TableFamily
|
||||
want bool
|
||||
}{
|
||||
{nftables.TableFamilyIPv4, true},
|
||||
{nftables.TableFamilyIPv6, true},
|
||||
{nftables.TableFamilyINet, true},
|
||||
{nftables.TableFamilyARP, false},
|
||||
{nftables.TableFamilyBridge, false},
|
||||
{nftables.TableFamilyNetdev, false},
|
||||
{nftables.TableFamilyUnspecified, false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
assert.Equal(t, tc.want, isMonitoredFamily(tc.family), "family=%d", tc.family)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRelevantMonitorEvent(t *testing.T) {
|
||||
inetTable := &nftables.Table{Name: "firewalld", Family: nftables.TableFamilyINet}
|
||||
ipTable := &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4}
|
||||
arpTable := &nftables.Table{Name: "arp", Family: nftables.TableFamilyARP}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ev *nftables.MonitorEvent
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "new chain in inet firewalld",
|
||||
ev: &nftables.MonitorEvent{
|
||||
Type: nftables.MonitorEventTypeNewChain,
|
||||
Data: &nftables.Chain{Name: "filter_INPUT", Table: inetTable},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "new chain in ip filter",
|
||||
ev: &nftables.MonitorEvent{
|
||||
Type: nftables.MonitorEventTypeNewChain,
|
||||
Data: &nftables.Chain{Name: "INPUT", Table: ipTable},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "new chain in unwatched arp family",
|
||||
ev: &nftables.MonitorEvent{
|
||||
Type: nftables.MonitorEventTypeNewChain,
|
||||
Data: &nftables.Chain{Name: "x", Table: arpTable},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "new table inet",
|
||||
ev: &nftables.MonitorEvent{
|
||||
Type: nftables.MonitorEventTypeNewTable,
|
||||
Data: inetTable,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "del chain (we only act on new)",
|
||||
ev: &nftables.MonitorEvent{
|
||||
Type: nftables.MonitorEventTypeDelChain,
|
||||
Data: &nftables.Chain{Name: "filter_INPUT", Table: inetTable},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "chain with nil table",
|
||||
ev: &nftables.MonitorEvent{
|
||||
Type: nftables.MonitorEventTypeNewChain,
|
||||
Data: &nftables.Chain{Name: "x"},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nil data",
|
||||
ev: &nftables.MonitorEvent{
|
||||
Type: nftables.MonitorEventTypeNewChain,
|
||||
Data: (*nftables.Chain)(nil),
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, isRelevantMonitorEvent(tc.ev))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// fakeReconciler records reconcile invocations for debounce tests.
|
||||
type fakeReconciler struct {
|
||||
calls chan struct{}
|
||||
}
|
||||
|
||||
func (f *fakeReconciler) reconcileExternalChains() error {
|
||||
f.calls <- struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestExternalChainMonitorStopWithoutStart(t *testing.T) {
|
||||
m := newExternalChainMonitor(&fakeReconciler{calls: make(chan struct{}, 1)})
|
||||
// Must not panic or block.
|
||||
m.stop()
|
||||
}
|
||||
|
||||
func TestExternalChainMonitorDoubleStart(t *testing.T) {
|
||||
// start() twice should be a no-op; stop() cleans up once.
|
||||
// We avoid exercising the netlink watch loop here because it needs root.
|
||||
m := newExternalChainMonitor(&fakeReconciler{calls: make(chan struct{}, 1)})
|
||||
|
||||
// Replace run with a stub that just waits for cancel, so start() stays
|
||||
// deterministic without opening a netlink socket.
|
||||
origDone := make(chan struct{})
|
||||
m.done = origDone
|
||||
m.cancel = func() { close(origDone) }
|
||||
|
||||
// Second start should be a no-op (cancel already set).
|
||||
m.start()
|
||||
assert.NotNil(t, m.cancel)
|
||||
|
||||
m.stop()
|
||||
assert.Nil(t, m.cancel)
|
||||
assert.Nil(t, m.done)
|
||||
}
|
||||
@@ -11,11 +11,9 @@ import (
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
@@ -51,17 +49,10 @@ type Manager struct {
|
||||
rConn *nftables.Conn
|
||||
wgIface iFaceMapper
|
||||
|
||||
router *router
|
||||
aclManager *AclManager
|
||||
|
||||
// IPv6 counterparts, nil when no v6 overlay
|
||||
router6 *router
|
||||
aclManager6 *AclManager
|
||||
|
||||
router *router
|
||||
aclManager *AclManager
|
||||
notrackOutputChain *nftables.Chain
|
||||
notrackPreroutingChain *nftables.Chain
|
||||
|
||||
extMonitor *externalChainMonitor
|
||||
}
|
||||
|
||||
// Create nftables firewall manager
|
||||
@@ -71,8 +62,7 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
wgIface: wgIface,
|
||||
}
|
||||
|
||||
tableName := getTableName()
|
||||
workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}
|
||||
workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}
|
||||
|
||||
var err error
|
||||
m.router, err = newRouter(workTable, wgIface, mtu)
|
||||
@@ -85,137 +75,35 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||
}
|
||||
|
||||
if wgIface.Address().HasIPv6() {
|
||||
if err := m.createIPv6Components(tableName, wgIface, mtu); err != nil {
|
||||
return nil, fmt.Errorf("create IPv6 firewall: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
m.extMonitor = newExternalChainMonitor(m)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mtu uint16) error {
|
||||
workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}
|
||||
|
||||
var err error
|
||||
m.router6, err = newRouter(workTable6, wgIface, mtu)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 router: %w", err)
|
||||
}
|
||||
|
||||
// Share the same IP forwarding state with the v4 router, since
|
||||
// EnableIPForwarding controls both v4 and v6 sysctls.
|
||||
m.router6.ipFwdState = m.router.ipFwdState
|
||||
|
||||
m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 acl manager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasIPv6 reports whether the manager has IPv6 components initialized.
|
||||
func (m *Manager) hasIPv6() bool {
|
||||
return m.router6 != nil
|
||||
}
|
||||
|
||||
func (m *Manager) initIPv6() error {
|
||||
workTable6, err := m.createWorkTableFamily(nftables.TableFamilyIPv6)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 work table: %w", err)
|
||||
}
|
||||
|
||||
if err := m.router6.init(workTable6); err != nil {
|
||||
return fmt.Errorf("v6 router init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.aclManager6.init(workTable6); err != nil {
|
||||
return fmt.Errorf("v6 acl manager init: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Init nftables firewall manager
|
||||
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
if err := m.initFirewall(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.persistState(stateManager)
|
||||
|
||||
// Start after initFirewall has installed the baseline external-chain
|
||||
// accept rules. start() is idempotent across Init/Close/Init cycles.
|
||||
m.extMonitor.start()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reconcileExternalChains re-applies passthrough accept rules to external
|
||||
// filter chains for both IPv4 and IPv6 routers. Called by the monitor when
|
||||
// tables or chains appear (e.g. after firewalld reloads).
|
||||
func (m *Manager) reconcileExternalChains() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
if m.router != nil {
|
||||
if err := m.router.acceptExternalChainsRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("v4: %w", err))
|
||||
}
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
if err := m.router6.acceptExternalChainsRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("v6: %w", err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (m *Manager) initFirewall() (err error) {
|
||||
workTable, err := m.createWorkTable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create work table: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
m.rollbackInit()
|
||||
}
|
||||
}()
|
||||
|
||||
if err := m.router.init(workTable); err != nil {
|
||||
return fmt.Errorf("router init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.aclManager.init(workTable); err != nil {
|
||||
// TODO: cleanup router
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
if err := m.initIPv6(); err != nil {
|
||||
// Peer has a v6 address: v6 firewall MUST work or we risk fail-open.
|
||||
return fmt.Errorf("init IPv6 firewall (required because peer has IPv6 address): %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.initNoTrackChains(workTable); err != nil {
|
||||
log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// persistState saves the current interface state for potential recreation on restart.
|
||||
// Unlike iptables, which requires tracking individual rules, nftables maintains
|
||||
// a known state (our netbird table plus a few static rules). This allows for easy
|
||||
// cleanup using Close() without needing to store specific rules.
|
||||
func (m *Manager) persistState(stateManager *statemanager.Manager) {
|
||||
stateManager.RegisterState(&ShutdownState{})
|
||||
|
||||
// We only need to record minimal interface state for potential recreation.
|
||||
// Unlike iptables, which requires tracking individual rules, nftables maintains
|
||||
// a known state (our netbird table plus a few static rules). This allows for easy
|
||||
// cleanup using Close() without needing to store specific rules.
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
InterfaceState: &InterfaceState{
|
||||
NameStr: m.wgIface.Name(),
|
||||
@@ -226,29 +114,14 @@ func (m *Manager) persistState(stateManager *statemanager.Manager) {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
|
||||
// persist early
|
||||
go func() {
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through.
|
||||
func (m *Manager) rollbackInit() {
|
||||
if err := m.router.Reset(); err != nil {
|
||||
log.Warnf("rollback router: %v", err)
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
if err := m.router6.Reset(); err != nil {
|
||||
log.Warnf("rollback v6 router: %v", err)
|
||||
}
|
||||
}
|
||||
if err := m.cleanupNetbirdTables(); err != nil {
|
||||
log.Warnf("cleanup tables: %v", err)
|
||||
}
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
log.Warnf("flush: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
@@ -267,14 +140,12 @@ func (m *Manager) AddPeerFiltering(
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if ip.To4() != nil {
|
||||
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
rawIP := ip.To4()
|
||||
if rawIP == nil {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||
}
|
||||
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
@@ -288,11 +159,8 @@ func (m *Manager) AddRouteFiltering(
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if isIPv6RouteRule(sources, destination) {
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
@@ -303,66 +171,15 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && isIPv6Rule(rule) {
|
||||
return m.aclManager6.DeletePeerRule(rule)
|
||||
}
|
||||
return m.aclManager.DeletePeerRule(rule)
|
||||
}
|
||||
|
||||
func isIPv6Rule(rule firewall.Rule) bool {
|
||||
r, ok := rule.(*Rule)
|
||||
return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6
|
||||
}
|
||||
|
||||
// isIPv6RouteRule determines whether a route rule belongs to the v6 table.
|
||||
// For static routes, the destination prefix determines the family. For dynamic
|
||||
// routes (DomainSet), the sources determine the family since management
|
||||
// duplicates dynamic rules per family.
|
||||
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||
if destination.IsPrefix() {
|
||||
return destination.Prefix.Addr().Is6()
|
||||
}
|
||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||
}
|
||||
|
||||
// DeleteRouteRule deletes a routing rule. Route rules live in exactly one
|
||||
// router; the cached maps are normally authoritative, so the kernel is only
|
||||
// consulted when neither map knows about the rule.
|
||||
// DeleteRouteRule deletes a routing rule
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
id := rule.ID()
|
||||
r, err := m.routerForRuleID(id, (*router).hasRule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.DeleteRouteRule(rule)
|
||||
}
|
||||
|
||||
// routerForRuleID picks the router holding the rule with the given id, using
|
||||
// the supplied lookup. If the cached maps disagree (or both miss), it refreshes
|
||||
// from the kernel once and re-checks before falling back to the v4 router.
|
||||
func (m *Manager) routerForRuleID(id string, has func(*router, string) bool) (*router, error) {
|
||||
if has(m.router, id) {
|
||||
return m.router, nil
|
||||
}
|
||||
if m.hasIPv6() && has(m.router6, id) {
|
||||
return m.router6, nil
|
||||
}
|
||||
if !m.hasIPv6() {
|
||||
return m.router, nil
|
||||
}
|
||||
if err := m.router.refreshRulesMap(); err != nil {
|
||||
return nil, fmt.Errorf("refresh v4 rules: %w", err)
|
||||
}
|
||||
if err := m.router6.refreshRulesMap(); err != nil {
|
||||
return nil, fmt.Errorf("refresh v6 rules: %w", err)
|
||||
}
|
||||
if has(m.router6, id) && !has(m.router, id) {
|
||||
return m.router6, nil
|
||||
}
|
||||
return m.router, nil
|
||||
return m.router.DeleteRouteRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
@@ -377,70 +194,19 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddNatRule(pair)
|
||||
}
|
||||
|
||||
if err := m.router.AddNatRule(pair); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Dynamic routes need NAT in both tables since resolved IPs can be
|
||||
// either v4 or v6. This covers both DomainSet (modern) and the legacy
|
||||
// wildcard 0.0.0.0/0 destination where the client resolves DNS.
|
||||
// On v6 failure we keep the v4 NAT rule rather than rolling back: half
|
||||
// connectivity is better than none, and RemoveNatRule is content-keyed
|
||||
// so the eventual cleanup still works.
|
||||
if m.hasIPv6() && pair.Dynamic {
|
||||
v6Pair := firewall.ToV6NatPair(pair)
|
||||
if err := m.router6.AddNatRule(v6Pair); err != nil {
|
||||
return fmt.Errorf("add v6 NAT rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return m.router.AddNatRule(pair)
|
||||
}
|
||||
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
|
||||
if !m.hasIPv6() {
|
||||
return nil
|
||||
}
|
||||
return m.router6.RemoveNatRule(pair)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := m.router.RemoveNatRule(pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
|
||||
}
|
||||
|
||||
if m.hasIPv6() && pair.Dynamic {
|
||||
v6Pair := firewall.ToV6NatPair(pair)
|
||||
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
return m.router.RemoveNatRule(pair)
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic.
|
||||
// This is called when USPFilter wraps the native firewall, adding blanket accept
|
||||
// rules so that packet filtering is handled in userspace instead of by netfilter.
|
||||
//
|
||||
// TODO: In USP mode this only adds ACCEPT to the netbird table's own chains,
|
||||
// which doesn't override DROP rules in external tables (e.g. firewalld).
|
||||
// Should add passthrough rules to external chains (like the native mode router's
|
||||
// addExternalChainsRules does) for both the netbird table family and inet tables.
|
||||
// The netbird table itself is fine (routing chains already exist there), but
|
||||
// non-netbird tables with INPUT/FORWARD hooks can still DROP our WG traffic.
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -448,11 +214,6 @@ func (m *Manager) AllowNetbird() error {
|
||||
if err := m.aclManager.createDefaultAllowRules(); err != nil {
|
||||
return fmt.Errorf("create default allow rules: %w", err)
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
if err := m.aclManager6.createDefaultAllowRules(); err != nil {
|
||||
return fmt.Errorf("create v6 default allow rules: %w", err)
|
||||
}
|
||||
}
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
||||
}
|
||||
@@ -466,47 +227,31 @@ func (m *Manager) AllowNetbird() error {
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
|
||||
return err
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
return firewall.SetLegacyManagement(m.router6, isLegacy)
|
||||
}
|
||||
return nil
|
||||
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||
}
|
||||
|
||||
// Close closes the firewall manager
|
||||
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
m.extMonitor.stop()
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := m.router.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %v", err))
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
if err := m.router6.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %v", err))
|
||||
}
|
||||
return fmt.Errorf("reset router: %v", err)
|
||||
}
|
||||
|
||||
if err := m.cleanupNetbirdTables(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("cleanup netbird tables: %v", err))
|
||||
return fmt.Errorf("cleanup netbird tables: %v", err)
|
||||
}
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete state: %v", err))
|
||||
return fmt.Errorf("delete state: %v", err)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) cleanupNetbirdTables() error {
|
||||
@@ -555,12 +300,6 @@ func (m *Manager) Flush() error {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
if err := m.aclManager6.Flush(); err != nil {
|
||||
return fmt.Errorf("flush v6 acl: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.refreshNoTrackChains(); err != nil {
|
||||
log.Errorf("failed to refresh notrack chains: %v", err)
|
||||
}
|
||||
@@ -573,12 +312,6 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if rule.TranslatedAddress.Is6() {
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddDNATRule(rule)
|
||||
}
|
||||
return m.router.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
@@ -587,11 +320,7 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
r, err := m.routerForRuleID(rule.ID(), (*router).hasDNATRule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.DeleteDNATRule(rule)
|
||||
return m.router.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
// UpdateSet updates the set with the given prefixes
|
||||
@@ -599,82 +328,39 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var v4Prefixes, v6Prefixes []netip.Prefix
|
||||
for _, p := range prefixes {
|
||||
if p.Addr().Is6() {
|
||||
v6Prefixes = append(v6Prefixes, p)
|
||||
} else {
|
||||
v4Prefixes = append(v4Prefixes, p)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
||||
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
|
||||
return fmt.Errorf("update v6 set: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return m.router.UpdateSet(set, prefixes)
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if localAddr.Is6() {
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if localAddr.Is6() {
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if localAddr.Is6() {
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if localAddr.Is6() {
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -848,11 +534,7 @@ func (m *Manager) refreshNoTrackChains() error {
|
||||
}
|
||||
|
||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
return m.createWorkTableFamily(nftables.TableFamilyIPv4)
|
||||
}
|
||||
|
||||
func (m *Manager) createWorkTableFamily(family nftables.TableFamily) (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(family)
|
||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
@@ -864,7 +546,7 @@ func (m *Manager) createWorkTableFamily(family nftables.TableFamily) (*nftables.
|
||||
}
|
||||
}
|
||||
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: family})
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4})
|
||||
err = m.rConn.Flush()
|
||||
return table, err
|
||||
}
|
||||
|
||||
@@ -383,138 +383,10 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
err = manager.AddNatRule(pair)
|
||||
require.NoError(t, err, "failed to add NAT rule")
|
||||
|
||||
dnatRule, err := manager.AddDNATRule(fw.ForwardRule{
|
||||
Protocol: fw.ProtocolTCP,
|
||||
DestinationPort: fw.Port{Values: []uint16{8080}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.96.0.2"),
|
||||
TranslatedPort: fw.Port{Values: []uint16{80}},
|
||||
})
|
||||
require.NoError(t, err, "failed to add DNAT rule")
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.DeleteDNATRule(dnatRule), "failed to delete DNAT rule")
|
||||
})
|
||||
|
||||
stdout, stderr = runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
||||
func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
for _, bin := range []string{"ip6tables", "ip6tables-save", "iptables-save"} {
|
||||
if _, err := exec.LookPath(bin); err != nil {
|
||||
t.Skipf("%s not available on this system: %v", bin, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Seed ip6 tables in the nft backend. Docker may not create them.
|
||||
seedIp6tables(t)
|
||||
|
||||
ifaceMockV6 := &iFaceMock{
|
||||
NameFunc: func() string { return "wt-test" },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.96.0.1"),
|
||||
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||
IPv6: netip.MustParseAddr("fd00::1"),
|
||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMockV6, iface.DefaultMTU)
|
||||
require.NoError(t, err, "create manager")
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil), "close manager")
|
||||
|
||||
stdout, stderr := runIp6tablesSave(t)
|
||||
verifyIp6tablesOutput(t, stdout, stderr)
|
||||
})
|
||||
|
||||
ip := netip.MustParseAddr("fd00::2")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "add v6 peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("fd00:1::/64")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err, "add v6 route filtering rule")
|
||||
|
||||
err = manager.AddNatRule(fw.RouterPair{
|
||||
Source: fw.Network{Prefix: netip.MustParsePrefix("fd00::/64")},
|
||||
Destination: fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
|
||||
Masquerade: true,
|
||||
})
|
||||
require.NoError(t, err, "add v6 NAT rule")
|
||||
|
||||
dnatRule, err := manager.AddDNATRule(fw.ForwardRule{
|
||||
Protocol: fw.ProtocolTCP,
|
||||
DestinationPort: fw.Port{Values: []uint16{8080}},
|
||||
TranslatedAddress: netip.MustParseAddr("fd00::2"),
|
||||
TranslatedPort: fw.Port{Values: []uint16{80}},
|
||||
})
|
||||
require.NoError(t, err, "add v6 DNAT rule")
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.DeleteDNATRule(dnatRule), "delete v6 DNAT rule")
|
||||
})
|
||||
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
|
||||
stdout, stderr = runIp6tablesSave(t)
|
||||
verifyIp6tablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
||||
func seedIp6tables(t *testing.T) {
|
||||
t.Helper()
|
||||
for _, tc := range []struct{ table, chain string }{
|
||||
{"filter", "FORWARD"},
|
||||
{"nat", "POSTROUTING"},
|
||||
{"mangle", "FORWARD"},
|
||||
} {
|
||||
add := exec.Command("ip6tables", "-t", tc.table, "-A", tc.chain, "-j", "ACCEPT")
|
||||
require.NoError(t, add.Run(), "seed ip6tables -t %s", tc.table)
|
||||
del := exec.Command("ip6tables", "-t", tc.table, "-D", tc.chain, "-j", "ACCEPT")
|
||||
require.NoError(t, del.Run(), "unseed ip6tables -t %s", tc.table)
|
||||
}
|
||||
}
|
||||
|
||||
func runIp6tablesSave(t *testing.T) (string, string) {
|
||||
t.Helper()
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd := exec.Command("ip6tables-save")
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
require.NoError(t, cmd.Run(), "ip6tables-save failed")
|
||||
return stdout.String(), stderr.String()
|
||||
}
|
||||
|
||||
func verifyIp6tablesOutput(t *testing.T, stdout, stderr string) {
|
||||
t.Helper()
|
||||
for _, msg := range []string{
|
||||
"Table `nat' is incompatible",
|
||||
"Table `mangle' is incompatible",
|
||||
"Table `filter' is incompatible",
|
||||
} {
|
||||
require.NotContains(t, stdout, msg,
|
||||
"ip6tables-save stdout reports incompatibility: %s", stdout)
|
||||
require.NotContains(t, stderr, msg,
|
||||
"ip6tables-save stderr reports incompatibility: %s", stderr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
|
||||
@@ -50,10 +50,8 @@ const (
|
||||
dnatSuffix = "_dnat"
|
||||
snatSuffix = "_snat"
|
||||
|
||||
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
|
||||
ipv4TCPHeaderSize = 40
|
||||
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
|
||||
ipv6TCPHeaderSize = 60
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
|
||||
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
|
||||
maxPrefixesSet = 1500
|
||||
@@ -78,7 +76,6 @@ type router struct {
|
||||
rules map[string]*nftables.Rule
|
||||
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
|
||||
|
||||
af addrFamily
|
||||
wgIface iFaceMapper
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
legacyManagement bool
|
||||
@@ -91,7 +88,6 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
|
||||
workTable: workTable,
|
||||
chains: make(map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4),
|
||||
wgIface: wgIface,
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||
mtu: mtu,
|
||||
@@ -154,7 +150,7 @@ func (r *router) Reset() error {
|
||||
func (r *router) removeNatPreroutingRules() error {
|
||||
table := &nftables.Table{
|
||||
Name: tableNat,
|
||||
Family: r.af.tableFamily,
|
||||
Family: nftables.TableFamilyIPv4,
|
||||
}
|
||||
chain := &nftables.Chain{
|
||||
Name: chainNameNatPrerouting,
|
||||
@@ -187,7 +183,7 @@ func (r *router) removeNatPreroutingRules() error {
|
||||
}
|
||||
|
||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
tables, err := r.conn.ListTablesOfFamily(r.af.tableFamily)
|
||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list tables: %w", err)
|
||||
}
|
||||
@@ -423,7 +419,7 @@ func (r *router) AddRouteFiltering(
|
||||
|
||||
// Handle protocol
|
||||
if proto != firewall.ProtocolALL {
|
||||
protoNum, err := r.af.protoNum(proto)
|
||||
protoNum, err := protoToInt(proto)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
@@ -483,24 +479,7 @@ func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bo
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
|
||||
return r.getIpSetExprs(ref, isSource)
|
||||
}
|
||||
|
||||
func (r *router) iptablesProto() iptables.Protocol {
|
||||
if r.af.tableFamily == nftables.TableFamilyIPv6 {
|
||||
return iptables.ProtocolIPv6
|
||||
}
|
||||
return iptables.ProtocolIPv4
|
||||
}
|
||||
|
||||
func (r *router) hasRule(id string) bool {
|
||||
_, ok := r.rules[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *router) hasDNATRule(id string) bool {
|
||||
_, ok := r.rules[id+dnatSuffix]
|
||||
return ok
|
||||
return getIpSetExprs(ref, isSource)
|
||||
}
|
||||
|
||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
@@ -549,10 +528,10 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
|
||||
Table: r.workTable,
|
||||
// required for prefixes
|
||||
Interval: true,
|
||||
KeyType: r.af.setKeyType,
|
||||
KeyType: nftables.TypeIPAddr,
|
||||
}
|
||||
|
||||
elements := r.convertPrefixesToSet(prefixes)
|
||||
elements := convertPrefixesToSet(prefixes)
|
||||
nElements := len(elements)
|
||||
|
||||
maxElements := maxPrefixesSet * 2
|
||||
@@ -585,17 +564,23 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
|
||||
return nfset, nil
|
||||
}
|
||||
|
||||
func (r *router) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
|
||||
func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
|
||||
var elements []nftables.SetElement
|
||||
for _, prefix := range prefixes {
|
||||
// TODO: Implement IPv6 support
|
||||
if prefix.Addr().Is6() {
|
||||
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||
continue
|
||||
}
|
||||
|
||||
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
|
||||
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
|
||||
firstIP := prefix.Addr()
|
||||
lastIP := calculateLastIP(prefix).Next()
|
||||
|
||||
elements = append(elements,
|
||||
// the nft tool also adds a zero-address IntervalEnd element, see https://github.com/google/nftables/issues/247
|
||||
// nftables.SetElement{Key: make([]byte, r.af.addrLen), IntervalEnd: true},
|
||||
// the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247
|
||||
// nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true},
|
||||
nftables.SetElement{Key: firstIP.AsSlice()},
|
||||
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
|
||||
)
|
||||
@@ -605,20 +590,10 @@ func (r *router) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetEle
|
||||
|
||||
// calculateLastIP determines the last IP in a given prefix.
|
||||
func calculateLastIP(prefix netip.Prefix) netip.Addr {
|
||||
masked := prefix.Masked()
|
||||
if masked.Addr().Is4() {
|
||||
hostMask := ^uint32(0) >> masked.Bits()
|
||||
lastIP := uint32FromNetipAddr(masked.Addr()) | hostMask
|
||||
return netip.AddrFrom4(uint32ToBytes(lastIP))
|
||||
}
|
||||
hostMask := ^uint32(0) >> prefix.Masked().Bits()
|
||||
lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask
|
||||
|
||||
// IPv6: set host bits to all 1s
|
||||
b := masked.Addr().As16()
|
||||
bits := masked.Bits()
|
||||
for i := bits; i < 128; i++ {
|
||||
b[i/8] |= 1 << (7 - i%8)
|
||||
}
|
||||
return netip.AddrFrom16(b)
|
||||
return netip.AddrFrom4(uint32ToBytes(lastIP))
|
||||
}
|
||||
|
||||
// Utility function to convert netip.Addr to uint32.
|
||||
@@ -870,16 +845,9 @@ func (r *router) addPostroutingRules() {
|
||||
}
|
||||
|
||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||
// TODO: Add IPv6 support
|
||||
func (r *router) addMSSClampingRules() error {
|
||||
overhead := uint16(ipv4TCPHeaderSize)
|
||||
if r.af.tableFamily == nftables.TableFamilyIPv6 {
|
||||
overhead = ipv6TCPHeaderSize
|
||||
}
|
||||
if r.mtu <= overhead {
|
||||
log.Debugf("MTU %d too small for MSS clamping (overhead %d), skipping", r.mtu, overhead)
|
||||
return nil
|
||||
}
|
||||
mss := r.mtu - overhead
|
||||
mss := r.mtu - ipTCPHeaderMinSize
|
||||
|
||||
exprsOut := []expr.Any{
|
||||
&expr.Meta{
|
||||
@@ -1086,22 +1054,17 @@ func (r *router) acceptFilterTableRules() error {
|
||||
log.Debugf("Used %s to add accept forward and input rules", fw)
|
||||
}()
|
||||
|
||||
// Try iptables first and fallback to nftables if iptables is not available.
|
||||
// Use the correct protocol (iptables vs ip6tables) for the address family.
|
||||
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
|
||||
// Try iptables first and fallback to nftables if iptables is not available
|
||||
ipt, err := iptables.New()
|
||||
if err != nil {
|
||||
// iptables is not available but the filter table exists
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
|
||||
fw = "nftables"
|
||||
return r.acceptFilterRulesNftables(r.filterTable)
|
||||
}
|
||||
|
||||
if err := r.acceptFilterRulesIptables(ipt); err != nil {
|
||||
log.Warnf("iptables failed (table may be incompatible), falling back to nftables: %v", err)
|
||||
fw = "nftables"
|
||||
return r.acceptFilterRulesNftables(r.filterTable)
|
||||
}
|
||||
return nil
|
||||
return r.acceptFilterRulesIptables(ipt)
|
||||
}
|
||||
|
||||
func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
@@ -1172,122 +1135,83 @@ func (r *router) acceptExternalChainsRules() error {
|
||||
}
|
||||
|
||||
intf := ifname(r.wgIface.Name())
|
||||
|
||||
for _, chain := range chains {
|
||||
r.applyExternalChainAccept(chain, intf)
|
||||
if chain.Hooknum == nil {
|
||||
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
|
||||
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
|
||||
|
||||
switch *chain.Hooknum {
|
||||
case *nftables.ChainHookForward:
|
||||
r.insertForwardAcceptRules(chain, intf)
|
||||
case *nftables.ChainHookInput:
|
||||
r.insertInputAcceptRule(chain, intf)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush external chain rules: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) applyExternalChainAccept(chain *nftables.Chain, intf []byte) {
|
||||
if chain.Hooknum == nil {
|
||||
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
|
||||
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
|
||||
|
||||
switch *chain.Hooknum {
|
||||
case *nftables.ChainHookForward:
|
||||
r.insertForwardAcceptRules(chain, intf)
|
||||
case *nftables.ChainHookInput:
|
||||
r.insertInputAcceptRule(chain, intf)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
|
||||
existing, err := r.existingNetbirdRulesInChain(chain)
|
||||
if err != nil {
|
||||
log.Warnf("skip forward accept rules in %s/%s: %v", chain.Table.Name, chain.Name, err)
|
||||
return
|
||||
}
|
||||
r.insertForwardIifRule(chain, intf, existing)
|
||||
r.insertForwardOifEstablishedRule(chain, intf, existing)
|
||||
}
|
||||
|
||||
func (r *router) insertForwardIifRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
|
||||
if existing[userDataAcceptForwardRuleIif] {
|
||||
return
|
||||
}
|
||||
r.conn.InsertRule(&nftables.Rule{
|
||||
iifRule := &nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: intf,
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
},
|
||||
UserData: []byte(userDataAcceptForwardRuleIif),
|
||||
})
|
||||
}
|
||||
}
|
||||
r.conn.InsertRule(iifRule)
|
||||
|
||||
func (r *router) insertForwardOifEstablishedRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
|
||||
if existing[userDataAcceptForwardRuleOif] {
|
||||
return
|
||||
}
|
||||
exprs := []expr.Any{
|
||||
oifExprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: intf,
|
||||
},
|
||||
}
|
||||
r.conn.InsertRule(&nftables.Rule{
|
||||
oifRule := &nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: append(exprs, getEstablishedExprs(2)...),
|
||||
Exprs: append(oifExprs, getEstablishedExprs(2)...),
|
||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||
})
|
||||
}
|
||||
r.conn.InsertRule(oifRule)
|
||||
}
|
||||
|
||||
func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
|
||||
existing, err := r.existingNetbirdRulesInChain(chain)
|
||||
if err != nil {
|
||||
log.Warnf("skip input accept rule in %s/%s: %v", chain.Table.Name, chain.Name, err)
|
||||
return
|
||||
}
|
||||
if existing[userDataAcceptInputRule] {
|
||||
return
|
||||
}
|
||||
r.conn.InsertRule(&nftables.Rule{
|
||||
inputRule := &nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: intf,
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
},
|
||||
UserData: []byte(userDataAcceptInputRule),
|
||||
})
|
||||
}
|
||||
|
||||
// existingNetbirdRulesInChain returns the set of netbird-owned UserData tags present in a chain; callers must bail on error since InsertRule is additive.
|
||||
func (r *router) existingNetbirdRulesInChain(chain *nftables.Chain) (map[string]bool, error) {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list rules: %w", err)
|
||||
}
|
||||
present := map[string]bool{}
|
||||
for _, rule := range rules {
|
||||
if !isNetbirdAcceptRuleTag(rule.UserData) {
|
||||
continue
|
||||
}
|
||||
present[string(rule.UserData)] = true
|
||||
}
|
||||
return present, nil
|
||||
}
|
||||
|
||||
func isNetbirdAcceptRuleTag(userData []byte) bool {
|
||||
switch string(userData) {
|
||||
case userDataAcceptForwardRuleIif,
|
||||
userDataAcceptForwardRuleOif,
|
||||
userDataAcceptInputRule:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
r.conn.InsertRule(inputRule)
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptFilterRules() error {
|
||||
@@ -1309,17 +1233,13 @@ func (r *router) removeFilterTableRules() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
|
||||
ipt, err := iptables.New()
|
||||
if err != nil {
|
||||
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
|
||||
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||
}
|
||||
|
||||
if err := r.removeAcceptFilterRulesIptables(ipt); err != nil {
|
||||
log.Debugf("iptables removal failed (table may be incompatible), falling back to nftables: %v", err)
|
||||
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||
}
|
||||
return nil
|
||||
return r.removeAcceptFilterRulesIptables(ipt)
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
|
||||
@@ -1386,7 +1306,7 @@ func (r *router) removeExternalChainsRules() error {
|
||||
func (r *router) findExternalChains() []*nftables.Chain {
|
||||
var chains []*nftables.Chain
|
||||
|
||||
families := []nftables.TableFamily{r.af.tableFamily, nftables.TableFamilyINet}
|
||||
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet}
|
||||
|
||||
for _, family := range families {
|
||||
allChains, err := r.conn.ListChainsOfTableFamily(family)
|
||||
@@ -1417,8 +1337,8 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat)
|
||||
if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) {
|
||||
// Skip all iptables-managed tables in the ip family
|
||||
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1559,7 +1479,7 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
protoNum, err := r.af.protoNum(rule.Protocol)
|
||||
protoNum, err := protoToInt(rule.Protocol)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
@@ -1622,7 +1542,7 @@ func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, rule
|
||||
dnatExprs = append(dnatExprs,
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(r.af.tableFamily),
|
||||
Family: uint32(nftables.TableFamilyIPv4),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: regProtoMin,
|
||||
RegProtoMax: regProtoMax,
|
||||
@@ -1715,15 +1635,14 @@ func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule f
|
||||
},
|
||||
)
|
||||
|
||||
natTable := &nftables.Table{
|
||||
Name: tableNat,
|
||||
Family: r.af.tableFamily,
|
||||
}
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: natTable,
|
||||
Table: &nftables.Table{
|
||||
Name: tableNat,
|
||||
Family: nftables.TableFamilyIPv4,
|
||||
},
|
||||
Chain: &nftables.Chain{
|
||||
Name: chainNameNatPrerouting,
|
||||
Table: natTable,
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityNATDest,
|
||||
@@ -1754,8 +1673,8 @@ func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: r.af.dstAddrOffset,
|
||||
Len: r.af.addrLen,
|
||||
Offset: 16,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
@@ -1833,7 +1752,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
|
||||
}
|
||||
|
||||
elements := r.convertPrefixesToSet(prefixes)
|
||||
elements := convertPrefixesToSet(prefixes)
|
||||
if err := r.conn.SetAddElements(nfset, elements); err != nil {
|
||||
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
|
||||
}
|
||||
@@ -1848,14 +1767,14 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
|
||||
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
protoNum, err := r.af.protoNum(protocol)
|
||||
protoNum, err := protoToInt(protocol)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
@@ -1882,15 +1801,11 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 3,
|
||||
Data: binaryutil.BigEndian.PutUint16(originalPort),
|
||||
Data: binaryutil.BigEndian.PutUint16(sourcePort),
|
||||
},
|
||||
}
|
||||
|
||||
bits := 32
|
||||
if localAddr.Is6() {
|
||||
bits = 128
|
||||
}
|
||||
exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...)
|
||||
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
@@ -1899,11 +1814,11 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(translatedPort),
|
||||
Data: binaryutil.BigEndian.PutUint16(targetPort),
|
||||
},
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(r.af.tableFamily),
|
||||
Family: uint32(nftables.TableFamilyIPv4),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: 2,
|
||||
RegProtoMax: 0,
|
||||
@@ -1928,12 +1843,12 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
rule, exists := r.rules[ruleID]
|
||||
if !exists {
|
||||
@@ -1979,8 +1894,8 @@ func (r *router) ensureNATOutputChain() error {
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
|
||||
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
@@ -1990,7 +1905,7 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
|
||||
return err
|
||||
}
|
||||
|
||||
protoNum, err := r.af.protoNum(protocol)
|
||||
protoNum, err := protoToInt(protocol)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
@@ -2011,15 +1926,11 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(originalPort),
|
||||
Data: binaryutil.BigEndian.PutUint16(sourcePort),
|
||||
},
|
||||
}
|
||||
|
||||
bits := 32
|
||||
if localAddr.Is6() {
|
||||
bits = 128
|
||||
}
|
||||
exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...)
|
||||
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
@@ -2028,11 +1939,11 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(translatedPort),
|
||||
Data: binaryutil.BigEndian.PutUint16(targetPort),
|
||||
},
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(r.af.tableFamily),
|
||||
Family: uint32(nftables.TableFamilyIPv4),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: 2,
|
||||
},
|
||||
@@ -2056,12 +1967,12 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
rule, exists := r.rules[ruleID]
|
||||
if !exists {
|
||||
@@ -2100,44 +2011,45 @@ func (r *router) applyNetwork(
|
||||
}
|
||||
|
||||
if network.IsPrefix() {
|
||||
return r.applyPrefix(network.Prefix, isSource), nil
|
||||
return applyPrefix(network.Prefix, isSource), nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// applyPrefix generates nftables expressions for a CIDR prefix
|
||||
func (r *router) applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
|
||||
// dst offset by default
|
||||
offset := r.af.dstAddrOffset
|
||||
func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
|
||||
// dst offset
|
||||
offset := uint32(16)
|
||||
if isSource {
|
||||
// src offset
|
||||
offset = r.af.srcAddrOffset
|
||||
offset = 12
|
||||
}
|
||||
|
||||
ones := prefix.Bits()
|
||||
// unspecified address (/0) doesn't need extra expressions
|
||||
// 0.0.0.0/0 doesn't need extra expressions
|
||||
if ones == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
mask := net.CIDRMask(ones, r.af.totalBits)
|
||||
xor := make([]byte, r.af.addrLen)
|
||||
mask := net.CIDRMask(ones, 32)
|
||||
|
||||
return []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offset,
|
||||
Len: r.af.addrLen,
|
||||
Len: 4,
|
||||
},
|
||||
// netmask
|
||||
&expr.Bitwise{
|
||||
DestRegister: 1,
|
||||
SourceRegister: 1,
|
||||
Len: r.af.addrLen,
|
||||
Len: 4,
|
||||
Mask: mask,
|
||||
Xor: xor,
|
||||
Xor: []byte{0, 0, 0, 0},
|
||||
},
|
||||
// net address
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
@@ -2220,12 +2132,13 @@ func getCtNewExprs() []expr.Any {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
|
||||
// dst offset by default
|
||||
offset := r.af.dstAddrOffset
|
||||
func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
|
||||
|
||||
// dst offset
|
||||
offset := uint32(16)
|
||||
if isSource {
|
||||
// src offset
|
||||
offset = r.af.srcAddrOffset
|
||||
offset = 12
|
||||
}
|
||||
|
||||
return []expr.Any{
|
||||
@@ -2233,7 +2146,7 @@ func (r *router) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool)
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offset,
|
||||
Len: r.af.addrLen,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
|
||||
@@ -90,9 +90,8 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
}
|
||||
|
||||
// Build CIDR matching expressions
|
||||
testRouter := &router{af: afIPv4}
|
||||
sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true)
|
||||
destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false)
|
||||
sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true)
|
||||
destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false)
|
||||
|
||||
// Combine all expressions in the correct order
|
||||
// nolint:gocritic
|
||||
@@ -509,136 +508,6 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNftablesCreateIpSet_IPv6(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
workTable, err := createWorkTableIPv6()
|
||||
require.NoError(t, err, "Failed to create v6 work table")
|
||||
defer deleteWorkTableIPv6()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
require.NoError(t, r.init(workTable))
|
||||
defer func() {
|
||||
require.NoError(t, r.Reset(), "Failed to reset router")
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sources []netip.Prefix
|
||||
expected []netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "Single IPv6",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("2001:db8::1/128")},
|
||||
},
|
||||
{
|
||||
name: "Multiple IPv6 Subnets",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("fd00::/64"),
|
||||
netip.MustParsePrefix("2001:db8::/48"),
|
||||
netip.MustParsePrefix("fe80::/10"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Overlapping IPv6",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("fd00::/48"),
|
||||
netip.MustParsePrefix("fd00::/64"),
|
||||
netip.MustParsePrefix("fd00::1/128"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("fd00::/48"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Mixed prefix lengths",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("2001:db8:1::/48"),
|
||||
netip.MustParsePrefix("2001:db8:2::1/128"),
|
||||
netip.MustParsePrefix("fd00:abcd::/32"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
||||
set, err := r.createIpSet(setName, setInput{prefixes: tt.sources})
|
||||
require.NoError(t, err, "Failed to create IPv6 set")
|
||||
require.NotNil(t, set)
|
||||
|
||||
assert.Equal(t, setName, set.Name)
|
||||
assert.True(t, set.Interval)
|
||||
assert.Equal(t, nftables.TypeIP6Addr, set.KeyType)
|
||||
|
||||
fetchedSet, err := r.conn.GetSetByName(r.workTable, setName)
|
||||
require.NoError(t, err, "Failed to fetch created set")
|
||||
|
||||
elements, err := r.conn.GetSetElements(fetchedSet)
|
||||
require.NoError(t, err, "Failed to get set elements")
|
||||
|
||||
uniquePrefixes := make(map[string]bool)
|
||||
for _, elem := range elements {
|
||||
if !elem.IntervalEnd && len(elem.Key) == 16 {
|
||||
ip := netip.AddrFrom16([16]byte(elem.Key))
|
||||
uniquePrefixes[ip.String()] = true
|
||||
}
|
||||
}
|
||||
|
||||
expectedCount := len(tt.expected)
|
||||
if expectedCount == 0 {
|
||||
expectedCount = len(tt.sources)
|
||||
}
|
||||
assert.Equal(t, expectedCount, len(uniquePrefixes), "unique prefix count mismatch")
|
||||
|
||||
r.conn.DelSet(set)
|
||||
require.NoError(t, r.conn.Flush())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createWorkTableIPv6() (*nftables.Table, error) {
|
||||
sConn, err := nftables.New(nftables.AsLasting())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, t := range tables {
|
||||
if t.Name == tableNameNetbird {
|
||||
sConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv6})
|
||||
err = sConn.Flush()
|
||||
return table, err
|
||||
}
|
||||
|
||||
func deleteWorkTableIPv6() {
|
||||
sConn, err := nftables.New(nftables.AsLasting())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, t := range tables {
|
||||
if t.Name == tableNameNetbird {
|
||||
sConn.DelTable(t)
|
||||
_ = sConn.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
|
||||
t.Helper()
|
||||
|
||||
@@ -758,7 +627,7 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
|
||||
|
||||
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
|
||||
var metaFound, cmpFound bool
|
||||
expectedProto, _ := afIPv4.protoNum(proto)
|
||||
expectedProto, _ := protoToInt(proto)
|
||||
for _, e := range exprs {
|
||||
switch ex := e.(type) {
|
||||
case *expr.Meta:
|
||||
@@ -985,55 +854,3 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||
}
|
||||
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
|
||||
}
|
||||
|
||||
func TestCalculateLastIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
prefix string
|
||||
want string
|
||||
}{
|
||||
{"10.0.0.0/24", "10.0.0.255"},
|
||||
{"10.0.0.0/32", "10.0.0.0"},
|
||||
{"0.0.0.0/0", "255.255.255.255"},
|
||||
{"192.168.1.0/28", "192.168.1.15"},
|
||||
{"fd00::/64", "fd00::ffff:ffff:ffff:ffff"},
|
||||
{"fd00::/128", "fd00::"},
|
||||
{"2001:db8::/48", "2001:db8:0:ffff:ffff:ffff:ffff:ffff"},
|
||||
{"::/0", "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.prefix, func(t *testing.T) {
|
||||
prefix := netip.MustParsePrefix(tt.prefix)
|
||||
got := calculateLastIP(prefix)
|
||||
assert.Equal(t, tt.want, got.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertPrefixesToSet_IPv6(t *testing.T) {
|
||||
r := &router{af: afIPv6}
|
||||
prefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("fd00::/64"),
|
||||
netip.MustParsePrefix("2001:db8::1/128"),
|
||||
}
|
||||
|
||||
elements := r.convertPrefixesToSet(prefixes)
|
||||
|
||||
// Each prefix produces 2 elements (start + end)
|
||||
require.Len(t, elements, 4)
|
||||
|
||||
// fd00::/64 start
|
||||
assert.Equal(t, netip.MustParseAddr("fd00::").As16(), [16]byte(elements[0].Key))
|
||||
assert.False(t, elements[0].IntervalEnd)
|
||||
|
||||
// fd00::/64 end (fd00:0:0:1::, one past the last)
|
||||
assert.Equal(t, netip.MustParseAddr("fd00:0:0:1::").As16(), [16]byte(elements[1].Key))
|
||||
assert.True(t, elements[1].IntervalEnd)
|
||||
|
||||
// 2001:db8::1/128 start
|
||||
assert.Equal(t, netip.MustParseAddr("2001:db8::1").As16(), [16]byte(elements[2].Key))
|
||||
assert.False(t, elements[2].IntervalEnd)
|
||||
|
||||
// 2001:db8::1/128 end (2001:db8::2)
|
||||
assert.Equal(t, netip.MustParseAddr("2001:db8::2").As16(), [16]byte(elements[3].Key))
|
||||
assert.True(t, elements[3].IntervalEnd)
|
||||
}
|
||||
|
||||
@@ -5,10 +5,8 @@ import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -31,20 +29,15 @@ func (m *Manager) Close(*statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
if isFirewallRuleActive(firewallRuleName) {
|
||||
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove windows firewall rule: %w", err))
|
||||
}
|
||||
if !isFirewallRuleActive(firewallRuleName) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if isFirewallRuleActive(firewallRuleName + "-v6") {
|
||||
if err := manageFirewallRule(firewallRuleName+"-v6", deleteRule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove windows v6 firewall rule: %w", err))
|
||||
}
|
||||
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
|
||||
return fmt.Errorf("couldn't remove windows firewall: %w", err)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
@@ -53,33 +46,17 @@ func (m *Manager) AllowNetbird() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isFirewallRuleActive(firewallRuleName) {
|
||||
if err := manageFirewallRule(firewallRuleName,
|
||||
addRule,
|
||||
"dir=in",
|
||||
"enable=yes",
|
||||
"action=allow",
|
||||
"profile=any",
|
||||
"localip="+m.wgIface.Address().IP.String(),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
if isFirewallRuleActive(firewallRuleName) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if v6 := m.wgIface.Address().IPv6; v6.IsValid() && !isFirewallRuleActive(firewallRuleName+"-v6") {
|
||||
if err := manageFirewallRule(firewallRuleName+"-v6",
|
||||
addRule,
|
||||
"dir=in",
|
||||
"enable=yes",
|
||||
"action=allow",
|
||||
"profile=any",
|
||||
"localip="+v6.String(),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return manageFirewallRule(firewallRuleName,
|
||||
addRule,
|
||||
"dir=in",
|
||||
"enable=yes",
|
||||
"action=allow",
|
||||
"profile=any",
|
||||
"localip="+m.wgIface.Address().IP.String(),
|
||||
)
|
||||
}
|
||||
|
||||
func manageFirewallRule(ruleName string, action action, extraArgs ...string) error {
|
||||
|
||||
@@ -1,125 +0,0 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTCPCapEvicts(t *testing.T) {
|
||||
t.Setenv(EnvTCPMaxEntries, "4")
|
||||
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
require.Equal(t, 4, tracker.maxEntries)
|
||||
|
||||
src := netip.MustParseAddr("100.64.0.1")
|
||||
dst := netip.MustParseAddr("100.64.0.2")
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
tracker.TrackOutbound(src, dst, uint16(10000+i), 80, TCPSyn, 0)
|
||||
}
|
||||
require.LessOrEqual(t, len(tracker.connections), 4,
|
||||
"TCP table must not exceed the configured cap")
|
||||
require.Greater(t, len(tracker.connections), 0,
|
||||
"some entries must remain after eviction")
|
||||
|
||||
// The most recently admitted flow must be present: eviction must make
|
||||
// room for new entries, not silently drop them.
|
||||
require.Contains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10009), DstPort: 80},
|
||||
"newest TCP flow must be admitted after eviction")
|
||||
// A pre-cap flow must have been evicted to fit the last one.
|
||||
require.NotContains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10000), DstPort: 80},
|
||||
"oldest TCP flow should have been evicted")
|
||||
}
|
||||
|
||||
func TestTCPCapPrefersTombstonedForEviction(t *testing.T) {
|
||||
t.Setenv(EnvTCPMaxEntries, "3")
|
||||
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
src := netip.MustParseAddr("100.64.0.1")
|
||||
dst := netip.MustParseAddr("100.64.0.2")
|
||||
|
||||
// Fill to cap with 3 live connections.
|
||||
for i := 0; i < 3; i++ {
|
||||
tracker.TrackOutbound(src, dst, uint16(20000+i), 80, TCPSyn, 0)
|
||||
}
|
||||
require.Len(t, tracker.connections, 3)
|
||||
|
||||
// Tombstone one by sending RST through IsValidInbound.
|
||||
tombstonedKey := ConnKey{SrcIP: src, DstIP: dst, SrcPort: 20001, DstPort: 80}
|
||||
require.True(t, tracker.IsValidInbound(dst, src, 80, 20001, TCPRst|TCPAck, 0))
|
||||
require.True(t, tracker.connections[tombstonedKey].IsTombstone())
|
||||
|
||||
// Another live connection forces eviction. The tombstone must go first.
|
||||
tracker.TrackOutbound(src, dst, uint16(29999), 80, TCPSyn, 0)
|
||||
|
||||
_, tombstonedStillPresent := tracker.connections[tombstonedKey]
|
||||
require.False(t, tombstonedStillPresent,
|
||||
"tombstoned entry should be evicted before live entries")
|
||||
require.LessOrEqual(t, len(tracker.connections), 3)
|
||||
|
||||
// Both live pre-cap entries must survive: eviction must prefer the
|
||||
// tombstone, not just satisfy the size bound by dropping any entry.
|
||||
require.Contains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20000), DstPort: 80},
|
||||
"live entries must not be evicted while a tombstone exists")
|
||||
require.Contains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20002), DstPort: 80},
|
||||
"live entries must not be evicted while a tombstone exists")
|
||||
}
|
||||
|
||||
func TestUDPCapEvicts(t *testing.T) {
|
||||
t.Setenv(EnvUDPMaxEntries, "5")
|
||||
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
require.Equal(t, 5, tracker.maxEntries)
|
||||
|
||||
src := netip.MustParseAddr("100.64.0.1")
|
||||
dst := netip.MustParseAddr("100.64.0.2")
|
||||
|
||||
for i := 0; i < 12; i++ {
|
||||
tracker.TrackOutbound(src, dst, uint16(30000+i), 53, 0)
|
||||
}
|
||||
require.LessOrEqual(t, len(tracker.connections), 5)
|
||||
require.Greater(t, len(tracker.connections), 0)
|
||||
|
||||
require.Contains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30011), DstPort: 53},
|
||||
"newest UDP flow must be admitted after eviction")
|
||||
require.NotContains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30000), DstPort: 53},
|
||||
"oldest UDP flow should have been evicted")
|
||||
}
|
||||
|
||||
func TestICMPCapEvicts(t *testing.T) {
|
||||
t.Setenv(EnvICMPMaxEntries, "3")
|
||||
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
require.Equal(t, 3, tracker.maxEntries)
|
||||
|
||||
src := netip.MustParseAddr("100.64.0.1")
|
||||
dst := netip.MustParseAddr("100.64.0.2")
|
||||
|
||||
echoReq := layers.CreateICMPv4TypeCode(uint8(layers.ICMPv4TypeEchoRequest), 0)
|
||||
for i := 0; i < 8; i++ {
|
||||
tracker.TrackOutbound(src, dst, uint16(i), echoReq, nil, 64)
|
||||
}
|
||||
require.LessOrEqual(t, len(tracker.connections), 3)
|
||||
require.Greater(t, len(tracker.connections), 0)
|
||||
|
||||
require.Contains(t, tracker.connections,
|
||||
ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(7)},
|
||||
"newest ICMP flow must be admitted after eviction")
|
||||
require.NotContains(t, tracker.connections,
|
||||
ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(0)},
|
||||
"oldest ICMP flow should have been evicted")
|
||||
}
|
||||
@@ -1,63 +1,16 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
// evictSampleSize bounds how many map entries we scan per eviction call.
|
||||
// Keeps eviction O(1) even at cap under sustained load; the sampled-LRU
|
||||
// heuristic is good enough for a conntrack table that only overflows under
|
||||
// abuse.
|
||||
const evictSampleSize = 8
|
||||
|
||||
// envDuration parses an os.Getenv(name) as a time.Duration. Falls back to
|
||||
// def on empty or invalid; logs a warning on invalid.
|
||||
func envDuration(logger *nblog.Logger, name string, def time.Duration) time.Duration {
|
||||
v := os.Getenv(name)
|
||||
if v == "" {
|
||||
return def
|
||||
}
|
||||
d, err := time.ParseDuration(v)
|
||||
if err != nil {
|
||||
logger.Warn3("invalid %s=%q: %v, using default", name, v, err)
|
||||
return def
|
||||
}
|
||||
if d <= 0 {
|
||||
logger.Warn2("invalid %s=%q: must be positive, using default", name, v)
|
||||
return def
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// envInt parses an os.Getenv(name) as an int. Falls back to def on empty,
|
||||
// invalid, or non-positive. Logs a warning on invalid input.
|
||||
func envInt(logger *nblog.Logger, name string, def int) int {
|
||||
v := os.Getenv(name)
|
||||
if v == "" {
|
||||
return def
|
||||
}
|
||||
n, err := strconv.Atoi(v)
|
||||
switch {
|
||||
case err != nil:
|
||||
logger.Warn3("invalid %s=%q: %v, using default", name, v, err)
|
||||
return def
|
||||
case n <= 0:
|
||||
logger.Warn2("invalid %s=%q: must be positive, using default", name, v)
|
||||
return def
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// BaseConnTrack provides common fields and locking for all connection types
|
||||
type BaseConnTrack struct {
|
||||
FlowId uuid.UUID
|
||||
@@ -111,7 +64,5 @@ type ConnKey struct {
|
||||
}
|
||||
|
||||
func (c ConnKey) String() string {
|
||||
return net.JoinHostPort(c.SrcIP.Unmap().String(), strconv.Itoa(int(c.SrcPort))) +
|
||||
" → " +
|
||||
net.JoinHostPort(c.DstIP.Unmap().String(), strconv.Itoa(int(c.DstPort)))
|
||||
return fmt.Sprintf("%s:%d → %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||
}
|
||||
|
||||
@@ -13,54 +13,6 @@ import (
|
||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
func TestConnKey_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key ConnKey
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "IPv4",
|
||||
key: ConnKey{
|
||||
SrcIP: netip.MustParseAddr("192.168.1.1"),
|
||||
DstIP: netip.MustParseAddr("10.0.0.1"),
|
||||
SrcPort: 12345,
|
||||
DstPort: 80,
|
||||
},
|
||||
expect: "192.168.1.1:12345 → 10.0.0.1:80",
|
||||
},
|
||||
{
|
||||
name: "IPv6",
|
||||
key: ConnKey{
|
||||
SrcIP: netip.MustParseAddr("2001:db8::1"),
|
||||
DstIP: netip.MustParseAddr("2001:db8::2"),
|
||||
SrcPort: 54321,
|
||||
DstPort: 443,
|
||||
},
|
||||
expect: "[2001:db8::1]:54321 → [2001:db8::2]:443",
|
||||
},
|
||||
{
|
||||
name: "IPv4-mapped IPv6 unmaps",
|
||||
key: ConnKey{
|
||||
SrcIP: netip.MustParseAddr("::ffff:10.0.0.1"),
|
||||
DstIP: netip.MustParseAddr("::ffff:10.0.0.2"),
|
||||
SrcPort: 1000,
|
||||
DstPort: 2000,
|
||||
},
|
||||
expect: "10.0.0.1:1000 → 10.0.0.2:2000",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := tc.key.String()
|
||||
if got != tc.expect {
|
||||
t.Errorf("got %q, want %q", got, tc.expect)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Memory pressure tests
|
||||
func BenchmarkMemoryPressure(b *testing.B) {
|
||||
b.Run("TCPHighLoad", func(b *testing.B) {
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
//go:build !ios && !android
|
||||
|
||||
package conntrack
|
||||
|
||||
// Default per-tracker entry caps on desktop/server platforms. These mirror
|
||||
// typical Linux netfilter nf_conntrack_max territory with ample headroom.
|
||||
const (
|
||||
DefaultMaxTCPEntries = 65536
|
||||
DefaultMaxUDPEntries = 16384
|
||||
DefaultMaxICMPEntries = 2048
|
||||
)
|
||||
@@ -1,13 +0,0 @@
|
||||
//go:build ios || android
|
||||
|
||||
package conntrack
|
||||
|
||||
// Default per-tracker entry caps on mobile platforms. iOS network extensions
|
||||
// are capped at ~50 MB; Android runs under aggressive memory pressure. These
|
||||
// values keep conntrack footprint well under 5 MB worst case (TCPConnTrack
|
||||
// is ~200 B plus map overhead).
|
||||
const (
|
||||
DefaultMaxTCPEntries = 4096
|
||||
DefaultMaxUDPEntries = 2048
|
||||
DefaultMaxICMPEntries = 512
|
||||
)
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -22,14 +21,9 @@ const (
|
||||
// ICMPCleanupInterval is how often we check for stale ICMP connections
|
||||
ICMPCleanupInterval = 15 * time.Second
|
||||
|
||||
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info.
|
||||
// IPv4: 20-byte header + 8-byte transport = 28 bytes.
|
||||
// IPv6: 40-byte header + 8-byte transport = 48 bytes.
|
||||
MaxICMPPayloadLength = 48
|
||||
// minICMPPayloadIPv4 is the minimum embedded packet length for IPv4 ICMP errors.
|
||||
minICMPPayloadIPv4 = 28
|
||||
// minICMPPayloadIPv6 is the minimum embedded packet length for IPv6 ICMP errors.
|
||||
minICMPPayloadIPv6 = 48
|
||||
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info,
|
||||
// which includes the IP header (20 bytes) and transport header (8 bytes)
|
||||
MaxICMPPayloadLength = 28
|
||||
)
|
||||
|
||||
// ICMPConnKey uniquely identifies an ICMP connection
|
||||
@@ -50,9 +44,6 @@ type ICMPConnTrack struct {
|
||||
ICMPCode uint8
|
||||
}
|
||||
|
||||
// EnvICMPMaxEntries caps the ICMP conntrack table size.
|
||||
const EnvICMPMaxEntries = "NB_CONNTRACK_ICMP_MAX"
|
||||
|
||||
// ICMPTracker manages ICMP connection states
|
||||
type ICMPTracker struct {
|
||||
logger *nblog.Logger
|
||||
@@ -61,7 +52,6 @@ type ICMPTracker struct {
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
mutex sync.RWMutex
|
||||
maxEntries int
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
@@ -75,7 +65,7 @@ type ICMPInfo struct {
|
||||
|
||||
// String implements fmt.Stringer for lazy evaluation in log messages
|
||||
func (info ICMPInfo) String() string {
|
||||
if info.isErrorMessage() && info.PayloadLen >= minICMPPayloadIPv4 {
|
||||
if info.isErrorMessage() && info.PayloadLen >= MaxICMPPayloadLength {
|
||||
if origInfo := info.parseOriginalPacket(); origInfo != "" {
|
||||
return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo)
|
||||
}
|
||||
@@ -84,72 +74,42 @@ func (info ICMPInfo) String() string {
|
||||
return info.TypeCode.String()
|
||||
}
|
||||
|
||||
// isErrorMessage returns true if this ICMP type carries original packet info.
|
||||
// Covers both ICMPv4 and ICMPv6 error types. Without a family field we match
|
||||
// both sets; type 3 overlaps (v4 DestUnreachable / v6 TimeExceeded) so it's
|
||||
// kept as a literal.
|
||||
// isErrorMessage returns true if this ICMP type carries original packet info
|
||||
func (info ICMPInfo) isErrorMessage() bool {
|
||||
typ := info.TypeCode.Type()
|
||||
// ICMPv4 error types
|
||||
if typ == layers.ICMPv4TypeDestinationUnreachable ||
|
||||
typ == layers.ICMPv4TypeRedirect ||
|
||||
typ == layers.ICMPv4TypeTimeExceeded ||
|
||||
typ == layers.ICMPv4TypeParameterProblem {
|
||||
return true
|
||||
}
|
||||
// ICMPv6 error types (type 3 already matched above as v4 DestUnreachable)
|
||||
if typ == layers.ICMPv6TypeDestinationUnreachable ||
|
||||
typ == layers.ICMPv6TypePacketTooBig ||
|
||||
typ == layers.ICMPv6TypeParameterProblem {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
return typ == 3 || // Destination Unreachable
|
||||
typ == 5 || // Redirect
|
||||
typ == 11 || // Time Exceeded
|
||||
typ == 12 // Parameter Problem
|
||||
}
|
||||
|
||||
// parseOriginalPacket extracts info about the original packet from ICMP payload
|
||||
func (info ICMPInfo) parseOriginalPacket() string {
|
||||
if info.PayloadLen == 0 {
|
||||
if info.PayloadLen < MaxICMPPayloadLength {
|
||||
return ""
|
||||
}
|
||||
|
||||
version := (info.PayloadData[0] >> 4) & 0xF
|
||||
|
||||
var protocol uint8
|
||||
var srcIP, dstIP net.IP
|
||||
var transportData []byte
|
||||
|
||||
switch version {
|
||||
case 4:
|
||||
if info.PayloadLen < minICMPPayloadIPv4 {
|
||||
return ""
|
||||
}
|
||||
protocol = info.PayloadData[9]
|
||||
srcIP = net.IP(info.PayloadData[12:16])
|
||||
dstIP = net.IP(info.PayloadData[16:20])
|
||||
transportData = info.PayloadData[20:]
|
||||
case 6:
|
||||
if info.PayloadLen < minICMPPayloadIPv6 {
|
||||
return ""
|
||||
}
|
||||
// Next Header field in IPv6 header
|
||||
protocol = info.PayloadData[6]
|
||||
srcIP = net.IP(info.PayloadData[8:24])
|
||||
dstIP = net.IP(info.PayloadData[24:40])
|
||||
transportData = info.PayloadData[40:]
|
||||
default:
|
||||
// TODO: handle IPv6
|
||||
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
|
||||
return ""
|
||||
}
|
||||
|
||||
protocol := info.PayloadData[9]
|
||||
srcIP := net.IP(info.PayloadData[12:16])
|
||||
dstIP := net.IP(info.PayloadData[16:20])
|
||||
|
||||
transportData := info.PayloadData[20:]
|
||||
|
||||
switch nftypes.Protocol(protocol) {
|
||||
case nftypes.TCP:
|
||||
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
||||
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
||||
return "TCP " + net.JoinHostPort(srcIP.String(), strconv.Itoa(int(srcPort))) + " → " + net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort)))
|
||||
return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
case nftypes.UDP:
|
||||
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
||||
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
||||
return "UDP " + net.JoinHostPort(srcIP.String(), strconv.Itoa(int(srcPort))) + " → " + net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort)))
|
||||
return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
case nftypes.ICMP:
|
||||
icmpType := transportData[0]
|
||||
@@ -175,7 +135,6 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
maxEntries: envInt(logger, EnvICMPMaxEntries, DefaultMaxICMPEntries),
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
@@ -262,9 +221,7 @@ func (t *ICMPTracker) track(
|
||||
|
||||
// non echo requests don't need tracking
|
||||
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||
}
|
||||
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
|
||||
return
|
||||
}
|
||||
@@ -283,22 +240,16 @@ func (t *ICMPTracker) track(
|
||||
conn.UpdateCounters(direction, size)
|
||||
|
||||
t.mutex.Lock()
|
||||
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
|
||||
t.evictOneLocked()
|
||||
}
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||
}
|
||||
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request.
|
||||
// Accepts both ICMPv4 (type 0) and ICMPv6 (type 129) echo replies.
|
||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
|
||||
if icmpType != uint8(layers.ICMPv4TypeEchoReply) && icmpType != uint8(layers.ICMPv6TypeEchoReply) {
|
||||
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -335,34 +286,6 @@ func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
|
||||
// Bounded sample scan: picks the oldest among up to evictSampleSize entries.
|
||||
func (t *ICMPTracker) evictOneLocked() {
|
||||
var candKey ICMPConnKey
|
||||
var candSeen int64
|
||||
haveCand := false
|
||||
sampled := 0
|
||||
|
||||
for k, c := range t.connections {
|
||||
seen := c.lastSeen.Load()
|
||||
if !haveCand || seen < candSeen {
|
||||
candKey = k
|
||||
candSeen = seen
|
||||
haveCand = true
|
||||
}
|
||||
sampled++
|
||||
if sampled >= evictSampleSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
if haveCand {
|
||||
if evicted := t.connections[candKey]; evicted != nil {
|
||||
t.sendEvent(nftypes.TypeEnd, evicted, nil)
|
||||
}
|
||||
delete(t.connections, candKey)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) cleanup() {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
@@ -371,22 +294,13 @@ func (t *ICMPTracker) cleanup() {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
delete(t.connections, key)
|
||||
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
}
|
||||
t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func icmpProtocolForAddr(ip netip.Addr) nftypes.Protocol {
|
||||
if ip.Is6() {
|
||||
return nftypes.ICMPv6
|
||||
}
|
||||
return nftypes.ICMP
|
||||
}
|
||||
|
||||
// Close stops the cleanup routine and releases resources
|
||||
func (t *ICMPTracker) Close() {
|
||||
t.tickerCancel()
|
||||
@@ -402,7 +316,7 @@ func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []
|
||||
Type: typ,
|
||||
RuleID: ruleID,
|
||||
Direction: conn.Direction,
|
||||
Protocol: icmpProtocolForAddr(conn.SourceIP),
|
||||
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
|
||||
SourceIP: conn.SourceIP,
|
||||
DestIP: conn.DestIP,
|
||||
ICMPType: conn.ICMPType,
|
||||
@@ -420,7 +334,7 @@ func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Ad
|
||||
Type: nftypes.TypeStart,
|
||||
RuleID: ruleID,
|
||||
Direction: direction,
|
||||
Protocol: icmpProtocolForAddr(srcIP),
|
||||
Protocol: nftypes.ICMP,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
ICMPType: typ,
|
||||
|
||||
@@ -5,42 +5,6 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestICMPConnKey_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key ICMPConnKey
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "IPv4",
|
||||
key: ICMPConnKey{
|
||||
SrcIP: netip.MustParseAddr("192.168.1.1"),
|
||||
DstIP: netip.MustParseAddr("10.0.0.1"),
|
||||
ID: 1234,
|
||||
},
|
||||
expect: "192.168.1.1 → 10.0.0.1 (id 1234)",
|
||||
},
|
||||
{
|
||||
name: "IPv6",
|
||||
key: ICMPConnKey{
|
||||
SrcIP: netip.MustParseAddr("2001:db8::1"),
|
||||
DstIP: netip.MustParseAddr("2001:db8::2"),
|
||||
ID: 5678,
|
||||
},
|
||||
expect: "2001:db8::1 → 2001:db8::2 (id 5678)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := tc.key.String()
|
||||
if got != tc.expect {
|
||||
t.Errorf("got %q, want %q", got, tc.expect)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkICMPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||
|
||||
@@ -38,27 +38,6 @@ const (
|
||||
TCPHandshakeTimeout = 60 * time.Second
|
||||
// TCPCleanupInterval is how often we check for stale connections
|
||||
TCPCleanupInterval = 5 * time.Minute
|
||||
// FinWaitTimeout bounds FIN_WAIT_1 / FIN_WAIT_2 / CLOSING states.
|
||||
// Matches Linux netfilter nf_conntrack_tcp_timeout_fin_wait.
|
||||
FinWaitTimeout = 60 * time.Second
|
||||
// CloseWaitTimeout bounds CLOSE_WAIT. Matches Linux default; apps
|
||||
// holding CloseWait longer than this should bump the env var.
|
||||
CloseWaitTimeout = 60 * time.Second
|
||||
// LastAckTimeout bounds LAST_ACK. Matches Linux default.
|
||||
LastAckTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// Env vars to override per-state teardown timeouts. Values parsed by
|
||||
// time.ParseDuration (e.g. "120s", "2m"). Invalid values fall back to the
|
||||
// defaults above with a warning.
|
||||
const (
|
||||
EnvTCPFinWaitTimeout = "NB_CONNTRACK_TCP_FIN_WAIT_TIMEOUT"
|
||||
EnvTCPCloseWaitTimeout = "NB_CONNTRACK_TCP_CLOSE_WAIT_TIMEOUT"
|
||||
EnvTCPLastAckTimeout = "NB_CONNTRACK_TCP_LAST_ACK_TIMEOUT"
|
||||
|
||||
// EnvTCPMaxEntries caps the TCP conntrack table size. Oldest entries
|
||||
// (tombstones first) are evicted when the cap is reached.
|
||||
EnvTCPMaxEntries = "NB_CONNTRACK_TCP_MAX"
|
||||
)
|
||||
|
||||
// TCPState represents the state of a TCP connection
|
||||
@@ -154,18 +133,14 @@ func (t *TCPConnTrack) SetTombstone() {
|
||||
|
||||
// TCPTracker manages TCP connection states
|
||||
type TCPTracker struct {
|
||||
logger *nblog.Logger
|
||||
connections map[ConnKey]*TCPConnTrack
|
||||
mutex sync.RWMutex
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
timeout time.Duration
|
||||
waitTimeout time.Duration
|
||||
finWaitTimeout time.Duration
|
||||
closeWaitTimeout time.Duration
|
||||
lastAckTimeout time.Duration
|
||||
maxEntries int
|
||||
flowLogger nftypes.FlowLogger
|
||||
logger *nblog.Logger
|
||||
connections map[ConnKey]*TCPConnTrack
|
||||
mutex sync.RWMutex
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
timeout time.Duration
|
||||
waitTimeout time.Duration
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
// NewTCPTracker creates a new TCP connection tracker
|
||||
@@ -180,17 +155,13 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
tracker := &TCPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*TCPConnTrack),
|
||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
timeout: timeout,
|
||||
waitTimeout: waitTimeout,
|
||||
finWaitTimeout: envDuration(logger, EnvTCPFinWaitTimeout, FinWaitTimeout),
|
||||
closeWaitTimeout: envDuration(logger, EnvTCPCloseWaitTimeout, CloseWaitTimeout),
|
||||
lastAckTimeout: envDuration(logger, EnvTCPLastAckTimeout, LastAckTimeout),
|
||||
maxEntries: envInt(logger, EnvTCPMaxEntries, DefaultMaxTCPEntries),
|
||||
flowLogger: flowLogger,
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*TCPConnTrack),
|
||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
timeout: timeout,
|
||||
waitTimeout: waitTimeout,
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
go tracker.cleanupRoutine(ctx)
|
||||
@@ -238,12 +209,6 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
|
||||
if exists || flags&TCPSyn == 0 {
|
||||
return
|
||||
}
|
||||
// Reject illegal SYN combinations (SYN+FIN, SYN+RST, …) so they don't
|
||||
// create spurious conntrack entries. Not mandated by RFC 9293 but a
|
||||
// common hardening (Linux netfilter/nftables rejects these too).
|
||||
if !isValidFlagCombination(flags) {
|
||||
return
|
||||
}
|
||||
|
||||
conn := &TCPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
@@ -260,65 +225,20 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
|
||||
conn.state.Store(int32(TCPStateNew))
|
||||
conn.DNATOrigPort.Store(uint32(origPort))
|
||||
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s TCP connection: %s", direction, key)
|
||||
}
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s TCP connection: %s", direction, key)
|
||||
}
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
|
||||
t.mutex.Lock()
|
||||
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
|
||||
t.evictOneLocked()
|
||||
}
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||
}
|
||||
|
||||
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
|
||||
// Bounded scan: samples up to evictSampleSize pseudo-random entries (Go map
|
||||
// iteration order is randomized), preferring a tombstone. If no tombstone
|
||||
// found in the sample, evicts the oldest among the sampled entries. O(1)
|
||||
// worst case — cheap enough to run on every insert at cap during abuse.
|
||||
func (t *TCPTracker) evictOneLocked() {
|
||||
var candKey ConnKey
|
||||
var candSeen int64
|
||||
haveCand := false
|
||||
sampled := 0
|
||||
|
||||
for k, c := range t.connections {
|
||||
if c.IsTombstone() {
|
||||
delete(t.connections, k)
|
||||
return
|
||||
}
|
||||
seen := c.lastSeen.Load()
|
||||
if !haveCand || seen < candSeen {
|
||||
candKey = k
|
||||
candSeen = seen
|
||||
haveCand = true
|
||||
}
|
||||
sampled++
|
||||
if sampled >= evictSampleSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
if haveCand {
|
||||
if evicted := t.connections[candKey]; evicted != nil {
|
||||
// TypeEnd is already emitted at the state transition to
|
||||
// TimeWait and when a connection is tombstoned. Only emit
|
||||
// here when we're reaping a still-active flow.
|
||||
if evicted.GetState() != TCPStateTimeWait && !evicted.IsTombstone() {
|
||||
t.sendEvent(nftypes.TypeEnd, evicted, nil)
|
||||
}
|
||||
}
|
||||
delete(t.connections, candKey)
|
||||
}
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||
func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool {
|
||||
key := ConnKey{
|
||||
@@ -336,19 +256,12 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
||||
return false
|
||||
}
|
||||
|
||||
// Reject illegal flag combinations regardless of state. These never belong
|
||||
// to a legitimate flow and must not advance or tear down state.
|
||||
if !isValidFlagCombination(flags) {
|
||||
if t.logger.Enabled(nblog.LevelWarn) {
|
||||
t.logger.Warn3("TCP illegal flag combination %x for connection %s (state %s)", flags, key, conn.GetState())
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
currentState := conn.GetState()
|
||||
if !t.isValidStateForFlags(currentState, flags) {
|
||||
if t.logger.Enabled(nblog.LevelWarn) {
|
||||
t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
|
||||
t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
|
||||
// allow all flags for established for now
|
||||
if currentState == TCPStateEstablished {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -357,208 +270,116 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
||||
return true
|
||||
}
|
||||
|
||||
// updateState updates the TCP connection state based on flags.
|
||||
// updateState updates the TCP connection state based on flags
|
||||
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) {
|
||||
conn.UpdateCounters(packetDir, size)
|
||||
|
||||
// Malformed flag combinations must not refresh lastSeen or drive state,
|
||||
// otherwise spoofed packets keep a dead flow alive past its timeout.
|
||||
if !isValidFlagCombination(flags) {
|
||||
return
|
||||
}
|
||||
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(packetDir, size)
|
||||
|
||||
currentState := conn.GetState()
|
||||
|
||||
if flags&TCPRst != 0 {
|
||||
// Hardening beyond RFC 9293 §3.10.7.4: without sequence tracking we
|
||||
// cannot apply the RFC 5961 in-window RST check, so we conservatively
|
||||
// reject RSTs that the spec would accept (TIME-WAIT with in-window
|
||||
// SEQ, SynSent from same direction as own SYN, etc.).
|
||||
t.handleRst(key, conn, currentState, packetDir)
|
||||
return
|
||||
}
|
||||
|
||||
newState := nextState(currentState, conn.Direction, packetDir, flags)
|
||||
if newState == 0 || !conn.CompareAndSwapState(currentState, newState) {
|
||||
return
|
||||
}
|
||||
t.onTransition(key, conn, currentState, newState, packetDir)
|
||||
}
|
||||
|
||||
// handleRst processes a RST segment. Late RSTs in TimeWait and spoofed RSTs
|
||||
// from the SYN direction are ignored; otherwise the flow is tombstoned.
|
||||
func (t *TCPTracker) handleRst(key ConnKey, conn *TCPConnTrack, currentState TCPState, packetDir nftypes.Direction) {
|
||||
// TimeWait exists to absorb late segments; don't let a late RST
|
||||
// tombstone the entry and break same-4-tuple reuse.
|
||||
if currentState == TCPStateTimeWait {
|
||||
return
|
||||
}
|
||||
// A RST from the same direction as the SYN cannot be a legitimate
|
||||
// response and must not tear down a half-open connection.
|
||||
if currentState == TCPStateSynSent && packetDir == conn.Direction {
|
||||
return
|
||||
}
|
||||
if !conn.CompareAndSwapState(currentState, TCPStateClosed) {
|
||||
return
|
||||
}
|
||||
conn.SetTombstone()
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
}
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
|
||||
// stateTransition describes one state's transition logic. It receives the
|
||||
// packet's flags plus whether the packet direction matches the connection's
|
||||
// origin direction (same=true means same side as the SYN initiator). Return 0
|
||||
// for no transition.
|
||||
type stateTransition func(flags uint8, connDir nftypes.Direction, same bool) TCPState
|
||||
|
||||
// stateTable maps each state to its transition function. Centralized here so
|
||||
// nextState stays trivial and each rule is easy to read in isolation.
|
||||
var stateTable = map[TCPState]stateTransition{
|
||||
TCPStateNew: transNew,
|
||||
TCPStateSynSent: transSynSent,
|
||||
TCPStateSynReceived: transSynReceived,
|
||||
TCPStateEstablished: transEstablished,
|
||||
TCPStateFinWait1: transFinWait1,
|
||||
TCPStateFinWait2: transFinWait2,
|
||||
TCPStateClosing: transClosing,
|
||||
TCPStateCloseWait: transCloseWait,
|
||||
TCPStateLastAck: transLastAck,
|
||||
}
|
||||
|
||||
// nextState returns the target TCP state for the given current state and
|
||||
// packet, or 0 if the packet does not trigger a transition.
|
||||
func nextState(currentState TCPState, connDir, packetDir nftypes.Direction, flags uint8) TCPState {
|
||||
fn, ok := stateTable[currentState]
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return fn(flags, connDir, packetDir == connDir)
|
||||
}
|
||||
|
||||
func transNew(flags uint8, connDir nftypes.Direction, _ bool) TCPState {
|
||||
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||
if connDir == nftypes.Egress {
|
||||
return TCPStateSynSent
|
||||
if conn.CompareAndSwapState(currentState, TCPStateClosed) {
|
||||
conn.SetTombstone()
|
||||
t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
return TCPStateSynReceived
|
||||
return
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func transSynSent(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
||||
if same {
|
||||
return TCPStateSynReceived // simultaneous open
|
||||
var newState TCPState
|
||||
switch currentState {
|
||||
case TCPStateNew:
|
||||
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||
if conn.Direction == nftypes.Egress {
|
||||
newState = TCPStateSynSent
|
||||
} else {
|
||||
newState = TCPStateSynReceived
|
||||
}
|
||||
}
|
||||
return TCPStateEstablished
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func transSynReceived(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPAck != 0 && flags&TCPSyn == 0 && same {
|
||||
return TCPStateEstablished
|
||||
}
|
||||
return 0
|
||||
}
|
||||
case TCPStateSynSent:
|
||||
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
||||
if packetDir != conn.Direction {
|
||||
newState = TCPStateEstablished
|
||||
} else {
|
||||
// Simultaneous open
|
||||
newState = TCPStateSynReceived
|
||||
}
|
||||
}
|
||||
|
||||
func transEstablished(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPFin == 0 {
|
||||
return 0
|
||||
}
|
||||
if same {
|
||||
return TCPStateFinWait1
|
||||
}
|
||||
return TCPStateCloseWait
|
||||
}
|
||||
case TCPStateSynReceived:
|
||||
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
|
||||
if packetDir == conn.Direction {
|
||||
newState = TCPStateEstablished
|
||||
}
|
||||
}
|
||||
|
||||
// transFinWait1 handles the active-close peer response. A FIN carrying our
|
||||
// ACK piggybacked goes straight to TIME-WAIT (RFC 9293 §3.10.7.4, FIN-WAIT-1:
|
||||
// "if our FIN has been ACKed... enter the TIME-WAIT state"); a lone FIN moves
|
||||
// to CLOSING; a pure ACK of our FIN moves to FIN-WAIT-2.
|
||||
func transFinWait1(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if same {
|
||||
return 0
|
||||
}
|
||||
if flags&TCPFin != 0 && flags&TCPAck != 0 {
|
||||
return TCPStateTimeWait
|
||||
}
|
||||
switch {
|
||||
case flags&TCPFin != 0:
|
||||
return TCPStateClosing
|
||||
case flags&TCPAck != 0:
|
||||
return TCPStateFinWait2
|
||||
}
|
||||
return 0
|
||||
}
|
||||
case TCPStateEstablished:
|
||||
if flags&TCPFin != 0 {
|
||||
if packetDir == conn.Direction {
|
||||
newState = TCPStateFinWait1
|
||||
} else {
|
||||
newState = TCPStateCloseWait
|
||||
}
|
||||
}
|
||||
|
||||
// transFinWait2 ignores own-side FIN retransmits; only the peer's FIN advances.
|
||||
func transFinWait2(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPFin != 0 && !same {
|
||||
return TCPStateTimeWait
|
||||
}
|
||||
return 0
|
||||
}
|
||||
case TCPStateFinWait1:
|
||||
if packetDir != conn.Direction {
|
||||
switch {
|
||||
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||
newState = TCPStateClosing
|
||||
case flags&TCPFin != 0:
|
||||
newState = TCPStateClosing
|
||||
case flags&TCPAck != 0:
|
||||
newState = TCPStateFinWait2
|
||||
}
|
||||
}
|
||||
|
||||
// transClosing completes a simultaneous close on the peer's ACK.
|
||||
func transClosing(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPAck != 0 && !same {
|
||||
return TCPStateTimeWait
|
||||
}
|
||||
return 0
|
||||
}
|
||||
case TCPStateFinWait2:
|
||||
if flags&TCPFin != 0 {
|
||||
newState = TCPStateTimeWait
|
||||
}
|
||||
|
||||
// transCloseWait only advances to LastAck when WE send FIN, ignoring peer retransmits.
|
||||
func transCloseWait(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPFin != 0 && same {
|
||||
return TCPStateLastAck
|
||||
}
|
||||
return 0
|
||||
}
|
||||
case TCPStateClosing:
|
||||
if flags&TCPAck != 0 {
|
||||
newState = TCPStateTimeWait
|
||||
}
|
||||
|
||||
// transLastAck closes the flow only on the peer's ACK (not our own ACK retransmits).
|
||||
func transLastAck(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPAck != 0 && !same {
|
||||
return TCPStateClosed
|
||||
}
|
||||
return 0
|
||||
}
|
||||
case TCPStateCloseWait:
|
||||
if flags&TCPFin != 0 {
|
||||
newState = TCPStateLastAck
|
||||
}
|
||||
|
||||
// onTransition handles logging and flow-event emission after a successful
|
||||
// state transition. TimeWait and Closed are terminal for flow accounting.
|
||||
func (t *TCPTracker) onTransition(key ConnKey, conn *TCPConnTrack, from, to TCPState, packetDir nftypes.Direction) {
|
||||
traceOn := t.logger.Enabled(nblog.LevelTrace)
|
||||
if traceOn {
|
||||
t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, from, to, packetDir)
|
||||
case TCPStateLastAck:
|
||||
if flags&TCPAck != 0 {
|
||||
newState = TCPStateClosed
|
||||
}
|
||||
}
|
||||
|
||||
switch to {
|
||||
case TCPStateTimeWait:
|
||||
if traceOn {
|
||||
if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
|
||||
t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
|
||||
|
||||
switch newState {
|
||||
case TCPStateTimeWait:
|
||||
t.logger.Trace5("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
}
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
case TCPStateClosed:
|
||||
conn.SetTombstone()
|
||||
if traceOn {
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
|
||||
case TCPStateClosed:
|
||||
conn.SetTombstone()
|
||||
t.logger.Trace5("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}
|
||||
|
||||
// isValidStateForFlags checks if the TCP flags are valid for the current
|
||||
// connection state. Caller must have already verified the flag combination is
|
||||
// legal via isValidFlagCombination.
|
||||
// isValidStateForFlags checks if the TCP flags are valid for the current connection state
|
||||
func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
||||
if !isValidFlagCombination(flags) {
|
||||
return false
|
||||
}
|
||||
if flags&TCPRst != 0 {
|
||||
if state == TCPStateSynSent {
|
||||
return flags&TCPAck != 0
|
||||
@@ -628,24 +449,15 @@ func (t *TCPTracker) cleanup() {
|
||||
timeout = t.waitTimeout
|
||||
case TCPStateEstablished:
|
||||
timeout = t.timeout
|
||||
case TCPStateFinWait1, TCPStateFinWait2, TCPStateClosing:
|
||||
timeout = t.finWaitTimeout
|
||||
case TCPStateCloseWait:
|
||||
timeout = t.closeWaitTimeout
|
||||
case TCPStateLastAck:
|
||||
timeout = t.lastAckTimeout
|
||||
default:
|
||||
// SynSent / SynReceived / New
|
||||
timeout = TCPHandshakeTimeout
|
||||
}
|
||||
|
||||
if conn.timeoutExceeded(timeout) {
|
||||
delete(t.connections, key)
|
||||
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
}
|
||||
t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
|
||||
// event already handled by state change
|
||||
if currentState != TCPStateTimeWait {
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// RST hygiene tests: the tracker currently closes the flow on any RST that
|
||||
// matches the 4-tuple, regardless of direction or state. These tests cover
|
||||
// the minimum checks we want (no SEQ tracking).
|
||||
|
||||
func TestTCPRstInSynSentWrongDirection(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateSynSent, conn.GetState())
|
||||
|
||||
// A RST arriving in the same direction as the SYN (i.e. TrackOutbound)
|
||||
// cannot be a legitimate response. It must not close the connection.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPRst|TCPAck, 0)
|
||||
require.Equal(t, TCPStateSynSent, conn.GetState(),
|
||||
"RST in same direction as SYN must not close connection")
|
||||
require.False(t, conn.IsTombstone())
|
||||
}
|
||||
|
||||
func TestTCPRstInTimeWaitIgnored(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
// Drive to TIME-WAIT via active close.
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0))
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
require.False(t, conn.IsTombstone(), "TIME-WAIT must not be tombstoned")
|
||||
|
||||
// Late RST during TIME-WAIT must not tombstone the entry (TIME-WAIT
|
||||
// exists to absorb late segments).
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState(),
|
||||
"RST in TIME-WAIT must not transition state")
|
||||
require.False(t, conn.IsTombstone(),
|
||||
"RST in TIME-WAIT must not tombstone the entry")
|
||||
}
|
||||
|
||||
func TestTCPIllegalFlagCombos(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
conn := tracker.connections[key]
|
||||
|
||||
// Illegal combos must be rejected and must not change state.
|
||||
combos := []struct {
|
||||
name string
|
||||
flags uint8
|
||||
}{
|
||||
{"SYN+RST", TCPSyn | TCPRst},
|
||||
{"FIN+RST", TCPFin | TCPRst},
|
||||
{"SYN+FIN", TCPSyn | TCPFin},
|
||||
{"SYN+FIN+RST", TCPSyn | TCPFin | TCPRst},
|
||||
}
|
||||
|
||||
for _, c := range combos {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
before := conn.GetState()
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, c.flags, 0)
|
||||
require.False(t, valid, "illegal flag combo must be rejected: %s", c.name)
|
||||
require.Equal(t, before, conn.GetState(),
|
||||
"illegal flag combo must not change state")
|
||||
require.False(t, conn.IsTombstone())
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,235 +0,0 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// These tests exercise cases where the TCP state machine currently advances
|
||||
// on retransmitted or wrong-direction segments and tears the flow down
|
||||
// prematurely. They are expected to fail until the direction checks are added.
|
||||
|
||||
func TestTCPCloseWaitRetransmittedPeerFIN(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Peer sends FIN -> CloseWait (our app has not yet closed).
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateCloseWait, conn.GetState())
|
||||
|
||||
// Peer retransmits their FIN (ACK may have been delayed). We have NOT
|
||||
// sent our FIN yet, so state must remain CloseWait.
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid, "retransmitted peer FIN must still be accepted")
|
||||
require.Equal(t, TCPStateCloseWait, conn.GetState(),
|
||||
"retransmitted peer FIN must not advance CloseWait to LastAck")
|
||||
|
||||
// Our app finally closes -> LastAck.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||
|
||||
// Peer ACK closes.
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||
}
|
||||
|
||||
func TestTCPFinWait2RetransmittedOwnFIN(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// We initiate close.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||
|
||||
// Stray retransmit of our own FIN (same direction as originator) must
|
||||
// NOT advance FinWait2 to TimeWait; only the peer's FIN should.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateFinWait2, conn.GetState(),
|
||||
"own FIN retransmit must not advance FinWait2 to TimeWait")
|
||||
|
||||
// Peer FIN -> TimeWait.
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
}
|
||||
|
||||
func TestTCPLastAckDirectionCheck(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Drive to LastAck: peer FIN -> CloseWait, our FIN -> LastAck.
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||
|
||||
// Our own ACK retransmit (same direction as originator) must NOT close.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
require.Equal(t, TCPStateLastAck, conn.GetState(),
|
||||
"own ACK retransmit in LastAck must not transition to Closed")
|
||||
|
||||
// Peer's ACK -> Closed.
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0))
|
||||
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||
}
|
||||
|
||||
func TestTCPFinWait1OwnAckDoesNotAdvance(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||
|
||||
// Our own ACK retransmit (same direction as originator) must not advance.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState(),
|
||||
"own ACK in FinWait1 must not advance to FinWait2")
|
||||
}
|
||||
|
||||
func TestTCPPerStateTeardownTimeouts(t *testing.T) {
|
||||
// Verify cleanup reaps entries in each teardown state at the configured
|
||||
// per-state timeout, not at the single handshake timeout.
|
||||
t.Setenv(EnvTCPFinWaitTimeout, "50ms")
|
||||
t.Setenv(EnvTCPCloseWaitTimeout, "80ms")
|
||||
t.Setenv(EnvTCPLastAckTimeout, "30ms")
|
||||
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Drives a connection to the target state, forces its lastSeen well
|
||||
// beyond the configured timeout, runs cleanup, and asserts reaping.
|
||||
cases := []struct {
|
||||
name string
|
||||
// drive takes a fresh tracker and returns the conn key after
|
||||
// transitioning the flow into the intended teardown state.
|
||||
drive func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState)
|
||||
}{
|
||||
{
|
||||
name: "FinWait1",
|
||||
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
|
||||
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
|
||||
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → FinWait1
|
||||
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait1
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "FinWait2",
|
||||
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
|
||||
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
|
||||
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // FinWait1
|
||||
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)) // → FinWait2
|
||||
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait2
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CloseWait",
|
||||
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
|
||||
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
|
||||
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // → CloseWait
|
||||
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateCloseWait
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "LastAck",
|
||||
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
|
||||
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
|
||||
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // CloseWait
|
||||
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → LastAck
|
||||
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateLastAck
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Use a unique source port per subtest so nothing aliases.
|
||||
port := uint16(12345)
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
require.Equal(t, 50*time.Millisecond, tracker.finWaitTimeout)
|
||||
require.Equal(t, 80*time.Millisecond, tracker.closeWaitTimeout)
|
||||
require.Equal(t, 30*time.Millisecond, tracker.lastAckTimeout)
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
port++
|
||||
key, wantState := c.drive(t, tracker, srcIP, port)
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, wantState, conn.GetState())
|
||||
|
||||
// Age the entry past the largest per-state timeout.
|
||||
conn.lastSeen.Store(time.Now().Add(-500 * time.Millisecond).UnixNano())
|
||||
tracker.cleanup()
|
||||
_, exists := tracker.connections[key]
|
||||
require.False(t, exists, "%s entry should be reaped", c.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTCPEstablishedPSHACKInFinStates(t *testing.T) {
|
||||
// Verifies FIN|PSH|ACK and bare ACK keepalives are not dropped in FIN
|
||||
// teardown states, which some stacks emit during close.
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Peer FIN -> CloseWait.
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
|
||||
|
||||
// Peer pushes trailing data + FIN|PSH|ACK (legal).
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPPush|TCPAck, 100),
|
||||
"FIN|PSH|ACK in CloseWait must be accepted")
|
||||
|
||||
// Bare ACK keepalive from peer in CloseWait must be accepted.
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0),
|
||||
"bare ACK in CloseWait must be accepted")
|
||||
}
|
||||
@@ -17,9 +17,6 @@ const (
|
||||
DefaultUDPTimeout = 30 * time.Second
|
||||
// UDPCleanupInterval is how often we check for stale connections
|
||||
UDPCleanupInterval = 15 * time.Second
|
||||
|
||||
// EnvUDPMaxEntries caps the UDP conntrack table size.
|
||||
EnvUDPMaxEntries = "NB_CONNTRACK_UDP_MAX"
|
||||
)
|
||||
|
||||
// UDPConnTrack represents a UDP connection state
|
||||
@@ -37,7 +34,6 @@ type UDPTracker struct {
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
mutex sync.RWMutex
|
||||
maxEntries int
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
@@ -55,7 +51,6 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
maxEntries: envInt(logger, EnvUDPMaxEntries, DefaultMaxUDPEntries),
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
@@ -122,18 +117,13 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
|
||||
conn.UpdateCounters(direction, size)
|
||||
|
||||
t.mutex.Lock()
|
||||
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
|
||||
t.evictOneLocked()
|
||||
}
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s UDP connection: %s", direction, key)
|
||||
}
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s UDP connection: %s", direction, key)
|
||||
}
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||
}
|
||||
@@ -161,34 +151,6 @@ func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
return true
|
||||
}
|
||||
|
||||
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
|
||||
// Bounded sample: picks the oldest among up to evictSampleSize entries.
|
||||
func (t *UDPTracker) evictOneLocked() {
|
||||
var candKey ConnKey
|
||||
var candSeen int64
|
||||
haveCand := false
|
||||
sampled := 0
|
||||
|
||||
for k, c := range t.connections {
|
||||
seen := c.lastSeen.Load()
|
||||
if !haveCand || seen < candSeen {
|
||||
candKey = k
|
||||
candSeen = seen
|
||||
haveCand = true
|
||||
}
|
||||
sampled++
|
||||
if sampled >= evictSampleSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
if haveCand {
|
||||
if evicted := t.connections[candKey]; evicted != nil {
|
||||
t.sendEvent(nftypes.TypeEnd, evicted, nil)
|
||||
}
|
||||
delete(t.connections, candKey)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupRoutine periodically removes stale connections
|
||||
func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
|
||||
defer t.cleanupTicker.Stop()
|
||||
@@ -211,10 +173,8 @@ func (t *UDPTracker) cleanup() {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
delete(t.connections, key)
|
||||
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
}
|
||||
t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,10 +18,9 @@ import (
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
@@ -36,10 +35,8 @@ import (
|
||||
const (
|
||||
layerTypeAll = 255
|
||||
|
||||
// ipv4TCPHeaderMinSize represents minimum IPv4 (20) + TCP (20) header size for MSS calculation
|
||||
ipv4TCPHeaderMinSize = 40
|
||||
// ipv6TCPHeaderMinSize represents minimum IPv6 (40) + TCP (20) header size for MSS calculation
|
||||
ipv6TCPHeaderMinSize = 60
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
)
|
||||
|
||||
// serviceKey represents a protocol/port combination for netstack service registry
|
||||
@@ -126,7 +123,7 @@ type Manager struct {
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
|
||||
blockRules []firewall.Rule
|
||||
blockRule firewall.Rule
|
||||
|
||||
// Internal 1:1 DNAT
|
||||
dnatEnabled atomic.Bool
|
||||
@@ -141,10 +138,9 @@ type Manager struct {
|
||||
netstackServices map[serviceKey]struct{}
|
||||
netstackServiceMutex sync.RWMutex
|
||||
|
||||
mtu uint16
|
||||
mssClampValueIPv4 uint16
|
||||
mssClampValueIPv6 uint16
|
||||
mssClampEnabled bool
|
||||
mtu uint16
|
||||
mssClampValue uint16
|
||||
mssClampEnabled bool
|
||||
|
||||
// Only one hook per protocol is supported. Outbound direction only.
|
||||
udpHookOut atomic.Pointer[common.PacketHook]
|
||||
@@ -161,28 +157,11 @@ type decoder struct {
|
||||
icmp4 layers.ICMPv4
|
||||
icmp6 layers.ICMPv6
|
||||
decoded []gopacket.LayerType
|
||||
parser4 *gopacket.DecodingLayerParser
|
||||
parser6 *gopacket.DecodingLayerParser
|
||||
parser *gopacket.DecodingLayerParser
|
||||
|
||||
dnatOrigPort uint16
|
||||
}
|
||||
|
||||
// decodePacket decodes packet data using the appropriate parser based on IP version.
|
||||
func (d *decoder) decodePacket(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return errors.New("empty packet")
|
||||
}
|
||||
version := data[0] >> 4
|
||||
switch version {
|
||||
case 4:
|
||||
return d.parser4.DecodeLayers(data, &d.decoded)
|
||||
case 6:
|
||||
return d.parser6.DecodeLayers(data, &d.decoded)
|
||||
default:
|
||||
return fmt.Errorf("unknown IP version %d", version)
|
||||
}
|
||||
}
|
||||
|
||||
// Create userspace firewall manager constructor
|
||||
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||
return create(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||
@@ -240,17 +219,11 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser4.IgnoreUnsupported = true
|
||||
|
||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv6,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser6.IgnoreUnsupported = true
|
||||
d.parser.IgnoreUnsupported = true
|
||||
return d
|
||||
},
|
||||
},
|
||||
@@ -276,12 +249,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
|
||||
if !disableMSSClamping {
|
||||
m.mssClampEnabled = true
|
||||
if mtu > ipv4TCPHeaderMinSize {
|
||||
m.mssClampValueIPv4 = mtu - ipv4TCPHeaderMinSize
|
||||
}
|
||||
if mtu > ipv6TCPHeaderMinSize {
|
||||
m.mssClampValueIPv6 = mtu - ipv6TCPHeaderMinSize
|
||||
}
|
||||
m.mssClampValue = mtu - ipTCPHeaderMinSize
|
||||
}
|
||||
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||
@@ -304,25 +272,13 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// blockInvalidRouted installs drop rules for traffic to the wg overlay that
|
||||
// arrives via the routing path. v4 and v6 are independent: a v6 install
|
||||
// failure leaves v4 protection in place (and vice versa) so the returned
|
||||
// slice always contains whatever was successfully installed, even on error.
|
||||
// Callers must persist the slice so DisableRouting can clean partial state.
|
||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule, error) {
|
||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
|
||||
wgPrefix := iface.Address().Network
|
||||
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||
|
||||
sources := []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
v6Net := iface.Address().IPv6Net
|
||||
if v6Net.IsValid() {
|
||||
sources = append(sources, netip.PrefixFrom(netip.IPv6Unspecified(), 0))
|
||||
}
|
||||
|
||||
var rules []firewall.Rule
|
||||
v4Rule, err := m.addRouteFiltering(
|
||||
rule, err := m.addRouteFiltering(
|
||||
nil,
|
||||
sources,
|
||||
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
||||
firewall.Network{Prefix: wgPrefix},
|
||||
firewall.ProtocolALL,
|
||||
nil,
|
||||
@@ -330,30 +286,12 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule,
|
||||
firewall.ActionDrop,
|
||||
)
|
||||
if err != nil {
|
||||
return rules, fmt.Errorf("block wg v4 net: %w", err)
|
||||
}
|
||||
rules = append(rules, v4Rule)
|
||||
|
||||
if v6Net.IsValid() {
|
||||
log.Debugf("blocking invalid routed traffic for %s", v6Net)
|
||||
v6Rule, err := m.addRouteFiltering(
|
||||
nil,
|
||||
sources,
|
||||
firewall.Network{Prefix: v6Net},
|
||||
firewall.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
firewall.ActionDrop,
|
||||
)
|
||||
if err != nil {
|
||||
return rules, fmt.Errorf("block wg v6 net: %w", err)
|
||||
}
|
||||
rules = append(rules, v6Rule)
|
||||
return nil, fmt.Errorf("block wg nte : %w", err)
|
||||
}
|
||||
|
||||
// TODO: Block networks that we're a client of
|
||||
|
||||
return rules, nil
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (m *Manager) determineRouting() error {
|
||||
@@ -583,7 +521,7 @@ func (m *Manager) addRouteFiltering(
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
dstSet: destination.Set,
|
||||
protoLayer: protoToLayer(proto, ipLayerFromPrefix(destination.Prefix)),
|
||||
protoLayer: protoToLayer(proto, layers.LayerTypeIPv4),
|
||||
srcPort: sPort,
|
||||
dstPort: dPort,
|
||||
action: action,
|
||||
@@ -674,10 +612,10 @@ func (m *Manager) Flush() error { return nil }
|
||||
// resetState clears all firewall rules and closes connection trackers.
|
||||
// Must be called with m.mutex held.
|
||||
func (m *Manager) resetState() {
|
||||
clear(m.outgoingRules)
|
||||
clear(m.incomingDenyRules)
|
||||
clear(m.incomingRules)
|
||||
clear(m.routeRulesMap)
|
||||
maps.Clear(m.outgoingRules)
|
||||
maps.Clear(m.incomingDenyRules)
|
||||
maps.Clear(m.incomingRules)
|
||||
maps.Clear(m.routeRulesMap)
|
||||
m.routeRules = m.routeRules[:0]
|
||||
m.udpHookOut.Store(nil)
|
||||
m.tcpHookOut.Store(nil)
|
||||
@@ -738,7 +676,11 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
}
|
||||
|
||||
destinations := matches[0].destinations
|
||||
destinations = append(destinations, prefixes...)
|
||||
for _, prefix := range prefixes {
|
||||
if prefix.Addr().Is4() {
|
||||
destinations = append(destinations, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
slices.SortFunc(destinations, func(a, b netip.Prefix) int {
|
||||
cmp := a.Addr().Compare(b.Addr())
|
||||
@@ -777,7 +719,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -787,9 +729,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
|
||||
srcIP, dstIP := m.extractIPs(d)
|
||||
if !srcIP.IsValid() {
|
||||
if m.logger.Enabled(nblog.LevelError) {
|
||||
m.logger.Error1("Unknown network layer: %v", d.decoded[0])
|
||||
}
|
||||
m.logger.Error1("Unknown network layer: %v", d.decoded[0])
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -863,32 +803,12 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
var mssClampValue uint16
|
||||
var ipHeaderSize int
|
||||
switch d.decoded[0] {
|
||||
case layers.LayerTypeIPv4:
|
||||
mssClampValue = m.mssClampValueIPv4
|
||||
ipHeaderSize = int(d.ip4.IHL) * 4
|
||||
if ipHeaderSize < 20 {
|
||||
return false
|
||||
}
|
||||
case layers.LayerTypeIPv6:
|
||||
mssClampValue = m.mssClampValueIPv6
|
||||
ipHeaderSize = 40
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
if mssClampValue == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
mssOptionIndex := -1
|
||||
var currentMSS uint16
|
||||
for i, opt := range d.tcp.Options {
|
||||
if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 {
|
||||
currentMSS = binary.BigEndian.Uint16(opt.OptionData)
|
||||
if currentMSS > mssClampValue {
|
||||
if currentMSS > m.mssClampValue {
|
||||
mssOptionIndex = i
|
||||
break
|
||||
}
|
||||
@@ -899,17 +819,20 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if !m.updateMSSOption(packetData, d, mssOptionIndex, mssClampValue, ipHeaderSize) {
|
||||
ipHeaderSize := int(d.ip4.IHL) * 4
|
||||
if ipHeaderSize < 20 {
|
||||
return false
|
||||
}
|
||||
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue)
|
||||
if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) {
|
||||
return false
|
||||
}
|
||||
|
||||
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex int, mssClampValue uint16, ipHeaderSize int) bool {
|
||||
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool {
|
||||
tcpHeaderStart := ipHeaderSize
|
||||
tcpOptionsStart := tcpHeaderStart + 20
|
||||
|
||||
@@ -924,7 +847,7 @@ func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex
|
||||
}
|
||||
|
||||
mssValueOffset := optOffset + 2
|
||||
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], mssClampValue)
|
||||
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue)
|
||||
|
||||
m.recalculateTCPChecksum(packetData, d, tcpHeaderStart)
|
||||
return true
|
||||
@@ -934,32 +857,18 @@ func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeade
|
||||
tcpLayer := packetData[tcpHeaderStart:]
|
||||
tcpLength := len(packetData) - tcpHeaderStart
|
||||
|
||||
// Zero out existing checksum
|
||||
tcpLayer[16] = 0
|
||||
tcpLayer[17] = 0
|
||||
|
||||
// Build pseudo-header checksum based on IP version
|
||||
var pseudoSum uint32
|
||||
switch d.decoded[0] {
|
||||
case layers.LayerTypeIPv4:
|
||||
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
|
||||
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
|
||||
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
|
||||
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
|
||||
pseudoSum += uint32(d.ip4.Protocol)
|
||||
pseudoSum += uint32(tcpLength)
|
||||
case layers.LayerTypeIPv6:
|
||||
for i := 0; i < 16; i += 2 {
|
||||
pseudoSum += uint32(d.ip6.SrcIP[i])<<8 | uint32(d.ip6.SrcIP[i+1])
|
||||
}
|
||||
for i := 0; i < 16; i += 2 {
|
||||
pseudoSum += uint32(d.ip6.DstIP[i])<<8 | uint32(d.ip6.DstIP[i+1])
|
||||
}
|
||||
pseudoSum += uint32(tcpLength)
|
||||
pseudoSum += uint32(layers.IPProtocolTCP)
|
||||
}
|
||||
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
|
||||
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
|
||||
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
|
||||
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
|
||||
pseudoSum += uint32(d.ip4.Protocol)
|
||||
pseudoSum += uint32(tcpLength)
|
||||
|
||||
sum := pseudoSum
|
||||
var sum = pseudoSum
|
||||
for i := 0; i < tcpLength-1; i += 2 {
|
||||
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
|
||||
}
|
||||
@@ -997,9 +906,6 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData
|
||||
}
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
|
||||
case layers.LayerTypeICMPv6:
|
||||
id, tc := icmpv6EchoFields(d)
|
||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, id, tc, d.icmp6.Payload, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1013,9 +919,6 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
||||
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort)
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
|
||||
case layers.LayerTypeICMPv6:
|
||||
id, tc := icmpv6EchoFields(d)
|
||||
m.icmpTracker.TrackInbound(srcIP, dstIP, id, tc, ruleID, d.icmp6.Payload, size)
|
||||
}
|
||||
|
||||
d.dnatOrigPort = 0
|
||||
@@ -1048,21 +951,15 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
|
||||
// TODO: pass fragments of routed packets to forwarder
|
||||
if fragment {
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
if d.decoded[0] == layers.LayerTypeIPv4 {
|
||||
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
|
||||
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
|
||||
} else {
|
||||
m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP)
|
||||
}
|
||||
}
|
||||
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
|
||||
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
|
||||
return false
|
||||
}
|
||||
|
||||
// TODO: optimize port DNAT by caching matched rules in conntrack
|
||||
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
|
||||
// Re-decode after port DNAT translation to update port information
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
|
||||
return true
|
||||
}
|
||||
@@ -1071,7 +968,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
|
||||
if translated := m.translateInboundReverse(packetData, d); translated {
|
||||
// Re-decode after translation to get original addresses
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err)
|
||||
return true
|
||||
}
|
||||
@@ -1097,10 +994,8 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
||||
pnum := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||
}
|
||||
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
@@ -1150,10 +1045,8 @@ func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
|
||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
|
||||
// Drop if routing is disabled
|
||||
if !m.routingEnabled.Load() {
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||
srcIP, dstIP)
|
||||
}
|
||||
m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||
srcIP, dstIP)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1170,10 +1063,8 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
||||
if !pass {
|
||||
proto := getProtocolFromPacket(d)
|
||||
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
|
||||
}
|
||||
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
@@ -1209,48 +1100,6 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
||||
return true
|
||||
}
|
||||
|
||||
// icmpv6EchoFields extracts the echo identifier from an ICMPv6 packet and maps
|
||||
// the ICMPv6 type code to an ICMPv4TypeCode so the ICMP conntrack can handle
|
||||
// both families uniformly. The echo ID is in the first two payload bytes.
|
||||
func icmpv6EchoFields(d *decoder) (id uint16, tc layers.ICMPv4TypeCode) {
|
||||
if len(d.icmp6.Payload) >= 2 {
|
||||
id = uint16(d.icmp6.Payload[0])<<8 | uint16(d.icmp6.Payload[1])
|
||||
}
|
||||
// Map ICMPv6 echo types to ICMPv4 equivalents for unified tracking.
|
||||
switch d.icmp6.TypeCode.Type() {
|
||||
case layers.ICMPv6TypeEchoRequest:
|
||||
tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0)
|
||||
case layers.ICMPv6TypeEchoReply:
|
||||
tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoReply, 0)
|
||||
default:
|
||||
tc = layers.CreateICMPv4TypeCode(d.icmp6.TypeCode.Type(), d.icmp6.TypeCode.Code())
|
||||
}
|
||||
return id, tc
|
||||
}
|
||||
|
||||
// protoLayerMatches checks if a packet's protocol layer matches a rule's expected
|
||||
// protocol layer. ICMPv4 and ICMPv6 are treated as equivalent when matching
|
||||
// ICMP rules since management sends a single ICMP rule for both families.
|
||||
func protoLayerMatches(ruleLayer, packetLayer gopacket.LayerType) bool {
|
||||
if ruleLayer == packetLayer {
|
||||
return true
|
||||
}
|
||||
if ruleLayer == layers.LayerTypeICMPv4 && packetLayer == layers.LayerTypeICMPv6 {
|
||||
return true
|
||||
}
|
||||
if ruleLayer == layers.LayerTypeICMPv6 && packetLayer == layers.LayerTypeICMPv4 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func ipLayerFromPrefix(p netip.Prefix) gopacket.LayerType {
|
||||
if p.Addr().Is6() {
|
||||
return layers.LayerTypeIPv6
|
||||
}
|
||||
return layers.LayerTypeIPv4
|
||||
}
|
||||
|
||||
func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType {
|
||||
switch proto {
|
||||
case firewall.ProtocolTCP:
|
||||
@@ -1274,10 +1123,8 @@ func getProtocolFromPacket(d *decoder) nftypes.Protocol {
|
||||
return nftypes.TCP
|
||||
case layers.LayerTypeUDP:
|
||||
return nftypes.UDP
|
||||
case layers.LayerTypeICMPv4:
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return nftypes.ICMP
|
||||
case layers.LayerTypeICMPv6:
|
||||
return nftypes.ICMPv6
|
||||
default:
|
||||
return nftypes.ProtocolUnknown
|
||||
}
|
||||
@@ -1298,10 +1145,8 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
|
||||
// It returns true, false if the packet is valid and not a fragment.
|
||||
// It returns true, true if the packet is a fragment and valid.
|
||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace1("couldn't decode packet, err: %s", err)
|
||||
}
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
m.logger.Trace1("couldn't decode packet, err: %s", err)
|
||||
return false, false
|
||||
}
|
||||
|
||||
@@ -1313,21 +1158,10 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
|
||||
}
|
||||
|
||||
// Fragments are also valid
|
||||
if l == 1 {
|
||||
switch d.decoded[0] {
|
||||
case layers.LayerTypeIPv4:
|
||||
if d.ip4.Flags&layers.IPv4MoreFragments != 0 || d.ip4.FragOffset != 0 {
|
||||
return true, true
|
||||
}
|
||||
case layers.LayerTypeIPv6:
|
||||
// IPv6 uses Fragment extension header (NextHeader=44). If gopacket
|
||||
// only decoded the IPv6 layer, the transport is in a fragment.
|
||||
// TODO: handle non-Fragment extension headers (HopByHop, Routing,
|
||||
// DestOpts) by walking the chain. gopacket's parser does not
|
||||
// support them as DecodingLayers; today we drop such packets.
|
||||
if d.ip6.NextHeader == layers.IPProtocolIPv6Fragment {
|
||||
return true, true
|
||||
}
|
||||
if l == 1 && d.decoded[0] == layers.LayerTypeIPv4 {
|
||||
ip4 := d.ip4
|
||||
if ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0 {
|
||||
return true, true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1365,35 +1199,21 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr,
|
||||
size,
|
||||
)
|
||||
|
||||
case layers.LayerTypeICMPv6:
|
||||
id, _ := icmpv6EchoFields(d)
|
||||
return m.icmpTracker.IsValidInbound(
|
||||
srcIP,
|
||||
dstIP,
|
||||
id,
|
||||
d.icmp6.TypeCode.Type(),
|
||||
size,
|
||||
)
|
||||
// TODO: ICMPv6
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isSpecialICMP returns true if the packet is a special ICMP error packet that should be allowed.
|
||||
// isSpecialICMP returns true if the packet is a special ICMP packet that should be allowed
|
||||
func (m *Manager) isSpecialICMP(d *decoder) bool {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeICMPv4:
|
||||
icmpType := d.icmp4.TypeCode.Type()
|
||||
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
|
||||
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||
case layers.LayerTypeICMPv6:
|
||||
icmpType := d.icmp6.TypeCode.Type()
|
||||
return icmpType == layers.ICMPv6TypeDestinationUnreachable ||
|
||||
icmpType == layers.ICMPv6TypePacketTooBig ||
|
||||
icmpType == layers.ICMPv6TypeTimeExceeded ||
|
||||
icmpType == layers.ICMPv6TypeParameterProblem
|
||||
if d.decoded[1] != layers.LayerTypeICMPv4 {
|
||||
return false
|
||||
}
|
||||
return false
|
||||
|
||||
icmpType := d.icmp4.TypeCode.Type()
|
||||
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
|
||||
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||
}
|
||||
|
||||
func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte) ([]byte, bool) {
|
||||
@@ -1450,7 +1270,7 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
|
||||
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
|
||||
if payloadLayer != rule.protoLayer {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1485,7 +1305,8 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.Lay
|
||||
}
|
||||
|
||||
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool {
|
||||
if rule.protoLayer != layerTypeAll && !protoLayerMatches(rule.protoLayer, protoLayer) {
|
||||
// TODO: handle ipv6 vs ipv4 icmp rules
|
||||
if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1546,14 +1367,13 @@ func (m *Manager) EnableRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
rules, err := m.blockInvalidRouted(m.wgIface)
|
||||
// Persist whatever was installed even on partial failure, so DisableRouting
|
||||
// can clean it up later.
|
||||
m.blockRules = rules
|
||||
rule, err := m.blockInvalidRouted(m.wgIface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("block invalid routed: %w", err)
|
||||
}
|
||||
|
||||
m.blockRule = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1569,16 +1389,9 @@ func (m *Manager) DisableRouting() error {
|
||||
m.routingEnabled.Store(false)
|
||||
m.nativeRouter.Store(false)
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, rule := range m.blockRules {
|
||||
if err := m.deleteRouteRule(rule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete block rule: %w", err))
|
||||
}
|
||||
}
|
||||
m.blockRules = nil
|
||||
|
||||
// don't stop forwarder if in use by netstack
|
||||
if m.netstack && m.localForwarding {
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
return nil
|
||||
}
|
||||
|
||||
fwder.Stop()
|
||||
@@ -1586,7 +1399,14 @@ func (m *Manager) DisableRouting() error {
|
||||
|
||||
log.Debug("forwarder stopped")
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
if m.blockRule != nil {
|
||||
if err := m.deleteRouteRule(m.blockRule); err != nil {
|
||||
return fmt.Errorf("delete block rule: %w", err)
|
||||
}
|
||||
m.blockRule = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterNetstackService registers a service as listening on the netstack for the given protocol and port
|
||||
@@ -1640,8 +1460,7 @@ func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool {
|
||||
}
|
||||
|
||||
// traffic to our other local interfaces (not NetBird IP) - always forward
|
||||
addr := m.wgIface.Address()
|
||||
if dstIP != addr.IP && (!addr.IPv6.IsValid() || dstIP != addr.IPv6) {
|
||||
if dstIP != m.wgIface.Address().IP {
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -1023,8 +1023,7 @@ func BenchmarkMSSClamping(b *testing.B) {
|
||||
}()
|
||||
|
||||
manager.mssClampEnabled = true
|
||||
manager.mssClampValueIPv4 = 1240
|
||||
manager.mssClampValueIPv6 = 1220
|
||||
manager.mssClampValue = 1240
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.2")
|
||||
dstIP := net.ParseIP("8.8.8.8")
|
||||
@@ -1089,8 +1088,7 @@ func BenchmarkMSSClampingOverhead(b *testing.B) {
|
||||
|
||||
manager.mssClampEnabled = sc.enabled
|
||||
if sc.enabled {
|
||||
manager.mssClampValueIPv4 = 1240
|
||||
manager.mssClampValueIPv6 = 1220
|
||||
manager.mssClampValue = 1240
|
||||
}
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.2")
|
||||
@@ -1143,8 +1141,7 @@ func BenchmarkMSSClampingMemory(b *testing.B) {
|
||||
}()
|
||||
|
||||
manager.mssClampEnabled = true
|
||||
manager.mssClampValueIPv4 = 1240
|
||||
manager.mssClampValueIPv6 = 1220
|
||||
manager.mssClampValue = 1240
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.2")
|
||||
dstIP := net.ParseIP("8.8.8.8")
|
||||
|
||||
@@ -539,236 +539,53 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerACLFilteringIPv6(t *testing.T) {
|
||||
localIP := netip.MustParseAddr("100.10.0.100")
|
||||
localIPv6 := netip.MustParseAddr("fd00::100")
|
||||
wgNet := netip.MustParsePrefix("100.10.0.0/16")
|
||||
wgNetV6 := netip.MustParsePrefix("fd00::/64")
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: localIP,
|
||||
Network: wgNet,
|
||||
IPv6: localIPv6,
|
||||
IPv6Net: wgNetV6,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||
|
||||
err = manager.UpdateLocalIPs()
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
srcIP string
|
||||
dstIP string
|
||||
proto fw.Protocol
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
ruleIP string
|
||||
ruleProto fw.Protocol
|
||||
ruleDstPort *fw.Port
|
||||
ruleAction fw.Action
|
||||
shouldBeBlocked bool
|
||||
}{
|
||||
{
|
||||
name: "IPv6: allow TCP from peer",
|
||||
srcIP: "fd00::1",
|
||||
dstIP: "fd00::100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "fd00::1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6: allow UDP from peer",
|
||||
srcIP: "fd00::1",
|
||||
dstIP: "fd00::100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 53,
|
||||
ruleIP: "fd00::1",
|
||||
ruleProto: fw.ProtocolUDP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{53}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6: allow ICMPv6 from peer",
|
||||
srcIP: "fd00::1",
|
||||
dstIP: "fd00::100",
|
||||
proto: fw.ProtocolICMP,
|
||||
ruleIP: "fd00::1",
|
||||
ruleProto: fw.ProtocolICMP,
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6: block TCP without rule",
|
||||
srcIP: "fd00::2",
|
||||
dstIP: "fd00::100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "fd00::1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6: drop rule",
|
||||
srcIP: "fd00::1",
|
||||
dstIP: "fd00::100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 22,
|
||||
ruleIP: "fd00::1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{22}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6: allow all protocols",
|
||||
srcIP: "fd00::1",
|
||||
dstIP: "fd00::100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 9999,
|
||||
ruleIP: "fd00::1",
|
||||
ruleProto: fw.ProtocolALL,
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6: v4 wildcard ICMP rule matches ICMPv6 via protoLayerMatches",
|
||||
srcIP: "fd00::1",
|
||||
dstIP: "fd00::100",
|
||||
proto: fw.ProtocolICMP,
|
||||
ruleIP: "0.0.0.0",
|
||||
ruleProto: fw.ProtocolICMP,
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("IPv6 implicit DROP (no rules)", func(t *testing.T) {
|
||||
packet := createTestPacket(t, "fd00::1", "fd00::100", fw.ProtocolTCP, 12345, 443)
|
||||
isDropped := manager.FilterInbound(packet, 0)
|
||||
require.True(t, isDropped, "IPv6 packet should be dropped when no rules exist")
|
||||
})
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.ruleAction == fw.ActionDrop {
|
||||
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeletePeerRule(rule))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction, "")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, rules)
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeletePeerRule(rule))
|
||||
}
|
||||
})
|
||||
|
||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
isDropped := manager.FilterInbound(packet, 0)
|
||||
require.Equal(t, tc.shouldBeBlocked, isDropped, "packet filter result mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcPort, dstPort uint16) []byte {
|
||||
t.Helper()
|
||||
|
||||
src := net.ParseIP(srcIP)
|
||||
dst := net.ParseIP(dstIP)
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{
|
||||
ComputeChecksums: true,
|
||||
FixLengths: true,
|
||||
}
|
||||
|
||||
// Detect address family
|
||||
isV6 := src.To4() == nil
|
||||
ipLayer := &layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
SrcIP: net.ParseIP(srcIP),
|
||||
DstIP: net.ParseIP(dstIP),
|
||||
}
|
||||
|
||||
var err error
|
||||
switch proto {
|
||||
case fw.ProtocolTCP:
|
||||
ipLayer.Protocol = layers.IPProtocolTCP
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
}
|
||||
err = tcp.SetNetworkLayerForChecksum(ipLayer)
|
||||
require.NoError(t, err)
|
||||
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcp)
|
||||
|
||||
if isV6 {
|
||||
ip6 := &layers.IPv6{
|
||||
Version: 6,
|
||||
HopLimit: 64,
|
||||
SrcIP: src,
|
||||
DstIP: dst,
|
||||
case fw.ProtocolUDP:
|
||||
ipLayer.Protocol = layers.IPProtocolUDP
|
||||
udp := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(srcPort),
|
||||
DstPort: layers.UDPPort(dstPort),
|
||||
}
|
||||
err = udp.SetNetworkLayerForChecksum(ipLayer)
|
||||
require.NoError(t, err)
|
||||
err = gopacket.SerializeLayers(buf, opts, ipLayer, udp)
|
||||
|
||||
switch proto {
|
||||
case fw.ProtocolTCP:
|
||||
ip6.NextHeader = layers.IPProtocolTCP
|
||||
tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)}
|
||||
_ = tcp.SetNetworkLayerForChecksum(ip6)
|
||||
err = gopacket.SerializeLayers(buf, opts, ip6, tcp)
|
||||
case fw.ProtocolUDP:
|
||||
ip6.NextHeader = layers.IPProtocolUDP
|
||||
udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)}
|
||||
_ = udp.SetNetworkLayerForChecksum(ip6)
|
||||
err = gopacket.SerializeLayers(buf, opts, ip6, udp)
|
||||
case fw.ProtocolICMP:
|
||||
ip6.NextHeader = layers.IPProtocolICMPv6
|
||||
icmp := &layers.ICMPv6{
|
||||
TypeCode: layers.CreateICMPv6TypeCode(layers.ICMPv6TypeEchoRequest, 0),
|
||||
}
|
||||
_ = icmp.SetNetworkLayerForChecksum(ip6)
|
||||
err = gopacket.SerializeLayers(buf, opts, ip6, icmp)
|
||||
default:
|
||||
err = gopacket.SerializeLayers(buf, opts, ip6)
|
||||
}
|
||||
} else {
|
||||
ip4 := &layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
SrcIP: src,
|
||||
DstIP: dst,
|
||||
case fw.ProtocolICMP:
|
||||
ipLayer.Protocol = layers.IPProtocolICMPv4
|
||||
icmp := &layers.ICMPv4{
|
||||
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
|
||||
}
|
||||
err = gopacket.SerializeLayers(buf, opts, ipLayer, icmp)
|
||||
|
||||
switch proto {
|
||||
case fw.ProtocolTCP:
|
||||
ip4.Protocol = layers.IPProtocolTCP
|
||||
tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)}
|
||||
_ = tcp.SetNetworkLayerForChecksum(ip4)
|
||||
err = gopacket.SerializeLayers(buf, opts, ip4, tcp)
|
||||
case fw.ProtocolUDP:
|
||||
ip4.Protocol = layers.IPProtocolUDP
|
||||
udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)}
|
||||
_ = udp.SetNetworkLayerForChecksum(ip4)
|
||||
err = gopacket.SerializeLayers(buf, opts, ip4, udp)
|
||||
case fw.ProtocolICMP:
|
||||
ip4.Protocol = layers.IPProtocolICMPv4
|
||||
icmp := &layers.ICMPv4{TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0)}
|
||||
err = gopacket.SerializeLayers(buf, opts, ip4, icmp)
|
||||
default:
|
||||
err = gopacket.SerializeLayers(buf, opts, ip4)
|
||||
}
|
||||
default:
|
||||
err = gopacket.SerializeLayers(buf, opts, ipLayer)
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
@@ -1681,103 +1498,3 @@ func TestRouteACLSet(t *testing.T) {
|
||||
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
|
||||
}
|
||||
|
||||
// TestRouteACLFilteringIPv6 tests IPv6 route ACL matching directly via routeACLsPass.
|
||||
// Note: full FilterInbound for routed IPv6 traffic drops at the forwarder stage (IPv4-only)
|
||||
// but the ACL decision itself is correct.
|
||||
func TestRouteACLFilteringIPv6(t *testing.T) {
|
||||
manager := setupRoutedManager(t, "10.10.0.100/16")
|
||||
|
||||
v6Dst := netip.MustParsePrefix("fd00:dead:beef::/48")
|
||||
_, err := manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
||||
fw.Network{Prefix: v6Dst},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("fd00:dead:beef:1::/64")},
|
||||
fw.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionDrop,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
srcIP netip.Addr
|
||||
dstIP netip.Addr
|
||||
proto gopacket.LayerType
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
allowed bool
|
||||
}{
|
||||
{
|
||||
name: "IPv6 TCP to allowed dest",
|
||||
srcIP: netip.MustParseAddr("fd00::1"),
|
||||
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
||||
proto: layers.LayerTypeTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
allowed: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 TCP wrong port",
|
||||
srcIP: netip.MustParseAddr("fd00::1"),
|
||||
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
||||
proto: layers.LayerTypeTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
allowed: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 UDP not matched by TCP rule",
|
||||
srcIP: netip.MustParseAddr("fd00::1"),
|
||||
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
||||
proto: layers.LayerTypeUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
allowed: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 ICMPv6 matches ICMP rule via protoLayerMatches",
|
||||
srcIP: netip.MustParseAddr("fd00::1"),
|
||||
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
||||
proto: layers.LayerTypeICMPv6,
|
||||
allowed: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 to denied subnet",
|
||||
srcIP: netip.MustParseAddr("fd00::1"),
|
||||
dstIP: netip.MustParseAddr("fd00:dead:beef:1::1"),
|
||||
proto: layers.LayerTypeTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
allowed: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 source outside allowed range",
|
||||
srcIP: netip.MustParseAddr("fe80::1"),
|
||||
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
||||
proto: layers.LayerTypeTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
allowed: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, pass := manager.routeACLsPass(tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
require.Equal(t, tc.allowed, pass, "route ACL result mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,21 +189,21 @@ func TestBlockInvalidRoutedIdempotent(t *testing.T) {
|
||||
})
|
||||
|
||||
// Call blockInvalidRouted directly multiple times
|
||||
rules1, err := manager.blockInvalidRouted(ifaceMock)
|
||||
rule1, err := manager.blockInvalidRouted(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, rules1)
|
||||
require.NotNil(t, rule1)
|
||||
|
||||
rules2, err := manager.blockInvalidRouted(ifaceMock)
|
||||
rule2, err := manager.blockInvalidRouted(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, rules2)
|
||||
require.NotNil(t, rule2)
|
||||
|
||||
rules3, err := manager.blockInvalidRouted(ifaceMock)
|
||||
rule3, err := manager.blockInvalidRouted(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, rules3)
|
||||
require.NotNil(t, rule3)
|
||||
|
||||
// All calls should return the same v4 block rule (idempotent install).
|
||||
assert.Equal(t, rules1[0].ID(), rules2[0].ID(), "Second call should return same v4 rule")
|
||||
assert.Equal(t, rules2[0].ID(), rules3[0].ID(), "Third call should return same v4 rule")
|
||||
// All should return the same rule
|
||||
assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule")
|
||||
assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule")
|
||||
|
||||
// Should have exactly 1 route rule
|
||||
manager.mutex.RLock()
|
||||
|
||||
@@ -535,16 +535,11 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser4.IgnoreUnsupported = true
|
||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv6,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser6.IgnoreUnsupported = true
|
||||
d.parser.IgnoreUnsupported = true
|
||||
return d
|
||||
},
|
||||
}
|
||||
@@ -643,16 +638,11 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser4.IgnoreUnsupported = true
|
||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv6,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser6.IgnoreUnsupported = true
|
||||
d.parser.IgnoreUnsupported = true
|
||||
return d
|
||||
},
|
||||
}
|
||||
@@ -1058,8 +1048,8 @@ func TestMSSClamping(t *testing.T) {
|
||||
}()
|
||||
|
||||
require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default")
|
||||
require.Equal(t, uint16(1280-ipv4TCPHeaderMinSize), manager.mssClampValueIPv4, "IPv4 MSS clamp value should be MTU - 40")
|
||||
require.Equal(t, uint16(1280-ipv6TCPHeaderMinSize), manager.mssClampValueIPv6, "IPv6 MSS clamp value should be MTU - 60")
|
||||
expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize)
|
||||
require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40")
|
||||
|
||||
err = manager.UpdateLocalIPs()
|
||||
require.NoError(t, err)
|
||||
@@ -1077,7 +1067,7 @@ func TestMSSClamping(t *testing.T) {
|
||||
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
||||
require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType))
|
||||
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
||||
require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS should be clamped to MTU - 40")
|
||||
require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40")
|
||||
})
|
||||
|
||||
t.Run("SYN packet with low MSS unchanged", func(t *testing.T) {
|
||||
@@ -1101,7 +1091,7 @@ func TestMSSClamping(t *testing.T) {
|
||||
d := parsePacket(t, packet)
|
||||
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
||||
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
||||
require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS in SYN-ACK should be clamped")
|
||||
require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped")
|
||||
})
|
||||
|
||||
t.Run("Non-SYN packet unchanged", func(t *testing.T) {
|
||||
@@ -1273,18 +1263,13 @@ func TestShouldForward(t *testing.T) {
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser4.IgnoreUnsupported = true
|
||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv6,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser6.IgnoreUnsupported = true
|
||||
d.parser.IgnoreUnsupported = true
|
||||
|
||||
err = d.decodePacket(buf.Bytes())
|
||||
err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded)
|
||||
require.NoError(t, err)
|
||||
|
||||
return d
|
||||
@@ -1344,44 +1329,6 @@ func TestShouldForward(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// Add IPv6 to the interface and test dual-stack cases
|
||||
wgIPv6 := netip.MustParseAddr("fd00::1")
|
||||
otherIPv6 := netip.MustParseAddr("fd00::2")
|
||||
ifaceMock.AddressFunc = func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: wgIP,
|
||||
Network: netip.PrefixFrom(wgIP, 24),
|
||||
IPv6: wgIPv6,
|
||||
IPv6Net: netip.PrefixFrom(wgIPv6, 64),
|
||||
}
|
||||
}
|
||||
|
||||
// Re-create manager to pick up the new address with IPv6
|
||||
require.NoError(t, manager.Close(nil))
|
||||
manager, err = Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
v6Cases := []struct {
|
||||
name string
|
||||
dstIP netip.Addr
|
||||
expected bool
|
||||
description string
|
||||
}{
|
||||
{"v6 traffic to other address", otherIPv6, true, "should forward v6 traffic not destined to our v6 address"},
|
||||
{"v6 traffic to our v6 IP", wgIPv6, false, "should not forward traffic destined to our v6 address"},
|
||||
{"v4 traffic to other with v6 configured", otherIP, true, "should forward v4 traffic when v6 configured"},
|
||||
{"v4 traffic to our v4 IP with v6 configured", wgIP, false, "should not forward traffic to our v4 address"},
|
||||
}
|
||||
for _, tt := range v6Cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager.localForwarding = true
|
||||
manager.netstack = false
|
||||
decoder := createTCPDecoder(8080)
|
||||
result := manager.shouldForward(decoder, tt.dstIP)
|
||||
require.Equal(t, tt.expected, result, tt.description)
|
||||
})
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Configure manager
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
@@ -55,23 +54,16 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
|
||||
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
||||
var written int
|
||||
for _, pkt := range pkts.AsSlice() {
|
||||
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
|
||||
|
||||
data := stack.PayloadSince(pkt.NetworkHeader())
|
||||
if data == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
raw := pkt.NetworkHeader().View().AsSlice()
|
||||
if len(raw) == 0 {
|
||||
continue
|
||||
}
|
||||
var address tcpip.Address
|
||||
if raw[0]>>4 == 6 {
|
||||
address = header.IPv6(raw).DestinationAddress()
|
||||
} else {
|
||||
address = header.IPv4(raw).DestinationAddress()
|
||||
}
|
||||
|
||||
pktBytes := data.AsSlice()
|
||||
|
||||
address := netHeader.DestinationAddress()
|
||||
if err := e.device.CreateOutboundPacket(pktBytes, address.AsSlice()); err != nil {
|
||||
e.logger.Error1("CreateOutboundPacket: %v", err)
|
||||
continue
|
||||
@@ -122,7 +114,5 @@ type epID stack.TransportEndpointID
|
||||
|
||||
func (i epID) String() string {
|
||||
// src and remote is swapped
|
||||
return net.JoinHostPort(i.RemoteAddress.String(), strconv.Itoa(int(i.RemotePort))) +
|
||||
" → " +
|
||||
net.JoinHostPort(i.LocalAddress.String(), strconv.Itoa(int(i.LocalPort)))
|
||||
return fmt.Sprintf("%s:%d → %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
@@ -37,31 +36,25 @@ type Forwarder struct {
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
// ruleIdMap is used to store the rule ID for a given connection
|
||||
ruleIdMap sync.Map
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip tcpip.Address
|
||||
ipv6 tcpip.Address
|
||||
netstack bool
|
||||
hasRawICMPAccess bool
|
||||
hasRawICMPv6Access bool
|
||||
pingSemaphore chan struct{}
|
||||
ruleIdMap sync.Map
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip tcpip.Address
|
||||
netstack bool
|
||||
hasRawICMPAccess bool
|
||||
pingSemaphore chan struct{}
|
||||
}
|
||||
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{
|
||||
ipv4.NewProtocol,
|
||||
ipv6.NewProtocol,
|
||||
},
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{
|
||||
tcp.NewProtocol,
|
||||
udp.NewProtocol,
|
||||
icmp.NewProtocol4,
|
||||
icmp.NewProtocol6,
|
||||
},
|
||||
HandleLocal: false,
|
||||
})
|
||||
@@ -80,7 +73,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFrom4(iface.Address().IP.As4()),
|
||||
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
PrefixLen: iface.Address().Network.Bits(),
|
||||
},
|
||||
}
|
||||
@@ -89,19 +82,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
return nil, fmt.Errorf("failed to add protocol address: %s", err)
|
||||
}
|
||||
|
||||
if v6 := iface.Address().IPv6; v6.IsValid() {
|
||||
v6Addr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv6.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFrom16(v6.As16()),
|
||||
PrefixLen: iface.Address().IPv6Net.Bits(),
|
||||
},
|
||||
}
|
||||
if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil {
|
||||
return nil, fmt.Errorf("add IPv6 protocol address: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
defaultSubnet, err := tcpip.NewSubnet(
|
||||
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
||||
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
||||
@@ -110,14 +90,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
return nil, fmt.Errorf("creating default subnet: %w", err)
|
||||
}
|
||||
|
||||
defaultSubnetV6, err := tcpip.NewSubnet(
|
||||
tcpip.AddrFrom16([16]byte{}),
|
||||
tcpip.MaskFromBytes(make([]byte, 16)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating default v6 subnet: %w", err)
|
||||
}
|
||||
|
||||
if err := s.SetPromiscuousMode(nicID, true); err != nil {
|
||||
return nil, fmt.Errorf("set promiscuous mode: %s", err)
|
||||
}
|
||||
@@ -126,8 +98,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
}
|
||||
|
||||
s.SetRouteTable([]tcpip.Route{
|
||||
{Destination: defaultSubnet, NIC: nicID},
|
||||
{Destination: defaultSubnetV6, NIC: nicID},
|
||||
{
|
||||
Destination: defaultSubnet,
|
||||
NIC: nicID,
|
||||
},
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -140,8 +114,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
ip: tcpip.AddrFrom4(iface.Address().IP.As4()),
|
||||
ipv6: addrFromNetipAddr(iface.Address().IPv6),
|
||||
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
pingSemaphore: make(chan struct{}, 3),
|
||||
}
|
||||
|
||||
@@ -158,10 +131,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
udpForwarder := udp.NewForwarder(s, f.handleUDP)
|
||||
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||
|
||||
// ICMP is handled directly in InjectIncomingPacket, bypassing gVisor's
|
||||
// network layer. This avoids duplicate echo replies (v4) and the v6
|
||||
// auto-reply bug where gVisor responds at the network layer before
|
||||
// our transport handler fires.
|
||||
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
|
||||
|
||||
f.checkICMPCapability()
|
||||
|
||||
@@ -180,30 +150,8 @@ func (f *Forwarder) SetCapture(pc PacketCapture) {
|
||||
}
|
||||
|
||||
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||
if len(payload) == 0 {
|
||||
return fmt.Errorf("empty packet")
|
||||
}
|
||||
|
||||
var protoNum tcpip.NetworkProtocolNumber
|
||||
switch payload[0] >> 4 {
|
||||
case 4:
|
||||
if len(payload) < header.IPv4MinimumSize {
|
||||
return fmt.Errorf("IPv4 packet too small: %d bytes", len(payload))
|
||||
}
|
||||
if f.handleICMPDirect(payload) {
|
||||
return nil
|
||||
}
|
||||
protoNum = ipv4.ProtocolNumber
|
||||
case 6:
|
||||
if len(payload) < header.IPv6MinimumSize {
|
||||
return fmt.Errorf("IPv6 packet too small: %d bytes", len(payload))
|
||||
}
|
||||
if f.handleICMPDirect(payload) {
|
||||
return nil
|
||||
}
|
||||
protoNum = ipv6.ProtocolNumber
|
||||
default:
|
||||
return fmt.Errorf("unknown IP version: %d", payload[0]>>4)
|
||||
if len(payload) < header.IPv4MinimumSize {
|
||||
return fmt.Errorf("packet too small: %d bytes", len(payload))
|
||||
}
|
||||
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
@@ -212,160 +160,11 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||
defer pkt.DecRef()
|
||||
|
||||
if f.endpoint.dispatcher != nil {
|
||||
f.endpoint.dispatcher.DeliverNetworkPacket(protoNum, pkt)
|
||||
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleICMPDirect intercepts ICMP packets from raw IP payloads before they
|
||||
// enter gVisor. It synthesizes the TransportEndpointID and PacketBuffer that
|
||||
// the existing handlers expect, then dispatches to handleICMP/handleICMPv6.
|
||||
// This bypasses gVisor's network layer which causes duplicate v4 echo replies
|
||||
// and auto-replies to all v6 echo requests in promiscuous mode.
|
||||
//
|
||||
// Unlike gVisor's network layer, this does not validate ICMP checksums or
|
||||
// reassemble IP fragments. Fragmented ICMP packets fall through to gVisor.
|
||||
func parseICMPv4(payload []byte) (ipHdrLen, icmpLen int, src, dst tcpip.Address, ok bool) {
|
||||
if len(payload) < header.IPv4MinimumSize {
|
||||
return 0, 0, src, dst, false
|
||||
}
|
||||
ip := header.IPv4(payload)
|
||||
if ip.Protocol() != uint8(header.ICMPv4ProtocolNumber) {
|
||||
return 0, 0, src, dst, false
|
||||
}
|
||||
if ip.FragmentOffset() != 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 {
|
||||
return 0, 0, src, dst, false
|
||||
}
|
||||
ipHdrLen = int(ip.HeaderLength())
|
||||
totalLen := int(ip.TotalLength())
|
||||
if ipHdrLen < header.IPv4MinimumSize || ipHdrLen > totalLen || totalLen > len(payload) {
|
||||
return 0, 0, src, dst, false
|
||||
}
|
||||
icmpLen = totalLen - ipHdrLen
|
||||
if icmpLen < header.ICMPv4MinimumSize {
|
||||
return 0, 0, src, dst, false
|
||||
}
|
||||
return ipHdrLen, icmpLen, ip.SourceAddress(), ip.DestinationAddress(), true
|
||||
}
|
||||
|
||||
func parseICMPv6(payload []byte) (ipHdrLen, icmpLen int, src, dst tcpip.Address, ok bool) {
|
||||
if len(payload) < header.IPv6MinimumSize {
|
||||
return 0, 0, src, dst, false
|
||||
}
|
||||
ip := header.IPv6(payload)
|
||||
declaredLen := int(ip.PayloadLength())
|
||||
hdrEnd := header.IPv6MinimumSize + declaredLen
|
||||
if hdrEnd > len(payload) {
|
||||
return 0, 0, src, dst, false
|
||||
}
|
||||
icmpStart, ok := skipIPv6ExtensionsToICMPv6(payload, ip.NextHeader(), hdrEnd)
|
||||
if !ok {
|
||||
return 0, 0, src, dst, false
|
||||
}
|
||||
icmpLen = hdrEnd - icmpStart
|
||||
if icmpLen < header.ICMPv6MinimumSize {
|
||||
return 0, 0, src, dst, false
|
||||
}
|
||||
return icmpStart, icmpLen, ip.SourceAddress(), ip.DestinationAddress(), true
|
||||
}
|
||||
|
||||
// skipIPv6ExtensionsToICMPv6 walks the IPv6 extension-header chain starting
|
||||
// after the fixed header. It advances past Hop-by-Hop, Routing, and
|
||||
// Destination Options headers (which share the NextHeader+ExtLen+6+ExtLen*8
|
||||
// layout) and returns the offset of the ICMPv6 payload. Fragment, ESP, AH,
|
||||
// and unknown identifiers are reported as not handleable so the caller can
|
||||
// defer to gVisor.
|
||||
func skipIPv6ExtensionsToICMPv6(payload []byte, next uint8, hdrEnd int) (int, bool) {
|
||||
off := header.IPv6MinimumSize
|
||||
for {
|
||||
if next == uint8(header.ICMPv6ProtocolNumber) {
|
||||
return off, true
|
||||
}
|
||||
if !isWalkableIPv6ExtHdr(next) {
|
||||
return 0, false
|
||||
}
|
||||
newOff, newNext, ok := advanceIPv6ExtHdr(payload, off, hdrEnd)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
off = newOff
|
||||
next = newNext
|
||||
}
|
||||
}
|
||||
|
||||
func isWalkableIPv6ExtHdr(id uint8) bool {
|
||||
switch id {
|
||||
case uint8(header.IPv6HopByHopOptionsExtHdrIdentifier),
|
||||
uint8(header.IPv6RoutingExtHdrIdentifier),
|
||||
uint8(header.IPv6DestinationOptionsExtHdrIdentifier):
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func advanceIPv6ExtHdr(payload []byte, off, hdrEnd int) (int, uint8, bool) {
|
||||
if off+8 > hdrEnd {
|
||||
return 0, 0, false
|
||||
}
|
||||
extLen := (int(payload[off+1]) + 1) * 8
|
||||
if off+extLen > hdrEnd {
|
||||
return 0, 0, false
|
||||
}
|
||||
return off + extLen, payload[off], true
|
||||
}
|
||||
|
||||
func (f *Forwarder) handleICMPDirect(payload []byte) bool {
|
||||
if len(payload) == 0 {
|
||||
return false
|
||||
}
|
||||
var (
|
||||
ipHdrLen int
|
||||
icmpLen int
|
||||
srcAddr tcpip.Address
|
||||
dstAddr tcpip.Address
|
||||
ok bool
|
||||
)
|
||||
version := payload[0] >> 4
|
||||
switch version {
|
||||
case 4:
|
||||
ipHdrLen, icmpLen, srcAddr, dstAddr, ok = parseICMPv4(payload)
|
||||
case 6:
|
||||
ipHdrLen, icmpLen, srcAddr, dstAddr, ok = parseICMPv6(payload)
|
||||
}
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// Let gVisor handle ICMP destined for our own addresses natively.
|
||||
// Its network-layer auto-reply is correct and efficient for local traffic.
|
||||
if f.ip.Equal(dstAddr) || f.ipv6.Equal(dstAddr) {
|
||||
return false
|
||||
}
|
||||
|
||||
id := stack.TransportEndpointID{
|
||||
LocalAddress: dstAddr,
|
||||
RemoteAddress: srcAddr,
|
||||
}
|
||||
|
||||
// Trim the buffer to the IP-declared length so gVisor doesn't see padding.
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(payload[:ipHdrLen+icmpLen]),
|
||||
})
|
||||
defer pkt.DecRef()
|
||||
|
||||
if _, ok := pkt.NetworkHeader().Consume(ipHdrLen); !ok {
|
||||
return false
|
||||
}
|
||||
if _, ok := pkt.TransportHeader().Consume(icmpLen); !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if version == 6 {
|
||||
return f.handleICMPv6(id, pkt)
|
||||
}
|
||||
return f.handleICMP(id, pkt)
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the forwarder
|
||||
func (f *Forwarder) Stop() {
|
||||
f.cancel()
|
||||
@@ -378,14 +177,11 @@ func (f *Forwarder) Stop() {
|
||||
f.stack.Wait()
|
||||
}
|
||||
|
||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) netip.Addr {
|
||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||
if f.netstack && f.ip.Equal(addr) {
|
||||
return netip.AddrFrom4([4]byte{127, 0, 0, 1})
|
||||
return net.IPv4(127, 0, 0, 1)
|
||||
}
|
||||
if f.netstack && f.ipv6.Equal(addr) {
|
||||
return netip.IPv6Loopback()
|
||||
}
|
||||
return addrToNetipAddr(addr)
|
||||
return addr.AsSlice()
|
||||
}
|
||||
|
||||
func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) {
|
||||
@@ -419,50 +215,23 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe
|
||||
}
|
||||
}
|
||||
|
||||
// addrFromNetipAddr converts a netip.Addr to a gvisor tcpip.Address without allocating.
|
||||
func addrFromNetipAddr(addr netip.Addr) tcpip.Address {
|
||||
if !addr.IsValid() {
|
||||
return tcpip.Address{}
|
||||
}
|
||||
if addr.Is4() {
|
||||
return tcpip.AddrFrom4(addr.As4())
|
||||
}
|
||||
return tcpip.AddrFrom16(addr.As16())
|
||||
}
|
||||
|
||||
// addrToNetipAddr converts a gvisor tcpip.Address to netip.Addr without allocating.
|
||||
func addrToNetipAddr(addr tcpip.Address) netip.Addr {
|
||||
switch addr.Len() {
|
||||
case 4:
|
||||
return netip.AddrFrom4(addr.As4())
|
||||
case 16:
|
||||
return netip.AddrFrom16(addr.As16())
|
||||
default:
|
||||
return netip.Addr{}
|
||||
}
|
||||
}
|
||||
|
||||
// checkICMPCapability tests whether we have raw ICMP socket access at startup.
|
||||
func (f *Forwarder) checkICMPCapability() {
|
||||
f.hasRawICMPAccess = probeRawICMP("ip4:icmp", "0.0.0.0", f.logger)
|
||||
f.hasRawICMPv6Access = probeRawICMP("ip6:ipv6-icmp", "::", f.logger)
|
||||
}
|
||||
|
||||
func probeRawICMP(network, addr string, logger *nblog.Logger) bool {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
conn, err := lc.ListenPacket(ctx, network, addr)
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
logger.Debug1("forwarder: no raw %s socket access, will use ping binary fallback", network)
|
||||
return false
|
||||
f.hasRawICMPAccess = false
|
||||
f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback")
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
logger.Debug2("forwarder: failed to close %s capability test socket: %v", network, err)
|
||||
f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err)
|
||||
}
|
||||
|
||||
logger.Debug1("forwarder: raw %s socket access available", network)
|
||||
return true
|
||||
f.hasRawICMPAccess = true
|
||||
f.logger.Debug("forwarder: Raw ICMP socket access available")
|
||||
}
|
||||
|
||||
@@ -1,162 +0,0 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
)
|
||||
|
||||
const echoRequestSize = 8
|
||||
|
||||
func makeIPv6(t *testing.T, src, dst netip.Addr, nextHdr uint8, payload []byte) []byte {
|
||||
t.Helper()
|
||||
buf := make([]byte, header.IPv6MinimumSize+len(payload))
|
||||
ip := header.IPv6(buf)
|
||||
ip.Encode(&header.IPv6Fields{
|
||||
PayloadLength: uint16(len(payload)),
|
||||
TransportProtocol: 0, // overwritten below to allow any value
|
||||
HopLimit: 64,
|
||||
SrcAddr: tcpipAddrFromNetip(src),
|
||||
DstAddr: tcpipAddrFromNetip(dst),
|
||||
})
|
||||
buf[6] = nextHdr
|
||||
copy(buf[header.IPv6MinimumSize:], payload)
|
||||
return buf
|
||||
}
|
||||
|
||||
func tcpipAddrFromNetip(a netip.Addr) tcpip.Address {
|
||||
b := a.As16()
|
||||
return tcpip.AddrFrom16(b)
|
||||
}
|
||||
|
||||
func echoRequest() []byte {
|
||||
icmp := make([]byte, echoRequestSize)
|
||||
icmp[0] = uint8(header.ICMPv6EchoRequest)
|
||||
return icmp
|
||||
}
|
||||
|
||||
// extHdr builds a generic IPv6 extension header (HBH/Routing/DestOpts) of the
|
||||
// given total octet length (must be multiple of 8, >= 8) with the given next
|
||||
// header.
|
||||
func extHdr(t *testing.T, next uint8, totalLen int) []byte {
|
||||
t.Helper()
|
||||
require.GreaterOrEqual(t, totalLen, 8)
|
||||
require.Equal(t, 0, totalLen%8)
|
||||
buf := make([]byte, totalLen)
|
||||
buf[0] = next
|
||||
buf[1] = uint8(totalLen/8 - 1)
|
||||
return buf
|
||||
}
|
||||
|
||||
func TestParseICMPv6_NoExtensions(t *testing.T) {
|
||||
src := netip.MustParseAddr("fd00::1")
|
||||
dst := netip.MustParseAddr("fd00::2")
|
||||
pkt := makeIPv6(t, src, dst, uint8(header.ICMPv6ProtocolNumber), echoRequest())
|
||||
|
||||
off, icmpLen, _, _, ok := parseICMPv6(pkt)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, header.IPv6MinimumSize, off)
|
||||
assert.Equal(t, echoRequestSize, icmpLen)
|
||||
}
|
||||
|
||||
func TestParseICMPv6_SingleExtension(t *testing.T) {
|
||||
src := netip.MustParseAddr("fd00::1")
|
||||
dst := netip.MustParseAddr("fd00::2")
|
||||
hbh := extHdr(t, uint8(header.ICMPv6ProtocolNumber), 8)
|
||||
payload := append([]byte{}, hbh...)
|
||||
payload = append(payload, echoRequest()...)
|
||||
pkt := makeIPv6(t, src, dst, uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), payload)
|
||||
|
||||
off, icmpLen, _, _, ok := parseICMPv6(pkt)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, header.IPv6MinimumSize+8, off)
|
||||
assert.Equal(t, echoRequestSize, icmpLen)
|
||||
}
|
||||
|
||||
func TestParseICMPv6_ChainedExtensions(t *testing.T) {
|
||||
src := netip.MustParseAddr("fd00::1")
|
||||
dst := netip.MustParseAddr("fd00::2")
|
||||
dest := extHdr(t, uint8(header.ICMPv6ProtocolNumber), 16)
|
||||
rt := extHdr(t, uint8(header.IPv6DestinationOptionsExtHdrIdentifier), 8)
|
||||
hbh := extHdr(t, uint8(header.IPv6RoutingExtHdrIdentifier), 8)
|
||||
payload := append(append(append([]byte{}, hbh...), rt...), dest...)
|
||||
payload = append(payload, echoRequest()...)
|
||||
pkt := makeIPv6(t, src, dst, uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), payload)
|
||||
|
||||
off, icmpLen, _, _, ok := parseICMPv6(pkt)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, header.IPv6MinimumSize+8+8+16, off)
|
||||
assert.Equal(t, echoRequestSize, icmpLen)
|
||||
}
|
||||
|
||||
func TestParseICMPv6_FragmentDefersToGVisor(t *testing.T) {
|
||||
src := netip.MustParseAddr("fd00::1")
|
||||
dst := netip.MustParseAddr("fd00::2")
|
||||
pkt := makeIPv6(t, src, dst, uint8(header.IPv6FragmentExtHdrIdentifier), make([]byte, 8))
|
||||
|
||||
_, _, _, _, ok := parseICMPv6(pkt)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestParseICMPv6_TruncatedExtension(t *testing.T) {
|
||||
src := netip.MustParseAddr("fd00::1")
|
||||
dst := netip.MustParseAddr("fd00::2")
|
||||
// Extension claims 16 bytes but only 8 remain after the IP header.
|
||||
hbh := []byte{uint8(header.ICMPv6ProtocolNumber), 1, 0, 0, 0, 0, 0, 0}
|
||||
pkt := makeIPv6(t, src, dst, uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), hbh)
|
||||
|
||||
_, _, _, _, ok := parseICMPv6(pkt)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestParseICMPv6_TruncatedICMPPayload(t *testing.T) {
|
||||
src := netip.MustParseAddr("fd00::1")
|
||||
dst := netip.MustParseAddr("fd00::2")
|
||||
// PayloadLength claims 8 bytes of ICMPv6 but the buffer only holds 4.
|
||||
pkt := makeIPv6(t, src, dst, uint8(header.ICMPv6ProtocolNumber), make([]byte, 8))
|
||||
pkt = pkt[:header.IPv6MinimumSize+4]
|
||||
|
||||
_, _, _, _, ok := parseICMPv6(pkt)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestParseICMPv4_RejectsShortIHL(t *testing.T) {
|
||||
pkt := make([]byte, 28)
|
||||
pkt[0] = 0x44 // version 4, IHL 4 (16 bytes - below minimum)
|
||||
pkt[9] = uint8(header.ICMPv4ProtocolNumber)
|
||||
header.IPv4(pkt).SetTotalLength(28)
|
||||
|
||||
_, _, _, _, ok := parseICMPv4(pkt)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestParseICMPv4_RejectsTotalLenOverBuffer(t *testing.T) {
|
||||
pkt := make([]byte, header.IPv4MinimumSize+header.ICMPv4MinimumSize)
|
||||
ip := header.IPv4(pkt)
|
||||
ip.Encode(&header.IPv4Fields{
|
||||
TotalLength: uint16(len(pkt) + 16),
|
||||
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
||||
TTL: 64,
|
||||
})
|
||||
|
||||
_, _, _, _, ok := parseICMPv4(pkt)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestParseICMPv4_RejectsFragment(t *testing.T) {
|
||||
pkt := make([]byte, header.IPv4MinimumSize+header.ICMPv4MinimumSize)
|
||||
ip := header.IPv4(pkt)
|
||||
ip.Encode(&header.IPv4Fields{
|
||||
TotalLength: uint16(len(pkt)),
|
||||
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
||||
TTL: 64,
|
||||
Flags: header.IPv4FlagMoreFragments,
|
||||
})
|
||||
|
||||
_, _, _, _, ok := parseICMPv4(pkt)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
@@ -36,7 +35,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBu
|
||||
}
|
||||
|
||||
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), false, 100*time.Millisecond)
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond)
|
||||
if err != nil {
|
||||
f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err)
|
||||
return true
|
||||
@@ -59,7 +58,7 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI
|
||||
defer func() { <-f.pingSemaphore }()
|
||||
|
||||
if f.hasRawICMPAccess {
|
||||
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes, false)
|
||||
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
||||
} else {
|
||||
f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
||||
}
|
||||
@@ -73,23 +72,18 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI
|
||||
|
||||
// forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection.
|
||||
// The caller is responsible for closing the returned connection.
|
||||
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, v6 bool, timeout time.Duration) (net.PacketConn, error) {
|
||||
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) {
|
||||
ctx, cancel := context.WithTimeout(f.ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
network, listenAddr := "ip4:icmp", "0.0.0.0"
|
||||
if v6 {
|
||||
network, listenAddr = "ip6:ipv6-icmp", "::"
|
||||
}
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
conn, err := lc.ListenPacket(ctx, network, listenAddr)
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create ICMP socket: %w", err)
|
||||
}
|
||||
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
dst := &net.IPAddr{IP: dstIP.AsSlice()}
|
||||
dst := &net.IPAddr{IP: dstIP}
|
||||
|
||||
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
@@ -98,19 +92,17 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by
|
||||
return nil, fmt.Errorf("write ICMP packet: %w", err)
|
||||
}
|
||||
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
}
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// handleICMPViaSocket handles ICMP echo requests using raw sockets for both v4 and v6.
|
||||
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int, v6 bool) {
|
||||
// handleICMPViaSocket handles ICMP echo requests using raw sockets.
|
||||
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
|
||||
sendTime := time.Now()
|
||||
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, v6, 5*time.Second)
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second)
|
||||
if err != nil {
|
||||
f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err)
|
||||
return
|
||||
@@ -121,22 +113,16 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp
|
||||
}
|
||||
}()
|
||||
|
||||
txBytes := f.handleEchoResponse(conn, id, v6)
|
||||
txBytes := f.handleEchoResponse(conn, id)
|
||||
rtt := time.Since(sendTime).Round(10 * time.Microsecond)
|
||||
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
proto := "ICMP"
|
||||
if v6 {
|
||||
proto = "ICMPv6"
|
||||
}
|
||||
f.logger.Trace5("forwarder: Forwarded %s echo reply %v type %v code %v (rtt=%v, raw socket)",
|
||||
proto, epID(id), icmpType, icmpCode, rtt)
|
||||
}
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID, v6 bool) int {
|
||||
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) int {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err)
|
||||
return 0
|
||||
@@ -151,19 +137,6 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn
|
||||
return 0
|
||||
}
|
||||
|
||||
if v6 {
|
||||
// Recompute checksum: the raw socket response has a checksum computed
|
||||
// over the real endpoint addresses, but we inject with overlay addresses.
|
||||
icmpHdr := header.ICMPv6(response[:n])
|
||||
icmpHdr.SetChecksum(0)
|
||||
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||
Header: icmpHdr,
|
||||
Src: id.LocalAddress,
|
||||
Dst: id.RemoteAddress,
|
||||
}))
|
||||
return f.injectICMPv6Reply(id, response[:n])
|
||||
}
|
||||
|
||||
return f.injectICMPReply(id, response[:n])
|
||||
}
|
||||
|
||||
@@ -177,23 +150,19 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
|
||||
txPackets = 1
|
||||
}
|
||||
|
||||
srcIp := addrToNetipAddr(id.RemoteAddress)
|
||||
dstIp := addrToNetipAddr(id.LocalAddress)
|
||||
|
||||
proto := nftypes.ICMP
|
||||
if srcIp.Is6() {
|
||||
proto = nftypes.ICMPv6
|
||||
}
|
||||
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: proto,
|
||||
SourceIP: srcIp,
|
||||
DestIP: dstIp,
|
||||
ICMPType: icmpType,
|
||||
ICMPCode: icmpCode,
|
||||
Protocol: nftypes.ICMP,
|
||||
// TODO: handle ipv6
|
||||
SourceIP: srcIp,
|
||||
DestIP: dstIp,
|
||||
ICMPType: icmpType,
|
||||
ICMPCode: icmpCode,
|
||||
|
||||
RxBytes: rxBytes,
|
||||
TxBytes: txBytes,
|
||||
@@ -229,179 +198,37 @@ func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpoi
|
||||
}
|
||||
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
|
||||
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
}
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
|
||||
txBytes := f.synthesizeEchoReply(id, icmpData)
|
||||
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
}
|
||||
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
// handleICMPv6 handles ICMPv6 packets from the network stack.
|
||||
func (f *Forwarder) handleICMPv6(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||
icmpHdr := header.ICMPv6(pkt.TransportHeader().View().AsSlice())
|
||||
|
||||
flowID := uuid.New()
|
||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0)
|
||||
|
||||
if icmpHdr.Type() == header.ICMPv6EchoRequest {
|
||||
return f.handleICMPv6Echo(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()))
|
||||
}
|
||||
|
||||
// For non-echo types (Destination Unreachable, Packet Too Big, etc), forward without waiting
|
||||
if !f.hasRawICMPv6Access {
|
||||
f.logger.Debug2("forwarder: Cannot handle ICMPv6 type %v without raw socket access for %v", icmpHdr.Type(), epID(id))
|
||||
return false
|
||||
}
|
||||
|
||||
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), true, 100*time.Millisecond)
|
||||
if err != nil {
|
||||
f.logger.Error2("forwarder: Failed to forward ICMPv6 packet for %v: %v", epID(id), err)
|
||||
return true
|
||||
}
|
||||
if err := conn.Close(); err != nil {
|
||||
f.logger.Debug1("forwarder: Failed to close ICMPv6 socket: %v", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// handleICMPv6Echo handles ICMPv6 echo requests via raw socket or ping binary fallback.
|
||||
func (f *Forwarder) handleICMPv6Echo(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool {
|
||||
select {
|
||||
case f.pingSemaphore <- struct{}{}:
|
||||
icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice()
|
||||
rxBytes := pkt.Size()
|
||||
|
||||
go func() {
|
||||
defer func() { <-f.pingSemaphore }()
|
||||
|
||||
if f.hasRawICMPv6Access {
|
||||
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes, true)
|
||||
} else {
|
||||
f.handleICMPv6ViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
||||
}
|
||||
}()
|
||||
default:
|
||||
f.logger.Debug3("forwarder: ICMPv6 rate limit exceeded for %v type %v code %v", epID(id), icmpType, icmpCode)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// handleICMPv6ViaPing uses the system ping6 binary for ICMPv6 echo.
|
||||
func (f *Forwarder) handleICMPv6ViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
|
||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
cmd := buildPingCommand(ctx, dstIP, 5*time.Second)
|
||||
|
||||
pingStart := time.Now()
|
||||
if err := cmd.Run(); err != nil {
|
||||
f.logger.Warn4("forwarder: Ping6 failed for %v type %v code %v: %v", epID(id), icmpType, icmpCode, err)
|
||||
return
|
||||
}
|
||||
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMPv6 echo request %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
|
||||
txBytes := f.synthesizeICMPv6EchoReply(id, icmpData)
|
||||
|
||||
f.logger.Trace4("forwarder: Forwarded ICMPv6 echo reply %v type %v code %v (rtt=%v, ping binary)",
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
// synthesizeICMPv6EchoReply creates an ICMPv6 echo reply and injects it back.
|
||||
func (f *Forwarder) synthesizeICMPv6EchoReply(id stack.TransportEndpointID, icmpData []byte) int {
|
||||
replyICMP := make([]byte, len(icmpData))
|
||||
copy(replyICMP, icmpData)
|
||||
|
||||
replyHdr := header.ICMPv6(replyICMP)
|
||||
replyHdr.SetType(header.ICMPv6EchoReply)
|
||||
replyHdr.SetChecksum(0)
|
||||
// ICMPv6Checksum computes the pseudo-header internally from Src/Dst.
|
||||
// Header contains the full ICMP message, so PayloadCsum/PayloadLen are zero.
|
||||
replyHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||
Header: replyHdr,
|
||||
Src: id.LocalAddress,
|
||||
Dst: id.RemoteAddress,
|
||||
}))
|
||||
|
||||
return f.injectICMPv6Reply(id, replyICMP)
|
||||
}
|
||||
|
||||
// injectICMPv6Reply wraps an ICMPv6 payload in an IPv6 header and sends to the peer.
|
||||
func (f *Forwarder) injectICMPv6Reply(id stack.TransportEndpointID, icmpPayload []byte) int {
|
||||
ipHdr := make([]byte, header.IPv6MinimumSize)
|
||||
ip := header.IPv6(ipHdr)
|
||||
ip.Encode(&header.IPv6Fields{
|
||||
PayloadLength: uint16(len(icmpPayload)),
|
||||
TransportProtocol: header.ICMPv6ProtocolNumber,
|
||||
HopLimit: 64,
|
||||
SrcAddr: id.LocalAddress,
|
||||
DstAddr: id.RemoteAddress,
|
||||
})
|
||||
|
||||
fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload))
|
||||
fullPacket = append(fullPacket, ipHdr...)
|
||||
fullPacket = append(fullPacket, icmpPayload...)
|
||||
|
||||
if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil {
|
||||
f.logger.Error1("forwarder: Failed to send ICMPv6 reply to peer: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
return len(fullPacket)
|
||||
}
|
||||
|
||||
const (
|
||||
pingBin = "ping"
|
||||
ping6Bin = "ping6"
|
||||
)
|
||||
|
||||
// buildPingCommand creates a platform-specific ping command.
|
||||
// Most platforms auto-detect IPv6 from raw addresses. macOS/iOS/OpenBSD require ping6.
|
||||
func buildPingCommand(ctx context.Context, target netip.Addr, timeout time.Duration) *exec.Cmd {
|
||||
func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd {
|
||||
timeoutSec := int(timeout.Seconds())
|
||||
if timeoutSec < 1 {
|
||||
timeoutSec = 1
|
||||
}
|
||||
|
||||
isV6 := target.Is6()
|
||||
timeoutStr := fmt.Sprintf("%d", timeoutSec)
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux", "android":
|
||||
return exec.CommandContext(ctx, pingBin, "-c", "1", "-W", timeoutStr, "-q", target.String())
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
|
||||
case "darwin", "ios":
|
||||
bin := pingBin
|
||||
if isV6 {
|
||||
bin = ping6Bin
|
||||
}
|
||||
return exec.CommandContext(ctx, bin, "-c", "1", "-t", timeoutStr, "-q", target.String())
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
|
||||
case "freebsd":
|
||||
return exec.CommandContext(ctx, pingBin, "-c", "1", "-t", timeoutStr, target.String())
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String())
|
||||
case "openbsd", "netbsd":
|
||||
bin := pingBin
|
||||
if isV6 {
|
||||
bin = ping6Bin
|
||||
}
|
||||
return exec.CommandContext(ctx, bin, "-c", "1", "-w", timeoutStr, target.String())
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String())
|
||||
case "windows":
|
||||
return exec.CommandContext(ctx, pingBin, "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
|
||||
return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
|
||||
default:
|
||||
return exec.CommandContext(ctx, pingBin, "-c", "1", target.String())
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", target.String())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
@@ -12,9 +16,7 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
)
|
||||
|
||||
// handleTCP is called by the TCP forwarder for new connections.
|
||||
@@ -31,14 +33,12 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
}
|
||||
}()
|
||||
|
||||
dialAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort)))
|
||||
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||
if err != nil {
|
||||
r.Complete(true)
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err)
|
||||
}
|
||||
f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -61,22 +61,64 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
inConn := gonet.NewTCPConn(&wq, ep)
|
||||
|
||||
success = true
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace1("forwarder: established TCP connection %v", epID(id))
|
||||
}
|
||||
f.logger.Trace1("forwarder: established TCP connection %v", epID(id))
|
||||
|
||||
go f.proxyTCP(id, inConn, outConn, ep, flowID)
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
|
||||
// netrelay.Relay copies bidirectionally with proper half-close propagation
|
||||
// and fully closes both conns before returning.
|
||||
bytesFromInToOut, bytesFromOutToIn := netrelay.Relay(f.ctx, inConn, outConn, netrelay.Options{
|
||||
Logger: f.logger,
|
||||
})
|
||||
|
||||
// Close the netstack endpoint after both conns are drained.
|
||||
ep.Close()
|
||||
ctx, cancel := context.WithCancel(f.ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
// Close connections and endpoint.
|
||||
if err := inConn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug1("forwarder: inConn close error: %v", err)
|
||||
}
|
||||
if err := outConn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug1("forwarder: outConn close error: %v", err)
|
||||
}
|
||||
|
||||
ep.Close()
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var (
|
||||
bytesFromInToOut int64 // bytes from client to server (tx for client)
|
||||
bytesFromOutToIn int64 // bytes from server to client (rx for client)
|
||||
errInToOut error
|
||||
errOutToIn error
|
||||
)
|
||||
|
||||
go func() {
|
||||
bytesFromInToOut, errInToOut = io.Copy(outConn, inConn)
|
||||
cancel()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
|
||||
bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn)
|
||||
cancel()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if errInToOut != nil {
|
||||
if !isClosedError(errInToOut) {
|
||||
f.logger.Error2("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut)
|
||||
}
|
||||
}
|
||||
if errOutToIn != nil {
|
||||
if !isClosedError(errOutToIn) {
|
||||
f.logger.Error2("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn)
|
||||
}
|
||||
}
|
||||
|
||||
var rxPackets, txPackets uint64
|
||||
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
||||
@@ -85,22 +127,21 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
||||
txPackets = tcpStats.SegmentsReceived.Value()
|
||||
}
|
||||
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
|
||||
}
|
||||
f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
|
||||
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
|
||||
}
|
||||
|
||||
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
||||
srcIp := addrToNetipAddr(id.RemoteAddress)
|
||||
dstIp := addrToNetipAddr(id.LocalAddress)
|
||||
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.TCP,
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.TCP,
|
||||
// TODO: handle ipv6
|
||||
SourceIP: srcIp,
|
||||
DestIP: dstIp,
|
||||
SourcePort: id.RemotePort,
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -125,9 +125,7 @@ func (f *udpForwarder) cleanup() {
|
||||
delete(f.conns, idle.id)
|
||||
f.Unlock()
|
||||
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
|
||||
}
|
||||
f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -146,9 +144,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
_, exists := f.udpForwarder.conns[id]
|
||||
f.udpForwarder.RUnlock()
|
||||
if exists {
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
|
||||
}
|
||||
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -162,7 +158,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
}
|
||||
}()
|
||||
|
||||
dstAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort)))
|
||||
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||
if err != nil {
|
||||
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
|
||||
@@ -210,9 +206,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
f.udpForwarder.Unlock()
|
||||
|
||||
success = true
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
|
||||
}
|
||||
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
|
||||
|
||||
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||
return true
|
||||
@@ -271,9 +265,7 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
||||
txPackets = udpStats.PacketsReceived.Value()
|
||||
}
|
||||
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
|
||||
}
|
||||
f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
|
||||
|
||||
f.udpForwarder.Lock()
|
||||
delete(f.udpForwarder.conns, id)
|
||||
@@ -284,14 +276,15 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
||||
|
||||
// sendUDPEvent stores flow events for UDP connections
|
||||
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
||||
srcIp := addrToNetipAddr(id.RemoteAddress)
|
||||
dstIp := addrToNetipAddr(id.LocalAddress)
|
||||
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.UDP,
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.UDP,
|
||||
// TODO: handle ipv6
|
||||
SourceIP: srcIp,
|
||||
DestIP: dstIp,
|
||||
SourcePort: id.RemotePort,
|
||||
|
||||
@@ -13,6 +13,7 @@ const (
|
||||
ipv4HeaderMinLen = 20
|
||||
ipv4ProtoOffset = 9
|
||||
ipv4FlagsOffset = 6
|
||||
ipv4DstOffset = 16
|
||||
ipProtoUDP = 17
|
||||
ipProtoTCP = 6
|
||||
ipv4FragOffMask = 0x1fff
|
||||
|
||||
@@ -4,32 +4,89 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
)
|
||||
|
||||
// localIPSnapshot is an immutable snapshot of local IP addresses, swapped
|
||||
// atomically so reads are lock-free.
|
||||
type localIPSnapshot struct {
|
||||
ips map[netip.Addr]struct{}
|
||||
type localIPManager struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// fixed-size high array for upper byte of a IPv4 address
|
||||
ipv4Bitmap [256]*ipv4LowBitmap
|
||||
}
|
||||
|
||||
type localIPManager struct {
|
||||
snapshot atomic.Pointer[localIPSnapshot]
|
||||
// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address
|
||||
type ipv4LowBitmap struct {
|
||||
bitmap [8192]uint32
|
||||
}
|
||||
|
||||
func newLocalIPManager() *localIPManager {
|
||||
m := &localIPManager{}
|
||||
m.snapshot.Store(&localIPSnapshot{
|
||||
ips: make(map[netip.Addr]struct{}),
|
||||
})
|
||||
return m
|
||||
return &localIPManager{}
|
||||
}
|
||||
|
||||
func processInterface(iface net.Interface, ips map[netip.Addr]struct{}, addresses *[]netip.Addr) {
|
||||
func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||
ipv4 := ip.To4()
|
||||
if ipv4 == nil {
|
||||
return
|
||||
}
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
|
||||
if m.ipv4Bitmap[high] == nil {
|
||||
m.ipv4Bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
|
||||
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
||||
}
|
||||
|
||||
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||
if !ip.Is4() {
|
||||
return
|
||||
}
|
||||
ipv4 := ip.AsSlice()
|
||||
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
|
||||
if bitmap[high] == nil {
|
||||
bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
bitmap[high].bitmap[index] |= 1 << bit
|
||||
|
||||
if _, exists := ipv4Set[ip]; !exists {
|
||||
ipv4Set[ip] = struct{}{}
|
||||
*ipv4Addresses = append(*ipv4Addresses, ip)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||
high := uint16(ip[0])
|
||||
low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3])
|
||||
|
||||
if m.ipv4Bitmap[high] == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
||||
}
|
||||
|
||||
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
|
||||
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||
@@ -47,19 +104,18 @@ func processInterface(iface net.Interface, ips map[netip.Addr]struct{}, addresse
|
||||
continue
|
||||
}
|
||||
|
||||
parsed, ok := netip.AddrFromSlice(ip)
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
parsed = parsed.Unmap()
|
||||
ips[parsed] = struct{}{}
|
||||
*addresses = append(*addresses, parsed)
|
||||
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
log.Debugf("process IP failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateLocalIPs rebuilds the local IP snapshot and swaps it in atomically.
|
||||
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
@@ -67,20 +123,20 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
}
|
||||
}()
|
||||
|
||||
ips := make(map[netip.Addr]struct{})
|
||||
var addresses []netip.Addr
|
||||
var newIPv4Bitmap [256]*ipv4LowBitmap
|
||||
ipv4Set := make(map[netip.Addr]struct{})
|
||||
var ipv4Addresses []netip.Addr
|
||||
|
||||
// loopback
|
||||
ips[netip.AddrFrom4([4]byte{127, 0, 0, 1})] = struct{}{}
|
||||
ips[netip.IPv6Loopback()] = struct{}{}
|
||||
// 127.0.0.0/8
|
||||
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
||||
for i := 0; i < 8192; i++ {
|
||||
// #nosec G602 -- bitmap is defined as [8192]uint32, loop range is correct
|
||||
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
|
||||
}
|
||||
|
||||
if iface != nil {
|
||||
ip := iface.Address().IP
|
||||
ips[ip] = struct{}{}
|
||||
addresses = append(addresses, ip)
|
||||
if v6 := iface.Address().IPv6; v6.IsValid() {
|
||||
ips[v6] = struct{}{}
|
||||
addresses = append(addresses, v6)
|
||||
if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,24 +147,25 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
|
||||
// case where an interface comes up between refreshes.
|
||||
for _, intf := range interfaces {
|
||||
processInterface(intf, ips, &addresses)
|
||||
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
|
||||
}
|
||||
}
|
||||
|
||||
m.snapshot.Store(&localIPSnapshot{ips: ips})
|
||||
m.mu.Lock()
|
||||
m.ipv4Bitmap = newIPv4Bitmap
|
||||
m.mu.Unlock()
|
||||
|
||||
log.Debugf("Local IP addresses: %v", addresses)
|
||||
log.Debugf("Local IPv4 addresses: %v", ipv4Addresses)
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsLocalIP checks if the given IP is a local address. Lock-free on the read path.
|
||||
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
|
||||
s := m.snapshot.Load()
|
||||
|
||||
if ip.Is4() && ip.As4()[0] == 127 {
|
||||
return true
|
||||
if !ip.Is4() {
|
||||
return false
|
||||
}
|
||||
|
||||
_, found := s.ips[ip]
|
||||
return found
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
return m.checkBitmapBit(ip.AsSlice())
|
||||
}
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
func setupManager(b *testing.B) *localIPManager {
|
||||
b.Helper()
|
||||
m := newLocalIPManager()
|
||||
mock := &IFaceMock{
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||
IPv6: netip.MustParseAddr("fd00::1"),
|
||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
||||
}
|
||||
},
|
||||
}
|
||||
if err := m.UpdateLocalIPs(mock); err != nil {
|
||||
b.Fatalf("UpdateLocalIPs: %v", err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func BenchmarkIsLocalIP_v4_hit(b *testing.B) {
|
||||
m := setupManager(b)
|
||||
ip := netip.MustParseAddr("100.64.0.1")
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.IsLocalIP(ip)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIsLocalIP_v4_miss(b *testing.B) {
|
||||
m := setupManager(b)
|
||||
ip := netip.MustParseAddr("8.8.8.8")
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.IsLocalIP(ip)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIsLocalIP_v6_hit(b *testing.B) {
|
||||
m := setupManager(b)
|
||||
ip := netip.MustParseAddr("fd00::1")
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.IsLocalIP(ip)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIsLocalIP_v6_miss(b *testing.B) {
|
||||
m := setupManager(b)
|
||||
ip := netip.MustParseAddr("2001:db8::1")
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.IsLocalIP(ip)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIsLocalIP_loopback(b *testing.B) {
|
||||
m := setupManager(b)
|
||||
ip := netip.MustParseAddr("127.0.0.1")
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.IsLocalIP(ip)
|
||||
}
|
||||
}
|
||||
@@ -72,45 +72,14 @@ func TestLocalIPManager(t *testing.T) {
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 address matches",
|
||||
name: "IPv6 address",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||
IPv6: netip.MustParseAddr("fd00::1"),
|
||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("fd00::1"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 address does not match",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||
IPv6: netip.MustParseAddr("fd00::1"),
|
||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("fd00::99"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "No aliasing between similar IPs",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
IP: netip.MustParseAddr("fe80::1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("192.168.0.17"),
|
||||
testIP: netip.MustParseAddr("fe80::1"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 loopback",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("::1"),
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -202,3 +171,90 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// MapImplementation is a version using map[string]struct{}
|
||||
type MapImplementation struct {
|
||||
localIPs map[string]struct{}
|
||||
}
|
||||
|
||||
func BenchmarkIPChecks(b *testing.B) {
|
||||
interfaces := make([]net.IP, 16)
|
||||
for i := range interfaces {
|
||||
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
||||
}
|
||||
|
||||
// Setup bitmap
|
||||
bitmapManager := newLocalIPManager()
|
||||
for _, ip := range interfaces[:8] { // Add half of IPs
|
||||
bitmapManager.setBitmapBit(ip)
|
||||
}
|
||||
|
||||
// Setup map version
|
||||
mapManager := &MapImplementation{
|
||||
localIPs: make(map[string]struct{}),
|
||||
}
|
||||
for _, ip := range interfaces[:8] {
|
||||
mapManager.localIPs[ip.String()] = struct{}{}
|
||||
}
|
||||
|
||||
b.Run("Bitmap_Hit", func(b *testing.B) {
|
||||
ip := interfaces[4]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bitmapManager.checkBitmapBit(ip)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Bitmap_Miss", func(b *testing.B) {
|
||||
ip := interfaces[12]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bitmapManager.checkBitmapBit(ip)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Map_Hit", func(b *testing.B) {
|
||||
ip := interfaces[4]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// nolint:gosimple
|
||||
_ = mapManager.localIPs[ip.String()]
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Map_Miss", func(b *testing.B) {
|
||||
ip := interfaces[12]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// nolint:gosimple
|
||||
_ = mapManager.localIPs[ip.String()]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkWGPosition(b *testing.B) {
|
||||
wgIP := net.ParseIP("10.10.0.1")
|
||||
|
||||
// Create two managers - one checks WG IP first, other checks it last
|
||||
b.Run("WG_First", func(b *testing.B) {
|
||||
bm := newLocalIPManager()
|
||||
bm.setBitmapBit(wgIP)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bm.checkBitmapBit(wgIP)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("WG_Last", func(b *testing.B) {
|
||||
bm := newLocalIPManager()
|
||||
// Fill with other IPs first
|
||||
for i := 0; i < 15; i++ {
|
||||
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
||||
}
|
||||
bm.setBitmapBit(wgIP) // Add WG IP last
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bm.checkBitmapBit(wgIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -53,17 +53,16 @@ var levelStrings = map[Level]string{
|
||||
}
|
||||
|
||||
type logMessage struct {
|
||||
level Level
|
||||
argCount uint8
|
||||
format string
|
||||
arg1 any
|
||||
arg2 any
|
||||
arg3 any
|
||||
arg4 any
|
||||
arg5 any
|
||||
arg6 any
|
||||
arg7 any
|
||||
arg8 any
|
||||
level Level
|
||||
format string
|
||||
arg1 any
|
||||
arg2 any
|
||||
arg3 any
|
||||
arg4 any
|
||||
arg5 any
|
||||
arg6 any
|
||||
arg7 any
|
||||
arg8 any
|
||||
}
|
||||
|
||||
// Logger is a high-performance, non-blocking logger
|
||||
@@ -108,13 +107,6 @@ func (l *Logger) SetLevel(level Level) {
|
||||
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
||||
}
|
||||
|
||||
// Enabled reports whether the given level is currently logged. Callers on the
|
||||
// hot path should guard log sites with this to avoid boxing arguments into
|
||||
// any when the level is off.
|
||||
func (l *Logger) Enabled(level Level) bool {
|
||||
return l.level.Load() >= uint32(level)
|
||||
}
|
||||
|
||||
func (l *Logger) Error(format string) {
|
||||
if l.level.Load() >= uint32(LevelError) {
|
||||
select {
|
||||
@@ -163,7 +155,7 @@ func (l *Logger) Trace(format string) {
|
||||
func (l *Logger) Error1(format string, arg1 any) {
|
||||
if l.level.Load() >= uint32(LevelError) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelError, argCount: 1, format: format, arg1: arg1}:
|
||||
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -172,16 +164,7 @@ func (l *Logger) Error1(format string, arg1 any) {
|
||||
func (l *Logger) Error2(format string, arg1, arg2 any) {
|
||||
if l.level.Load() >= uint32(LevelError) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelError, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Warn2(format string, arg1, arg2 any) {
|
||||
if l.level.Load() >= uint32(LevelWarn) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
|
||||
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1, arg2: arg2}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -190,7 +173,7 @@ func (l *Logger) Warn2(format string, arg1, arg2 any) {
|
||||
func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
|
||||
if l.level.Load() >= uint32(LevelWarn) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -199,7 +182,7 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
|
||||
func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
|
||||
if l.level.Load() >= uint32(LevelWarn) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -208,7 +191,7 @@ func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
|
||||
func (l *Logger) Debug1(format string, arg1 any) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 1, format: format, arg1: arg1}:
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -217,7 +200,7 @@ func (l *Logger) Debug1(format string, arg1 any) {
|
||||
func (l *Logger) Debug2(format string, arg1, arg2 any) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -226,59 +209,16 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) {
|
||||
func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Debugf is the variadic shape. Dispatches to Debug/Debug1/Debug2/Debug3
|
||||
// to avoid allocating an args slice on the fast path when the arg count is
|
||||
// known (0-3). Args beyond 3 land on the general variadic path; callers on
|
||||
// the hot path should prefer DebugN for known counts.
|
||||
func (l *Logger) Debugf(format string, args ...any) {
|
||||
if l.level.Load() < uint32(LevelDebug) {
|
||||
return
|
||||
}
|
||||
switch len(args) {
|
||||
case 0:
|
||||
l.Debug(format)
|
||||
case 1:
|
||||
l.Debug1(format, args[0])
|
||||
case 2:
|
||||
l.Debug2(format, args[0], args[1])
|
||||
case 3:
|
||||
l.Debug3(format, args[0], args[1], args[2])
|
||||
default:
|
||||
l.sendVariadic(LevelDebug, format, args)
|
||||
}
|
||||
}
|
||||
|
||||
// sendVariadic packs a slice of arguments into a logMessage and non-blocking
|
||||
// enqueues it. Used for arg counts beyond the fixed-arity fast paths. Args
|
||||
// beyond the 8-arg slot limit are dropped so callers don't produce silently
|
||||
// empty log lines via uint8 wraparound in argCount.
|
||||
func (l *Logger) sendVariadic(level Level, format string, args []any) {
|
||||
const maxArgs = 8
|
||||
n := len(args)
|
||||
if n > maxArgs {
|
||||
n = maxArgs
|
||||
}
|
||||
msg := logMessage{level: level, argCount: uint8(n), format: format}
|
||||
slots := [maxArgs]*any{&msg.arg1, &msg.arg2, &msg.arg3, &msg.arg4, &msg.arg5, &msg.arg6, &msg.arg7, &msg.arg8}
|
||||
for i := 0; i < n; i++ {
|
||||
*slots[i] = args[i]
|
||||
}
|
||||
select {
|
||||
case l.msgChannel <- msg:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Trace1(format string, arg1 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 1, format: format, arg1: arg1}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -287,7 +227,7 @@ func (l *Logger) Trace1(format string, arg1 any) {
|
||||
func (l *Logger) Trace2(format string, arg1, arg2 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -296,7 +236,7 @@ func (l *Logger) Trace2(format string, arg1, arg2 any) {
|
||||
func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -305,7 +245,7 @@ func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) {
|
||||
func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -314,7 +254,7 @@ func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) {
|
||||
func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 5, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -323,7 +263,7 @@ func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) {
|
||||
func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 6, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -333,7 +273,7 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
|
||||
func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 8, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -346,8 +286,35 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
|
||||
*buf = append(*buf, levelStrings[msg.level]...)
|
||||
*buf = append(*buf, ' ')
|
||||
|
||||
// Count non-nil arguments for switch
|
||||
argCount := 0
|
||||
if msg.arg1 != nil {
|
||||
argCount++
|
||||
if msg.arg2 != nil {
|
||||
argCount++
|
||||
if msg.arg3 != nil {
|
||||
argCount++
|
||||
if msg.arg4 != nil {
|
||||
argCount++
|
||||
if msg.arg5 != nil {
|
||||
argCount++
|
||||
if msg.arg6 != nil {
|
||||
argCount++
|
||||
if msg.arg7 != nil {
|
||||
argCount++
|
||||
if msg.arg8 != nil {
|
||||
argCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var formatted string
|
||||
switch msg.argCount {
|
||||
switch argCount {
|
||||
case 0:
|
||||
formatted = msg.format
|
||||
case 1:
|
||||
|
||||
@@ -11,9 +11,10 @@ import (
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
|
||||
|
||||
var (
|
||||
errInvalidIPHeaderLength = errors.New("invalid IP header length")
|
||||
)
|
||||
@@ -24,33 +25,10 @@ const (
|
||||
destinationPortOffset = 2
|
||||
|
||||
// IP address offsets in IPv4 header
|
||||
ipv4SrcOffset = 12
|
||||
ipv4DstOffset = 16
|
||||
|
||||
// IP address offsets in IPv6 header
|
||||
ipv6SrcOffset = 8
|
||||
ipv6DstOffset = 24
|
||||
|
||||
// IPv6 fixed header length
|
||||
ipv6HeaderLen = 40
|
||||
sourceIPOffset = 12
|
||||
destinationIPOffset = 16
|
||||
)
|
||||
|
||||
// ipHeaderLen returns the IP header length based on the decoded layer type.
|
||||
func ipHeaderLen(d *decoder) (int, error) {
|
||||
switch d.decoded[0] {
|
||||
case layers.LayerTypeIPv4:
|
||||
n := int(d.ip4.IHL) * 4
|
||||
if n < 20 {
|
||||
return 0, errInvalidIPHeaderLength
|
||||
}
|
||||
return n, nil
|
||||
case layers.LayerTypeIPv6:
|
||||
return ipv6HeaderLen, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown IP layer: %v", d.decoded[0])
|
||||
}
|
||||
}
|
||||
|
||||
// ipv4Checksum calculates IPv4 header checksum.
|
||||
func ipv4Checksum(header []byte) uint16 {
|
||||
if len(header) < 20 {
|
||||
@@ -256,22 +234,19 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
_, dstIP := extractPacketIPs(packetData, d)
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
|
||||
translatedIP, exists := m.getDNATTranslation(dstIP)
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewritePacketIP(packetData, d, translatedIP, false); err != nil {
|
||||
if m.logger.Enabled(nblog.LevelError) {
|
||||
m.logger.Error1("failed to rewrite packet destination: %v", err)
|
||||
}
|
||||
if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite packet destination: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP)
|
||||
}
|
||||
m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -281,115 +256,54 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP, _ := extractPacketIPs(packetData, d)
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
|
||||
originalIP, exists := m.findReverseDNATMapping(srcIP)
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewritePacketIP(packetData, d, originalIP, true); err != nil {
|
||||
if m.logger.Enabled(nblog.LevelError) {
|
||||
m.logger.Error1("failed to rewrite packet source: %v", err)
|
||||
}
|
||||
if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite packet source: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP)
|
||||
}
|
||||
m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP)
|
||||
return true
|
||||
}
|
||||
|
||||
// extractPacketIPs extracts src and dst IP addresses directly from raw packet bytes.
|
||||
func extractPacketIPs(packetData []byte, d *decoder) (src, dst netip.Addr) {
|
||||
switch d.decoded[0] {
|
||||
case layers.LayerTypeIPv4:
|
||||
src = netip.AddrFrom4([4]byte{packetData[ipv4SrcOffset], packetData[ipv4SrcOffset+1], packetData[ipv4SrcOffset+2], packetData[ipv4SrcOffset+3]})
|
||||
dst = netip.AddrFrom4([4]byte{packetData[ipv4DstOffset], packetData[ipv4DstOffset+1], packetData[ipv4DstOffset+2], packetData[ipv4DstOffset+3]})
|
||||
case layers.LayerTypeIPv6:
|
||||
src = netip.AddrFrom16([16]byte(packetData[ipv6SrcOffset : ipv6SrcOffset+16]))
|
||||
dst = netip.AddrFrom16([16]byte(packetData[ipv6DstOffset : ipv6DstOffset+16]))
|
||||
}
|
||||
return src, dst
|
||||
}
|
||||
|
||||
// rewritePacketIP replaces a source (isSource=true) or destination IP address in the packet and updates checksums.
|
||||
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, isSource bool) error {
|
||||
hdrLen, err := ipHeaderLen(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch d.decoded[0] {
|
||||
case layers.LayerTypeIPv4:
|
||||
return m.rewriteIPv4(packetData, d, newIP, hdrLen, isSource)
|
||||
case layers.LayerTypeIPv6:
|
||||
return m.rewriteIPv6(packetData, d, newIP, hdrLen, isSource)
|
||||
default:
|
||||
return fmt.Errorf("unknown IP layer: %v", d.decoded[0])
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) rewriteIPv4(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error {
|
||||
// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums.
|
||||
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error {
|
||||
if !newIP.Is4() {
|
||||
return fmt.Errorf("cannot write IPv6 address into IPv4 packet")
|
||||
}
|
||||
|
||||
offset := ipv4DstOffset
|
||||
if isSource {
|
||||
offset = ipv4SrcOffset
|
||||
return ErrIPv4Only
|
||||
}
|
||||
|
||||
var oldIP [4]byte
|
||||
copy(oldIP[:], packetData[offset:offset+4])
|
||||
copy(oldIP[:], packetData[ipOffset:ipOffset+4])
|
||||
newIPBytes := newIP.As4()
|
||||
copy(packetData[offset:offset+4], newIPBytes[:])
|
||||
|
||||
// Recalculate IPv4 header checksum
|
||||
copy(packetData[ipOffset:ipOffset+4], newIPBytes[:])
|
||||
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return errInvalidIPHeaderLength
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||
binary.BigEndian.PutUint16(packetData[10:12], ipv4Checksum(packetData[:hdrLen]))
|
||||
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||
|
||||
// Update transport checksums incrementally
|
||||
if len(d.decoded) > 1 {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
|
||||
case layers.LayerTypeUDP:
|
||||
m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.updateICMPChecksum(packetData, hdrLen)
|
||||
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) rewriteIPv6(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error {
|
||||
if !newIP.Is6() {
|
||||
return fmt.Errorf("cannot write IPv4 address into IPv6 packet")
|
||||
}
|
||||
|
||||
offset := ipv6DstOffset
|
||||
if isSource {
|
||||
offset = ipv6SrcOffset
|
||||
}
|
||||
|
||||
var oldIP [16]byte
|
||||
copy(oldIP[:], packetData[offset:offset+16])
|
||||
newIPBytes := newIP.As16()
|
||||
copy(packetData[offset:offset+16], newIPBytes[:])
|
||||
|
||||
// IPv6 has no header checksum, only update transport checksums
|
||||
if len(d.decoded) > 1 {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
||||
case layers.LayerTypeUDP:
|
||||
m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
||||
case layers.LayerTypeICMPv6:
|
||||
// ICMPv6 checksum includes pseudo-header with addresses, use incremental update
|
||||
m.updateICMPv6Checksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -437,20 +351,6 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
||||
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
|
||||
}
|
||||
|
||||
// updateICMPv6Checksum updates ICMPv6 checksum after address change.
|
||||
// ICMPv6 uses a pseudo-header (like TCP/UDP), so incremental update applies.
|
||||
func (m *Manager) updateICMPv6Checksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||
icmpStart := ipHeaderLen
|
||||
if len(packetData) < icmpStart+4 {
|
||||
return
|
||||
}
|
||||
|
||||
checksumOffset := icmpStart + 2
|
||||
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
// incrementalUpdate performs incremental checksum update per RFC 1624.
|
||||
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||
sum := uint32(^oldChecksum)
|
||||
@@ -503,14 +403,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
}
|
||||
|
||||
// addPortRedirection adds a port redirection rule.
|
||||
func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, originalPort, translatedPort uint16) error {
|
||||
func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
|
||||
m.portDNATMutex.Lock()
|
||||
defer m.portDNATMutex.Unlock()
|
||||
|
||||
rule := portDNATRule{
|
||||
protocol: protocol,
|
||||
origPort: originalPort,
|
||||
targetPort: translatedPort,
|
||||
origPort: sourcePort,
|
||||
targetPort: targetPort,
|
||||
targetIP: targetIP,
|
||||
}
|
||||
|
||||
@@ -522,7 +422,7 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
// TODO: also delegate to nativeFirewall when available for kernel WG mode
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
var layerType gopacket.LayerType
|
||||
switch protocol {
|
||||
case firewall.ProtocolTCP:
|
||||
@@ -533,16 +433,16 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
|
||||
return fmt.Errorf("unsupported protocol: %s", protocol)
|
||||
}
|
||||
|
||||
return m.addPortRedirection(localAddr, layerType, originalPort, translatedPort)
|
||||
return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// removePortRedirection removes a port redirection rule.
|
||||
func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, originalPort, translatedPort uint16) error {
|
||||
func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
|
||||
m.portDNATMutex.Lock()
|
||||
defer m.portDNATMutex.Unlock()
|
||||
|
||||
m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool {
|
||||
return rule.protocol == protocol && rule.origPort == originalPort && rule.targetPort == translatedPort && rule.targetIP.Compare(targetIP) == 0
|
||||
return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0
|
||||
})
|
||||
|
||||
if len(m.portDNATRules) == 0 {
|
||||
@@ -553,7 +453,7 @@ func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.L
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
var layerType gopacket.LayerType
|
||||
switch protocol {
|
||||
case firewall.ProtocolTCP:
|
||||
@@ -564,23 +464,23 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return fmt.Errorf("unsupported protocol: %s", protocol)
|
||||
}
|
||||
|
||||
return m.removePortRedirection(localAddr, layerType, originalPort, translatedPort)
|
||||
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT delegates to the native firewall if available.
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return fmt.Errorf("output DNAT not supported without native firewall")
|
||||
}
|
||||
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT delegates to the native firewall if available.
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil
|
||||
}
|
||||
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
|
||||
@@ -621,9 +521,7 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti
|
||||
}
|
||||
|
||||
if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil {
|
||||
if m.logger.Enabled(nblog.LevelError) {
|
||||
m.logger.Error1("failed to rewrite port: %v", err)
|
||||
}
|
||||
m.logger.Error1("failed to rewrite port: %v", err)
|
||||
return false
|
||||
}
|
||||
d.dnatOrigPort = rule.origPort
|
||||
@@ -634,12 +532,12 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti
|
||||
|
||||
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
|
||||
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
|
||||
hdrLen, err := ipHeaderLen(d)
|
||||
if err != nil {
|
||||
return err
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return errInvalidIPHeaderLength
|
||||
}
|
||||
|
||||
tcpStart := hdrLen
|
||||
tcpStart := ipHeaderLen
|
||||
if len(packetData) < tcpStart+4 {
|
||||
return fmt.Errorf("packet too short for TCP header")
|
||||
}
|
||||
@@ -665,12 +563,12 @@ func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16,
|
||||
|
||||
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
|
||||
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
|
||||
hdrLen, err := ipHeaderLen(d)
|
||||
if err != nil {
|
||||
return err
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return errInvalidIPHeaderLength
|
||||
}
|
||||
|
||||
udpStart := hdrLen
|
||||
udpStart := ipHeaderLen
|
||||
if len(packetData) < udpStart+8 {
|
||||
return fmt.Errorf("packet too short for UDP header")
|
||||
}
|
||||
|
||||
@@ -342,17 +342,12 @@ func BenchmarkDNATMemoryAllocations(b *testing.B) {
|
||||
|
||||
// Parse the packet fresh each time to get a clean decoder
|
||||
d := &decoder{decoded: []gopacket.LayerType{}}
|
||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser4.IgnoreUnsupported = true
|
||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv6,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser6.IgnoreUnsupported = true
|
||||
err = d.decodePacket(testPacket)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
err = d.parser.DecodeLayers(testPacket, &d.decoded)
|
||||
assert.NoError(b, err)
|
||||
|
||||
manager.translateOutboundDNAT(testPacket, d)
|
||||
@@ -376,17 +371,12 @@ func BenchmarkDirectIPExtraction(b *testing.B) {
|
||||
b.Run("decoder_extraction", func(b *testing.B) {
|
||||
// Create decoder once for comparison
|
||||
d := &decoder{decoded: []gopacket.LayerType{}}
|
||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser4.IgnoreUnsupported = true
|
||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv6,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser6.IgnoreUnsupported = true
|
||||
err := d.decodePacket(packet)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
err := d.parser.DecodeLayers(packet, &d.decoded)
|
||||
assert.NoError(b, err)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
|
||||
@@ -86,18 +86,13 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser4.IgnoreUnsupported = true
|
||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv6,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser6.IgnoreUnsupported = true
|
||||
d.parser.IgnoreUnsupported = true
|
||||
|
||||
err := d.decodePacket(packetData)
|
||||
err := d.parser.DecodeLayers(packetData, &d.decoded)
|
||||
require.NoError(t, err)
|
||||
return d
|
||||
}
|
||||
|
||||
@@ -2,9 +2,7 @@ package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
@@ -114,13 +112,10 @@ func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string,
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) Build() ([]byte, error) {
|
||||
ipLayer, err := p.buildIPLayer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pktLayers := []gopacket.SerializableLayer{ipLayer}
|
||||
ip := p.buildIPLayer()
|
||||
pktLayers := []gopacket.SerializableLayer{ip}
|
||||
|
||||
transportLayer, err := p.buildTransportLayer(ipLayer)
|
||||
transportLayer, err := p.buildTransportLayer(ip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -134,43 +129,30 @@ func (p *PacketBuilder) Build() ([]byte, error) {
|
||||
return serializePacket(pktLayers)
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildIPLayer() (gopacket.SerializableLayer, error) {
|
||||
if p.SrcIP.Is4() != p.DstIP.Is4() {
|
||||
return nil, fmt.Errorf("mixed address families: src=%s dst=%s", p.SrcIP, p.DstIP)
|
||||
}
|
||||
proto := getIPProtocolNumber(p.Protocol, p.SrcIP.Is6())
|
||||
if p.SrcIP.Is6() {
|
||||
return &layers.IPv6{
|
||||
Version: 6,
|
||||
HopLimit: 64,
|
||||
NextHeader: proto,
|
||||
SrcIP: p.SrcIP.AsSlice(),
|
||||
DstIP: p.DstIP.AsSlice(),
|
||||
}, nil
|
||||
}
|
||||
func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
|
||||
return &layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: proto,
|
||||
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
||||
SrcIP: p.SrcIP.AsSlice(),
|
||||
DstIP: p.DstIP.AsSlice(),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildTransportLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
|
||||
func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||
switch p.Protocol {
|
||||
case "tcp":
|
||||
return p.buildTCPLayer(ipLayer)
|
||||
return p.buildTCPLayer(ip)
|
||||
case "udp":
|
||||
return p.buildUDPLayer(ipLayer)
|
||||
return p.buildUDPLayer(ip)
|
||||
case "icmp":
|
||||
return p.buildICMPLayer(ipLayer)
|
||||
return p.buildICMPLayer()
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildTCPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
|
||||
func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(p.SrcPort),
|
||||
DstPort: layers.TCPPort(p.DstPort),
|
||||
@@ -182,44 +164,24 @@ func (p *PacketBuilder) buildTCPLayer(ipLayer gopacket.SerializableLayer) ([]gop
|
||||
PSH: p.TCPState != nil && p.TCPState.PSH,
|
||||
URG: p.TCPState != nil && p.TCPState.URG,
|
||||
}
|
||||
if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
|
||||
if err := tcp.SetNetworkLayerForChecksum(nl); err != nil {
|
||||
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
|
||||
}
|
||||
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
|
||||
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
|
||||
}
|
||||
return []gopacket.SerializableLayer{tcp}, nil
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildUDPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
|
||||
func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||
udp := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(p.SrcPort),
|
||||
DstPort: layers.UDPPort(p.DstPort),
|
||||
}
|
||||
if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
|
||||
if err := udp.SetNetworkLayerForChecksum(nl); err != nil {
|
||||
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
|
||||
}
|
||||
if err := udp.SetNetworkLayerForChecksum(ip); err != nil {
|
||||
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
|
||||
}
|
||||
return []gopacket.SerializableLayer{udp}, nil
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildICMPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
|
||||
if p.SrcIP.Is6() || p.DstIP.Is6() {
|
||||
icmp := &layers.ICMPv6{
|
||||
TypeCode: layers.CreateICMPv6TypeCode(p.ICMPType, p.ICMPCode),
|
||||
}
|
||||
if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
|
||||
_ = icmp.SetNetworkLayerForChecksum(nl)
|
||||
}
|
||||
if p.ICMPType == layers.ICMPv6TypeEchoRequest || p.ICMPType == layers.ICMPv6TypeEchoReply {
|
||||
echo := &layers.ICMPv6Echo{
|
||||
Identifier: 1,
|
||||
SeqNumber: 1,
|
||||
}
|
||||
return []gopacket.SerializableLayer{icmp, echo}, nil
|
||||
}
|
||||
return []gopacket.SerializableLayer{icmp}, nil
|
||||
}
|
||||
func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) {
|
||||
icmp := &layers.ICMPv4{
|
||||
TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode),
|
||||
}
|
||||
@@ -242,17 +204,14 @@ func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) {
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func getIPProtocolNumber(protocol fw.Protocol, isV6 bool) layers.IPProtocol {
|
||||
func getIPProtocolNumber(protocol fw.Protocol) int {
|
||||
switch protocol {
|
||||
case fw.ProtocolTCP:
|
||||
return layers.IPProtocolTCP
|
||||
return int(layers.IPProtocolTCP)
|
||||
case fw.ProtocolUDP:
|
||||
return layers.IPProtocolUDP
|
||||
return int(layers.IPProtocolUDP)
|
||||
case fw.ProtocolICMP:
|
||||
if isV6 {
|
||||
return layers.IPProtocolICMPv6
|
||||
}
|
||||
return layers.IPProtocolICMPv4
|
||||
return int(layers.IPProtocolICMPv4)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
@@ -275,7 +234,7 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
||||
trace := &PacketTrace{Direction: direction}
|
||||
|
||||
// Initial packet decoding
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false)
|
||||
return trace
|
||||
}
|
||||
@@ -297,8 +256,6 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
||||
trace.DestinationPort = uint16(d.udp.DstPort)
|
||||
case layers.LayerTypeICMPv4:
|
||||
trace.Protocol = "ICMP"
|
||||
case layers.LayerTypeICMPv6:
|
||||
trace.Protocol = "ICMPv6"
|
||||
}
|
||||
|
||||
trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d",
|
||||
@@ -362,13 +319,6 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
|
||||
flags&conntrack.TCPFin != 0)
|
||||
case layers.LayerTypeICMPv4:
|
||||
msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq)
|
||||
case layers.LayerTypeICMPv6:
|
||||
var id, seq uint16
|
||||
if len(d.icmp6.Payload) >= 4 {
|
||||
id = uint16(d.icmp6.Payload[0])<<8 | uint16(d.icmp6.Payload[1])
|
||||
seq = uint16(d.icmp6.Payload[2])<<8 | uint16(d.icmp6.Payload[3])
|
||||
}
|
||||
msg += fmt.Sprintf(" (ICMPv6 ID=%d, Seq=%d)", id, seq)
|
||||
}
|
||||
return msg
|
||||
}
|
||||
@@ -445,7 +395,7 @@ func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP n
|
||||
trace.AddResult(StageRouteACL, msg, allowed)
|
||||
|
||||
if allowed && m.forwarder.Load() != nil {
|
||||
m.addForwardingResult(trace, "proxy-remote", net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort))), true)
|
||||
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
||||
}
|
||||
|
||||
trace.AddResult(StageCompleted, msgProcessingCompleted, allowed)
|
||||
@@ -465,7 +415,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
trace.AddResult(StageCompleted, "Packet dropped - decode error", false)
|
||||
return trace
|
||||
}
|
||||
@@ -484,7 +434,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
|
||||
func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool {
|
||||
portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d)
|
||||
if portDNATApplied {
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false)
|
||||
return true
|
||||
}
|
||||
@@ -494,7 +444,7 @@ func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *de
|
||||
|
||||
nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d)
|
||||
if nat1to1Applied {
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false)
|
||||
return true
|
||||
}
|
||||
@@ -559,7 +509,7 @@ func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP, _ := extractPacketIPs(packetData, d)
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
|
||||
translated := m.translateInboundReverse(packetData, d)
|
||||
if translated {
|
||||
@@ -589,7 +539,7 @@ func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d
|
||||
return false
|
||||
}
|
||||
|
||||
_, dstIP := extractPacketIPs(packetData, d)
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
|
||||
translated := m.translateOutboundDNAT(packetData, d)
|
||||
if translated {
|
||||
|
||||
@@ -119,7 +119,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse endpoint address: %w", err)
|
||||
}
|
||||
addrPort := netip.AddrPortFrom(addr.Unmap(), uint16(endpoint.Port))
|
||||
addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port))
|
||||
c.activityRecorder.UpsertAddress(peerKey, addrPort)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -2,7 +2,7 @@ package device
|
||||
|
||||
// TunAdapter is an interface for create tun device from external service
|
||||
type TunAdapter interface {
|
||||
ConfigureInterface(address string, addressV6 string, mtu int, dns string, searchDomains string, routes string) (int, error)
|
||||
ConfigureInterface(address string, mtu int, dns string, searchDomains string, routes string) (int, error)
|
||||
UpdateAddr(address string) error
|
||||
ProtectSocket(fd int32) bool
|
||||
}
|
||||
|
||||
@@ -63,7 +63,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
||||
searchDomainsToString = ""
|
||||
}
|
||||
|
||||
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.address.IPv6String(), int(t.mtu), dns, searchDomainsToString, routesString)
|
||||
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), int(t.mtu), dns, searchDomainsToString, routesString)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create Android interface: %s", err)
|
||||
return nil, err
|
||||
|
||||
@@ -131,32 +131,23 @@ func (t *TunDevice) Device() *device.Device {
|
||||
|
||||
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
||||
func (t *TunDevice) assignAddr() error {
|
||||
if out, err := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()).CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("add v4 address: %s: %w", string(out), err)
|
||||
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
log.Errorf("adding address command '%v' failed with output: %s", cmd.String(), out)
|
||||
return err
|
||||
}
|
||||
|
||||
// Assign a dummy link-local so macOS enables IPv6 on the tun device.
|
||||
// When a real overlay v6 is present, use that instead.
|
||||
v6Addr := "fe80::/64"
|
||||
if t.address.HasIPv6() {
|
||||
v6Addr = t.address.IPv6String()
|
||||
}
|
||||
if out, err := exec.Command("ifconfig", t.name, "inet6", v6Addr).CombinedOutput(); err != nil {
|
||||
log.Warnf("failed to assign IPv6 address %s, continuing v4-only: %s: %v", v6Addr, string(out), err)
|
||||
t.address.ClearIPv6()
|
||||
// dummy ipv6 so routing works
|
||||
cmd = exec.Command("ifconfig", t.name, "inet6", "fe80::/64")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
log.Debugf("adding address command '%v' failed with output: %s", cmd.String(), out)
|
||||
}
|
||||
|
||||
if out, err := exec.Command("route", "add", "-net", t.address.Network.String(), "-interface", t.name).CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("add route %s via %s: %s: %w", t.address.Network, t.name, string(out), err)
|
||||
routeCmd := exec.Command("route", "add", "-net", t.address.Network.String(), "-interface", t.name)
|
||||
if out, err := routeCmd.CombinedOutput(); err != nil {
|
||||
log.Errorf("adding route command '%v' failed with output: %s", routeCmd.String(), out)
|
||||
return err
|
||||
}
|
||||
|
||||
if t.address.HasIPv6() {
|
||||
if out, err := exec.Command("route", "add", "-inet6", "-net", t.address.IPv6Net.String(), "-interface", t.name).CombinedOutput(); err != nil {
|
||||
log.Warnf("failed to add route %s via %s, continuing v4-only: %s: %v", t.address.IPv6Net, t.name, string(out), err)
|
||||
t.address.ClearIPv6()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -151,11 +151,8 @@ func (t *TunDevice) MTU() uint16 {
|
||||
return t.mtu
|
||||
}
|
||||
|
||||
// UpdateAddr updates the device address. On iOS the tunnel is managed by the
|
||||
// NetworkExtension, so we only store the new value. The extension picks up the
|
||||
// change on the next tunnel reconfiguration.
|
||||
func (t *TunDevice) UpdateAddr(addr wgaddr.Address) error {
|
||||
t.address = addr
|
||||
func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
|
||||
// todo implement
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -173,7 +173,7 @@ func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
|
||||
|
||||
// assignAddr Adds IP address to the tunnel interface
|
||||
func (t *TunKernelDevice) assignAddr() error {
|
||||
return t.link.assignAddr(&t.address)
|
||||
return t.link.assignAddr(t.address)
|
||||
}
|
||||
|
||||
func (t *TunKernelDevice) GetNet() *netstack.Net {
|
||||
|
||||
@@ -3,7 +3,6 @@ package device
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
@@ -64,12 +63,8 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
||||
return nil, fmt.Errorf("last ip: %w", err)
|
||||
}
|
||||
|
||||
addresses := []netip.Addr{t.address.IP}
|
||||
if t.address.HasIPv6() {
|
||||
addresses = append(addresses, t.address.IPv6)
|
||||
}
|
||||
log.Debugf("netstack using addresses: %v", addresses)
|
||||
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, addresses, dnsAddr, int(t.mtu))
|
||||
log.Debugf("netstack using address: %s", t.address.IP)
|
||||
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, int(t.mtu))
|
||||
log.Debugf("netstack using dns address: %s", dnsAddr)
|
||||
tunIface, net, err := t.nsTun.Create()
|
||||
if err != nil {
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type TunDevice struct {
|
||||
type USPDevice struct {
|
||||
name string
|
||||
address wgaddr.Address
|
||||
port int
|
||||
@@ -30,10 +30,10 @@ type TunDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *TunDevice {
|
||||
func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *USPDevice {
|
||||
log.Infof("using userspace bind mode")
|
||||
|
||||
return &TunDevice{
|
||||
return &USPDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
port: port,
|
||||
@@ -43,7 +43,7 @@ func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
func (t *USPDevice) Create() (WGConfigurer, error) {
|
||||
log.Info("create tun interface")
|
||||
tunIface, err := tun.CreateTUN(t.name, int(t.mtu))
|
||||
if err != nil {
|
||||
@@ -75,7 +75,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
return t.configurer, nil
|
||||
}
|
||||
|
||||
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
if t.device == nil {
|
||||
return nil, fmt.Errorf("device is not ready yet")
|
||||
}
|
||||
@@ -95,12 +95,12 @@ func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
return udpMux, nil
|
||||
}
|
||||
|
||||
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
|
||||
func (t *USPDevice) UpdateAddr(address wgaddr.Address) error {
|
||||
t.address = address
|
||||
return t.assignAddr()
|
||||
}
|
||||
|
||||
func (t *TunDevice) Close() error {
|
||||
func (t *USPDevice) Close() error {
|
||||
if t.configurer != nil {
|
||||
t.configurer.Close()
|
||||
}
|
||||
@@ -115,39 +115,39 @@ func (t *TunDevice) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TunDevice) WgAddress() wgaddr.Address {
|
||||
func (t *USPDevice) WgAddress() wgaddr.Address {
|
||||
return t.address
|
||||
}
|
||||
|
||||
func (t *TunDevice) MTU() uint16 {
|
||||
func (t *USPDevice) MTU() uint16 {
|
||||
return t.mtu
|
||||
}
|
||||
|
||||
func (t *TunDevice) DeviceName() string {
|
||||
func (t *USPDevice) DeviceName() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
||||
func (t *USPDevice) FilteredDevice() *FilteredDevice {
|
||||
return t.filteredDevice
|
||||
}
|
||||
|
||||
// Device returns the wireguard device
|
||||
func (t *TunDevice) Device() *device.Device {
|
||||
func (t *USPDevice) Device() *device.Device {
|
||||
return t.device
|
||||
}
|
||||
|
||||
// assignAddr Adds IP address to the tunnel interface
|
||||
func (t *TunDevice) assignAddr() error {
|
||||
func (t *USPDevice) assignAddr() error {
|
||||
link := newWGLink(t.name)
|
||||
|
||||
return link.assignAddr(&t.address)
|
||||
return link.assignAddr(t.address)
|
||||
}
|
||||
|
||||
func (t *TunDevice) GetNet() *netstack.Net {
|
||||
func (t *USPDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetICEBind returns the ICEBind instance
|
||||
func (t *TunDevice) GetICEBind() EndpointManager {
|
||||
func (t *USPDevice) GetICEBind() EndpointManager {
|
||||
return t.iceBind
|
||||
}
|
||||
|
||||
@@ -87,21 +87,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
err = nbiface.Set()
|
||||
if err != nil {
|
||||
t.device.Close()
|
||||
return nil, fmt.Errorf("set IPv4 interface MTU: %s", err)
|
||||
}
|
||||
|
||||
if t.address.HasIPv6() {
|
||||
nbiface6, err := luid.IPInterface(windows.AF_INET6)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get IPv6 interface for MTU, continuing v4-only: %v", err)
|
||||
t.address.ClearIPv6()
|
||||
} else {
|
||||
nbiface6.NLMTU = uint32(t.mtu)
|
||||
if err := nbiface6.Set(); err != nil {
|
||||
log.Warnf("failed to set IPv6 interface MTU, continuing v4-only: %v", err)
|
||||
t.address.ClearIPv6()
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("got error when getting setting the interface mtu: %s", err)
|
||||
}
|
||||
err = t.assignAddr()
|
||||
if err != nil {
|
||||
@@ -192,21 +178,8 @@ func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
|
||||
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
||||
func (t *TunDevice) assignAddr() error {
|
||||
luid := winipcfg.LUID(t.nativeTunDevice.LUID())
|
||||
|
||||
v4Prefix := t.address.Prefix()
|
||||
if t.address.HasIPv6() {
|
||||
v6Prefix := t.address.IPv6Prefix()
|
||||
log.Debugf("adding addresses %s, %s to interface: %s", v4Prefix, v6Prefix, t.name)
|
||||
if err := luid.SetIPAddresses([]netip.Prefix{v4Prefix, v6Prefix}); err != nil {
|
||||
log.Warnf("failed to assign dual-stack addresses, retrying v4-only: %v", err)
|
||||
t.address.ClearIPv6()
|
||||
return luid.SetIPAddresses([]netip.Prefix{v4Prefix})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugf("adding address %s to interface: %s", v4Prefix, t.name)
|
||||
return luid.SetIPAddresses([]netip.Prefix{v4Prefix})
|
||||
log.Debugf("adding address %s to interface: %s", t.address.IP, t.name)
|
||||
return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())})
|
||||
}
|
||||
|
||||
func (t *TunDevice) GetNet() *netstack.Net {
|
||||
|
||||
8
client/iface/device/kernel_module.go
Normal file
8
client/iface/device/kernel_module.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build (!linux && !freebsd) || android
|
||||
|
||||
package device
|
||||
|
||||
// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
|
||||
func WireGuardModuleIsLoaded() bool {
|
||||
return false
|
||||
}
|
||||
18
client/iface/device/kernel_module_freebsd.go
Normal file
18
client/iface/device/kernel_module_freebsd.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package device
|
||||
|
||||
// WireGuardModuleIsLoaded check if kernel support wireguard
|
||||
func WireGuardModuleIsLoaded() bool {
|
||||
// Despite the fact FreeBSD natively support Wireguard (https://github.com/WireGuard/wireguard-freebsd)
|
||||
// we are currently do not use it, since it is required to add wireguard kernel support to
|
||||
// - https://github.com/netbirdio/netbird/tree/main/sharedsock
|
||||
// - https://github.com/mdlayher/socket
|
||||
// TODO: implement kernel space
|
||||
return false
|
||||
}
|
||||
|
||||
// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it
|
||||
func ModuleTunIsLoaded() bool {
|
||||
// Assume tun supported by freebsd kernel by default
|
||||
// TODO: implement check for module loaded in kernel or build-it
|
||||
return true
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package device
|
||||
|
||||
// WireGuardModuleIsLoaded reports whether the kernel WireGuard module is available.
|
||||
func WireGuardModuleIsLoaded() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// ModuleTunIsLoaded reports whether the tun device is available.
|
||||
func ModuleTunIsLoaded() bool {
|
||||
return true
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package device
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -58,32 +57,32 @@ func (l *wgLink) up() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *wgLink) assignAddr(address *wgaddr.Address) error {
|
||||
func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
||||
link, err := freebsd.LinkByName(l.name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("link by name: %w", err)
|
||||
}
|
||||
|
||||
ip := address.IP.String()
|
||||
|
||||
// Convert prefix length to hex netmask
|
||||
prefixLen := address.Network.Bits()
|
||||
if !address.IP.Is4() {
|
||||
return fmt.Errorf("IPv6 not supported for interface assignment")
|
||||
}
|
||||
|
||||
maskBits := uint32(0xffffffff) << (32 - prefixLen)
|
||||
mask := fmt.Sprintf("0x%08x", maskBits)
|
||||
|
||||
log.Infof("assign addr %s mask %s to %s interface", address.IP, mask, l.name)
|
||||
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)
|
||||
|
||||
if err := link.AssignAddr(address.IP.String(), mask); err != nil {
|
||||
err = link.AssignAddr(ip, mask)
|
||||
if err != nil {
|
||||
return fmt.Errorf("assign addr: %w", err)
|
||||
}
|
||||
|
||||
if address.HasIPv6() {
|
||||
log.Infof("assign IPv6 addr %s to %s interface", address.IPv6String(), l.name)
|
||||
cmd := exec.Command("ifconfig", l.name, "inet6", address.IPv6String())
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
log.Warnf("failed to assign IPv6 address %s to %s, continuing v4-only: %s: %v", address.IPv6String(), l.name, string(out), err)
|
||||
address.ClearIPv6()
|
||||
}
|
||||
}
|
||||
|
||||
if err := link.Up(); err != nil {
|
||||
err = link.Up()
|
||||
if err != nil {
|
||||
return fmt.Errorf("up: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,8 +4,6 @@ package device
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -94,7 +92,7 @@ func (l *wgLink) up() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *wgLink) assignAddr(address *wgaddr.Address) error {
|
||||
func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
||||
//delete existing addresses
|
||||
list, err := netlink.AddrList(l, 0)
|
||||
if err != nil {
|
||||
@@ -112,16 +110,20 @@ func (l *wgLink) assignAddr(address *wgaddr.Address) error {
|
||||
}
|
||||
|
||||
name := l.attrs.Name
|
||||
addrStr := address.String()
|
||||
|
||||
if err := l.addAddr(name, address.Prefix()); err != nil {
|
||||
return err
|
||||
log.Debugf("adding address %s to interface: %s", addrStr, name)
|
||||
|
||||
addr, err := netlink.ParseAddr(addrStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse addr: %w", err)
|
||||
}
|
||||
|
||||
if address.HasIPv6() {
|
||||
if err := l.addAddr(name, address.IPv6Prefix()); err != nil {
|
||||
log.Warnf("failed to assign IPv6 address %s to %s, continuing v4-only: %v", address.IPv6Prefix(), name, err)
|
||||
address.ClearIPv6()
|
||||
}
|
||||
err = netlink.AddrAdd(l, addr)
|
||||
if os.IsExist(err) {
|
||||
log.Infof("interface %s already has the address: %s", name, addrStr)
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("add addr: %w", err)
|
||||
}
|
||||
|
||||
// On linux, the link must be brought up
|
||||
@@ -131,22 +133,3 @@ func (l *wgLink) assignAddr(address *wgaddr.Address) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *wgLink) addAddr(ifaceName string, prefix netip.Prefix) error {
|
||||
log.Debugf("adding address %s to interface: %s", prefix, ifaceName)
|
||||
|
||||
addr := &netlink.Addr{
|
||||
IPNet: &net.IPNet{
|
||||
IP: prefix.Addr().AsSlice(),
|
||||
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()),
|
||||
},
|
||||
}
|
||||
|
||||
if err := netlink.AddrAdd(l, addr); os.IsExist(err) {
|
||||
log.Infof("interface %s already has the address: %s", ifaceName, prefix)
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("add addr %s: %w", prefix, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ type wgProxyFactory interface {
|
||||
|
||||
type WGIFaceOpts struct {
|
||||
IFaceName string
|
||||
Address wgaddr.Address
|
||||
Address string
|
||||
WGPort int
|
||||
WGPrivKey string
|
||||
MTU uint16
|
||||
@@ -141,11 +141,16 @@ func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
}
|
||||
|
||||
// UpdateAddr updates address of the interface
|
||||
func (w *WGIface) UpdateAddr(newAddr wgaddr.Address) error {
|
||||
func (w *WGIface) UpdateAddr(newAddr string) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
return w.tun.UpdateAddr(newAddr)
|
||||
addr, err := wgaddr.ParseWGAddress(newAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return w.tun.UpdateAddr(addr)
|
||||
}
|
||||
|
||||
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
|
||||
|
||||
@@ -4,17 +4,23 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
|
||||
if netstack.IsEnabled() {
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
|
||||
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
}
|
||||
return wgIFace, nil
|
||||
@@ -22,7 +28,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
tun: device.NewTunDevice(opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
|
||||
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
}
|
||||
return wgIFace, nil
|
||||
|
||||
35
client/iface/iface_new_darwin.go
Normal file
35
client/iface/iface_new_darwin.go
Normal file
@@ -0,0 +1,35 @@
|
||||
//go:build !ios
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
|
||||
var tun WGTunDevice
|
||||
if netstack.IsEnabled() {
|
||||
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||
} else {
|
||||
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||
}
|
||||
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
tun: tun,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
}
|
||||
return wgIFace, nil
|
||||
}
|
||||
41
client/iface/iface_new_freebsd.go
Normal file
41
client/iface/iface_new_freebsd.go
Normal file
@@ -0,0 +1,41 @@
|
||||
//go:build freebsd
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wgIFace := &WGIface{}
|
||||
|
||||
if netstack.IsEnabled() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
if device.ModuleTunIsLoaded() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("couldn't check or load tun module")
|
||||
}
|
||||
@@ -5,15 +5,21 @@ package iface
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
|
||||
wgIFace := &WGIface{
|
||||
tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
|
||||
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
|
||||
userspaceBind: true,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
}
|
||||
|
||||
@@ -4,15 +4,21 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode)
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
relayBind := bind.NewRelayBindJS()
|
||||
|
||||
wgIface := &WGIface{
|
||||
tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()),
|
||||
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()),
|
||||
userspaceBind: true,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(relayBind, opts.MTU),
|
||||
}
|
||||
|
||||
@@ -3,40 +3,44 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wgIFace := &WGIface{}
|
||||
|
||||
if netstack.IsEnabled() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
return &WGIface{
|
||||
tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
|
||||
userspaceBind: true,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
}, nil
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
if device.WireGuardModuleIsLoaded() {
|
||||
return &WGIface{
|
||||
tun: device.NewKernelDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet),
|
||||
wgProxyFactory: wgproxy.NewKernelFactory(opts.WGPort, opts.MTU),
|
||||
}, nil
|
||||
wgIFace.tun = device.NewKernelDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet)
|
||||
wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort, opts.MTU)
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
if device.ModuleTunIsLoaded() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
return &WGIface{
|
||||
tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind),
|
||||
userspaceBind: true,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
}, nil
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("tun module not available")
|
||||
return nil, fmt.Errorf("couldn't check or load tun module")
|
||||
}
|
||||
|
||||
@@ -1,28 +1,33 @@
|
||||
//go:build !linux && !ios && !android && !js
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
wgaddr "github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
|
||||
var tun WGTunDevice
|
||||
if netstack.IsEnabled() {
|
||||
tun = device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||
} else {
|
||||
tun = device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||
}
|
||||
|
||||
return &WGIface{
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
tun: tun,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
}, nil
|
||||
}
|
||||
return wgIFace, nil
|
||||
|
||||
}
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
@@ -49,7 +48,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
|
||||
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: wgaddr.MustParseWGAddress(addr),
|
||||
Address: addr,
|
||||
WGPort: wgPort,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
@@ -85,7 +84,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
|
||||
|
||||
//update WireGuard address
|
||||
addr = "100.64.0.2/8"
|
||||
err = iface.UpdateAddr(wgaddr.MustParseWGAddress(addr))
|
||||
err = iface.UpdateAddr(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -131,7 +130,7 @@ func Test_CreateInterface(t *testing.T) {
|
||||
}
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: wgaddr.MustParseWGAddress(wgIP),
|
||||
Address: wgIP,
|
||||
WGPort: 33100,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
@@ -175,7 +174,7 @@ func Test_Close(t *testing.T) {
|
||||
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: wgaddr.MustParseWGAddress(wgIP),
|
||||
Address: wgIP,
|
||||
WGPort: wgPort,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
@@ -220,7 +219,7 @@ func TestRecreation(t *testing.T) {
|
||||
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: wgaddr.MustParseWGAddress(wgIP),
|
||||
Address: wgIP,
|
||||
WGPort: wgPort,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
@@ -292,7 +291,7 @@ func Test_ConfigureInterface(t *testing.T) {
|
||||
}
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: wgaddr.MustParseWGAddress(wgIP),
|
||||
Address: wgIP,
|
||||
WGPort: wgPort,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
@@ -348,7 +347,7 @@ func Test_UpdatePeer(t *testing.T) {
|
||||
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: wgaddr.MustParseWGAddress(wgIP),
|
||||
Address: wgIP,
|
||||
WGPort: 33100,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
@@ -418,7 +417,7 @@ func Test_RemovePeer(t *testing.T) {
|
||||
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: wgaddr.MustParseWGAddress(wgIP),
|
||||
Address: wgIP,
|
||||
WGPort: 33100,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
@@ -483,7 +482,7 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
|
||||
optsPeer1 := WGIFaceOpts{
|
||||
IFaceName: peer1ifaceName,
|
||||
Address: wgaddr.MustParseWGAddress(peer1wgIP.String()),
|
||||
Address: peer1wgIP.String(),
|
||||
WGPort: peer1wgPort,
|
||||
WGPrivKey: peer1Key.String(),
|
||||
MTU: DefaultMTU,
|
||||
@@ -523,7 +522,7 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
|
||||
optsPeer2 := WGIFaceOpts{
|
||||
IFaceName: peer2ifaceName,
|
||||
Address: wgaddr.MustParseWGAddress(peer2wgIP.String()),
|
||||
Address: peer2wgIP.String(),
|
||||
WGPort: peer2wgPort,
|
||||
WGPrivKey: peer2Key.String(),
|
||||
MTU: DefaultMTU,
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
|
||||
|
||||
type NetStackTun struct { //nolint:revive
|
||||
addresses []netip.Addr
|
||||
address netip.Addr
|
||||
dnsAddress netip.Addr
|
||||
mtu int
|
||||
listenAddress string
|
||||
@@ -22,9 +22,9 @@ type NetStackTun struct { //nolint:revive
|
||||
tundev tun.Device
|
||||
}
|
||||
|
||||
func NewNetStackTun(listenAddress string, addresses []netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
|
||||
func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
|
||||
return &NetStackTun{
|
||||
addresses: addresses,
|
||||
address: address,
|
||||
dnsAddress: dnsAddress,
|
||||
mtu: mtu,
|
||||
listenAddress: listenAddress,
|
||||
@@ -33,7 +33,7 @@ func NewNetStackTun(listenAddress string, addresses []netip.Addr, dnsAddress net
|
||||
|
||||
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
||||
nsTunDev, tunNet, err := netstack.CreateNetTUN(
|
||||
t.addresses,
|
||||
[]netip.Addr{t.address},
|
||||
[]netip.Addr{t.dnsAddress},
|
||||
t.mtu)
|
||||
if err != nil {
|
||||
|
||||
@@ -3,18 +3,12 @@ package wgaddr
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
)
|
||||
|
||||
// Address WireGuard parsed address
|
||||
type Address struct {
|
||||
IP netip.Addr
|
||||
Network netip.Prefix
|
||||
|
||||
// IPv6 overlay address, if assigned.
|
||||
IPv6 netip.Addr
|
||||
IPv6Net netip.Prefix
|
||||
}
|
||||
|
||||
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
||||
@@ -29,60 +23,6 @@ func ParseWGAddress(address string) (Address, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// HasIPv6 reports whether a v6 overlay address is assigned.
|
||||
func (addr Address) HasIPv6() bool {
|
||||
return addr.IPv6.IsValid()
|
||||
}
|
||||
|
||||
func (addr Address) String() string {
|
||||
return addr.Prefix().String()
|
||||
}
|
||||
|
||||
// IPv6String returns the v6 address in CIDR notation, or empty string if none.
|
||||
func (addr Address) IPv6String() string {
|
||||
if !addr.HasIPv6() {
|
||||
return ""
|
||||
}
|
||||
return addr.IPv6Prefix().String()
|
||||
}
|
||||
|
||||
// Prefix returns the v4 host address with its network prefix length (e.g. 100.64.0.1/16).
|
||||
func (addr Address) Prefix() netip.Prefix {
|
||||
return netip.PrefixFrom(addr.IP, addr.Network.Bits())
|
||||
}
|
||||
|
||||
// IPv6Prefix returns the v6 host address with its network prefix length, or a zero prefix if none.
|
||||
func (addr Address) IPv6Prefix() netip.Prefix {
|
||||
if !addr.HasIPv6() {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
return netip.PrefixFrom(addr.IPv6, addr.IPv6Net.Bits())
|
||||
}
|
||||
|
||||
// SetIPv6FromCompact decodes a compact prefix (5 or 17 bytes) and sets the IPv6 fields.
|
||||
// Returns an error if the bytes are invalid. A nil or empty input is a no-op.
|
||||
//
|
||||
//nolint:recvcheck
|
||||
func (addr *Address) SetIPv6FromCompact(raw []byte) error {
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
prefix, err := netiputil.DecodePrefix(raw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode v6 overlay address: %w", err)
|
||||
}
|
||||
if !prefix.Addr().Is6() {
|
||||
return fmt.Errorf("expected IPv6 address, got %s", prefix.Addr())
|
||||
}
|
||||
addr.IPv6 = prefix.Addr()
|
||||
addr.IPv6Net = prefix.Masked()
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearIPv6 removes the IPv6 overlay address, leaving only v4.
|
||||
//
|
||||
//nolint:recvcheck
|
||||
func (addr *Address) ClearIPv6() {
|
||||
addr.IPv6 = netip.Addr{}
|
||||
addr.IPv6Net = netip.Prefix{}
|
||||
return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits())
|
||||
}
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
package wgaddr
|
||||
|
||||
// MustParseWGAddress parses and returns a WG Address, panicking on error.
|
||||
func MustParseWGAddress(address string) Address {
|
||||
a, err := ParseWGAddress(address)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return a
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -196,25 +196,18 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// fakeAddress returns a fake address that is used as an identifier for the peer.
|
||||
// The fake address is in the format of 127.1.x.x where x.x is derived from the
|
||||
// last two bytes of the peer address (works for both IPv4 and IPv6).
|
||||
// fakeAddress returns a fake address that is used to as an identifier for the peer.
|
||||
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
|
||||
func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
|
||||
if peerAddress == nil {
|
||||
return nil, fmt.Errorf("nil peer address")
|
||||
}
|
||||
if peerAddress.Port < 0 || peerAddress.Port > 65535 {
|
||||
return nil, fmt.Errorf("invalid UDP port: %d", peerAddress.Port)
|
||||
}
|
||||
|
||||
addr, ok := netip.AddrFromSlice(peerAddress.IP)
|
||||
if !ok {
|
||||
octets := strings.Split(peerAddress.IP.String(), ".")
|
||||
if len(octets) != 4 {
|
||||
return nil, fmt.Errorf("invalid IP format")
|
||||
}
|
||||
addr = addr.Unmap()
|
||||
|
||||
raw := addr.As16()
|
||||
fakeIP := netip.AddrFrom4([4]byte{127, 1, raw[14], raw[15]})
|
||||
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3]))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse new IP: %w", err)
|
||||
}
|
||||
|
||||
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
|
||||
return &netipAddr, nil
|
||||
|
||||
@@ -260,23 +260,15 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
|
||||
|
||||
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
|
||||
|
||||
; Drop Run, App Paths and Uninstall entries left in the 32-bit registry view
|
||||
; or HKCU by legacy installers.
|
||||
DetailPrint "Cleaning legacy 32-bit / HKCU entries..."
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
SetRegView 32
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DeleteRegKey HKLM "${REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UNINSTALL_PATH}"
|
||||
SetRegView 64
|
||||
|
||||
; Create autostart registry entry based on checkbox
|
||||
DetailPrint "Autostart enabled: $AutostartEnabled"
|
||||
${If} $AutostartEnabled == "1"
|
||||
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
|
||||
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
|
||||
${Else}
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DetailPrint "Autostart not enabled by user"
|
||||
${EndIf}
|
||||
|
||||
@@ -307,16 +299,11 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
||||
DetailPrint "Terminating Netbird UI process..."
|
||||
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
||||
|
||||
; Remove autostart entries from every view a previous installer may have used.
|
||||
; Remove autostart registry entry
|
||||
DetailPrint "Removing autostart registry entry if exists..."
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
SetRegView 32
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DeleteRegKey HKLM "${REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UNINSTALL_PATH}"
|
||||
SetRegView 64
|
||||
|
||||
; Handle data deletion based on checkbox
|
||||
DetailPrint "Checking if user requested data deletion..."
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
@@ -18,7 +19,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
)
|
||||
|
||||
var ErrSourceRangesEmpty = errors.New("sources range is empty")
|
||||
@@ -105,10 +105,6 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
newRulePairs := make(map[id.RuleID][]firewall.Rule)
|
||||
ipsetByRuleSelectors := make(map[string]string)
|
||||
|
||||
// TODO: deny rules should be fatal: if a deny rule fails to apply, we must
|
||||
// roll back all allow rules to avoid a fail-open where allowed traffic bypasses
|
||||
// the missing deny. Currently we accumulate errors and continue.
|
||||
var merr *multierror.Error
|
||||
for _, r := range rules {
|
||||
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
||||
// it's IP address can be used in the ipset for firewall manager which supports it
|
||||
@@ -121,8 +117,9 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
}
|
||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
|
||||
continue
|
||||
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
||||
d.rollBack(newRulePairs)
|
||||
break
|
||||
}
|
||||
if len(rulePair) > 0 {
|
||||
d.peerRulesPairs[pairID] = rulePair
|
||||
@@ -130,10 +127,6 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
}
|
||||
}
|
||||
|
||||
if merr != nil {
|
||||
log.Errorf("failed to apply %d peer ACL rule(s): %v", merr.Len(), nberrors.FormatErrorOrNil(merr))
|
||||
}
|
||||
|
||||
for pairID, rules := range d.peerRulesPairs {
|
||||
if _, ok := newRulePairs[pairID]; !ok {
|
||||
for _, rule := range rules {
|
||||
@@ -223,9 +216,9 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
r *mgmProto.FirewallRule,
|
||||
ipsetName string,
|
||||
) (id.RuleID, []firewall.Rule, error) {
|
||||
ip, err := extractRuleIP(r)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
ip := net.ParseIP(r.PeerIP)
|
||||
if ip == nil {
|
||||
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
}
|
||||
|
||||
protocol, err := convertToFirewallProtocol(r.Protocol)
|
||||
@@ -296,13 +289,13 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
||||
|
||||
func (d *DefaultManager) addInRules(
|
||||
id []byte,
|
||||
ip netip.Addr,
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, nil, port, action, ipsetName)
|
||||
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, nil, port, action, ipsetName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
@@ -312,7 +305,7 @@ func (d *DefaultManager) addInRules(
|
||||
|
||||
func (d *DefaultManager) addOutRules(
|
||||
id []byte,
|
||||
ip netip.Addr,
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
@@ -322,7 +315,7 @@ func (d *DefaultManager) addOutRules(
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, port, nil, action, ipsetName)
|
||||
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, port, nil, action, ipsetName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
@@ -330,9 +323,9 @@ func (d *DefaultManager) addOutRules(
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// getPeerRuleID returns unique ID for the rule based on its parameters.
|
||||
// getPeerRuleID() returns unique ID for the rule based on its parameters.
|
||||
func (d *DefaultManager) getPeerRuleID(
|
||||
ip netip.Addr,
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
direction int,
|
||||
port *firewall.Port,
|
||||
@@ -351,25 +344,15 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
|
||||
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
|
||||
}
|
||||
|
||||
|
||||
// extractRuleIP extracts the peer IP from a firewall rule.
|
||||
// If sourcePrefixes is populated (new management), decode the first entry and use its address.
|
||||
// Otherwise fall back to the deprecated PeerIP string field (old management).
|
||||
func extractRuleIP(r *mgmProto.FirewallRule) (netip.Addr, error) {
|
||||
if len(r.SourcePrefixes) > 0 {
|
||||
addr, err := netiputil.DecodeAddr(r.SourcePrefixes[0])
|
||||
if err != nil {
|
||||
return netip.Addr{}, fmt.Errorf("decode source prefix: %w", err)
|
||||
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
|
||||
log.Debugf("rollback ACL to previous state")
|
||||
for _, rules := range newRulePairs {
|
||||
for _, rule := range rules {
|
||||
if err := d.firewall.DeletePeerRule(rule); err != nil {
|
||||
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.ID(), err)
|
||||
}
|
||||
}
|
||||
return addr.Unmap(), nil
|
||||
}
|
||||
|
||||
//nolint:staticcheck // PeerIP used for backward compatibility with old management
|
||||
addr, err := netip.ParseAddr(r.PeerIP)
|
||||
if err != nil {
|
||||
return netip.Addr{}, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
}
|
||||
return addr.Unmap(), nil
|
||||
}
|
||||
|
||||
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
|
||||
|
||||
@@ -321,7 +321,6 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||
a.config.DisableFirewall,
|
||||
a.config.BlockLANAccess,
|
||||
a.config.BlockInbound,
|
||||
a.config.DisableIPv6,
|
||||
a.config.LazyConnectionEnabled,
|
||||
a.config.EnableSSHRoot,
|
||||
a.config.EnableSSHSFTP,
|
||||
|
||||
@@ -14,13 +14,10 @@ import (
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
@@ -116,6 +113,7 @@ func (c *ConnectClient) RunOniOS(
|
||||
fileDescriptor int32,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsManager dns.IosDnsManager,
|
||||
dnsAddresses []netip.AddrPort,
|
||||
stateFilePath string,
|
||||
) error {
|
||||
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
||||
@@ -125,6 +123,7 @@ func (c *ConnectClient) RunOniOS(
|
||||
FileDescriptor: fileDescriptor,
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
DnsManager: dnsManager,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
StateFilePath: stateFilePath,
|
||||
}
|
||||
return c.run(mobileDependency, nil, "")
|
||||
@@ -537,20 +536,9 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
||||
if config.NetworkMonitor != nil {
|
||||
nm = *config.NetworkMonitor
|
||||
}
|
||||
wgAddr, err := wgaddr.ParseWGAddress(peerConfig.Address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse overlay address %q: %w", peerConfig.Address, err)
|
||||
}
|
||||
|
||||
if !config.DisableIPv6 {
|
||||
if err := wgAddr.SetIPv6FromCompact(peerConfig.GetAddressV6()); err != nil {
|
||||
log.Warn(err)
|
||||
}
|
||||
}
|
||||
|
||||
engineConf := &EngineConfig{
|
||||
WgIfaceName: config.WgIface,
|
||||
WgAddr: wgAddr,
|
||||
WgAddr: peerConfig.Address,
|
||||
IFaceBlackList: config.IFaceBlackList,
|
||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||
WgPrivateKey: key,
|
||||
@@ -575,7 +563,6 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
||||
DisableFirewall: config.DisableFirewall,
|
||||
BlockLANAccess: config.BlockLANAccess,
|
||||
BlockInbound: config.BlockInbound,
|
||||
DisableIPv6: config.DisableIPv6,
|
||||
|
||||
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
||||
|
||||
@@ -650,7 +637,6 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.DisableFirewall,
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.DisableIPv6,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
|
||||
@@ -40,10 +40,6 @@ func (noopNetworkChangeListener) SetInterfaceIP(string) {
|
||||
// network stack, not by OS-level interface configuration.
|
||||
}
|
||||
|
||||
func (noopNetworkChangeListener) SetInterfaceIPv6(string) {
|
||||
// No-op: same as SetInterfaceIP, IPv6 overlay is managed by userspace stack.
|
||||
}
|
||||
|
||||
// noopDnsReadyListener is a stub for embed.Client on Android.
|
||||
// DNS readiness notifications are not needed in netstack/embed mode
|
||||
// since system DNS is disabled and DNS resolution happens externally.
|
||||
|
||||
@@ -31,7 +31,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
)
|
||||
|
||||
const readmeContent = `Netbird debug bundle
|
||||
@@ -45,11 +44,8 @@ netbird.out: Most recent, anonymized stdout log file of the NetBird client.
|
||||
routes.txt: Detailed system routing table in tabular format including destination, gateway, interface, metrics, and protocol information, if --system-info flag was provided.
|
||||
interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
|
||||
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
|
||||
iptables.txt: Anonymized iptables (IPv4) rules with packet counters, if --system-info flag was provided.
|
||||
ip6tables.txt: Anonymized ip6tables (IPv6) rules with packet counters, if --system-info flag was provided.
|
||||
ipset.txt: Anonymized ipset list output, if --system-info flag was provided.
|
||||
nftables.txt: Anonymized nftables rules with packet counters across all families (ip, ip6, inet, etc.), if --system-info flag was provided.
|
||||
sysctls.txt: Forwarding, reverse-path filter, source-validation, and conntrack accounting sysctl values that the NetBird client may read or modify, if --system-info flag was provided (Linux only).
|
||||
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
|
||||
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
|
||||
resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided.
|
||||
scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided.
|
||||
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
|
||||
@@ -168,33 +164,22 @@ The config.txt file contains anonymized configuration information of the NetBird
|
||||
Other non-sensitive configuration options are included without anonymization.
|
||||
|
||||
Firewall Rules (Linux only)
|
||||
The bundle includes the following firewall-related files:
|
||||
The bundle includes two separate firewall rule files:
|
||||
|
||||
iptables.txt:
|
||||
- IPv4 iptables ruleset with packet counters using 'iptables-save' and 'iptables -v -n -L'
|
||||
- Complete iptables ruleset with packet counters using 'iptables -v -n -L'
|
||||
- Includes all tables (filter, nat, mangle, raw, security)
|
||||
- Shows packet and byte counters for each rule
|
||||
- All IP addresses are anonymized
|
||||
- Chain names, table names, and other non-sensitive information remain unchanged
|
||||
|
||||
ip6tables.txt:
|
||||
- IPv6 ip6tables ruleset with packet counters using 'ip6tables-save' and 'ip6tables -v -n -L'
|
||||
- Same table coverage and anonymization as iptables.txt
|
||||
- Omitted when ip6tables is not installed or no IPv6 rules are present
|
||||
|
||||
ipset.txt:
|
||||
- Output of 'ipset list' (family-agnostic)
|
||||
- IP addresses are anonymized; set names and types remain unchanged
|
||||
|
||||
nftables.txt:
|
||||
- Complete nftables ruleset across all families (ip, ip6, inet, arp, bridge, netdev) via 'nft -a list ruleset'
|
||||
- Complete nftables ruleset obtained via 'nft -a list ruleset'
|
||||
- Includes rule handle numbers and packet counters
|
||||
- All IP addresses are anonymized; chain/table names remain unchanged
|
||||
|
||||
sysctls.txt:
|
||||
- Forwarding (IPv4 + IPv6, global and per-interface), reverse-path filter, source-validation, conntrack accounting, and TCP-related sysctls that netbird may read or modify
|
||||
- Per-interface keys are enumerated from /proc/sys/net/ipv{4,6}/conf
|
||||
- Interface names anonymized when --anonymize is set
|
||||
- All tables, chains, and rules are included
|
||||
- Shows packet and byte counters for each rule
|
||||
- All IP addresses are anonymized
|
||||
- Chain names, table names, and other non-sensitive information remain unchanged
|
||||
|
||||
IP Rules (Linux only)
|
||||
The ip_rules.txt file contains detailed IP routing rule information:
|
||||
@@ -426,10 +411,6 @@ func (g *BundleGenerator) addSystemInfo() {
|
||||
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addSysctls(); err != nil {
|
||||
log.Errorf("failed to add sysctls to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addDNSInfo(); err != nil {
|
||||
log.Errorf("failed to add DNS info to debug bundle: %v", err)
|
||||
}
|
||||
@@ -643,7 +624,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
||||
configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall))
|
||||
configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess))
|
||||
configContent.WriteString(fmt.Sprintf("BlockInbound: %v\n", g.internalConfig.BlockInbound))
|
||||
configContent.WriteString(fmt.Sprintf("DisableIPv6: %v\n", g.internalConfig.DisableIPv6))
|
||||
|
||||
if g.internalConfig.DisableNotifications != nil {
|
||||
configContent.WriteString(fmt.Sprintf("DisableNotifications: %v\n", *g.internalConfig.DisableNotifications))
|
||||
@@ -1314,21 +1294,6 @@ func anonymizePeerConfig(config *mgmProto.PeerConfig, anonymizer *anonymize.Anon
|
||||
config.Address = anonymizer.AnonymizeIP(addr).String()
|
||||
}
|
||||
|
||||
if len(config.GetAddressV6()) > 0 {
|
||||
v6Prefix, err := netiputil.DecodePrefix(config.GetAddressV6())
|
||||
if err != nil {
|
||||
config.AddressV6 = nil
|
||||
} else {
|
||||
anonV6 := anonymizer.AnonymizeIP(v6Prefix.Addr())
|
||||
b, err := netiputil.EncodePrefix(netip.PrefixFrom(anonV6, v6Prefix.Bits()))
|
||||
if err != nil {
|
||||
config.AddressV6 = nil
|
||||
} else {
|
||||
config.AddressV6 = b
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
anonymizeSSHConfig(config.SshConfig)
|
||||
|
||||
config.Dns = anonymizer.AnonymizeString(config.Dns)
|
||||
@@ -1431,20 +1396,8 @@ func anonymizeFirewallRule(rule *mgmProto.FirewallRule, anonymizer *anonymize.An
|
||||
return
|
||||
}
|
||||
|
||||
//nolint:staticcheck // PeerIP used for backward compatibility
|
||||
if addr, err := netip.ParseAddr(rule.PeerIP); err == nil {
|
||||
rule.PeerIP = anonymizer.AnonymizeIP(addr).String() //nolint:staticcheck
|
||||
}
|
||||
|
||||
for i, raw := range rule.GetSourcePrefixes() {
|
||||
p, err := netiputil.DecodePrefix(raw)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
anonAddr := anonymizer.AnonymizeIP(p.Addr())
|
||||
if b, err := netiputil.EncodePrefix(netip.PrefixFrom(anonAddr, p.Bits())); err == nil {
|
||||
rule.SourcePrefixes[i] = b
|
||||
}
|
||||
rule.PeerIP = anonymizer.AnonymizeIP(addr).String()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -124,18 +124,15 @@ func getSystemdLogs(serviceName string) (string, error) {
|
||||
// addFirewallRules collects and adds firewall rules to the archive
|
||||
func (g *BundleGenerator) addFirewallRules() error {
|
||||
log.Info("Collecting firewall rules")
|
||||
g.addIPTablesRulesToBundle("iptables-save", "iptables", "iptables.txt")
|
||||
g.addIPTablesRulesToBundle("ip6tables-save", "ip6tables", "ip6tables.txt")
|
||||
|
||||
ipsetOutput, err := collectIPSets()
|
||||
iptablesRules, err := collectIPTablesRules()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to collect ipset information: %v", err)
|
||||
log.Warnf("Failed to collect iptables rules: %v", err)
|
||||
} else {
|
||||
if g.anonymize {
|
||||
ipsetOutput = g.anonymizer.AnonymizeString(ipsetOutput)
|
||||
iptablesRules = g.anonymizer.AnonymizeString(iptablesRules)
|
||||
}
|
||||
if err := g.addFileToZip(strings.NewReader(ipsetOutput), "ipset.txt"); err != nil {
|
||||
log.Warnf("Failed to add ipset output to bundle: %v", err)
|
||||
if err := g.addFileToZip(strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
|
||||
log.Warnf("Failed to add iptables rules to bundle: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -154,65 +151,44 @@ func (g *BundleGenerator) addFirewallRules() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// addIPTablesRulesToBundle collects iptables/ip6tables rules and writes them to the bundle.
|
||||
func (g *BundleGenerator) addIPTablesRulesToBundle(saveBin, listBin, filename string) {
|
||||
rules, err := collectIPTablesRules(saveBin, listBin)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to collect %s rules: %v", listBin, err)
|
||||
return
|
||||
}
|
||||
if g.anonymize {
|
||||
rules = g.anonymizer.AnonymizeString(rules)
|
||||
}
|
||||
if err := g.addFileToZip(strings.NewReader(rules), filename); err != nil {
|
||||
log.Warnf("Failed to add %s rules to bundle: %v", listBin, err)
|
||||
}
|
||||
}
|
||||
|
||||
// collectIPTablesRules collects rules using both <saveBin> and verbose listing via <listBin>.
|
||||
// Returns an error when neither command produced any output (e.g. the binary is missing),
|
||||
// so the caller can skip writing an empty file.
|
||||
func collectIPTablesRules(saveBin, listBin string) (string, error) {
|
||||
// collectIPTablesRules collects rules using both iptables-save and verbose listing
|
||||
func collectIPTablesRules() (string, error) {
|
||||
var builder strings.Builder
|
||||
var collected bool
|
||||
var firstErr error
|
||||
|
||||
saveOutput, err := runCommand(saveBin)
|
||||
switch {
|
||||
case err != nil:
|
||||
firstErr = err
|
||||
log.Warnf("Failed to collect %s output: %v", saveBin, err)
|
||||
case strings.TrimSpace(saveOutput) == "":
|
||||
log.Debugf("%s produced no output, skipping", saveBin)
|
||||
default:
|
||||
builder.WriteString(fmt.Sprintf("=== %s output ===\n", saveBin))
|
||||
saveOutput, err := collectIPTablesSave()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to collect iptables rules using iptables-save: %v", err)
|
||||
} else {
|
||||
builder.WriteString("=== iptables-save output ===\n")
|
||||
builder.WriteString(saveOutput)
|
||||
builder.WriteString("\n")
|
||||
collected = true
|
||||
}
|
||||
|
||||
listHeader := fmt.Sprintf("=== %s -v -n -L output ===\n", listBin)
|
||||
builder.WriteString(listHeader)
|
||||
ipsetOutput, err := collectIPSets()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to collect ipset information: %v", err)
|
||||
} else {
|
||||
builder.WriteString("=== ipset list output ===\n")
|
||||
builder.WriteString(ipsetOutput)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
builder.WriteString("=== iptables -v -n -L output ===\n")
|
||||
|
||||
tables := []string{"filter", "nat", "mangle", "raw", "security"}
|
||||
|
||||
for _, table := range tables {
|
||||
stats, err := runCommand(listBin, "-v", "-n", "-L", "-t", table)
|
||||
builder.WriteString(fmt.Sprintf("*%s\n", table))
|
||||
|
||||
stats, err := getTableStatistics(table)
|
||||
if err != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
log.Warnf("Failed to get %s statistics for table %s: %v", listBin, table, err)
|
||||
log.Warnf("Failed to get statistics for table %s: %v", table, err)
|
||||
continue
|
||||
}
|
||||
builder.WriteString(fmt.Sprintf("*%s\n", table))
|
||||
builder.WriteString(stats)
|
||||
builder.WriteString("\n")
|
||||
collected = true
|
||||
}
|
||||
|
||||
if !collected {
|
||||
return "", fmt.Errorf("collect %s rules: %w", listBin, firstErr)
|
||||
}
|
||||
return builder.String(), nil
|
||||
}
|
||||
|
||||
@@ -238,15 +214,34 @@ func collectIPSets() (string, error) {
|
||||
return ipsets, nil
|
||||
}
|
||||
|
||||
// runCommand executes a command and returns its stdout, wrapping stderr in the error on failure.
|
||||
func runCommand(name string, args ...string) (string, error) {
|
||||
cmd := exec.Command(name, args...)
|
||||
// collectIPTablesSave uses iptables-save to get rule definitions
|
||||
func collectIPTablesSave() (string, error) {
|
||||
cmd := exec.Command("iptables-save")
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("execute %s: %w (stderr: %s)", name, err, stderr.String())
|
||||
return "", fmt.Errorf("execute iptables-save: %w (stderr: %s)", err, stderr.String())
|
||||
}
|
||||
|
||||
rules := stdout.String()
|
||||
if strings.TrimSpace(rules) == "" {
|
||||
return "", fmt.Errorf("no iptables rules found")
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
// getTableStatistics gets verbose statistics for an entire table using iptables command
|
||||
func getTableStatistics(table string) (string, error) {
|
||||
cmd := exec.Command("iptables", "-v", "-n", "-L", "-t", table)
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("execute iptables -v -n -L: %w (stderr: %s)", err, stderr.String())
|
||||
}
|
||||
|
||||
return stdout.String(), nil
|
||||
@@ -809,91 +804,3 @@ func formatSetKeyType(keyType nftables.SetDatatype) string {
|
||||
return fmt.Sprintf("type-%v", keyType)
|
||||
}
|
||||
}
|
||||
|
||||
// addSysctls collects forwarding and netbird-managed sysctl values and writes them to the bundle.
|
||||
func (g *BundleGenerator) addSysctls() error {
|
||||
log.Info("Collecting sysctls")
|
||||
content := collectSysctls()
|
||||
if g.anonymize {
|
||||
content = g.anonymizer.AnonymizeString(content)
|
||||
}
|
||||
if err := g.addFileToZip(strings.NewReader(content), "sysctls.txt"); err != nil {
|
||||
return fmt.Errorf("add sysctls to bundle: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// collectSysctls reads every sysctl that the netbird client may modify, plus
|
||||
// global IPv4/IPv6 forwarding, and returns a formatted dump grouped by topic.
|
||||
// Per-interface values are enumerated by listing /proc/sys/net/ipv{4,6}/conf.
|
||||
func collectSysctls() string {
|
||||
var builder strings.Builder
|
||||
|
||||
writeSysctlGroup(&builder, "forwarding", []string{
|
||||
"net.ipv4.ip_forward",
|
||||
"net.ipv6.conf.all.forwarding",
|
||||
"net.ipv6.conf.default.forwarding",
|
||||
})
|
||||
writeSysctlGroup(&builder, "ipv4 per-interface forwarding", listInterfaceSysctls("ipv4", "forwarding"))
|
||||
writeSysctlGroup(&builder, "ipv6 per-interface forwarding", listInterfaceSysctls("ipv6", "forwarding"))
|
||||
writeSysctlGroup(&builder, "rp_filter", append(
|
||||
[]string{"net.ipv4.conf.all.rp_filter", "net.ipv4.conf.default.rp_filter"},
|
||||
listInterfaceSysctls("ipv4", "rp_filter")...,
|
||||
))
|
||||
writeSysctlGroup(&builder, "src_valid_mark", append(
|
||||
[]string{"net.ipv4.conf.all.src_valid_mark", "net.ipv4.conf.default.src_valid_mark"},
|
||||
listInterfaceSysctls("ipv4", "src_valid_mark")...,
|
||||
))
|
||||
writeSysctlGroup(&builder, "conntrack", []string{
|
||||
"net.netfilter.nf_conntrack_acct",
|
||||
"net.netfilter.nf_conntrack_tcp_loose",
|
||||
})
|
||||
writeSysctlGroup(&builder, "tcp", []string{
|
||||
"net.ipv4.tcp_tw_reuse",
|
||||
})
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func writeSysctlGroup(builder *strings.Builder, title string, keys []string) {
|
||||
builder.WriteString(fmt.Sprintf("=== %s ===\n", title))
|
||||
for _, key := range keys {
|
||||
value, err := readSysctl(key)
|
||||
if err != nil {
|
||||
builder.WriteString(fmt.Sprintf("%s = <error: %v>\n", key, err))
|
||||
continue
|
||||
}
|
||||
builder.WriteString(fmt.Sprintf("%s = %s\n", key, value))
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
// listInterfaceSysctls returns net.ipvX.conf.<iface>.<leaf> keys for every
|
||||
// interface present in /proc/sys/net/ipvX/conf, skipping "all" and "default"
|
||||
// (callers add those explicitly so they appear first).
|
||||
func listInterfaceSysctls(family, leaf string) []string {
|
||||
dir := fmt.Sprintf("/proc/sys/net/%s/conf", family)
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var keys []string
|
||||
for _, e := range entries {
|
||||
name := e.Name()
|
||||
if name == "all" || name == "default" {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, fmt.Sprintf("net.%s.conf.%s.%s", family, name, leaf))
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func readSysctl(key string) (string, error) {
|
||||
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
|
||||
value, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(string(value)), nil
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user