Compare commits

..

20 Commits

Author SHA1 Message Date
Viktor Liu
9925f64e76 Unify peer and route ACL filtering with multi-source peer rules 2026-06-01 20:25:12 +02:00
Pascal Fischer
9189625487 [management] enrich context in permissions manager (#6286) 2026-05-29 16:36:38 +02:00
Bethuel Mmbaga
e9dbf9db6f [management] Extend combined server initialization (#6156) 2026-05-29 17:35:35 +03:00
Theodor Midtlien
5a9e9e7bc9 [Infrastructure] Pin actions with SHA and improve workflows (#6249)
* Pin actions with SHA, replace unmaintained, add dependabot for actions

* Update FreeBSD to version 15 for tests

* Use shared actions

* Update sign-pipelines version
2026-05-29 15:24:30 +02:00
Viktor Liu
43e041cf9f [client] Apply netroute unspecified-destination workaround on android (#6192) 2026-05-29 15:15:22 +02:00
Viktor Liu
77e5693200 [client] Recognize NetBird DNS forwarder port in capture text format (#6177) 2026-05-29 15:14:32 +02:00
Zoltan Papp
174dc24867 [management] Add SSO session extend flow (management) (#6197)
* add SSO session extend flow (management)

Adds the management-server half of the SSO session-extension feature:

- New ExtendAuthSession gRPC RPC that refreshes a peer's session expiry
  using a fresh JWT, validated through the same pipeline as Login but
  without tearing down the tunnel or redoing the NetworkMap sync.
- Per-peer SessionExpiresAt timestamp on every LoginResponse and
  SyncResponse so connected clients learn the deadline on the existing
  long-lived stream, and admin-side changes (toggling expiration,
  changing the expiration window) reach every peer within seconds.
- SessionExpiresAt(...) helper on Peer that derives the absolute UTC
  deadline from LastLogin + the account-level PeerLoginExpiration
  setting, returning zero when the peer is not SSO-tracked or expiration
  is disabled.

The matching client-side consumer of these fields lands separately.

* encode SessionExpiresAt as 3-state on the wire

Previously the `sessionExpiresAt` field on LoginResponse, SyncResponse
and ExtendAuthSessionResponse was 2-state: a valid timestamp meant
"new deadline", and nil meant "clear". That conflated two distinct
meanings — "no info in this snapshot" vs "expiry is explicitly off /
peer is not SSO-tracked" — so a Sync push that legitimately couldn't
compute the deadline (settings lookup failed) would silently clear the
client's anchor and lose the warning window.

Three states now, encoded on the same field number (no .proto schema
churn — only comments and the server-side encoder change):

  - nil pointer (field absent) → "no info"; client preserves anchor
  - &Timestamp{} (seconds=0, nanos=0) → explicit "disabled / not SSO"
    sentinel; client clears
  - valid timestamp → new absolute UTC deadline

A new encodeSessionExpiresAt helper centralises the zero/non-zero
encoding and is shared by the Sync, Login and ExtendAuthSession
builders. The Sync builder still emits nil when settings are missing.
Login and ExtendAuthSession always carry an authoritative value.

The matching client-side decoder lands on feature/session-extend.

* add UserExtendedPeerSession activity event

ExtendAuthSession previously reused UserLoggedInPeer for its audit
record, which conflated two distinct user actions: a full interactive
SSO login (tunnel re-established, network map resync) versus an
in-place deadline refresh (tunnel untouched). Auditors reading the log
couldn't tell which one happened, and downstream dashboards/alerts on
"login" volume were polluted by routine extends.

Adds a dedicated UserExtendedPeerSession Activity (code 125,
"user.peer.session.extend") and switches ExtendPeerSession over to it.
The peer-extend audit trail is now distinguishable from interactive
logins.

* make ExtendAuthSession JWT-retry backoff cancellable

Skip the retry log and 200ms wait on the final attempt, and replace the
uncancellable time.Sleep with a select on time.After/ctx.Done so an
upstream cancellation aborts the wait instead of running it to
completion.
2026-05-28 19:14:14 +02:00
Riccardo Manfrin
7ea5e37dd4 [client] Improve rosenpass support (#6136)
* Updates rosenpass version

go-rosenpass v0.4.0 → v0.5.42 bump — detailed findings

Change summary
cunicu.li/go-rosenpass  v0.4.0  → v0.5.42   (target)
cilium/ebpf             v0.15.0 → v0.19.0   (transitive)
gopacket/gopacket       v1.1.1  → v1.4.0    (transitive)
wireguard               2023-07 → 2023-12   (transitive)
wireguard/wgctrl        2023-04 → 2024-12   (transitive)

Wire interop

v0.4.0 (in v0.70.5) <-> v0.5.42 OK
v0.5.42 <-> v0.5.42 OK

Quantum resistance: true both ends

---
**Replay error eliminated.**

Before (on v0.4.0):

`ERROR Failed to handle message: failed to load biscuit (ICR1): detected replay`

Recurring every ~50ms for minutes at a time. Gone entirely after both ends upgraded to v0.5.42. Upstream fix in biscuit/replay handling between v0.4.x and v0.5.x series.

* Fixup [::]:port socket trying to send to v4

* Adds more tests on netbird<->rosenpass interactions

* Anticipates rp handler creation before generateConfig

* [client] Moves deterministic key gen into rosenpass

* go mod tidy

* Adds reminder to reason about rosenpass surface area

* Apply code rabbit suggestions
2026-05-28 09:01:18 +02:00
Riccardo Manfrin
9d7ef9b255 [client] Fix statemanager possible deadlock (#6228)
1. Stop() takes m.mu.Lock() and defers m.mu.Unlock()
2. <-m.done under lock
3. periodicStateSave defers close(m.done)
4. periodicStateSave calls PersistState() (line 256) which does m.mu.Lock()

Double Stop() remains idempotent: second cancel() on dead ctx
 (no-op) and reads done already closed (immediate return).
2026-05-28 08:54:15 +02:00
Pascal Fischer
944a258459 [management] extend nmap monitoring (#6271) 2026-05-27 16:56:02 +02:00
Pascal Fischer
1f9a829f2c [management] update log levels (#6266) 2026-05-27 11:43:49 +02:00
Bethuel Mmbaga
14af179556 [management] Refactor management server bootstrap (#6256) 2026-05-26 17:44:28 +03:00
Pascal Fischer
1fbb5e6d5d [management] fix owner role update (#6264) 2026-05-26 16:37:58 +02:00
Viktor Liu
6771e35d57 [client] Release js.FuncOf callbacks in wasm ssh and rdp to prevent leaks (#5982) 2026-05-26 14:32:39 +02:00
Viktor Liu
e89b1e0596 [proxy, client] Bound embed client WireGuard per-Device memory (#5962) 2026-05-26 11:51:53 +02:00
Philip Laine
d542c60e21 Refactor Linux system info to use syscalls (#6230) 2026-05-25 21:00:24 +02:00
Viktor Liu
4983b5cf17 [client] Match DNS wildcard handlers on label boundaries (#6255) 2026-05-25 18:38:48 +02:00
Viktor Liu
b3b0feb3b8 [client] Filter scoped/cloned default routes from BSD network monitor RTM_ADD (#6208) 2026-05-25 18:38:21 +02:00
Maycon Santos
7aebdd69dd [management, client, proxy] add expose NetBird-only services over tunnel peers (#6226)
Adds a new "private" service mode for the reverse proxy: services reachable exclusively over the embedded WireGuard tunnel, gated by per-peer group membership instead of operator auth schemes.

Wire contract
- ProxyMapping.private (field 13): the proxy MUST call ValidateTunnelPeer and fail closed; operator schemes are bypassed.
- ProxyCapabilities.private (4) + supports_private_service (5): capability gate. Management never streams private mappings to proxies that don't claim the capability; the broadcast path applies the same filter via filterMappingsForProxy.
- ValidateTunnelPeer RPC: resolves an inbound tunnel IP to a peer, checks the peer's groups against service.AccessGroups, and mints a session JWT on success. checkPeerGroupAccess fails closed when a private service has empty AccessGroups.
- ValidateSession/ValidateTunnelPeer responses now carry peer_group_ids + peer_group_names so the proxy can authorise policy-aware middlewares without an extra management round-trip.
- ProxyInboundListener + SendStatusUpdate.inbound_listener: per-account inbound listener state surfaced to dashboards.
- PathTargetOptions.direct_upstream (11): bypass the embedded NetBird client and dial the target via the proxy host's network stack for upstreams reachable without WireGuard.

Data model
- Service.Private (bool) + Service.AccessGroups ([]string, JSON- serialised). Validate() rejects bearer auth on private services. Copy() deep-copies AccessGroups. pgx getServices loads the columns.
- DomainConfig.Private threaded into the proxy auth middleware. Request handler routes private services through forwardWithTunnelPeer and returns 403 on validation failure.
- Account-level SynthesizePrivateServiceZones (synthetic DNS) and injectPrivateServicePolicies (synthetic ACL) gate on len(svc.AccessGroups) > 0.

Proxy
- /netbird proxy --private (embedded mode) flag; Config.Private in proxy/lifecycle.go.
- Per-account inbound listener (proxy/inbound.go) binding HTTP/HTTPS on the embedded NetBird client's WireGuard tunnel netstack.
- proxy/internal/auth/tunnel_cache: ValidateTunnelPeer response cache with single-flight de-duplication and per-account eviction.
- Local peerstore short-circuit: when the inbound IP isn't in the account roster, deny fast without an RPC.
- proxy/server.go reports SupportsPrivateService=true and redacts the full ProxyMapping JSON from info logs (auth_token + header-auth hashed values now only at debug level).

Identity forwarding
- ValidateSessionJWT returns user_id, email, method, groups, group_names. sessionkey.Claims carries Email + Groups + GroupNames so the proxy can stamp identity onto upstream requests without an extra management round-trip on every cookie-bearing request.
- CapturedData carries userEmail / userGroups / userGroupNames; the proxy stamps X-NetBird-User and X-NetBird-Groups on r.Out from the authenticated identity (strips client-supplied values first to prevent spoofing).
- AccessLog.UserGroups: access-log enrichment captures the user's group memberships at write time so the dashboard can render group context without reverse-resolving stale memberships.

OpenAPI/dashboard surface
- ReverseProxyService gains private + access_groups; ReverseProxyCluster gains private + supports_private. ReverseProxyTarget target_type enum gains "cluster". ServiceTargetOptions gains direct_upstream. ProxyAccessLog gains user_groups.
2026-05-25 17:41:50 +02:00
Viktor Liu
0358be2313 [client] Revert "Clean up legacy 32-bit and HKCU registry entries on Windows install (#6176)" (#6232)
This reverts commit d927ef468a.
2026-05-21 16:27:12 +02:00
264 changed files with 21585 additions and 9889 deletions

45
.github/dependabot.yml vendored Normal file
View File

@@ -0,0 +1,45 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "daily"
open-pull-requests-limit: 15
groups:
actions:
patterns:
- "*"
ignore:
# git-town/action v1.3.x crashes on cyclic PR graphs (self-loop main->main
# fork PRs) via its topological-sort visualization. Pinned to v1.2.1 in
# git-town.yml; block v1.3.x until upstream tolerates cyclic edges.
- dependency-name: "git-town/action"
update-types:
- "version-update:semver-minor"
- "version-update:semver-major"
- package-ecosystem: "gomod"
directories:
- "/"
schedule:
interval: "daily"
open-pull-requests-limit: 15
groups:
aws-sdk:
patterns:
- "github.com/aws/aws-sdk-go-v2/*"
pion:
patterns:
- "github.com/pion/*"
gorm:
patterns:
- "gorm.io/*"
otel:
patterns:
- "go.opentelemetry.io/*"
testcontainers:
patterns:
- "github.com/testcontainers/testcontainers-go/*"
wireguard:
patterns:
- "golang.zx2c4.com/wireguard*"

View File

@@ -2,16 +2,16 @@ name: Check License Dependencies
on:
push:
branches: [ main ]
branches: [main]
paths:
- 'go.mod'
- 'go.sum'
- '.github/workflows/check-license-dependencies.yml'
- "go.mod"
- "go.sum"
- ".github/workflows/check-license-dependencies.yml"
pull_request:
paths:
- 'go.mod'
- 'go.sum'
- '.github/workflows/check-license-dependencies.yml'
- "go.mod"
- "go.sum"
- ".github/workflows/check-license-dependencies.yml"
jobs:
check-internal-dependencies:
@@ -19,7 +19,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Check for problematic license dependencies
run: |
@@ -56,55 +59,57 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version-file: 'go.mod'
cache: true
- name: Set up Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: true
- name: Install go-licenses
run: go install github.com/google/go-licenses@v1.6.0
- name: Install go-licenses
run: go install github.com/google/go-licenses@v1.6.0
- name: Check for GPL/AGPL licensed dependencies
run: |
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
echo ""
# Check all Go packages for copyleft licenses, excluding internal netbird packages
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
if [ -n "$COPYLEFT_DEPS" ]; then
echo "Found copyleft licensed dependencies:"
echo "$COPYLEFT_DEPS"
- name: Check for GPL/AGPL licensed dependencies
run: |
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
echo ""
# Filter out dependencies that are only pulled in by internal AGPL packages
INCOMPATIBLE=""
while IFS=',' read -r package url license; do
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
# Find ALL packages that import this GPL package using go list
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check all Go packages for copyleft licenses, excluding internal netbird packages
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
else
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
fi
fi
done <<< "$COPYLEFT_DEPS"
if [ -n "$INCOMPATIBLE" ]; then
if [ -n "$COPYLEFT_DEPS" ]; then
echo "Found copyleft licensed dependencies:"
echo "$COPYLEFT_DEPS"
echo ""
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
echo -e "$INCOMPATIBLE"
exit 1
fi
fi
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
# Filter out dependencies that are only pulled in by internal AGPL packages
INCOMPATIBLE=""
while IFS=',' read -r package url license; do
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
# Find ALL packages that import this GPL package using go list
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
else
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
fi
fi
done <<< "$COPYLEFT_DEPS"
if [ -n "$INCOMPATIBLE" ]; then
echo ""
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
echo -e "$INCOMPATIBLE"
exit 1
fi
fi
echo "✅ All external license dependencies are compatible with BSD-3-Clause"

View File

@@ -83,7 +83,7 @@ jobs:
- name: Verify docs PR exists (and is open or merged)
if: steps.validate.outputs.mode == 'added'
uses: actions/github-script@v7
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
id: verify
with:
pr_number: ${{ steps.extract.outputs.pr_number }}

View File

@@ -8,11 +8,10 @@ jobs:
post:
runs-on: ubuntu-latest
steps:
- uses: roots/discourse-topic-github-release-action@main
- uses: roots/discourse-topic-github-release-action@557d74ea05b6cc0c47f555c1d5d28a89d904005b # v1.1.0
with:
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
discourse-base-url: https://forum.netbird.io
discourse-author-username: NetBird
discourse-category: 17
discourse-tags:
releases
discourse-tags: releases

View File

@@ -3,7 +3,7 @@ name: Git Town
on:
pull_request:
branches:
- '**'
- "**"
jobs:
git-town:
@@ -15,7 +15,9 @@ jobs:
pull-requests: write
steps:
- uses: actions/checkout@v4
- uses: git-town/action@v1.2.1
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- uses: git-town/action@3d8b878379abb1ee393fb49865a28b4a6c2cd3b0 # v1.2.1
with:
skip-single-stacks: true

View File

@@ -16,16 +16,18 @@ jobs:
runs-on: macos-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: ~/go/pkg/mod
key: macos-gotest-${{ hashFiles('**/go.sum') }}
@@ -44,4 +46,3 @@ jobs:
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)

View File

@@ -15,20 +15,31 @@ jobs:
name: "Client / Unit"
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Read Go version from go.mod
id: goversion
run: echo "version=$(awk '/^go / {print $2}' go.mod)" >> "$GITHUB_OUTPUT"
- name: Test in FreeBSD
id: test
uses: vmactions/freebsd-vm@v1
env:
GO_VERSION: ${{ steps.goversion.outputs.version }}
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
with:
usesh: true
copyback: false
release: "14.2"
release: "15.0"
envs: "GO_VERSION"
prepare: |
pkg install -y curl pkgconf xorg
GO_TARBALL="go1.25.3.freebsd-amd64.tar.gz"
GO_TARBALL="go${GO_VERSION}.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -vLO "$GO_URL"
tar -C /usr/local -vxzf "$GO_TARBALL"
tar -C /usr/local -vxzf "$GO_TARBALL"
# -x - to print all executed commands
# -e - to faile on first error

View File

@@ -18,9 +18,11 @@ jobs:
management: ${{ steps.filter.outputs.management }}
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- uses: dorny/paths-filter@v3
- uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
id: filter
with:
filters: |
@@ -28,7 +30,7 @@ jobs:
- 'management/**'
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -36,10 +38,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
id: cache
with:
path: |
@@ -113,14 +115,16 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
arch: ["386", "amd64"]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -128,10 +132,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -154,18 +158,20 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags "devcert integration" -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
test_client_on_docker:
name: "Client (Docker) / Unit"
needs: [ build-cache ]
needs: [build-cache]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -177,7 +183,7 @@ jobs:
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
- name: Cache Go modules
uses: actions/cache/restore@v4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
id: cache-restore
with:
path: |
@@ -214,7 +220,7 @@ jobs:
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
go test -buildvcs=false -tags "devcert integration" -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
'
test_relay:
@@ -231,10 +237,12 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -246,10 +254,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -277,14 +285,16 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
arch: ["386", "amd64"]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -298,7 +308,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -324,14 +334,16 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
arch: ["386", "amd64"]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -343,10 +355,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -370,19 +382,21 @@ jobs:
test_management:
name: "Management / Unit"
needs: [ build-cache ]
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
arch: ["amd64"]
store: ["sqlite", "postgres", "mysql"]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -390,10 +404,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -410,7 +424,7 @@ jobs:
- name: Login to Docker hub
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@v3
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -427,7 +441,7 @@ jobs:
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
@@ -437,13 +451,13 @@ jobs:
benchmark:
name: "Management / Benchmark"
needs: [ build-cache ]
needs: [build-cache]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
arch: ["amd64"]
store: ["sqlite", "postgres"]
runs-on: ubuntu-22.04
steps:
- name: Create Docker network
@@ -474,10 +488,12 @@ jobs:
prom/prometheus
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -485,10 +501,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -505,7 +521,7 @@ jobs:
- name: Login to Docker hub
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@v3
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -529,13 +545,13 @@ jobs:
api_benchmark:
name: "Management / Benchmark (API)"
needs: [ build-cache ]
needs: [build-cache]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
arch: ["amd64"]
store: ["sqlite", "postgres"]
runs-on: ubuntu-22.04
steps:
- name: Create Docker network
@@ -566,10 +582,12 @@ jobs:
prom/prometheus
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -577,10 +595,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -597,7 +615,7 @@ jobs:
- name: Login to Docker hub
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@v3
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -623,20 +641,22 @@ jobs:
api_integration_test:
name: "Management / Integration"
needs: [ build-cache ]
needs: [build-cache]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres']
arch: ["amd64"]
store: ["sqlite", "postgres"]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -644,10 +664,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}

View File

@@ -18,10 +18,12 @@ jobs:
runs-on: windows-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
id: go
with:
go-version-file: "go.mod"
@@ -33,7 +35,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -44,16 +46,15 @@ jobs:
${{ runner.os }}-go-
- name: Download wintun
uses: carlosperate/download-file-action@v2
id: download-wintun
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
with:
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
file-name: wintun.zip
location: ${{ env.downloadPath }}
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
destination: ${{ env.downloadPath }}\wintun.zip
sha256: 07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51
- name: Decompressing wintun files
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
run: tar -xvf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'

View File

@@ -15,9 +15,11 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: codespell
uses: codespell-project/actions-codespell@v2
uses: codespell-project/actions-codespell@8f01853be192eb0f849a5c7d721450e7a467c579 # v2.2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
skip: go.mod,go.sum,**/proxy/web/**
@@ -38,13 +40,15 @@ jobs:
timeout-minutes: 15
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest'
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -52,7 +56,7 @@ jobs:
if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: golangci-lint
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
with:
version: latest
skip-cache: true

View File

@@ -22,7 +22,9 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: run install script
env:

View File

@@ -16,23 +16,25 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
- name: Setup Android SDK
uses: android-actions/setup-android@v3
uses: android-actions/setup-android@40fd30fb8d7440372e1316f5d1809ec01dcd3699 # v4.0.1
with:
cmdline-tools-version: 8512546
- name: Setup Java
uses: actions/setup-java@v4
uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654
with:
java-version: "11"
distribution: "adopt"
- name: NDK Cache
id: ndk-cache
uses: actions/cache@v4
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620
@@ -52,9 +54,11 @@ jobs:
runs-on: macos-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
- name: install gomobile

View File

@@ -9,7 +9,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Validate PR title prefix
uses: actions/github-script@v7
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
script: |
const title = context.payload.pull_request.title;

View File

@@ -3,74 +3,92 @@ name: Proto Version Check
on:
pull_request:
paths:
- "**/*.proto"
- "**/*.pb.go"
- "**/generate.sh"
- "proto-tools.env"
- ".github/workflows/proto-version-check.yml"
jobs:
regenerate-and-diff:
name: Regenerate proto and verify no drift
check-proto-versions:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Load pinned proto toolchain versions
run: |
# shellcheck source=/dev/null
. ./proto-tools.env
{
echo "PROTOC_VERSION=${PROTOC_VERSION}"
echo "PROTOC_GEN_GO_VERSION=${PROTOC_GEN_GO_VERSION}"
echo "PROTOC_GEN_GO_GRPC_VERSION=${PROTOC_GEN_GO_GRPC_VERSION}"
} >> "$GITHUB_ENV"
- name: Setup Go
uses: actions/setup-go@v5
- name: Check for proto tool version changes
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
go-version-file: go.mod
script: |
const files = await github.paginate(github.rest.pulls.listFiles, {
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: context.issue.number,
per_page: 100,
});
- name: Setup protoc
uses: arduino/setup-protoc@f4d5893b897028ff5739576ea0409746887fa536 # v3.0.0
with:
version: ${{ env.PROTOC_VERSION }}
repo-token: ${{ secrets.GITHUB_TOKEN }}
const modifiedPbFiles = files.filter(
f => f.filename.endsWith('.pb.go') && f.status === 'modified'
);
if (modifiedPbFiles.length === 0) {
console.log('No modified .pb.go files to check');
return;
}
- name: Install protoc plugins
run: |
go install "google.golang.org/protobuf/cmd/protoc-gen-go@${PROTOC_GEN_GO_VERSION}"
go install "google.golang.org/grpc/cmd/protoc-gen-go-grpc@${PROTOC_GEN_GO_GRPC_VERSION}"
echo "$(go env GOPATH)/bin" >> "$GITHUB_PATH"
const versionPattern = /^\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
const baseSha = context.payload.pull_request.base.sha;
const headSha = context.payload.pull_request.head.sha;
- name: Verify protoc version matches pin
run: |
actual=$(protoc --version | awk '{print $2}')
if [[ "$actual" != "$PROTOC_VERSION" ]]; then
echo "::error::protoc $actual does not match pinned $PROTOC_VERSION"
exit 1
fi
async function getVersionHeader(path, ref) {
try {
const res = await github.rest.repos.getContent({
owner: context.repo.owner,
repo: context.repo.repo,
path,
ref,
});
if (!res.data.content) {
return { ok: false, reason: 'no inline content (file too large)' };
}
const content = Buffer.from(res.data.content, 'base64').toString('utf8');
const lines = content
.split('\n')
.slice(0, 20)
.filter(line => versionPattern.test(line));
return { ok: true, lines };
} catch (e) {
return { ok: false, reason: e.message };
}
}
- name: Regenerate all proto bindings
run: |
set -euo pipefail
for script in \
client/proto/generate.sh \
shared/signal/proto/generate.sh \
shared/management/proto/generate.sh \
flow/proto/generate.sh \
encryption/testprotos/generate.sh; do
echo "::group::$script"
bash "$script"
echo "::endgroup::"
done
const violations = [];
for (const file of modifiedPbFiles) {
const [base, head] = await Promise.all([
getVersionHeader(file.filename, baseSha),
getVersionHeader(file.filename, headSha),
]);
if (!base.ok || !head.ok) {
core.warning(
`Skipping ${file.filename}: base=${base.ok ? 'ok' : base.reason}, head=${head.ok ? 'ok' : head.reason}`
);
continue;
}
if (base.lines.join('\n') !== head.lines.join('\n')) {
violations.push({
file: file.filename,
base: base.lines,
head: head.lines,
});
}
}
- name: Fail if regeneration changed any tracked or untracked file
run: |
if [[ -n "$(git status --porcelain --untracked-files=all)" ]]; then
echo "::error::Generated proto files drift from .proto sources or pinned tool versions."
echo "Run the generate.sh scripts locally with the toolchain in proto-tools.env and commit the result."
git status --short
exit 1
fi
if (violations.length > 0) {
const details = violations.map(v =>
`${v.file}:\n` +
` base:\n${v.base.map(l => ' ' + l).join('\n') || ' (none)'}\n` +
` head:\n${v.head.map(l => ' ' + l).join('\n') || ' (none)'}`
).join('\n\n');
core.setFailed(
`Proto version strings changed in generated files.\n` +
`This usually means the wrong protoc or protoc-gen-go version was used.\n` +
`Regenerate with the matching tool versions.\n\n` +
details
);
return;
}
console.log('No proto version string changes detected');

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.4"
SIGN_PIPE_VER: "v0.1.5"
GORELEASER_VER: "v2.14.3"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
@@ -24,7 +24,9 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Generate FreeBSD port diff
run: bash release_files/freebsd-port-diff.sh
@@ -51,19 +53,26 @@ jobs:
echo "Generated files for version: $VERSION"
cat netbird-*.diff
- name: Read Go version from go.mod
id: goversion
run: echo "version=$(awk '/^go / {print $2}' go.mod)" >> "$GITHUB_OUTPUT"
- name: Test FreeBSD port
if: steps.check_diff.outputs.diff_exists == 'true'
uses: vmactions/freebsd-vm@v1
env:
GO_VERSION: ${{ steps.goversion.outputs.version }}
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
with:
usesh: true
copyback: false
release: "15.0"
envs: "GO_VERSION"
prepare: |
# Install required packages
pkg install -y git curl portlint go
pkg install -y git curl portlint
# Install Go for building
GO_TARBALL="go1.25.5.freebsd-amd64.tar.gz"
GO_TARBALL="go${GO_VERSION}.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -LO "$GO_URL"
tar -C /usr/local -xzf "$GO_TARBALL"
@@ -93,19 +102,19 @@ jobs:
# Show patched Makefile
version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
cd /usr/ports/security/netbird
export BATCH=yes
make package
pkg add ./work/pkg/netbird-*.pkg
netbird version | grep "$version"
echo "FreeBSD port test completed successfully!"
- name: Upload FreeBSD port files
if: steps.check_diff.outputs.diff_exists == 'true'
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
with:
name: freebsd-port-files
path: |
@@ -124,26 +133,25 @@ jobs:
env:
flags: ""
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
persist-credentials: false
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
~/go/pkg/mod
@@ -156,18 +164,18 @@ jobs:
- name: check git status
run: git --no-pager diff --exit-code
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a #v4.0.0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
- name: Login to Docker hub
if: github.event_name != 'pull_request'
uses: docker/login-action@v1
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Log in to the GitHub container registry
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
uses: docker/login-action@v3
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -191,7 +199,7 @@ jobs:
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
- name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@v4
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
with:
version: ${{ env.GORELEASER_VER }}
args: release --clean ${{ env.flags }}
@@ -282,28 +290,28 @@ jobs:
} >> "$GITHUB_OUTPUT"
- name: upload non tags for debug purposes
id: upload_release
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
with:
name: release
path: dist/
retention-days: 7
- name: upload linux packages
id: upload_linux_packages
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
with:
name: linux-packages
path: dist/netbird_linux**
retention-days: 7
- name: upload windows packages
id: upload_windows_packages
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
with:
name: windows-packages
path: dist/netbird_windows**
retention-days: 7
- name: upload macos packages
id: upload_macos_packages
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
with:
name: macos-packages
path: dist/netbird_darwin**
@@ -314,27 +322,26 @@ jobs:
outputs:
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
persist-credentials: false
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
~/go/pkg/mod
@@ -375,7 +382,7 @@ jobs:
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
@@ -404,7 +411,7 @@ jobs:
run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: upload non tags for debug purposes
id: upload_release_ui
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
with:
name: release-ui
path: dist/
@@ -418,16 +425,17 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
persist-credentials: false
- name: Set up Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
~/go/pkg/mod
@@ -441,7 +449,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@v4
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
@@ -449,7 +457,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: upload non tags for debug purposes
id: upload_release_ui_darwin
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
with:
name: release-ui-darwin
path: dist/
@@ -474,27 +482,26 @@ jobs:
PackageWorkdir: netbird_windows_${{ matrix.arch }}
downloadPath: '${{ github.workspace }}\temp'
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- name: Checkout
uses: actions/checkout@v4
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
- name: Add 7-Zip to PATH
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- name: Download release artifacts
uses: actions/download-artifact@v4
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
with:
name: release
path: release
- name: Download UI release artifacts
uses: actions/download-artifact@v4
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
with:
name: release-ui
path: release-ui
@@ -514,29 +521,27 @@ jobs:
Get-ChildItem $workdir
- name: Download wintun
uses: carlosperate/download-file-action@v2
id: download-wintun
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
with:
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
file-name: wintun.zip
location: ${{ env.downloadPath }}
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
destination: ${{ env.downloadPath }}\wintun.zip
sha256: 07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51
- name: Decompress wintun files
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
run: tar -xvf "${{ env.downloadPath }}\wintun.zip" -C ${{ env.downloadPath }}
- name: Move wintun.dll into dist
run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
- name: Download Mesa3D (amd64 only)
uses: carlosperate/download-file-action@v2
id: download-mesa3d
if: matrix.arch == 'amd64'
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
with:
file-url: https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z
file-name: mesa3d.7z
location: ${{ env.downloadPath }}
sha256: '71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9'
url: https://pkgs.netbird.io/mesa3d/MesaForWindows-x64-20.1.8.7z
destination: ${{ env.downloadPath }}\mesa3d.7z
sha256: 71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9
- name: Extract Mesa3D driver (amd64 only)
if: matrix.arch == 'amd64'
@@ -547,35 +552,38 @@ jobs:
run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
- name: Download EnVar plugin for NSIS
uses: carlosperate/download-file-action@v2
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
with:
file-url: https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip
file-name: envar_plugin.zip
location: ${{ github.workspace }}
url: https://pkgs.netbird.io/nsis/EnVar_plugin.zip
destination: ${{ github.workspace }}\envar_plugin.zip
sha256: e9aa92de351345ed82795251d838f1ae9041ba35af9d381a5780c7843b01f56a
- name: Extract EnVar plugin
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/envar_plugin.zip"
- name: Download ShellExecAsUser plugin for NSIS (amd64 only)
uses: carlosperate/download-file-action@v2
if: matrix.arch == 'amd64'
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
with:
file-url: https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z
file-name: ShellExecAsUser_amd64-Unicode.7z
location: ${{ github.workspace }}
url: https://pkgs.netbird.io/nsis/ShellExecAsUser_amd64-Unicode.7z
destination: ${{ github.workspace }}\ShellExecAsUser_amd64-Unicode.7z
sha256: 0a55ea25c7330a92cec028eda8afcaf1b1a7092e0dfb77c21c8f654564b4ff9d
- name: Extract ShellExecAsUser plugin (amd64 only)
if: matrix.arch == 'amd64'
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z"
- name: Build NSIS installer
uses: joncloud/makensis-action@v3.3
with:
additional-plugin-paths: ${{ github.workspace }}/NSIS_Plugins/Plugins
script-file: client/installer.nsis
arguments: "/V4 /DARCH=${{ matrix.arch }}"
shell: pwsh
env:
APPVER: ${{ steps.semver_parser.outputs.major }}.${{ steps.semver_parser.outputs.minor }}.${{ steps.semver_parser.outputs.patch }}.${{ github.run_id }}
run: |
$nsisPluginDir = "C:\Program Files (x86)\NSIS\Plugins\x86-unicode"
$srcPlugins = "${{ github.workspace }}\NSIS_Plugins\Plugins"
Get-ChildItem -Path $srcPlugins -Recurse -Filter *.dll |
Copy-Item -Destination $nsisPluginDir -Force
& "C:\Program Files (x86)\NSIS\makensis.exe" /V4 "/DARCH=${{ matrix.arch }}" client\installer.nsis
if ($LASTEXITCODE -ne 0) { throw "makensis failed with exit code $LASTEXITCODE" }
- name: Rename NSIS installer
run: mv netbird-installer.exe netbird_installer_test_windows_${{ matrix.arch }}.exe
@@ -592,7 +600,7 @@ jobs:
- name: Upload installer artifacts
if: always()
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
with:
name: windows-installer-test-${{ matrix.arch }}
path: |
@@ -611,7 +619,7 @@ jobs:
pull-requests: write
steps:
- name: Create or update PR comment
uses: actions/github-script@v7
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
env:
RELEASE_RESULT: ${{ needs.release.result }}
RELEASE_UI_RESULT: ${{ needs.release_ui.result }}
@@ -703,7 +711,7 @@ jobs:
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Trigger binaries sign pipelines
uses: benc-uk/workflow-dispatch@v1
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
with:
workflow: Sign bin and installer
repo: netbirdio/sign-pipelines

View File

@@ -14,9 +14,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Trigger main branch sync
uses: benc-uk/workflow-dispatch@v1
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
with:
workflow: sync-main.yml
repo: ${{ secrets.UPSTREAM_REPO }}
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "sha": "${{ github.sha }}" }'
inputs: '{ "sha": "${{ github.sha }}" }'

View File

@@ -3,7 +3,7 @@ name: sync tag
on:
push:
tags:
- 'v*'
- "v*"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
@@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Trigger release tag sync
uses: benc-uk/workflow-dispatch@v1
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
with:
workflow: sync-tag.yml
ref: main
@@ -29,7 +29,7 @@ jobs:
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
steps:
- name: Trigger android-client submodule bump
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
with:
workflow: bump-netbird.yml
ref: main
@@ -42,10 +42,10 @@ jobs:
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
steps:
- name: Trigger ios-client submodule bump
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
with:
workflow: bump-netbird.yml
ref: main
repo: netbirdio/ios-client
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }'
inputs: '{ "tag": "${{ github.ref_name }}" }'

View File

@@ -6,10 +6,10 @@ on:
- main
pull_request:
paths:
- 'infrastructure_files/**'
- '.github/workflows/test-infrastructure-files.yml'
- 'management/cmd/**'
- 'signal/cmd/**'
- "infrastructure_files/**"
- ".github/workflows/test-infrastructure-files.yml"
- "management/cmd/**"
- "signal/cmd/**"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
@@ -20,7 +20,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
store: [ 'sqlite', 'postgres', 'mysql' ]
store: ["sqlite", "postgres", "mysql"]
services:
postgres:
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
@@ -68,15 +68,17 @@ jobs:
run: sudo apt-get install -y curl
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
- name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -139,8 +141,8 @@ jobs:
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_SIGNAL_PORT: 12345
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
NETBIRD_STORE_ENGINE_POSTGRES_DSN: "${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$"
NETBIRD_STORE_ENGINE_MYSQL_DSN: "${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$"
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
@@ -254,7 +256,9 @@ jobs:
run: sudo apt-get install -y jq
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: run script with Zitadel PostgreSQL
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh

View File

@@ -3,9 +3,9 @@ name: update docs
on:
push:
tags:
- 'v*'
- "v*"
paths:
- 'shared/management/http/api/openapi.yml'
- "shared/management/http/api/openapi.yml"
jobs:
trigger_docs_api_update:
@@ -13,10 +13,10 @@ jobs:
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Trigger API pages generation
uses: benc-uk/workflow-dispatch@v1
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
with:
workflow: generate api pages
repo: netbirdio/docs
ref: "refs/heads/main"
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}" }'
inputs: '{ "tag": "${{ github.ref }}" }'

View File

@@ -19,15 +19,17 @@ jobs:
GOARCH: wasm
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Install golangci-lint
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
with:
version: latest
install-mode: binary
@@ -42,9 +44,11 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
- name: Build Wasm client
@@ -65,4 +69,3 @@ jobs:
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
exit 1
fi

View File

@@ -11,7 +11,7 @@ import (
"go.opentelemetry.io/otel"
"google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
nbcache "github.com/netbirdio/netbird/management/server/cache"
@@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
t.Fatal(err)
}
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
iv, _ := validator.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
require.NoError(t, err)

View File

@@ -12,6 +12,7 @@ import (
"sync"
"github.com/sirupsen/logrus"
wgdevice "golang.zx2c4.com/wireguard/device"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface"
@@ -84,6 +85,12 @@ type Options struct {
DisableIPv6 bool
// BlockInbound blocks all inbound connections from peers
BlockInbound bool
// BlockLANAccess blocks the embedded peer from reaching the host's
// LAN (RFC 1918, link-local, loopback) when it's used as a routing
// peer. Mirrors profilemanager.ConfigInput.BlockLANAccess. Useful
// when the embedded client must never act as a stepping stone into
// the host's local network (e.g. the proxy's overlay peer).
BlockLANAccess bool
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
WireguardPort *int
// MTU is the MTU for the tunnel interface.
@@ -94,6 +101,26 @@ type Options struct {
MTU *uint16
// DNSLabels defines additional DNS labels configured in the peer.
DNSLabels []string
// Performance configures the tunnel's buffer pool cap and batch size.
Performance Performance
}
// Performance configures the embedded client's tunnel memory/throughput knobs.
//
// These settings are process-global: any non-nil field also becomes the
// default for Clients constructed by later embed.New calls in the same
// process. Nil fields are ignored.
type Performance struct {
// PreallocatedBuffersPerPool caps the per-tunnel buffer pool. Zero
// leaves the pool unbounded. Lower values trade throughput for a
// tighter memory ceiling. May also be changed on a running Client via
// Client.SetPerformance, provided this field was nonzero at construction.
PreallocatedBuffersPerPool *uint32
// MaxBatchSize overrides the number of packets the tunnel reads or
// writes per syscall, which also bounds eager buffer allocation per
// worker. Zero uses the platform default. Applied at construction
// only; ignored by Client.SetPerformance.
MaxBatchSize *uint32
}
// validateCredentials checks that exactly one credential type is provided
@@ -175,6 +202,7 @@ func New(opts Options) (*Client, error) {
DisableClientRoutes: &opts.DisableClientRoutes,
DisableIPv6: &opts.DisableIPv6,
BlockInbound: &opts.BlockInbound,
BlockLANAccess: &opts.BlockLANAccess,
WireguardPort: opts.WireguardPort,
MTU: opts.MTU,
DNSLabels: parsedLabels,
@@ -192,6 +220,13 @@ func New(opts Options) (*Client, error) {
config.PrivateKey = opts.PrivateKey
}
if opts.Performance.PreallocatedBuffersPerPool != nil {
wgdevice.SetPreallocatedBuffersPerPool(*opts.Performance.PreallocatedBuffersPerPool)
}
if opts.Performance.MaxBatchSize != nil {
wgdevice.SetMaxBatchSizeOverride(*opts.Performance.MaxBatchSize)
}
return &Client{
deviceName: opts.DeviceName,
setupKey: opts.SetupKey,
@@ -405,6 +440,21 @@ func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession,
}, nil
}
// IdentityForIP looks up a remote peer by its tunnel IP using the
// embedded client's status recorder. Returns the peer's WireGuard public
// key and FQDN. ok=false means the IP isn't in this client's peer
// roster — callers should treat that as "unknown peer".
func (c *Client) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
if !ip.IsValid() || c.recorder == nil {
return "", "", false
}
state, found := c.recorder.PeerStateByIP(ip.String())
if !found {
return "", "", false
}
return state.PubKey, state.FQDN, true
}
// Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock()
@@ -473,6 +523,25 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
}
// SetPerformance retunes a running Client. Only PreallocatedBuffersPerPool
// takes effect, and only when it was nonzero at construction;
// MaxBatchSize is construction-only and returns an error if set here.
//
// Returns ErrClientNotStarted / ErrEngineNotStarted if the Client is not
// running yet.
func (c *Client) SetPerformance(t Performance) error {
if t.MaxBatchSize != nil {
return errors.New("MaxBatchSize is construction-only and cannot be changed at runtime")
}
engine, err := c.getEngine()
if err != nil {
return err
}
return engine.SetPerformance(internal.Performance{
PreallocatedBuffersPerPool: t.PreallocatedBuffersPerPool,
})
}
// StartCapture begins capturing packets on this client's tunnel device.
// Only one capture can be active at a time; starting a new one stops the previous.
// Call StopCapture (or CaptureSession.Stop) to end it.

View File

@@ -21,7 +21,7 @@ func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.
}
// use userspace packet filtering firewall
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
fm, err := uspfilter.Create(iface, nil, disableServerRoutes, flowLogger, mtu)
if err != nil {
return nil, err
}

View File

@@ -29,47 +29,80 @@ const (
NFTABLES
)
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// SkipNftablesEnv is the environment variable to skip nftables check
const SkipNftablesEnv = "NB_SKIP_NFTABLES_CHECK"
// errNoFirewallManager indicates no kernel firewall backend is present,
// as opposed to a backend that exists but failed to create or initialize.
var errNoFirewallManager = errors.New("no firewall manager found")
// FWType is the type for the firewall type
type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
// We run in userspace mode and force userspace firewall was requested. We don't attempt native firewall.
// We run in userspace mode and force userspace firewall was requested.
if iface.IsUserspaceBind() && forceUserspaceFirewall() {
nativeFw, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding without it", err)
}
log.Info("forcing userspace firewall")
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
return createUserspaceFirewall(iface, nativeFw, disableServerRoutes, flowLogger, mtu)
}
// Use native firewall for either kernel or userspace, the interface appears identical to netfilter
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
// Kernel cannot fall back to anything else, need to return error
if !iface.IsUserspaceBind() {
return fm, err
}
// Fall back to the userspace packet filter if native is unavailable
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
switch {
case err == nil && !iface.IsUserspaceBind():
// Nothing to do, fall through
case err == nil && iface.IsUserspaceBind():
// Native firewall handles packet filtering, but the userspace WireGuard bind
// needs a device filter for DNS interception hooks. Install a minimal
// hooks-only filter that passes all traffic through to the kernel firewall.
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
}
case err != nil && !iface.IsUserspaceBind():
// Kernel cannot fall back to anything else, need to return error
return nil, err
case err != nil && iface.IsUserspaceBind():
// Fall back to the userspace packet filter if native is unavailable
logNativeFirewallUnavailable(err)
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
}
// Native firewall handles packet filtering, but the userspace WireGuard bind
// needs a device filter for DNS interception hooks. Install a minimal
// hooks-only filter that passes all traffic through to the kernel firewall.
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
}
return fm, nil
}
// createUserspaceFirewall builds the userspace packet filter, optionally
// backed by a native firewall, and allows netbird interface traffic.
func createUserspaceFirewall(iface IFaceMapper, nativeFw firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) {
fm, err := uspfilter.Create(iface, nativeFw, disableServerRoutes, flowLogger, mtu)
if err != nil {
return nil, err
}
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
return fm, nil
}
// logNativeFirewallUnavailable logs the fallback to userspace at info level
// when no kernel firewall backend exists, and at warn level otherwise.
func logNativeFirewallUnavailable(err error) {
if errors.Is(err, errNoFirewallManager) {
log.Infof("no native firewall backend available: %v. Proceeding with userspace", err)
} else {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
}
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
fm, err := createFW(iface, mtu)
if err != nil {
return nil, fmt.Errorf("create firewall: %s", err)
return nil, fmt.Errorf("create firewall: %w", err)
}
if err = fm.Init(stateManager); err != nil {
@@ -88,29 +121,10 @@ func createFW(iface IFaceMapper, mtu uint16) (firewall.Manager, error) {
log.Info("creating an nftables firewall manager")
return nbnftables.Create(iface, mtu)
default:
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
return nil, errors.New("no firewall manager found")
return nil, errNoFirewallManager
}
}
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) {
var errUsp error
if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
} else {
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
}
if errUsp != nil {
return nil, fmt.Errorf("create userspace firewall: %s", errUsp)
}
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
return fm, nil
}
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func check() FWType {
useIPTABLES := false
@@ -132,35 +146,38 @@ func check() FWType {
}
}
nf := nftables.Conn{}
if chains, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
if !useIPTABLES {
return NFTABLES
}
// search for chains where table is filter
// if we find one, we assume that nftables manager can be used with iptables
for _, chain := range chains {
if chain.Table.Name == "filter" {
// Honor the skip env before probing nftables at all.
if os.Getenv(SkipNftablesEnv) != "true" {
nf := nftables.Conn{}
if chains, err := nf.ListChains(); err == nil {
if !useIPTABLES {
return NFTABLES
}
}
// check tables for the following constraints:
// 1. there is no chain in nftables for the filter table and there is at least one chain in iptables, we assume that nftables manager can not be used
// 2. there is no tables or more than one table, we assume that nftables manager can be used
// 3. there is only one table and its name is filter, we assume that nftables manager can not be used, since there was no chain in it
// 4. if we find an error we log and continue with iptables check
nbTablesList, err := nf.ListTables()
switch {
case err == nil && len(iptablesChains) > 0:
return IPTABLES
case err == nil && len(nbTablesList) != 1:
return NFTABLES
case err == nil && len(nbTablesList) == 1 && nbTablesList[0].Name == "filter":
return IPTABLES
case err != nil:
log.Errorf("failed to list nftables tables on fw manager discovery: %s", err)
// search for chains where table is filter
// if we find one, we assume that nftables manager can be used with iptables
for _, chain := range chains {
if chain.Table.Name == "filter" {
return NFTABLES
}
}
// check tables for the following constraints:
// 1. there is no chain in nftables for the filter table and there is at least one chain in iptables, we assume that nftables manager can not be used
// 2. there is no tables or more than one table, we assume that nftables manager can be used
// 3. there is only one table and its name is filter, we assume that nftables manager can not be used, since there was no chain in it
// 4. if we find an error we log and continue with iptables check
nbTablesList, err := nf.ListTables()
switch {
case err == nil && len(iptablesChains) > 0:
return IPTABLES
case err == nil && len(nbTablesList) != 1:
return NFTABLES
case err == nil && len(nbTablesList) == 1 && nbTablesList[0].Name == "filter":
return IPTABLES
case err != nil:
log.Errorf("failed to list nftables tables on fw manager discovery: %s", err)
}
}
}

View File

@@ -1,554 +0,0 @@
package iptables
import (
"errors"
"fmt"
"net"
"slices"
"github.com/coreos/go-iptables/iptables"
"github.com/google/uuid"
ipset "github.com/lrh3321/ipset-go"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (
tableName = "filter"
// rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT"
// mangleFwdKey is the entries map key for mangle FORWARD guard rules that prevent
// external DNAT from bypassing ACL rules.
mangleFwdKey = "MANGLE-FORWARD"
)
type aclEntries map[string][][]string
type entry struct {
spec []string
position int
}
type aclManager struct {
iptablesClient *iptables.IPTables
wgIface iFaceMapper
entries aclEntries
optionalEntries map[string][]entry
ipsetStore *ipsetStore
v6 bool
stateManager *statemanager.Manager
}
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
return &aclManager{
iptablesClient: iptablesClient,
wgIface: wgIface,
entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(),
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
}, nil
}
func (m *aclManager) init(stateManager *statemanager.Manager) error {
m.stateManager = stateManager
m.seedInitialEntries()
m.seedInitialOptionalEntries()
if err := m.cleanChains(); err != nil {
return fmt.Errorf("clean chains: %w", err)
}
if err := m.createDefaultChains(); err != nil {
return fmt.Errorf("create default chains: %w", err)
}
m.updateState()
return nil
}
func (m *aclManager) AddPeerFiltering(
id []byte,
ip net.IP,
protocol firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
chain := chainNameInputRules
ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
if m.v6 && ipsetName != "" {
ipsetName += "-v6"
}
proto := protoForFamily(protocol, m.v6)
specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName)
mangleSpecs := slices.Clone(specs)
mangleSpecs = append(mangleSpecs,
"-i", m.wgIface.Name(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
)
specs = append(specs, "-j", actionToStr(action))
if ipsetName != "" {
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
if err := m.addToIPSet(ipsetName, ip); err != nil {
return nil, fmt.Errorf("add IP to ipset: %w", err)
}
// if ruleset already exists it means we already have the firewall rule
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
ipList.addIP(ip.String())
return []firewall.Rule{&Rule{
ruleID: uuid.New().String(),
ipsetName: ipsetName,
ip: ip.String(),
chain: chain,
specs: specs,
v6: m.v6,
}}, nil
}
if err := m.flushIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("flush ipset %s before use: %v", ipsetName, err)
} else {
log.Errorf("flush ipset %s before use: %v", ipsetName, err)
}
}
if err := m.createIPSet(ipsetName); err != nil {
return nil, fmt.Errorf("create ipset: %w", err)
}
if err := m.addToIPSet(ipsetName, ip); err != nil {
return nil, fmt.Errorf("add IP to ipset: %w", err)
}
ipList := newIpList(ip.String())
m.ipsetStore.addIpList(ipsetName, ipList)
}
ok, err := m.iptablesClient.Exists(tableFilter, chain, specs...)
if err != nil {
return nil, fmt.Errorf("failed to check rule: %w", err)
}
if ok {
return nil, fmt.Errorf("rule already exists")
}
// Insert DROP rules at the beginning, append ACCEPT rules at the end
if action == firewall.ActionDrop {
// Insert at the beginning of the chain (position 1)
err = m.iptablesClient.Insert(tableFilter, chain, 1, specs...)
} else {
err = m.iptablesClient.Append(tableFilter, chain, specs...)
}
if err != nil {
return nil, err
}
if err := m.iptablesClient.Append(tableMangle, chainRTPRE, mangleSpecs...); err != nil {
log.Errorf("failed to add mangle rule: %v", err)
mangleSpecs = nil
}
rule := &Rule{
ruleID: uuid.New().String(),
specs: specs,
mangleSpecs: mangleSpecs,
ipsetName: ipsetName,
ip: ip.String(),
chain: chain,
v6: m.v6,
}
m.updateState()
return []firewall.Rule{rule}, nil
}
// DeletePeerRule from the firewall by rule definition
func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
}
shouldDestroyIpset := false
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
// delete IP from ruleset IPs list and ipset
if _, ok := ipsetList.ips[r.ip]; ok {
ip := net.ParseIP(r.ip)
if ip == nil {
return fmt.Errorf("parse IP %s", r.ip)
}
if err := m.delFromIPSet(r.ipsetName, ip); err != nil {
return fmt.Errorf("delete ip from ipset: %w", err)
}
delete(ipsetList.ips, r.ip)
}
// if after delete, set still contains other IPs,
// no need to delete firewall rule and we should exit here
if len(ipsetList.ips) != 0 {
return nil
}
// we delete last IP from the set, that means we need to delete
// set itself and associated firewall rule too
m.ipsetStore.deleteIpset(r.ipsetName)
shouldDestroyIpset = true
}
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
}
if r.mangleSpecs != nil {
if err := m.iptablesClient.Delete(tableMangle, chainRTPRE, r.mangleSpecs...); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
if shouldDestroyIpset {
if err := m.destroyIPSet(r.ipsetName); err != nil {
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("destroy empty ipset: %v", err)
} else {
log.Errorf("destroy empty ipset: %v", err)
}
}
}
m.updateState()
return nil
}
func (m *aclManager) Reset() error {
if err := m.cleanChains(); err != nil {
return fmt.Errorf("clean chains: %w", err)
}
m.updateState()
return nil
}
// todo write less destructive cleanup mechanism
func (m *aclManager) cleanChains() error {
ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules)
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
}
if ok {
for _, rule := range m.entries["INPUT"] {
err := m.iptablesClient.DeleteIfExists(tableName, "INPUT", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
for _, rule := range m.entries["FORWARD"] {
err := m.iptablesClient.DeleteIfExists(tableName, "FORWARD", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameInputRules)
if err != nil {
log.Debugf("failed to clear and delete %s chain: %s", chainNameInputRules, err)
return err
}
}
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
if err != nil {
return fmt.Errorf("list chains: %w", err)
}
if ok {
for _, rule := range m.entries["PREROUTING"] {
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
}
for _, rule := range m.entries[mangleFwdKey] {
if err := m.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil {
log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err)
}
}
for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := m.flushIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("flush ipset %q during reset: %v", ipsetName, err)
} else {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
}
}
if err := m.destroyIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("destroy ipset %q during reset: %v", ipsetName, err)
} else {
log.Errorf("destroy ipset %q during reset: %v", ipsetName, err)
}
}
m.ipsetStore.deleteIpset(ipsetName)
}
return nil
}
func (m *aclManager) createDefaultChains() error {
// chain netbird-acl-input-rules
if err := m.iptablesClient.NewChain(tableName, chainNameInputRules); err != nil {
log.Debugf("failed to create '%s' chain: %s", chainNameInputRules, err)
return err
}
for chainName, rules := range m.entries {
// mangle FORWARD guard rules are handled separately below
if chainName == mangleFwdKey {
continue
}
for _, rule := range rules {
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
return err
}
}
}
for chainName, entries := range m.optionalEntries {
for _, entry := range entries {
if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil {
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
continue
}
m.entries[chainName] = append(m.entries[chainName], entry.spec)
}
}
clear(m.optionalEntries)
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
for _, rule := range m.entries[mangleFwdKey] {
if err := m.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil {
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
}
}
return nil
}
// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed.
// We want to make sure our traffic is not dropped by existing rules.
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
func (m *aclManager) seedInitialEntries() {
established := getConntrackEstablished()
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
// Inbound is handled by our ACLs, the rest is dropped.
// For outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules.
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
// Mangle FORWARD guard: when external DNAT redirects traffic from the wg interface, it
// traverses FORWARD instead of INPUT, bypassing ACL rules. ACCEPT rules in filter FORWARD
// can be inserted above ours. Mangle runs before filter, so these guard rules enforce the
// ACL mark check where it cannot be overridden.
m.appendToEntries(mangleFwdKey, []string{
"-i", m.wgIface.Name(),
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
"-j", "ACCEPT",
})
m.appendToEntries(mangleFwdKey, []string{
"-i", m.wgIface.Name(),
"-m", "conntrack", "--ctstate", "DNAT",
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
"-j", "DROP",
})
}
func (m *aclManager) seedInitialOptionalEntries() {
m.optionalEntries["FORWARD"] = []entry{
{
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
position: 2,
},
}
}
func (m *aclManager) appendToEntries(chainName string, spec []string) {
m.entries[chainName] = append(m.entries[chainName], spec)
}
func (m *aclManager) updateState() {
if m.stateManager == nil {
return
}
var currentState *ShutdownState
if existing := m.stateManager.GetState(currentState); existing != nil {
if existingState, ok := existing.(*ShutdownState); ok {
currentState = existingState
}
}
if currentState == nil {
currentState = &ShutdownState{}
}
currentState.Lock()
defer currentState.Unlock()
if m.v6 {
currentState.ACLEntries6 = m.entries
currentState.ACLIPsetStore6 = m.ipsetStore
} else {
currentState.ACLEntries = m.entries
currentState.ACLIPsetStore = m.ipsetStore
}
if err := m.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
// filterRuleSpecs returns the specs of a filtering rule
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
if v6 && protocol == firewall.ProtocolICMP {
return "ipv6-icmp"
}
return string(protocol)
}
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
// don't use IP matching if IP is 0.0.0.0
matchByIP := !ip.IsUnspecified()
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--match-set", ipsetName, "src")
} else {
specs = append(specs, "-s", ip.String())
}
}
if protocol != "all" {
specs = append(specs, "-p", protocol)
}
specs = append(specs, applyPort("--sport", sPort)...)
specs = append(specs, applyPort("--dport", dPort)...)
return specs
}
func actionToStr(action firewall.Action) string {
if action == firewall.ActionAccept {
return "ACCEPT"
}
return "DROP"
}
func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action firewall.Action) string {
if ipsetName == "" {
return ""
}
actionSuffix := ""
if action == firewall.ActionDrop {
actionSuffix = "-drop"
}
switch {
case sPort != nil && dPort != nil:
return ipsetName + "-sport-dport" + actionSuffix
case sPort != nil:
return ipsetName + "-sport" + actionSuffix
case dPort != nil:
return ipsetName + "-dport" + actionSuffix
default:
return ipsetName + actionSuffix
}
}
func (m *aclManager) createIPSet(name string) error {
opts := ipset.CreateOptions{
Replace: true,
}
if m.v6 {
opts.Family = ipset.FamilyIPV6
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err)
}
log.Debugf("created ipset %s with type hash:net", name)
return nil
}
func (m *aclManager) addToIPSet(name string, ip net.IP) error {
cidr := uint8(32)
if ip.To4() == nil {
cidr = 128
}
entry := &ipset.Entry{
IP: ip,
CIDR: cidr,
Replace: true,
}
if err := ipset.Add(name, entry); err != nil {
return fmt.Errorf("add IP to ipset %s: %w", name, err)
}
return nil
}
func (m *aclManager) delFromIPSet(name string, ip net.IP) error {
cidr := uint8(32)
if ip.To4() == nil {
cidr = 128
}
entry := &ipset.Entry{
IP: ip,
CIDR: cidr,
}
if err := ipset.Del(name, entry); err != nil {
return fmt.Errorf("delete IP from ipset %s: %w", name, err)
}
return nil
}
func (m *aclManager) flushIPSet(name string) error {
return ipset.Flush(name)
}
func (m *aclManager) destroyIPSet(name string) error {
return ipset.Destroy(name)
}

View File

@@ -0,0 +1,352 @@
//go:build !android
package iptables
import (
"fmt"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/client/net"
)
func (r *family) createContainers() error {
for _, chainInfo := range []struct {
chain string
table string
}{
{chainRTFwdIn, tableFilter},
{chainRTFwdOut, tableFilter},
{chainRTPre, tableMangle},
{chainRTNAT, tableNat},
{chainRTRdr, tableNat},
{chainRTMSSClamp, tableMangle},
} {
// Fallback: clear chains that survived an unclean shutdown.
if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok {
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err)
}
}
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
}
}
if err := r.insertEstablishedRule(chainRTFwdIn); err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
if err := r.insertEstablishedRule(chainRTFwdOut); err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add static nat rules: %w", err)
}
if err := r.addJumpRules(); err != nil {
return fmt.Errorf("add jump rules: %w", err)
}
if err := r.addMSSClampingRules(); err != nil {
log.Errorf("failed to add MSS clamping rules: %s", err)
}
return nil
}
func (r *family) addJumpRules() error {
// Jump to nat chain
natRule := []string{"-j", chainRTNAT}
if err := r.iptablesClient.Insert(tableNat, chainPostrouting, 1, natRule...); err != nil {
return fmt.Errorf("add nat postrouting jump rule: %w", err)
}
r.rules[jumpNATPost] = natRule
// Jump to mangle prerouting chain
preRule := []string{"-j", chainRTPre}
if err := r.iptablesClient.Insert(tableMangle, chainPrerouting, 1, preRule...); err != nil {
return fmt.Errorf("add mangle prerouting jump rule: %w", err)
}
r.rules[jumpManglePre] = preRule
// Jump to nat prerouting chain
rdrRule := []string{"-j", chainRTRdr}
if err := r.iptablesClient.Insert(tableNat, chainPrerouting, 1, rdrRule...); err != nil {
return fmt.Errorf("add nat prerouting jump rule: %w", err)
}
r.rules[jumpNATPre] = rdrRule
return nil
}
func (r *family) setupDataPlaneMark() error {
var merr *multierror.Error
preRule := []string{
"-i", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "NEW",
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkIn),
}
if err := r.iptablesClient.AppendUnique(tableMangle, chainPrerouting, preRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add mangle prerouting rule: %w", err))
} else {
r.rules[markManglePre] = preRule
}
postRule := []string{
"-o", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "NEW",
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkOut),
}
if err := r.iptablesClient.AppendUnique(tableMangle, chainPostrouting, postRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add mangle postrouting rule: %w", err))
} else {
r.rules[markManglePost] = postRule
}
return nberrors.FormatErrorOrNil(merr)
}
// seedInitialEntries adds default rules to the entries map. Rules are
// inserted at position 1, so the order here is reversed.
//
// Existing FORWARD policy decides outbound traffic towards our
// interface. If FORWARD policy is "drop", we add an
// established/related rule to allow return traffic for inbound rules.
func (r *family) seedInitialEntries() {
established := getConntrackEstablished()
r.appendToEntries(chainInput, []string{"-i", r.wgIface.Name(), "-j", "DROP"})
r.appendToEntries(chainInput, []string{"-i", r.wgIface.Name(), "-j", chainACLInput})
r.appendToEntries(chainInput, append([]string{"-i", r.wgIface.Name()}, established...))
r.appendToEntries(chainForward, []string{"-i", r.wgIface.Name(), "-j", "DROP"})
r.appendToEntries(chainForward, []string{"-o", r.wgIface.Name(), "-j", chainRTFwdOut})
r.appendToEntries(chainForward, []string{"-i", r.wgIface.Name(), "-j", chainRTFwdIn})
// Mangle FORWARD guard: when external DNAT redirects traffic from
// the wg interface, it traverses FORWARD instead of INPUT,
// bypassing ACL rules. ACCEPT rules in filter FORWARD can be
// inserted above ours. Mangle runs before filter, so these guard
// rules enforce the ACL mark check where it cannot be overridden.
r.appendToEntries(mangleForwardKey, []string{
"-i", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
"-j", "ACCEPT",
})
r.appendToEntries(mangleForwardKey, []string{
"-i", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "DNAT",
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
"-j", "DROP",
})
}
func (r *family) seedInitialOptionalEntries() {
r.optionalEntries[chainForward] = []entry{
{
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
position: 2,
},
}
}
func (r *family) appendToEntries(chain chainKey, spec ruleSpec) {
r.entries[chain] = append(r.entries[chain], spec)
}
func (r *family) createDefaultChains() error {
if err := r.iptablesClient.NewChain(tableName, chainACLInput); err != nil {
return fmt.Errorf("create %s chain: %w", chainACLInput, err)
}
for chain, rules := range r.entries {
// mangle FORWARD guard rules are handled separately below
if chain == mangleForwardKey {
continue
}
for _, rule := range rules {
if err := r.iptablesClient.InsertUnique(tableName, string(chain), 1, rule...); err != nil {
return fmt.Errorf("insert jump rule into %s: %w", chain, err)
}
}
}
for chain, entries := range r.optionalEntries {
for _, entry := range entries {
if err := r.iptablesClient.InsertUnique(tableName, string(chain), entry.position, entry.spec...); err != nil {
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
continue
}
r.entries[chain] = append(r.entries[chain], entry.spec)
}
}
clear(r.optionalEntries)
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
for _, rule := range r.entries[mangleForwardKey] {
if err := r.iptablesClient.AppendUnique(tableMangle, chainForward, rule...); err != nil {
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
}
}
return nil
}
func (r *family) cleanUpDefaultForwardRules() error {
var merr *multierror.Error
// cleanJumpRules removes the OUTPUT jump to NETBIRD-NAT-OUTPUT among
// the others, so the chain below deletes cleanly instead of failing
// with "device or resource busy".
if err := r.cleanJumpRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("clean jump rules: %w", err))
}
for _, chainInfo := range []struct {
chain string
table string
}{
{chainRTFwdIn, tableFilter},
{chainRTFwdOut, tableFilter},
{chainRTPre, tableMangle},
{chainRTNAT, tableNat},
{chainRTRdr, tableNat},
{chainNATOutput, tableNat},
{chainRTMSSClamp, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err))
continue
}
if ok {
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
merr = multierror.Append(merr, fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err))
}
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) cleanJumpRules() error {
// locations maps each tracked jump rule to the built-in table and
// chain it was inserted into.
locations := map[firewall.RuleID]struct{ table, chain string }{
jumpNATPost: {tableNat, chainPostrouting},
jumpManglePre: {tableMangle, chainPrerouting},
jumpNATPre: {tableNat, chainPrerouting},
jumpMSSClamp: {tableMangle, chainForward},
jumpNATOutput: {tableNat, chainOutput},
}
var merr *multierror.Error
for ruleID, loc := range locations {
rule, exists := r.rules[ruleID]
if !exists {
continue
}
if err := r.iptablesClient.DeleteIfExists(loc.table, loc.chain, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete rule from chain %s in table %s: %w", loc.chain, loc.table, err))
continue
}
delete(r.rules, ruleID)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) cleanAclChains() error {
var merr *multierror.Error
if err := r.cleanInputAclChain(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.cleanPreroutingEntries(); err != nil {
merr = multierror.Append(merr, err)
}
for _, rule := range r.entries[mangleForwardKey] {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainForward, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete mangle %s guard rule %v: %w", chainForward, rule, err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) cleanInputAclChain() error {
ok, err := r.iptablesClient.ChainExists(tableName, chainACLInput)
if err != nil {
return fmt.Errorf("check chain %s: %w", chainACLInput, err)
}
if !ok {
return nil
}
var merr *multierror.Error
for _, rule := range r.entries[chainInput] {
if err := r.iptablesClient.DeleteIfExists(tableName, chainInput, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete %s rule %v: %w", chainInput, rule, err))
}
}
for _, rule := range r.entries[chainForward] {
if err := r.iptablesClient.DeleteIfExists(tableName, chainForward, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete %s rule %v: %w", chainForward, rule, err))
}
}
if err := r.iptablesClient.ClearAndDeleteChain(tableName, chainACLInput); err != nil {
merr = multierror.Append(merr, fmt.Errorf("clear and delete %s chain: %w", chainACLInput, err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) cleanPreroutingEntries() error {
ok, err := r.iptablesClient.ChainExists(tableMangle, chainPrerouting)
if err != nil {
return fmt.Errorf("check chain %s in %s: %w", chainPrerouting, tableMangle, err)
}
if !ok {
return nil
}
var merr *multierror.Error
for _, rule := range r.entries[chainPrerouting] {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPrerouting, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete %s rule %v: %w", chainPrerouting, rule, err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) cleanupDataPlaneMark() error {
var merr *multierror.Error
if preRule, exists := r.rules[markManglePre]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPrerouting, preRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove mangle prerouting rule: %w", err))
} else {
delete(r.rules, markManglePre)
}
}
if postRule, exists := r.rules[markManglePost]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPostrouting, postRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove mangle postrouting rule: %w", err))
} else {
delete(r.rules, markManglePost)
}
}
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -0,0 +1,269 @@
//go:build !android
package iptables
import (
"fmt"
"net/netip"
"strconv"
"strings"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
func (r *family) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
ruleID := rule.ID()
if _, exists := r.rules[ruleID+dnatSuffix]; exists {
return rule, nil
}
toDestination := rule.TranslatedAddress.String()
switch {
case len(rule.TranslatedPort.Values) == 0:
// no translated port, use original port
case len(rule.TranslatedPort.Values) == 1:
toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0])
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
// need the "/originalport" suffix to avoid dnat port randomization
toDestination += fmt.Sprintf(":%d-%d/%d", rule.TranslatedPort.Values[0], rule.TranslatedPort.Values[1], rule.DestinationPort.Values[0])
default:
return nil, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
}
proto := strings.ToLower(string(rule.Protocol))
rules := make(map[firewall.RuleID]ruleInfo, 3)
// DNAT rule
dnatRule := []string{
"!", "-i", r.wgIface.Name(),
"-p", proto,
"-j", "DNAT",
"--to-destination", toDestination,
}
dnatRule = append(dnatRule, applyPort("--dport", &rule.DestinationPort)...)
rules[ruleID+dnatSuffix] = ruleInfo{
table: tableNat,
chain: chainRTRdr,
rule: dnatRule,
}
// SNAT rule
snatRule := []string{
"-o", r.wgIface.Name(),
"-p", proto,
"-d", rule.TranslatedAddress.String(),
"-j", "MASQUERADE",
}
snatRule = append(snatRule, applyPort("--dport", &rule.TranslatedPort)...)
rules[ruleID+snatSuffix] = ruleInfo{
table: tableNat,
chain: chainRTNAT,
rule: snatRule,
}
// Forward filtering rule, if fwd policy is DROP
forwardRule := []string{
"-o", r.wgIface.Name(),
"-p", proto,
"-d", rule.TranslatedAddress.String(),
"-j", "ACCEPT",
}
forwardRule = append(forwardRule, applyPort("--dport", &rule.TranslatedPort)...)
rules[ruleID+fwdSuffix] = ruleInfo{
table: tableFilter,
chain: chainRTFwdOut,
rule: forwardRule,
}
for key, ruleInfo := range rules {
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
log.Errorf("rollback failed: %v", rollbackErr)
}
return nil, fmt.Errorf("add rule %s: %w", key, err)
}
r.rules[key] = ruleInfo.rule
}
r.updateState()
return rule, nil
}
func (r *family) rollbackRules(rules map[firewall.RuleID]ruleInfo) error {
var merr *multierror.Error
for key, ruleInfo := range rules {
if err := r.iptablesClient.DeleteIfExists(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("rollback rule %s: %w", key, err))
// On rollback error, add to rules map for next cleanup
r.rules[key] = ruleInfo.rule
}
}
if merr != nil {
r.updateState()
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) DeleteDNATRule(rule firewall.Rule) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
ruleID := rule.ID()
var merr *multierror.Error
if dnatRule, exists := r.rules[ruleID+dnatSuffix]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRdr, dnatRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete DNAT rule: %w", err))
}
delete(r.rules, ruleID+dnatSuffix)
}
if snatRule, exists := r.rules[ruleID+snatSuffix]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTNAT, snatRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete SNAT rule: %w", err))
}
delete(r.rules, ruleID+snatSuffix)
}
if fwdRule, exists := r.rules[ruleID+fwdSuffix]; exists {
if err := r.iptablesClient.Delete(tableFilter, chainRTFwdOut, fwdRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
}
delete(r.rules, ruleID+fwdSuffix)
}
r.updateState()
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if _, exists := r.rules[ruleID]; exists {
return nil
}
dnatRule := []string{
"-i", r.wgIface.Name(),
"-p", strings.ToLower(protoForFamily(protocol, r.v6)),
"--dport", strconv.Itoa(int(originalPort)),
"-d", localAddr.String(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
}
info := ruleInfo{
table: tableNat,
chain: chainRTRdr,
rule: dnatRule,
}
if err := r.iptablesClient.Append(info.table, info.chain, info.rule...); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = info.rule
r.updateState()
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *family) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRdr, dnatRule...); err != nil {
return fmt.Errorf("delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
r.updateState()
return nil
}
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
func (r *family) ensureNATOutputChain() error {
if _, exists := r.rules[jumpNATOutput]; exists {
return nil
}
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
if err != nil {
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
}
if !chainExists {
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
}
}
jumpRule := []string{"-j", chainNATOutput}
if err := r.iptablesClient.Insert(tableNat, chainOutput, 1, jumpRule...); err != nil {
if !chainExists {
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
}
}
return fmt.Errorf("add OUTPUT jump rule: %w", err)
}
r.rules[jumpNATOutput] = jumpRule
r.updateState()
return nil
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *family) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if _, exists := r.rules[ruleID]; exists {
return nil
}
if err := r.ensureNATOutputChain(); err != nil {
return err
}
dnatRule := []string{
"-p", strings.ToLower(protoForFamily(protocol, localAddr.Is6())),
"--dport", strconv.Itoa(int(originalPort)),
"-d", localAddr.String(),
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
}
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
return fmt.Errorf("add output DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
r.updateState()
return nil
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *family) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
return fmt.Errorf("delete output DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
r.updateState()
return nil
}

View File

@@ -0,0 +1,246 @@
//go:build !android
package iptables
import (
"fmt"
"net/netip"
"github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// constants needed to manage and create iptable rules
const (
tableFilter = "filter"
tableName = tableFilter
tableNat = "nat"
tableMangle = "mangle"
// chainACLInput is the peer ACL chain that holds installed
// peer-filtering rules.
chainACLInput = "NETBIRD-ACL-INPUT"
// mangleForwardKey is the entries map key for mangle FORWARD guard
// rules that prevent external DNAT from bypassing ACL rules.
mangleForwardKey chainKey = "MANGLE-FORWARD"
chainInput = "INPUT"
chainPostrouting = "POSTROUTING"
chainPrerouting = "PREROUTING"
chainForward = "FORWARD"
chainRTNAT = "NETBIRD-RT-NAT"
chainRTFwdIn = "NETBIRD-RT-FWD-IN"
chainRTFwdOut = "NETBIRD-RT-FWD-OUT"
chainRTPre = "NETBIRD-RT-PRE"
chainRTRdr = "NETBIRD-RT-RDR"
chainNATOutput = "NETBIRD-NAT-OUTPUT"
chainRTMSSClamp = "NETBIRD-RT-MSSCLAMP"
jumpManglePre = "jump-mangle-pre"
jumpNATPre = "jump-nat-pre"
jumpNATPost = "jump-nat-post"
jumpNATOutput = "jump-nat-output"
jumpMSSClamp = "jump-mss-clamp"
markManglePre = "mark-mangle-pre"
markManglePost = "mark-mangle-post"
matchSet = "--match-set"
dnatSuffix firewall.RuleID = "_dnat"
snatSuffix firewall.RuleID = "_snat"
fwdSuffix firewall.RuleID = "_fwd"
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
ipv4TCPHeaderSize = 40
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
ipv6TCPHeaderSize = 60
)
type ruleInfo struct {
chain string
table string
rule []string
}
type routeRules map[firewall.RuleID][]string
// ruleSpec is a single iptables rule expressed as its argument list
// (e.g. {"-i", "wg0", "-j", "DROP"}).
type ruleSpec []string
// chainKey identifies the chain a seeded entry belongs to. It holds
// built-in chain names ("INPUT", "FORWARD", "PREROUTING") plus the
// synthetic mangleForwardKey bucket for the mangle FORWARD guard rules.
type chainKey string
// aclEntries maps a chain to the rules seeded into it to jump into or
// guard the netbird ACL chains.
type aclEntries map[chainKey][]ruleSpec
type entry struct {
spec ruleSpec
position int
}
// ipsetCounter is the shared hash:net refcounter used by peer and
// route ACLs alike. The ipset library does not support comments, so
// the key is just the set name (string).
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
// family holds the per-address-family iptables state. One instance
// handles route ACLs, peer ACLs, NAT, DNAT, and MSS clamping for a
// single family; the top-level Manager owns one for v4 and another
// for v6.
type family struct {
iptablesClient *iptables.IPTables
wgIface iFaceMapper
v6 bool
// Peer ACL chain bookkeeping.
entries aclEntries
optionalEntries map[chainKey][]entry
// filters holds peer + route filter rules keyed by content hash.
// AddFilterRule writes here; DeleteFilterRule looks up by id.
filters map[nbid.RuleID]*Rule
ipsetCounter *ipsetCounter
// rules holds NAT, jump, and MSS-clamping rules (auxiliary
// plumbing that isn't a filter rule).
rules routeRules
// Routing / NAT.
legacyManagement bool
mtu uint16
ipFwdState *ipfwdstate.IPForwardingState
stateManager *statemanager.Manager
}
func newFamily(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*family, error) {
r := &family{
iptablesClient: iptablesClient,
wgIface: wgIface,
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
entries: make(aclEntries),
optionalEntries: make(map[chainKey][]entry),
filters: make(map[nbid.RuleID]*Rule),
rules: make(routeRules),
mtu: mtu,
ipFwdState: ipfwdstate.NewIPForwardingState(),
}
r.ipsetCounter = refcounter.New(
func(name string, sources []netip.Prefix) (struct{}, error) {
return struct{}{}, r.createIpSet(name, sources)
},
func(name string, _ struct{}) error {
return r.deleteIpSet(name)
},
)
return r, nil
}
// init wires the family to the state manager and installs both the
// route ACL containers and the peer ACL chain skeleton.
func (r *family) init(stateManager *statemanager.Manager) error {
r.stateManager = stateManager
if err := r.cleanUpDefaultForwardRules(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
if err := r.createContainers(); err != nil {
return fmt.Errorf("create containers: %w", err)
}
if err := r.setupDataPlaneMark(); err != nil {
log.Errorf("failed to set up data plane mark: %v", err)
}
r.seedInitialEntries()
r.seedInitialOptionalEntries()
if err := r.cleanAclChains(); err != nil {
return fmt.Errorf("clean acl chains: %w", err)
}
if err := r.createDefaultChains(); err != nil {
return fmt.Errorf("create default chains: %w", err)
}
r.updateState()
return nil
}
// Reset tears down all firewall state owned by this family. ACL
// chain cleanup runs before route-chain cleanup because the route
// chains are still referenced by FORWARD jumps installed during
// seedInitialEntries; deleting them first would trip EBUSY.
func (r *family) Reset() error {
var merr *multierror.Error
if err := r.cleanAclChains(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.cleanUpDefaultForwardRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.ipsetCounter.Flush(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.cleanupDataPlaneMark(); err != nil {
merr = multierror.Append(merr, err)
}
clear(r.rules)
clear(r.filters)
r.updateState()
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) updateState() {
if r.stateManager == nil {
return
}
var currentState *ShutdownState
if existing := r.stateManager.GetState(currentState); existing != nil {
if existingState, ok := existing.(*ShutdownState); ok {
currentState = existingState
}
}
if currentState == nil {
currentState = &ShutdownState{}
}
currentState.Lock()
defer currentState.Unlock()
if r.v6 {
currentState.RouteRules6 = r.rules
currentState.RouteIPsetCounter6 = r.ipsetCounter
currentState.ACLEntries6 = r.entries
} else {
currentState.RouteRules = r.rules
currentState.RouteIPsetCounter = r.ipsetCounter
currentState.ACLEntries = r.entries
}
if err := r.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
}
}

View File

@@ -0,0 +1,334 @@
//go:build !android
package iptables
import (
"fmt"
"net/netip"
"slices"
"strconv"
"strings"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
nbnet "github.com/netbirdio/netbird/client/net"
)
// AddFilterRule installs a packet-filtering rule. With destination
// empty, the rule goes to the peer ACL input chain plus a paired
// mangle PREROUTING rule for the redirect mark. With destination set
// (prefix or named set), it goes to the route ACL forward chain.
// Multi-source rules collapse to one iptables rule via the shared
// hash:net ipset.
func (r *family) AddFilterRule(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
if existing, ok := r.filters[ruleID]; ok {
return existing, nil
}
srcMatch, err := r.applySourceMatch(sourceNetwork(sources), sources)
if err != nil {
return nil, fmt.Errorf("apply source match: %w", err)
}
rule, err := r.installFilterRule(ruleID, srcMatch, destination, proto, sPort, dPort, action)
if err != nil {
r.dropSourceMatch(srcMatch)
return nil, err
}
r.filters[ruleID] = rule
r.updateState()
return rule, nil
}
func (r *family) hasRule(id nbid.RuleID) bool {
_, ok := r.filters[id]
return ok
}
// hasDNATRule reports whether this family owns the DNAT rule set for
// the given user id. DNAT rules live in r.rules under the well-known
// "<id>_dnat" key; the lookup here is used by Manager.DeleteDNATRule
// to pick the right family.
func (r *family) hasDNATRule(id firewall.RuleID) bool {
_, ok := r.rules[id+dnatSuffix]
return ok
}
// DeleteFilterRule removes a previously installed filter rule. The
// rule's stored chain/table identify where to delete from; source set
// references are recovered from the spec via findSets and dropped
// from the shared ipset counter.
func (r *family) DeleteFilterRule(rule firewall.Rule) error {
ruleID := rule.ID()
pr, ok := r.filters[ruleID]
if !ok {
log.Debugf("filter rule %s not found", ruleID)
return nil
}
if err := r.iptablesClient.Delete(tableFilter, pr.chain, pr.specs...); err != nil {
return fmt.Errorf("delete rule %s: %w", pr.chain, err)
}
if pr.mangleSpecs != nil {
if err := r.iptablesClient.Delete(tableMangle, chainRTPre, pr.mangleSpecs...); err != nil {
log.Errorf("delete mangle rule: %v", err)
}
}
r.dropSourceMatch(pr.specs)
delete(r.filters, ruleID)
r.updateState()
return nil
}
// findSets scans an iptables rule spec for "-m set --match-set <name>
// <dir>" fragments and returns the named sets in occurrence order.
// Used at delete time to drop ipsetCounter references.
func findSets(rule []string) []string {
var sets []string
for i, arg := range rule {
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
sets = append(sets, rule[i+3])
}
}
return sets
}
// sourceNetwork classifies a source-prefix list into the firewall.Network
// shape the rest of the spec-builder consumes: empty for match-any, a
// single prefix inline, or an ipset for multiple sources.
func sourceNetwork(sources []netip.Prefix) firewall.Network {
switch {
case len(sources) == 0:
return firewall.Network{}
case len(sources) == 1 && sources[0].Bits() == 0:
return firewall.Network{}
case len(sources) == 1:
return firewall.Network{Prefix: sources[0]}
default:
return firewall.Network{Set: firewall.NewPrefixSet(sources)}
}
}
// applySourceMatch returns the iptables match fragment for the rule's
// source. For a Set it increments the shared ipset's refcount; for a
// Prefix it emits a direct -s match; for the wildcard it returns nil.
func (r *family) applySourceMatch(network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
switch {
case network.IsSet():
if r.ipsetCounter == nil {
return nil, fmt.Errorf("multi-source peer rule requires shared ipset counter")
}
name := r.ipsetName(network.Set.HashedName())
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
return nil, fmt.Errorf("ipset increment %s: %w", name, err)
}
return []string{"-m", "set", matchSet, name, "src"}, nil
case network.IsPrefix():
return []string{"-s", network.Prefix.String()}, nil
default:
return nil, nil
}
}
// dropSourceMatch undoes whatever applySourceMatch reserved. Safe to
// call when the spec is empty or holds only inline matchers. Decrement
// errors are logged but not returned: the filter rule has already been
// deleted at that point and we don't want to leak the deletion.
func (r *family) dropSourceMatch(srcMatch []string) {
if r.ipsetCounter == nil {
return
}
for _, name := range findSets(srcMatch) {
if _, err := r.ipsetCounter.Decrement(name); err != nil {
log.Errorf("rollback ipset decrement %s: %v", name, err)
}
}
}
// decrementSetCounter drops ipset references owned by a raw rule spec
// stored in r.rules (NAT / legacy route entries). It returns an error
// aggregate so the caller surfaces decrement failures.
func (r *family) decrementSetCounter(rule []string) error {
if r.ipsetCounter == nil {
return nil
}
var merr *multierror.Error
for _, name := range findSets(rule) {
if _, err := r.ipsetCounter.Decrement(name); err != nil {
merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// installFilterRule assembles and writes one iptables filter-chain
// rule. With destination empty the rule lands in the peer ACL input
// chain and a paired mangle PREROUTING rule is added for the redirect
// mark. With destination set the rule lands in the route ACL forward
// chain and there is no mangle pairing.
func (r *family) installFilterRule(
ruleID nbid.RuleID,
srcMatch []string,
destination firewall.Network,
protocol firewall.Protocol,
sPort, dPort *firewall.Port,
action firewall.Action,
) (*Rule, error) {
isRoute := !destination.IsZero()
proto := protoForFamily(protocol, r.v6)
specs := slices.Clone(srcMatch)
var destExp []string
if isRoute {
var err error
destExp, err = r.applyNetwork("-d", destination, nil)
if err != nil {
return nil, fmt.Errorf("apply network -d: %w", err)
}
specs = append(specs, destExp...)
}
specs = append(specs, filterMatchSpecs(proto, sPort, dPort)...)
var mangleSpecs []string
if !isRoute {
mangleSpecs = slices.Clone(specs)
mangleSpecs = append(mangleSpecs,
"-i", r.wgIface.Name(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
)
}
specs = append(specs, "-j", actionToStr(action))
chain := chainACLInput
if isRoute {
chain = chainRTFwdIn
}
// Peer ACL drops are inserted at position 1 so they precede the
// chain's catch-all; route ACL drops are inserted at position 2
// to sit immediately after the established/related accept rule.
var err error
if action == firewall.ActionDrop {
pos := 1
if isRoute {
pos = 2
}
err = r.iptablesClient.Insert(tableFilter, chain, pos, specs...)
} else {
err = r.iptablesClient.Append(tableFilter, chain, specs...)
}
if err != nil {
r.dropSourceMatch(destExp)
return nil, fmt.Errorf("install filter rule on %s: %w", chain, err)
}
// The mangle redirect-mark rule is best effort: the filter rule itself
// is what enforces the ACL, so a mangle failure must not undo it. Drop
// the spec so teardown does not try to remove a rule that was not added.
if mangleSpecs != nil {
if err := r.iptablesClient.Append(tableMangle, chainRTPre, mangleSpecs...); err != nil {
log.Errorf("add mangle rule: %v", err)
mangleSpecs = nil
}
}
return &Rule{
id: ruleID,
specs: specs,
mangleSpecs: mangleSpecs,
chain: chain,
v6: r.v6,
}, nil
}
// applyNetwork resolves a firewall.Network into the iptables match
// fragment for the given direction flag (-s or -d). Set networks
// increment the shared ipset refcount; prefixes emit a direct match;
// an empty network returns no spec ("match any").
func (r *family) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
direction := "src"
if flag == "-d" {
direction = "dst"
}
if network.IsSet() {
name := r.ipsetName(network.Set.HashedName())
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
return []string{"-m", "set", matchSet, name, direction}, nil
}
if network.IsPrefix() {
return []string{flag, network.Prefix.String()}, nil
}
// nolint:nilnil
return nil, nil
}
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
if v6 && protocol == firewall.ProtocolICMP {
return "ipv6-icmp"
}
return string(protocol)
}
// filterMatchSpecs returns the proto/port match fragment for a
// filtering rule. The source match (-s or -m set) is built by the
// caller and prepended.
func filterMatchSpecs(protocol string, sPort, dPort *firewall.Port) (specs []string) {
if protocol != "all" {
specs = append(specs, "-p", protocol)
}
specs = append(specs, applyPort("--sport", sPort)...)
specs = append(specs, applyPort("--dport", dPort)...)
return specs
}
func actionToStr(action firewall.Action) string {
if action == firewall.ActionAccept {
return "ACCEPT"
}
return "DROP"
}
func applyPort(flag string, port *firewall.Port) []string {
if port == nil {
return nil
}
if port.IsRange && len(port.Values) == 2 {
return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])}
}
if len(port.Values) > 1 {
portList := make([]string, len(port.Values))
for i, p := range port.Values {
portList[i] = strconv.Itoa(int(p))
}
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
}
return []string{flag, strconv.Itoa(int(port.Values[0]))}
}

View File

@@ -0,0 +1,104 @@
//go:build !android
package iptables
import (
"fmt"
"net/netip"
"github.com/hashicorp/go-multierror"
"github.com/lrh3321/ipset-go"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
func (r *family) createIpSet(setName string, sources []netip.Prefix) error {
if err := r.createIPSet(setName); err != nil {
return fmt.Errorf("create set %s: %w", setName, err)
}
for _, prefix := range sources {
if err := r.addPrefixToIPSet(setName, prefix); err != nil {
// The refcounter records nothing when this callback errors,
// so destroy the set or it leaks in the kernel. A partial
// source set would also fail-open for deny rules, so the
// rule must fail rather than install with a missing source.
if derr := r.destroyIPSet(setName); derr != nil {
log.Warnf("rollback ipset %s after add failure: %v", setName, derr)
}
return fmt.Errorf("add element to set %s: %w", setName, err)
}
}
return nil
}
func (r *family) deleteIpSet(setName string) error {
if err := r.destroyIPSet(setName); err != nil {
return fmt.Errorf("destroy set %s: %w", setName, err)
}
log.Debugf("deleted unused ipset %s", setName)
return nil
}
func (r *family) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
name := r.ipsetName(set.HashedName())
var merr *multierror.Error
for _, prefix := range prefixes {
if err := r.addPrefixToIPSet(name, prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
}
}
if merr == nil {
log.Debugf("updated set %s with prefixes %v", name, prefixes)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) ipsetName(name string) string {
if r.v6 {
return name + "-v6"
}
return name
}
func (r *family) createIPSet(name string) error {
opts := ipset.CreateOptions{
Replace: true,
}
if r.v6 {
opts.Family = ipset.FamilyIPV6
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err)
}
log.Debugf("created ipset %s with type hash:net", name)
return nil
}
func (r *family) addPrefixToIPSet(name string, prefix netip.Prefix) error {
addr := prefix.Addr()
ip := addr.AsSlice()
entry := &ipset.Entry{
IP: ip,
CIDR: uint8(prefix.Bits()),
Replace: true,
}
if err := ipset.Add(name, entry); err != nil {
return fmt.Errorf("add prefix to ipset %s: %w", name, err)
}
return nil
}
func (r *family) destroyIPSet(name string) error {
return ipset.Destroy(name)
}

View File

@@ -3,7 +3,6 @@ package iptables
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
@@ -18,25 +17,21 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
type resetter interface {
Reset() error
}
// Manager of iptables firewall
// Manager of iptables firewall. Per-family state (peer ACLs, route
// ACLs, NAT, DNAT, MSS clamping) lives on family; Manager dispatches
// by family and provides the public firewall.Manager surface.
type Manager struct {
mutex sync.Mutex
wgIface iFaceMapper
ipv4Client *iptables.IPTables
aclMgr *aclManager
router *router
family4 *family
rawSupported bool
// IPv6 counterparts, nil when no v6 overlay
ipv6Client *iptables.IPTables
aclMgr6 *aclManager
router6 *router
family6 *family
}
// iFaceMapper defines subset methods of interface required for manager
@@ -57,14 +52,9 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
ipv4Client: iptablesClient,
}
m.router, err = newRouter(iptablesClient, wgIface, mtu)
m.family4, err = newFamily(iptablesClient, wgIface, mtu)
if err != nil {
return nil, fmt.Errorf("create router: %w", err)
}
m.aclMgr, err = newAclManager(iptablesClient, wgIface)
if err != nil {
return nil, fmt.Errorf("create acl manager: %w", err)
return nil, fmt.Errorf("create family: %w", err)
}
if wgIface.Address().HasIPv6() {
@@ -81,21 +71,18 @@ func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error {
if err != nil {
return fmt.Errorf("init ip6tables: %w", err)
}
m.ipv6Client = ip6Client
m.router6, err = newRouter(ip6Client, wgIface, mtu)
family6, err := newFamily(ip6Client, wgIface, mtu)
if err != nil {
return fmt.Errorf("create v6 router: %w", err)
return fmt.Errorf("create v6 family: %w", err)
}
// Share the same IP forwarding state with the v4 router, since
// Share the same IP forwarding state with the v4 family, since
// EnableIPForwarding controls both v4 and v6 sysctls.
m.router6.ipFwdState = m.router.ipFwdState
family6.ipFwdState = m.family4.ipFwdState
m.aclMgr6, err = newAclManager(ip6Client, wgIface)
if err != nil {
return fmt.Errorf("create v6 acl manager: %w", err)
}
m.ipv6Client = ip6Client
m.family6 = family6
return nil
}
@@ -109,7 +96,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
MTU: m.router.mtu,
MTU: m.family4.mtu,
},
}
stateManager.RegisterState(state)
@@ -141,31 +128,24 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
return nil
}
// initChains initializes router and ACL chains for both address families,
// rolling back on failure.
// initChains initializes the per-family firewall state for both
// address families, rolling back on failure.
func (m *Manager) initChains(stateManager *statemanager.Manager) error {
type initStep struct {
name string
init func(*statemanager.Manager) error
mgr resetter
r *family
}
steps := []initStep{
{"router", m.router.init, m.router},
{"acl manager", m.aclMgr.init, m.aclMgr},
}
steps := []initStep{{"v4", m.family4}}
if m.hasIPv6() {
steps = append(steps,
initStep{"v6 router", m.router6.init, m.router6},
initStep{"v6 acl manager", m.aclMgr6.init, m.aclMgr6},
)
steps = append(steps, initStep{"v6", m.family6})
}
var initialized []initStep
for _, s := range steps {
if err := s.init(stateManager); err != nil {
if err := s.r.init(stateManager); err != nil {
for i := len(initialized) - 1; i >= 0; i-- {
if rerr := initialized[i].mgr.Reset(); rerr != nil {
if rerr := initialized[i].r.Reset(); rerr != nil {
log.Warnf("rollback %s: %v", initialized[i].name, rerr)
}
}
@@ -176,84 +156,50 @@ func (m *Manager) initChains(stateManager *statemanager.Manager) error {
return nil
}
// AddPeerFiltering adds a rule to the firewall
//
// Comment will be ignored because some system this feature is not supported
func (m *Manager) AddPeerFiltering(
id []byte,
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
if ip.To4() != nil {
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
if !m.hasIPv6() {
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
}
return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
func (m *Manager) AddRouteFiltering(
// AddFilterRule installs a packet-filtering rule. See firewall.Manager
// docs for destination semantics. Sources are a single address family;
// the rule is dispatched to the matching v4 / v6 backend.
func (m *Manager) AddFilterRule(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
if len(sources) == 0 {
return nil, firewall.ErrNoSources
}
m.mutex.Lock()
defer m.mutex.Unlock()
if isIPv6RouteRule(sources, destination) {
fam := m.family4
if isIPv6Rule(sources, destination) {
if !m.hasIPv6() {
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
return nil, fmt.Errorf("add filtering: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
fam = m.family6
}
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
return fam.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
}
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
return len(sources) > 0 && sources[0].Addr().Is6()
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
// DeleteFilterRule removes a rule previously added via AddFilterRule.
// The rule is looked up by id in each family's filter cache.
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && isIPv6IptRule(rule) {
return m.aclMgr6.DeletePeerRule(rule)
id := rule.ID()
if m.family4.hasRule(id) {
return m.family4.DeleteFilterRule(rule)
}
return m.aclMgr.DeletePeerRule(rule)
}
func isIPv6IptRule(rule firewall.Rule) bool {
r, ok := rule.(*Rule)
return ok && r.v6
}
// DeleteRouteRule deletes a routing rule.
// Route rules are keyed by content hash. Check v4 first, try v6 if not found.
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
return m.router6.DeleteRouteRule(rule)
if m.hasIPv6() && m.family6.hasRule(id) {
return m.family6.DeleteFilterRule(rule)
}
return m.router.DeleteRouteRule(rule)
log.Debugf("filter rule %s not found in any family", id)
return nil
}
func (m *Manager) IsServerRouteSupported() bool {
@@ -272,10 +218,10 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if !m.hasIPv6() {
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddNatRule(pair)
return m.family6.AddNatRule(pair)
}
if err := m.router.AddNatRule(pair); err != nil {
if err := m.family4.AddNatRule(pair); err != nil {
return err
}
@@ -284,7 +230,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
// wildcard 0.0.0.0/0 destination where the client resolves DNS.
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.router6.AddNatRule(v6Pair); err != nil {
if err := m.family6.AddNatRule(v6Pair); err != nil {
return fmt.Errorf("add v6 NAT rule: %w", err)
}
}
@@ -300,18 +246,18 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
if !m.hasIPv6() {
return nil
}
return m.router6.RemoveNatRule(pair)
return m.family6.RemoveNatRule(pair)
}
var merr *multierror.Error
if err := m.router.RemoveNatRule(pair); err != nil {
if err := m.family4.RemoveNatRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
}
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
if err := m.family6.RemoveNatRule(v6Pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
}
}
@@ -320,11 +266,11 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
}
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
if err := firewall.SetLegacyManagement(m.family4, isLegacy); err != nil {
return err
}
if m.hasIPv6() {
return firewall.SetLegacyManagement(m.router6, isLegacy)
return firewall.SetLegacyManagement(m.family6, isLegacy)
}
return nil
}
@@ -341,19 +287,13 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
}
if m.hasIPv6() {
if err := m.aclMgr6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err))
}
if err := m.router6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err))
if err := m.family6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 family: %w", err))
}
}
if err := m.aclMgr.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
}
if err := m.router.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
if err := m.family4.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset family: %w", err))
}
// Appending to merr intentionally blocks DeleteState below so ShutdownState
@@ -377,11 +317,11 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
// rules so that packet filtering is handled in userspace instead of by netfilter.
func (m *Manager) AllowNetbird() error {
var merr *multierror.Error
if _, err := m.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
if _, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}, firewall.Network{}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept); err != nil {
merr = multierror.Append(merr, fmt.Errorf("allow netbird v4 interface traffic: %w", err))
}
if m.hasIPv6() {
if _, err := m.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
if _, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}, firewall.Network{}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept); err != nil {
merr = multierror.Append(merr, fmt.Errorf("allow netbird v6 interface traffic: %w", err))
}
}
@@ -402,14 +342,14 @@ func (m *Manager) SetLogLevel(log.Level) {
}
func (m *Manager) EnableRouting() error {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
if err := m.family4.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil
}
func (m *Manager) DisableRouting() error {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
if err := m.family4.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil
@@ -424,9 +364,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
if !m.hasIPv6() {
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddDNATRule(rule)
return m.family6.AddDNATRule(rule)
}
return m.router.AddDNATRule(rule)
return m.family4.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
@@ -434,10 +374,10 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) {
return m.router6.DeleteDNATRule(rule)
if m.hasIPv6() && !m.family4.hasDNATRule(rule.ID()) {
return m.family6.DeleteDNATRule(rule)
}
return m.router.DeleteDNATRule(rule)
return m.family4.DeleteDNATRule(rule)
}
// UpdateSet updates the set with the given prefixes
@@ -454,12 +394,12 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
}
}
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
if err := m.family4.UpdateSet(set, v4Prefixes); err != nil {
return err
}
if m.hasIPv6() && len(v6Prefixes) > 0 {
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
if err := m.family6.UpdateSet(set, v6Prefixes); err != nil {
return fmt.Errorf("update v6 set: %w", err)
}
}
@@ -476,9 +416,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
if !m.hasIPv6() {
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family4.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
@@ -490,9 +430,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
if !m.hasIPv6() {
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family4.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
@@ -504,9 +444,9 @@ func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol
if !m.hasIPv6() {
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family4.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
@@ -518,14 +458,14 @@ func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Proto
if !m.hasIPv6() {
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family4.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
const (
chainNameRaw = "NETBIRD-RAW"
chainOUTPUT = "OUTPUT"
chainOutput = "OUTPUT"
tableRaw = "raw"
)
@@ -600,15 +540,15 @@ func (m *Manager) initNoTrackChain() error {
jumpRule := []string{"-j", chainNameRaw}
if err := m.ipv4Client.InsertUnique(tableRaw, chainOUTPUT, 1, jumpRule...); err != nil {
if err := m.ipv4Client.InsertUnique(tableRaw, chainOutput, 1, jumpRule...); err != nil {
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
log.Debugf("delete orphan chain: %v", delErr)
}
return fmt.Errorf("add output jump rule: %w", err)
}
if err := m.ipv4Client.InsertUnique(tableRaw, chainPREROUTING, 1, jumpRule...); err != nil {
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); delErr != nil {
if err := m.ipv4Client.InsertUnique(tableRaw, chainPrerouting, 1, jumpRule...); err != nil {
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOutput, jumpRule...); delErr != nil {
log.Debugf("delete output jump rule: %v", delErr)
}
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
@@ -635,11 +575,11 @@ func (m *Manager) cleanupNoTrackChain() error {
jumpRule := []string{"-j", chainNameRaw}
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); err != nil {
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOutput, jumpRule...); err != nil {
return fmt.Errorf("remove output jump rule: %w", err)
}
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPREROUTING, jumpRule...); err != nil {
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPrerouting, jumpRule...); err != nil {
return fmt.Errorf("remove prerouting jump rule: %w", err)
}
@@ -654,3 +594,13 @@ func (m *Manager) cleanupNoTrackChain() error {
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}
// isIPv6Rule reports whether the rule belongs to the IPv6 family, from
// the destination prefix when set, otherwise from the (single-family)
// sources.
func isIPv6Rule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
return len(sources) > 0 && sources[0].Addr().Is6()
}

View File

@@ -1,3 +1,5 @@
//go:build integration && !android
package iptables
import (
@@ -65,46 +67,39 @@ func TestIptablesManager(t *testing.T) {
time.Sleep(time.Second)
}()
var rule2 []fw.Rule
var rule2 fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{
IsRange: true,
Values: []uint16{8043, 8046},
}
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
rule2, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", port, nil, fw.ActionAccept)
require.NoError(t, err, "failed to add rule")
for _, r := range rule2 {
rr := r.(*Rule)
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
}
rr := rule2.(*Rule)
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
})
t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 {
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
}
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
require.NoError(t, manager.DeleteFilterRule(rule2), "failed to delete rule")
})
t.Run("reset check", func(t *testing.T) {
// add second rule
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{Values: []uint16{5353}}
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "udp", nil, port, fw.ActionAccept)
require.NoError(t, err, "failed to add rule")
err = manager.Close(nil)
require.NoError(t, err, "failed to reset")
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
ok, err := ipv4Client.ChainExists("filter", chainACLInput)
require.NoError(t, err, "failed check chain exists")
if ok {
require.NoErrorf(t, err, "chain '%v' still exists after Close", chainNameInputRules)
require.NoErrorf(t, err, "chain '%v' still exists after Close", chainACLInput)
}
})
}
@@ -126,15 +121,13 @@ func TestIptablesManagerDenyRules(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{Values: []uint16{22}}
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-ssh")
rule, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionDrop)
require.NoError(t, err, "failed to add deny rule")
require.NotEmpty(t, rule, "deny rule should not be empty")
require.NotNil(t, rule, "deny rule should not be nil")
// Verify the rule was added by checking iptables
for _, r := range rule {
rr := r.(*Rule)
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
}
rr := rule.(*Rule)
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
})
t.Run("deny rule precedence test", func(t *testing.T) {
@@ -142,36 +135,40 @@ func TestIptablesManagerDenyRules(t *testing.T) {
port := &fw.Port{Values: []uint16{80}}
// Add accept rule first
_, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "accept-http")
_, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
require.NoError(t, err, "failed to add accept rule")
// Add deny rule second for same IP/port - this should take precedence
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-http")
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionDrop)
require.NoError(t, err, "failed to add deny rule")
// Inspect the actual iptables rules to verify deny rule comes before accept rule
rules, err := ipv4Client.List("filter", chainNameInputRules)
rules, err := ipv4Client.List("filter", chainACLInput)
require.NoError(t, err, "failed to list iptables rules")
// Debug: print all rules
t.Logf("All iptables rules in chain %s:", chainNameInputRules)
t.Logf("All iptables rules in chain %s:", chainACLInput)
for i, rule := range rules {
t.Logf(" [%d] %s", i, rule)
}
// Single-source rules emit a direct `-s <ip>/32 ... --dport 80`
// match. Match on that shape instead of the legacy
// per-(action,port) ipset names ("deny-http"/"accept-http")
// that this test predates.
srcMatch := fmt.Sprintf("-s %s/32", ip)
var denyRuleIndex, acceptRuleIndex = -1, -1
for i, rule := range rules {
if strings.Contains(rule, "DROP") {
t.Logf("Found DROP rule at index %d: %s", i, rule)
if strings.Contains(rule, "deny-http") && strings.Contains(rule, "80") {
denyRuleIndex = i
}
if !strings.Contains(rule, srcMatch) || !strings.Contains(rule, "--dport 80") {
continue
}
if strings.Contains(rule, "ACCEPT") {
if strings.Contains(rule, "-j DROP") {
t.Logf("Found DROP rule at index %d: %s", i, rule)
denyRuleIndex = i
}
if strings.Contains(rule, "-j ACCEPT") {
t.Logf("Found ACCEPT rule at index %d: %s", i, rule)
if strings.Contains(rule, "accept-http") && strings.Contains(rule, "80") {
acceptRuleIndex = i
}
acceptRuleIndex = i
}
}
@@ -196,7 +193,6 @@ func TestIptablesManagerIPSet(t *testing.T) {
},
}
// just check on the local interface
manager, err := Create(mock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
@@ -210,27 +206,39 @@ func TestIptablesManagerIPSet(t *testing.T) {
time.Sleep(time.Second)
}()
var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) {
var rule2 fw.Rule
t.Run("single source uses direct -s match (no ipset)", func(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{
Values: []uint16{443},
}
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
for _, r := range rule2 {
require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
require.Equal(t, r.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
}
rule2, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", port, nil, fw.ActionAccept)
require.NoError(t, err, "failed to add rule")
require.NotNil(t, rule2)
require.Contains(t, rule2.(*Rule).specs, "-s",
"single-source rule should use direct -s match, not an ipset")
require.Empty(t, findSets(rule2.(*Rule).specs),
"single-source rule should not allocate a shared ipset")
})
t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 {
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
t.Run("delete single-source rule", func(t *testing.T) {
require.NoError(t, manager.DeleteFilterRule(rule2), "failed to delete rule")
})
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
t.Run("multi-source uses shared ipset", func(t *testing.T) {
sources := []netip.Prefix{
netip.PrefixFrom(netip.MustParseAddr("10.20.0.3"), 32),
netip.PrefixFrom(netip.MustParseAddr("10.20.0.4"), 32),
netip.PrefixFrom(netip.MustParseAddr("10.20.0.5"), 32),
}
port := &fw.Port{Values: []uint16{8080}}
multi, err := manager.AddFilterRule(nil, sources, fw.Network{}, "tcp", nil, port, fw.ActionAccept)
require.NoError(t, err, "failed to add multi-source rule")
require.NotNil(t, multi, "multi-source rule must produce one iptables rule")
sets := findSets(multi.(*Rule).specs)
require.Len(t, sets, 1, "multi-source rule must reference exactly one ipset")
require.NoError(t, manager.DeleteFilterRule(multi))
})
t.Run("reset check", func(t *testing.T) {
@@ -281,7 +289,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
require.NoError(t, err, "failed to add rule")
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
//go:build !android
//go:build integration && !android
package iptables
@@ -31,7 +31,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "should return a valid iptables manager")
require.NoError(t, manager.init(nil))
@@ -52,12 +52,12 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
// 11. MSS clamping rule for outbound traffic
require.Len(t, manager.rules, 11, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
exists, err := manager.iptablesClient.Exists(tableNat, chainPostrouting, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPostrouting)
require.True(t, exists, "postrouting jump rule should exist")
exists, err = manager.iptablesClient.Exists(tableMangle, chainPREROUTING, "-j", chainRTPRE)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
exists, err = manager.iptablesClient.Exists(tableMangle, chainPrerouting, "-j", chainRTPre)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPrerouting)
require.True(t, exists, "prerouting jump rule should exist")
pair := firewall.RouterPair{
@@ -84,7 +84,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
@@ -95,7 +95,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "marking rule should be inserted")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
natRuleKey := testCase.InputPair.GenKey(firewall.NatFormat)
markingRule := []string{
"-i", ifaceMock.Name(),
"-m", "conntrack",
@@ -106,8 +106,8 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
}
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
exists, err := iptablesClient.Exists(tableMangle, chainRTPre, markingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
if testCase.InputPair.Masquerade {
require.True(t, exists, "marking rule should be created")
foundRule, found := manager.rules[natRuleKey]
@@ -121,7 +121,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
// Check inverse rule
inversePair := firewall.GetInversePair(testCase.InputPair)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseRuleKey := inversePair.GenKey(firewall.NatFormat)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
@@ -132,8 +132,8 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
exists, err = iptablesClient.Exists(tableMangle, chainRTPre, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
if testCase.InputPair.Masquerade {
require.True(t, exists, "inverse marking rule should be created")
foundRule, found := manager.rules[inverseRuleKey]
@@ -157,7 +157,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() {
@@ -170,7 +170,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
natRuleKey := testCase.InputPair.GenKey(firewall.NatFormat)
markingRule := []string{
"-i", ifaceMock.Name(),
"-m", "conntrack",
@@ -181,8 +181,8 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
}
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
exists, err := iptablesClient.Exists(tableMangle, chainRTPre, markingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
require.False(t, exists, "marking rule should not exist")
_, found := manager.rules[natRuleKey]
@@ -190,7 +190,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
// Check inverse rule removal
inversePair := firewall.GetInversePair(testCase.InputPair)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseRuleKey := inversePair.GenKey(firewall.NatFormat)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
@@ -201,8 +201,8 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
exists, err = iptablesClient.Exists(tableMangle, chainRTPre, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
require.False(t, exists, "inverse marking rule should not exist")
_, found = manager.rules[inverseRuleKey]
@@ -219,13 +219,13 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "Failed to create iptables client")
r, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router manager")
r, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create family manager")
require.NoError(t, r.init(nil))
defer func() {
err := r.Reset()
require.NoError(t, err, "Failed to reset router")
require.NoError(t, err, "Failed to reset family")
}()
tests := []struct {
@@ -334,62 +334,30 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
ruleKey, err := r.AddFilterRule(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddFilterRule failed")
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "Rule not found in internal map")
stored, ok := r.filters[ruleKey.ID()]
require.True(t, ok, "rule not stored in filters")
t.Logf("Internal rule: %v", stored.specs)
// Log the internal rule
t.Logf("Internal rule: %v", rule)
// Check if the rule exists in iptables
exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, rule...)
exists, err := iptablesClient.Exists(tableFilter, chainRTFwdIn, stored.specs...)
assert.NoError(t, err, "Failed to check rule existence")
assert.True(t, exists, "Rule not found in iptables")
var source firewall.Network
if len(tt.sources) > 1 {
source.Set = firewall.NewPrefixSet(tt.sources)
} else if len(tt.sources) > 0 {
source.Prefix = tt.sources[0]
}
// Verify rule content
params := routeFilteringRuleParams{
Source: source,
Destination: firewall.Network{Prefix: tt.destination},
Proto: tt.proto,
SPort: tt.sPort,
DPort: tt.dPort,
Action: tt.action,
}
expectedRule, err := r.genRouteRuleSpec(params, nil)
require.NoError(t, err, "Failed to generate expected rule spec")
if tt.expectSet {
setName := firewall.NewPrefixSet(tt.sources).HashedName()
expectedRule, err = r.genRouteRuleSpec(params, nil)
require.NoError(t, err, "Failed to generate expected rule spec with set")
// Check if the set was created
_, exists := r.ipsetCounter.Get(setName)
assert.True(t, exists, "IPSet not created")
assert.NotEmpty(t, findSets(stored.specs), "Rule should reference an ipset")
}
assert.Equal(t, expectedRule, rule, "Rule content mismatch")
// Clean up
err = r.DeleteRouteRule(ruleKey)
require.NoError(t, err, "Failed to delete rule")
require.NoError(t, r.DeleteFilterRule(ruleKey), "Failed to delete rule")
})
}
}
func TestFindSetNameInRule(t *testing.T) {
r := &router{}
testCases := []struct {
name string
rule []string
@@ -430,7 +398,7 @@ func TestFindSetNameInRule(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := r.findSets(tc.rule)
result := findSets(tc.rule)
if len(result) != len(tc.expected) {
t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result)

View File

@@ -0,0 +1,269 @@
//go:build !android
package iptables
import (
"fmt"
"strings"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/client/net"
)
func (r *family) AddNatRule(pair firewall.RouterPair) error {
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
return fmt.Errorf("add legacy routing rule: %w", err)
}
}
if !pair.Masquerade {
return nil
}
if err := r.addNatRule(pair); err != nil {
return fmt.Errorf("add nat rule: %w", err)
}
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("add inverse nat rule: %w", err)
}
r.updateState()
return nil
}
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *family) RemoveNatRule(pair firewall.RouterPair) error {
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err)
}
}
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
r.updateState()
return nil
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *family) addLegacyRouteRule(pair firewall.RouterPair) error {
ruleID := pair.GenKey(firewall.ForwardingFormat)
if err := r.removeLegacyRouteRule(pair); err != nil {
return err
}
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", "ACCEPT"}
if err := r.iptablesClient.Append(tableFilter, chainRTFwdIn, rule...); err != nil {
return fmt.Errorf("add legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
r.rules[ruleID] = rule
return nil
}
func (r *family) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleID := pair.GenKey(firewall.ForwardingFormat)
if rule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFwdIn, rule...); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
delete(r.rules, ruleID)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement ipset counter: %w", err)
}
}
return nil
}
// GetLegacyManagement returns the current legacy management mode
func (r *family) GetLegacyManagement() bool {
return r.legacyManagement
}
// SetLegacyManagement sets the route manager to use legacy management mode
func (r *family) SetLegacyManagement(isLegacy bool) {
r.legacyManagement = isLegacy
}
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
func (r *family) RemoveAllLegacyRouteRules() error {
var merr *multierror.Error
for k, rule := range r.rules {
if !strings.HasPrefix(string(k), firewall.ForwardingFormatPrefix) {
continue
}
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFwdIn, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %w", err))
} else {
delete(r.rules, k)
}
}
r.updateState()
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) addPostroutingRules() error {
// First rule for outbound masquerade
rule1 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
"!", "-o", "lo",
"-j", "MASQUERADE",
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
return fmt.Errorf("add outbound masquerade rule: %w", err)
}
r.rules["static-nat-outbound"] = rule1
// Second rule for return traffic masquerade
rule2 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
"-o", r.wgIface.Name(),
"-j", "MASQUERADE",
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
return fmt.Errorf("add return masquerade rule: %w", err)
}
r.rules["static-nat-return"] = rule2
return nil
}
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
func (r *family) addMSSClampingRules() error {
overhead := uint16(ipv4TCPHeaderSize)
if r.v6 {
overhead = ipv6TCPHeaderSize
}
mss := r.mtu - overhead
// Add jump rule from FORWARD chain in mangle table to our custom chain
jumpRule := []string{
"-j", chainRTMSSClamp,
}
if err := r.iptablesClient.Insert(tableMangle, chainForward, 1, jumpRule...); err != nil {
return fmt.Errorf("add jump to MSS clamp chain: %w", err)
}
r.rules[jumpMSSClamp] = jumpRule
ruleOut := []string{
"-o", r.wgIface.Name(),
"-p", "tcp",
"--tcp-flags", "SYN,RST", "SYN",
"-j", "TCPMSS",
"--set-mss", fmt.Sprintf("%d", mss),
}
if err := r.iptablesClient.Append(tableMangle, chainRTMSSClamp, ruleOut...); err != nil {
return fmt.Errorf("add outbound MSS clamp rule: %w", err)
}
r.rules["mss-clamp-out"] = ruleOut
return nil
}
func (r *family) insertEstablishedRule(chain string) error {
establishedRule := getConntrackEstablished()
err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...)
if err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
ruleID := firewall.RuleID("established-" + chain)
r.rules[ruleID] = establishedRule
return nil
}
func (r *family) addNatRule(pair firewall.RouterPair) error {
ruleID := pair.GenKey(firewall.NatFormat)
if rule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPre, rule...); err != nil {
return fmt.Errorf("remove existing marking rule for %s: %w", pair.Destination, err)
}
delete(r.rules, ruleID)
}
markValue := nbnet.PreroutingFwmarkMasquerade
if pair.Inverse {
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
rule := []string{"-i", r.wgIface.Name()}
if pair.Inverse {
rule = []string{"!", "-i", r.wgIface.Name()}
}
rule = append(rule,
"-m", "conntrack",
"--ctstate", "NEW",
)
sourceExp, err := r.applyNetwork("-s", pair.Source, nil)
if err != nil {
return fmt.Errorf("apply network -s: %w", err)
}
destExp, err := r.applyNetwork("-d", pair.Destination, nil)
if err != nil {
return fmt.Errorf("apply network -d: %w", err)
}
rule = append(rule, sourceExp...)
rule = append(rule, destExp...)
rule = append(rule,
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
)
// Ensure nat rules come first, so the mark can be overwritten.
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
if err := r.iptablesClient.Insert(tableMangle, chainRTPre, 1, rule...); err != nil {
r.dropSourceMatch(rule)
return fmt.Errorf("add marking rule for %s: %w", pair.Destination, err)
}
r.rules[ruleID] = rule
r.updateState()
return nil
}
func (r *family) removeNatRule(pair firewall.RouterPair) error {
ruleID := pair.GenKey(firewall.NatFormat)
if rule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPre, rule...); err != nil {
return fmt.Errorf("remove marking rule for %s: %w", pair.Destination, err)
}
delete(r.rules, ruleID)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement ipset counter: %w", err)
}
} else {
log.Debugf("marking rule %s not found", ruleID)
}
r.updateState()
return nil
}

View File

@@ -1,18 +1,20 @@
package iptables
// Rule to handle management of rules
type Rule struct {
ruleID string
ipsetName string
import "github.com/netbirdio/netbird/client/firewall/manager"
// Rule to handle management of rules. Source set membership (when the
// rule was built against a shared hash:net ipset) is encoded in specs;
// DeleteFilterRule recovers it via findSets so the refcounter can drop
// the right reference.
type Rule struct {
id manager.RuleID
specs []string
mangleSpecs []string
ip string
chain string
v6 bool
}
// GetRuleID returns the rule id
func (r *Rule) ID() string {
return r.ruleID
// ID returns the rule id
func (r *Rule) ID() manager.RuleID {
return r.id
}

View File

@@ -1,103 +0,0 @@
package iptables
import "encoding/json"
type ipList struct {
ips map[string]struct{}
}
func newIpList(ip string) *ipList {
ips := make(map[string]struct{})
ips[ip] = struct{}{}
return &ipList{
ips: ips,
}
}
func (s *ipList) addIP(ip string) {
s.ips[ip] = struct{}{}
}
// MarshalJSON implements json.Marshaler
func (s *ipList) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPs map[string]struct{} `json:"ips"`
}{
IPs: s.ips,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipList) UnmarshalJSON(data []byte) error {
temp := struct {
IPs map[string]struct{} `json:"ips"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ips = temp.IPs
if temp.IPs == nil {
temp.IPs = make(map[string]struct{})
}
return nil
}
type ipsetStore struct {
ipsets map[string]*ipList
}
func newIpsetStore() *ipsetStore {
return &ipsetStore{
ipsets: make(map[string]*ipList),
}
}
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
r, ok := s.ipsets[ipsetName]
return r, ok
}
func (s *ipsetStore) addIpList(ipsetName string, list *ipList) {
s.ipsets[ipsetName] = list
}
func (s *ipsetStore) deleteIpset(ipsetName string) {
delete(s.ipsets, ipsetName)
}
func (s *ipsetStore) ipsetNames() []string {
names := make([]string, 0, len(s.ipsets))
for name := range s.ipsets {
names = append(names, name)
}
return names
}
// MarshalJSON implements json.Marshaler
func (s *ipsetStore) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPSets map[string]*ipList `json:"ipsets"`
}{
IPSets: s.ipsets,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipsetStore) UnmarshalJSON(data []byte) error {
temp := struct {
IPSets map[string]*ipList `json:"ipsets"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ipsets = temp.IPSets
if temp.IPSets == nil {
temp.IPSets = make(map[string]*ipList)
}
return nil
}

View File

@@ -29,17 +29,13 @@ type ShutdownState struct {
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
RouteRules routeRules `json:"route_rules,omitempty"`
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
ACLEntries aclEntries `json:"acl_entries,omitempty"`
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
// IPv6 counterparts
RouteRules routeRules `json:"route_rules,omitempty"`
RouteRules6 routeRules `json:"route_rules_v6,omitempty"`
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"`
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"`
ACLEntries aclEntries `json:"acl_entries,omitempty"`
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
}
func (s *ShutdownState) Name() string {
@@ -57,17 +53,14 @@ func (s *ShutdownState) Cleanup() error {
}
if s.RouteRules != nil {
ipt.router.rules = s.RouteRules
ipt.family4.rules = s.RouteRules
}
if s.RouteIPsetCounter != nil {
ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter)
ipt.family4.ipsetCounter.LoadData(s.RouteIPsetCounter)
}
if s.ACLEntries != nil {
ipt.aclMgr.entries = s.ACLEntries
}
if s.ACLIPsetStore != nil {
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
ipt.family4.entries = s.ACLEntries
}
// Clean up v6 state even if the current run has no IPv6.
@@ -79,16 +72,13 @@ func (s *ShutdownState) Cleanup() error {
}
if ipt.hasIPv6() {
if s.RouteRules6 != nil {
ipt.router6.rules = s.RouteRules6
ipt.family6.rules = s.RouteRules6
}
if s.RouteIPsetCounter6 != nil {
ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
ipt.family6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
}
if s.ACLEntries6 != nil {
ipt.aclMgr6.entries = s.ACLEntries6
}
if s.ACLIPsetStore6 != nil {
ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6
ipt.family6.entries = s.ACLEntries6
}
}

View File

@@ -0,0 +1,27 @@
//go:build integration && !android
package iptables
import (
"fmt"
"net"
"net/netip"
)
func pfx(ip net.IP) []netip.Prefix {
if ip == nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
if ip.IsUnspecified() {
if ip.To4() != nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
}
a, ok := netip.AddrFromSlice(ip)
if !ok {
panic(fmt.Sprintf("invalid IP length: %d", len(ip)))
}
a = a.Unmap()
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
}

View File

@@ -3,7 +3,6 @@ package manager
import (
"errors"
"fmt"
"net"
"net/netip"
"sort"
@@ -16,6 +15,12 @@ import (
// method but the IPv6 firewall components were not initialized.
var ErrIPv6NotInitialized = errors.New("IPv6 firewall not initialized")
// ErrNoSources is returned when AddFilterRule is called with an empty
// source list. "Match any source" must be expressed explicitly with a
// /0 prefix; an empty list is a caller error and is rejected rather
// than silently widening the rule to every source.
var ErrNoSources = errors.New("rule has no sources")
const (
ForwardingFormatPrefix = "netbird-fwd-"
ForwardingFormat = "netbird-fwd-%s-%t"
@@ -23,13 +28,18 @@ const (
NatFormat = "netbird-nat-%s-%t"
)
// RuleID identifies a firewall rule. It is a typed string so the
// compiler catches accidental mixing with arbitrary string keys. It is
// only an identifier and does not implement Rule.
type RuleID string
// Rule abstraction should be implemented by each firewall manager
//
// Each firewall type for different OS can use different type
// of the properties to hold data of the created rule
type Rule interface {
// ID returns the rule id
ID() string
ID() RuleID
}
// RuleDirection is the traffic direction which a rule is applied
@@ -91,6 +101,13 @@ func (d Network) IsPrefix() bool {
return d.Prefix.IsValid()
}
// IsZero returns true if the network designates no destination, i.e. it
// is the zero value. A zero Network is the peer-rule sentinel; a non-zero
// one carries a prefix or set destination.
func (d Network) IsZero() bool {
return !d.IsPrefix() && !d.IsSet()
}
// Manager is the high level abstraction of a firewall manager
//
// It declares methods which handle actions required by the
@@ -101,43 +118,42 @@ type Manager interface {
// AllowNetbird allows netbird interface traffic
AllowNetbird() error
// AddPeerFiltering adds a rule to the firewall
// AddFilterRule adds a packet-filtering rule to the firewall.
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
// If destination is the zero Network, the rule applies to traffic
// inbound to this node, i.e. peer ACL semantics, installed in
// the kernel's input chain. If destination is set (prefix or
// set), the rule applies to forwarded traffic with that
// destination, route ACL semantics, installed in the forward
// chain.
//
// Note: Callers should call Flush() after adding rules to ensure
// they are applied to the kernel and rule handles are refreshed.
AddPeerFiltering(
// sources must be a single address family; the caller splits mixed
// families and calls once per family. "Match any source" must be
// expressed with an explicit /0 prefix; an empty sources list is
// rejected with ErrNoSources so a zeroed list can never widen a
// rule to every source.
//
// Note: callers should call Flush() after adding rules.
AddFilterRule(
id []byte,
ip net.IP,
sources []netip.Prefix,
destination Network,
proto Protocol,
sPort *Port,
dPort *Port,
action Action,
ipsetName string,
) ([]Rule, error)
) (Rule, error)
// DeletePeerRule from the firewall by rule definition
DeletePeerRule(rule Rule) error
// DeleteFilterRule removes a filtering rule previously added via
// AddFilterRule. The rule's own type identifies whether it lives
// in the peer (input) or route (forward) path.
DeleteFilterRule(rule Rule) error
// IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool
IsStateful() bool
AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination Network,
proto Protocol,
sPort, dPort *Port,
action Action,
) (Rule, error)
// DeleteRouteRule deletes a routing rule
DeleteRouteRule(rule Rule) error
// AddNatRule inserts a routing NAT rule
AddNatRule(pair RouterPair) error
@@ -185,8 +201,9 @@ type Manager interface {
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
}
func GenKey(format string, pair RouterPair) string {
return fmt.Sprintf(format, pair.ID, pair.Inverse)
// GenKey builds the rule id for this pair from the given format.
func (p RouterPair) GenKey(format string) RuleID {
return RuleID(fmt.Sprintf(format, p.ID, p.Inverse))
}
// LegacyManager defines the interface for legacy management operations
@@ -242,6 +259,20 @@ func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
return merged
}
// UnmapPrefix normalizes a v4-mapped v6 prefix (::ffff:a.b.c.d) to its
// plain v4 form, shifting the prefix length out of the 96-bit mapped
// range. Other prefixes are returned unchanged. Keeping prefixes
// unmapped ensures v4 rules match consistently and the match builders
// read the correct address length.
func UnmapPrefix(p netip.Prefix) netip.Prefix {
addr := p.Addr()
if !addr.Is4In6() {
return p
}
bits := max(p.Bits()-96, 0)
return netip.PrefixFrom(addr.Unmap(), bits)
}
// SortPrefixes sorts the given slice of netip.Prefix in place.
// It sorts first by IP address, then by prefix length (most specific to least specific).
func SortPrefixes(prefixes []netip.Prefix) {

View File

@@ -13,13 +13,13 @@ type ForwardRule struct {
TranslatedPort Port
}
func (r ForwardRule) ID() string {
func (r ForwardRule) ID() RuleID {
id := fmt.Sprintf("%s;%s;%s;%s",
r.Protocol,
r.DestinationPort.String(),
r.TranslatedAddress.String(),
r.TranslatedPort.String())
return id
return RuleID(id)
}
func (r ForwardRule) String() string {

View File

@@ -40,7 +40,7 @@ func (h Set) Comment() string {
// NewPrefixSet generates a unique name for an ipset based on the given prefixes.
func NewPrefixSet(prefixes []netip.Prefix) Set {
// sort for consistent naming
prefixes = slices.Clone(prefixes)
SortPrefixes(prefixes)
hash := sha256.New()

View File

@@ -1,713 +0,0 @@
package nftables
import (
"bytes"
"fmt"
"net"
"slices"
"strconv"
"strings"
"time"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (
// rules chains contains the effective ACL rules
chainNameInputRules = "netbird-acl-input-rules"
// filter chains contains the rules that jump to the rules chains
chainNameInputFilter = "netbird-acl-input-filter"
chainNameForwardFilter = "netbird-acl-forward-filter"
chainNameManglePrerouting = "netbird-mangle-prerouting"
chainNameManglePostrouting = "netbird-mangle-postrouting"
)
const flushError = "flush: %w"
type AclManager struct {
rConn *nftables.Conn
sConn *nftables.Conn
wgIface iFaceMapper
routingFwChainName string
af addrFamily
workTable *nftables.Table
chainInputRules *nftables.Chain
chainPrerouting *nftables.Chain
ipsetStore *ipsetStore
rules map[string]*Rule
}
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
// sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation)
// and is permanent. Using same connection for both type of operations
// overloads netlink with high amount of rules ( > 10000)
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return nil, fmt.Errorf("create nf conn: %w", err)
}
return &AclManager{
rConn: &nftables.Conn{},
sConn: sConn,
wgIface: wgIface,
workTable: table,
routingFwChainName: routingFwChainName,
af: familyForAddr(table.Family == nftables.TableFamilyIPv4),
ipsetStore: newIpsetStore(),
rules: make(map[string]*Rule),
}, nil
}
func (m *AclManager) init(workTable *nftables.Table) error {
m.workTable = workTable
return m.createDefaultChains()
}
// AddPeerFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *AclManager) AddPeerFiltering(
id []byte,
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
var ipset *nftables.Set
if ipsetName != "" {
var err error
ipset, err = m.addIpToSet(ipsetName, ip)
if err != nil {
return nil, err
}
}
newRules := make([]firewall.Rule, 0, 2)
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset)
if err != nil {
return nil, err
}
newRules = append(newRules, ioRule)
return newRules, nil
}
// DeletePeerRule from the firewall by rule definition
func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
}
if r.nftSet == nil {
if err := m.rConn.DelRule(r.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
if r.mangleRule != nil {
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
delete(m.rules, r.ID())
return m.rConn.Flush()
}
ips, ok := m.ipsetStore.ips(r.nftSet.Name)
if !ok {
if err := m.rConn.DelRule(r.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
if r.mangleRule != nil {
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
delete(m.rules, r.ID())
return m.rConn.Flush()
}
if _, ok := ips[r.ip.String()]; ok {
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}})
if err != nil {
log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err)
}
if err := m.sConn.Flush(); err != nil {
log.Debugf("flush error of set delete element, %s", r.nftSet.Name)
return err
}
m.ipsetStore.DeleteIpFromSet(r.nftSet.Name, r.ip)
}
// if after delete, set still contains other IPs,
// no need to delete firewall rule and we should exit here
if len(ips) > 0 {
return nil
}
if err := m.rConn.DelRule(r.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
if r.mangleRule != nil {
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
if err := m.rConn.Flush(); err != nil {
return err
}
delete(m.rules, r.ID())
m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name)
if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) {
return nil
}
// we delete last IP from the set, that means we need to delete
// set itself and associated firewall rule too
m.rConn.FlushSet(r.nftSet)
m.rConn.DelSet(r.nftSet)
m.ipsetStore.deleteIpset(r.nftSet.Name)
return nil
}
// createDefaultAllowRules creates default allow rules for the input and output chains
func (m *AclManager) createDefaultAllowRules() error {
expIn := []expr.Any{
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainInputRules,
Position: 0,
Exprs: expIn,
})
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
// Flush rule/chain/set operations from the buffer
//
// Method also get all rules after flush and refreshes handle values in the rulesets
func (m *AclManager) Flush() error {
if err := m.flushWithBackoff(); err != nil {
return err
}
if err := m.refreshRuleHandles(m.chainInputRules, false); err != nil {
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
}
if err := m.refreshRuleHandles(m.chainPrerouting, true); err != nil {
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
}
return nil
}
func (m *AclManager) addIOFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipset *nftables.Set,
) (*Rule, error) {
ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset)
if r, ok := m.rules[ruleId]; ok {
return &Rule{
nftRule: r.nftRule,
mangleRule: r.mangleRule,
nftSet: r.nftSet,
ruleID: r.ruleID,
ip: ip,
}, nil
}
var expressions []expr.Any
if proto != firewall.ProtocolALL {
expressions = append(expressions, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: m.af.protoOffset,
Len: uint32(1),
})
protoData, err := m.af.protoNum(proto)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %v", err)
}
expressions = append(expressions, &expr.Cmp{
Register: 1,
Op: expr.CmpOpEq,
Data: []byte{protoData},
})
}
rawIP := ipToBytes(ip, m.af)
// check if rawIP contains zeroed IPv4 0.0.0.0 value
// in that case not add IP match expression into the rule definition
if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: m.af.srcAddrOffset,
Len: m.af.addrLen,
},
)
// add individual IP for match if no ipset defined
if ipset == nil {
expressions = append(expressions,
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rawIP,
},
)
} else {
expressions = append(expressions,
&expr.Lookup{
SourceRegister: 1,
SetName: ipset.Name,
SetID: ipset.ID,
},
)
}
}
expressions = append(expressions, applyPort(sPort, true)...)
expressions = append(expressions, applyPort(dPort, false)...)
mainExpressions := slices.Clone(expressions)
switch action {
case firewall.ActionAccept:
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictAccept})
case firewall.ActionDrop:
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
}
userData := []byte(ruleId)
chain := m.chainInputRules
rule := &nftables.Rule{
Table: m.workTable,
Chain: chain,
Exprs: mainExpressions,
UserData: userData,
}
// Insert DROP rules at the beginning, append ACCEPT rules at the end
var nftRule *nftables.Rule
if action == firewall.ActionDrop {
nftRule = m.rConn.InsertRule(rule)
} else {
nftRule = m.rConn.AddRule(rule)
}
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err)
}
ruleStruct := &Rule{
nftRule: nftRule,
// best effort mangle rule
mangleRule: m.createPreroutingRule(expressions, userData),
nftSet: ipset,
ruleID: ruleId,
ip: ip,
}
m.rules[ruleId] = ruleStruct
if ipset != nil {
m.ipsetStore.AddReferenceToIpset(ipset.Name)
}
return ruleStruct, nil
}
func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
if m.chainPrerouting == nil {
log.Warn("prerouting chain is not created")
return nil
}
preroutingExprs := slices.Clone(expressions)
// interface
preroutingExprs = append([]expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
}, preroutingExprs...)
// local destination and mark
preroutingExprs = append(preroutingExprs,
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
)
nfRule := m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainPrerouting,
Exprs: preroutingExprs,
UserData: userData,
})
if err := m.rConn.Flush(); err != nil {
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
return nil
}
return nfRule
}
func (m *AclManager) createDefaultChains() (err error) {
// chainNameInputRules
chain := m.createChain(chainNameInputRules)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
return fmt.Errorf(flushError, err)
}
m.chainInputRules = chain
// netbird-acl-input-filter
// type filter hook input priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
m.addDropExpressions(chain, expr.MetaKeyIIFNAME)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
return err
}
// netbird-acl-forward-filter
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
return fmt.Errorf(flushError, err)
}
if err := m.allowRedirectedTraffic(chainFwFilter); err != nil {
log.Errorf("failed to allow redirected traffic: %s", err)
}
return nil
}
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP.
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
// Chain is created by route manager
// TODO: move creation to a common place
m.chainPrerouting = &nftables.Chain{
Name: chainNameManglePrerouting,
Table: m.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
}
m.addFwmarkToForward(chainFwFilter)
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: chainFwFilter,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
})
}
func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.routingFwChainName,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: chainFwFilter,
Exprs: expressions,
})
}
func (m *AclManager) createChain(name string) *nftables.Chain {
chain := &nftables.Chain{
Name: name,
Table: m.workTable,
}
chain = m.rConn.AddChain(chain)
insertReturnTrafficRule(m.rConn, m.workTable, chain)
return chain
}
func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{
Name: name,
Table: m.workTable,
Hooknum: hookNum,
Priority: nftables.ChainPriorityFilter,
Type: nftables.ChainTypeFilter,
Policy: &polAccept,
}
return m.rConn.AddChain(chain)
}
func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{Kind: expr.VerdictDrop},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: chain,
Exprs: expressions,
})
return nil
}
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: to,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: expressions,
})
}
func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) {
ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName)
rawIP := ipToBytes(ip, m.af)
if err != nil {
if ipset, err = m.createSet(m.workTable, ipsetName); err != nil {
return nil, fmt.Errorf("get set name: %v", err)
}
m.ipsetStore.newIpset(ipset.Name)
}
if m.ipsetStore.IsIpInSet(ipset.Name, ip) {
return ipset, nil
}
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
return nil, fmt.Errorf("add set element for the first time: %v", err)
}
m.ipsetStore.AddIpToSet(ipset.Name, ip)
if err := m.sConn.Flush(); err != nil {
return nil, fmt.Errorf("flush add elements: %v", err)
}
return ipset, nil
}
// createSet in given table by name
func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Set, error) {
ipset := &nftables.Set{
Name: name,
Table: table,
Dynamic: true,
KeyType: m.af.setKeyType,
}
if err := m.rConn.AddSet(ipset, nil); err != nil {
return nil, fmt.Errorf("create set: %v", err)
}
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush created set: %v", err)
}
return ipset, nil
}
func (m *AclManager) flushWithBackoff() (err error) {
backoff := 4
backoffTime := 1000 * time.Millisecond
for i := 0; ; i++ {
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to flush nftables: %v", err)
if !strings.Contains(err.Error(), "busy") {
return
}
log.Error("failed to flush nftables, retrying...")
if i == backoff-1 {
return err
}
time.Sleep(backoffTime)
backoffTime *= 2
continue
}
break
}
return
}
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
if m.workTable == nil || chain == nil {
return nil
}
list, err := m.rConn.GetRules(m.workTable, chain)
if err != nil {
return err
}
for _, rule := range list {
if len(rule.UserData) == 0 {
continue
}
split := bytes.Split(rule.UserData, []byte(" "))
r, ok := m.rules[string(split[0])]
if ok {
if mangle {
*r.mangleRule = *rule
} else {
*r.nftRule = *rule
}
}
}
return nil
}
func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
rulesetID := ":" + string(proto) + ":"
if sPort != nil {
rulesetID += sPort.String()
}
rulesetID += ":"
if dPort != nil {
rulesetID += dPort.String()
}
rulesetID += ":"
rulesetID += strconv.Itoa(int(action))
if ipset == nil {
return "ip:" + ip.String() + rulesetID
}
return "set:" + ipset.Name + rulesetID
}
func ifname(n string) []byte {
b := make([]byte, 16)
copy(b, n+"\x00")
return b
}
// ipToBytes converts net.IP to the correct byte length for the address family.
func ipToBytes(ip net.IP, af addrFamily) []byte {
if af.addrLen == 4 {
return ip.To4()
}
return ip.To16()
}

View File

@@ -0,0 +1,885 @@
//go:build !android
package nftables
import (
"bytes"
"errors"
"fmt"
"slices"
"strings"
"time"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/client/net"
)
func (r *family) createContainers() error {
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingFw,
Table: r.workTable,
})
prio := *nftables.ChainPriorityNATSource - 1
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: &prio,
Type: nftables.ChainTypeNAT,
})
r.chains[chainNameRoutingRdr] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingRdr,
Table: r.workTable,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
})
r.chains[chainNameManglePostrouting] = r.conn.AddChain(&nftables.Chain{
Name: chainNameManglePostrouting,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
})
r.chains[chainNameManglePrerouting] = r.conn.AddChain(&nftables.Chain{
Name: chainNameManglePrerouting,
Table: r.workTable,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
})
r.chains[chainNameMangleForward] = r.conn.AddChain(&nftables.Chain{
Name: chainNameMangleForward,
Table: r.workTable,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
})
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
r.addPostroutingRules()
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("initialize tables: %v", err)
}
if err := r.addMSSClampingRules(); err != nil {
log.Errorf("failed to add MSS clamping rules: %s", err)
}
if err := r.acceptForwardRules(); err != nil {
log.Errorf("failed to add accept rules for the forward chain: %s", err)
}
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
if err := r.refreshRulesMap(); err != nil {
log.Errorf("failed to refresh rules: %s", err)
}
return nil
}
// setupDataPlaneMark configures the fwmark for the data plane
func (r *family) setupDataPlaneMark() error {
if r.chains[chainNameManglePrerouting] == nil || r.chains[chainNameManglePostrouting] == nil {
return errors.New("no mangle chains found")
}
ctNew := getCtNewExprs()
preExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
preExprs = append(preExprs, ctNew...)
preExprs = append(preExprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkIn),
},
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
SourceRegister: true,
},
)
preNftRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameManglePrerouting],
Exprs: preExprs,
}
r.conn.AddRule(preNftRule)
postExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
postExprs = append(postExprs, ctNew...)
postExprs = append(postExprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkOut),
},
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
SourceRegister: true,
},
)
postNftRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameManglePostrouting],
Exprs: postExprs,
}
r.conn.AddRule(postNftRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
}
return nil
}
func (r *family) acceptForwardRules() error {
var merr *multierror.Error
if err := r.acceptFilterTableRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) acceptFilterTableRules() error {
if r.filterTable == nil {
return nil
}
fw := "iptables"
defer func() {
log.Debugf("Used %s to add accept forward and input rules", fw)
}()
// Try iptables first and fallback to nftables if iptables is not available.
// Use the correct protocol (iptables vs ip6tables) for the address family.
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
if err != nil {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables"
return r.acceptFilterRulesNftables(r.filterTable)
}
if err := r.acceptFilterRulesIptables(ipt); err != nil {
log.Warnf("iptables failed (table may be incompatible), falling back to nftables: %v", err)
fw = "nftables"
return r.acceptFilterRulesNftables(r.filterTable)
}
return nil
}
func (r *family) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err))
} else {
log.Debugf("added iptables forward rule: %v", rule)
}
}
inputRule := r.getAcceptInputRule()
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err))
} else {
log.Debugf("added iptables input rule: %v", inputRule)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) getAcceptForwardRules() [][]string {
intf := r.wgIface.Name()
return [][]string{
{"-i", intf, "-j", "ACCEPT"},
{"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"},
}
}
func (r *family) getAcceptInputRule() []string {
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
}
// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables.
// This is used when iptables is not available.
func (r *family) acceptFilterRulesNftables(table *nftables.Table) error {
intf := ifname(r.wgIface.Name())
forwardChain := &nftables.Chain{
Name: chainNameForward,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
}
r.insertForwardAcceptRules(forwardChain, intf)
inputChain := &nftables.Chain{
Name: chainNameInput,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
}
r.insertInputAcceptRule(inputChain, intf)
return r.conn.Flush()
}
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
// It dynamically finds chains at call time to handle chains that may have been created after startup.
func (r *family) acceptExternalChainsRules() error {
chains := r.findExternalChains()
if len(chains) == 0 {
return nil
}
intf := ifname(r.wgIface.Name())
for _, chain := range chains {
r.applyExternalChainAccept(chain, intf)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush external chain rules: %w", err)
}
return nil
}
func (r *family) applyExternalChainAccept(chain *nftables.Chain, intf []byte) {
if chain.Hooknum == nil {
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
return
}
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
switch *chain.Hooknum {
case *nftables.ChainHookForward:
r.insertForwardAcceptRules(chain, intf)
case *nftables.ChainHookInput:
r.insertInputAcceptRule(chain, intf)
}
}
func (r *family) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
existing, err := r.existingNetbirdRulesInChain(chain)
if err != nil {
log.Warnf("skip forward accept rules in %s/%s: %v", chain.Table.Name, chain.Name, err)
return
}
r.insertForwardIifRule(chain, intf, existing)
r.insertForwardOifEstablishedRule(chain, intf, existing)
}
func (r *family) insertForwardIifRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
if existing[userDataAcceptForwardRuleIif] {
return
}
r.conn.InsertRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(userDataAcceptForwardRuleIif),
})
}
func (r *family) insertForwardOifEstablishedRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
if existing[userDataAcceptForwardRuleOif] {
return
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
}
r.conn.InsertRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: append(exprs, getEstablishedExprs(2)...),
UserData: []byte(userDataAcceptForwardRuleOif),
})
}
func (r *family) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
existing, err := r.existingNetbirdRulesInChain(chain)
if err != nil {
log.Warnf("skip input accept rule in %s/%s: %v", chain.Table.Name, chain.Name, err)
return
}
if existing[userDataAcceptInputRule] {
return
}
r.conn.InsertRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(userDataAcceptInputRule),
})
}
// existingNetbirdRulesInChain returns the set of netbird-owned UserData tags present in a chain; callers must bail on error since InsertRule is additive.
func (r *family) existingNetbirdRulesInChain(chain *nftables.Chain) (map[string]bool, error) {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return nil, fmt.Errorf("list rules: %w", err)
}
present := map[string]bool{}
for _, rule := range rules {
if !isNetbirdAcceptRuleTag(rule.UserData) {
continue
}
present[string(rule.UserData)] = true
}
return present, nil
}
func isNetbirdAcceptRuleTag(userData []byte) bool {
switch string(userData) {
case userDataAcceptForwardRuleIif,
userDataAcceptForwardRuleOif,
userDataAcceptInputRule:
return true
}
return false
}
func (r *family) removeAcceptFilterRules() error {
var merr *multierror.Error
if err := r.removeFilterTableRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.removeExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) removeFilterTableRules() error {
if r.filterTable == nil {
return nil
}
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
if err != nil {
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
return r.removeAcceptRulesFromTable(r.filterTable)
}
if err := r.removeAcceptFilterRulesIptables(ipt); err != nil {
log.Debugf("iptables removal failed (table may be incompatible), falling back to nftables: %v", err)
return r.removeAcceptRulesFromTable(r.filterTable)
}
return nil
}
func (r *family) removeAcceptRulesFromTable(table *nftables.Table) error {
chains, err := r.conn.ListChainsOfTableFamily(table.Family)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != table.Name {
continue
}
if chain.Name != chainNameForward && chain.Name != chainNameInput {
continue
}
if err := r.removeAcceptRulesFromChain(table, chain); err != nil {
return err
}
}
return r.conn.Flush()
}
func (r *family) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
rules, err := r.conn.GetRules(table, chain)
if err != nil {
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
}
}
}
return nil
}
// removeExternalChainsRules removes our accept rules from all external chains.
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
// ensuring cleanup works even after a crash or if chains changed.
func (r *family) removeExternalChainsRules() error {
chains := r.findExternalChains()
if len(chains) == 0 {
return nil
}
for _, chain := range chains {
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err)
}
}
return r.conn.Flush()
}
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
func (r *family) findExternalChains() []*nftables.Chain {
var chains []*nftables.Chain
families := []nftables.TableFamily{r.af.tableFamily, nftables.TableFamilyINet}
for _, family := range families {
allChains, err := r.conn.ListChainsOfTableFamily(family)
if err != nil {
log.Debugf("list chains for family %d: %v", family, err)
continue
}
for _, chain := range allChains {
if r.isExternalChain(chain) {
chains = append(chains, chain)
}
}
}
return chains
}
func (r *family) isExternalChain(chain *nftables.Chain) bool {
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
return false
}
// Skip firewalld-owned chains. Firewalld creates its chains with the
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
// We delegate acceptance to firewalld by trusting the interface instead.
if chain.Table.Name == firewalldTableName {
return false
}
// Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat)
if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) {
return false
}
if chain.Type != nftables.ChainTypeFilter {
return false
}
if chain.Hooknum == nil {
return false
}
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
}
func isIptablesTable(name string) bool {
switch name {
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
return true
}
return false
}
func (r *family) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err))
}
}
inputRule := r.getAcceptInputRule()
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) createDefaultAllowRules() error {
expIn := []expr.Any{
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = r.conn.InsertRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chainInputRules,
Position: 0,
Exprs: expIn,
})
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
// Flush rule/chain/set operations from the buffer
//
// Method also get all rules after flush and refreshes handle values in the rulesets
func (r *family) Flush() error {
if err := r.flushWithBackoff(); err != nil {
return err
}
if err := r.refreshRuleHandles(r.chainInputRules, false); err != nil {
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
}
if err := r.refreshRuleHandles(r.chainPrerouting, true); err != nil {
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
}
return nil
}
func (r *family) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
if r.chainPrerouting == nil {
log.Warn("prerouting chain is not created")
return nil
}
preroutingExprs := slices.Clone(expressions)
// interface
preroutingExprs = append([]expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}, preroutingExprs...)
// local destination and mark
preroutingExprs = append(preroutingExprs,
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
)
nfRule := r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chainPrerouting,
Exprs: preroutingExprs,
UserData: userData,
})
if err := r.conn.Flush(); err != nil {
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
return nil
}
return nfRule
}
func (r *family) createDefaultChains() (err error) {
// chainNameInputRules
chain := r.createChain(chainNameInputRules)
err = r.conn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
return fmt.Errorf(flushError, err)
}
r.chainInputRules = chain
// netbird-acl-input-filter
// type filter hook input priority filter; policy accept;
chain = r.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
r.addJumpRule(chain, r.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
r.addDropExpressions(chain, expr.MetaKeyIIFNAME)
err = r.conn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
return err
}
// netbird-acl-forward-filter
chainFwFilter := r.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
r.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
r.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
err = r.conn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
return fmt.Errorf(flushError, err)
}
if err := r.allowRedirectedTraffic(chainFwFilter); err != nil {
log.Errorf("failed to allow redirected traffic: %s", err)
}
return nil
}
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP.
func (r *family) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
r.chainPrerouting = r.chains[chainNameManglePrerouting]
r.addFwmarkToForward(chainFwFilter)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (r *family) addFwmarkToForward(chainFwFilter *nftables.Chain) {
r.conn.InsertRule(&nftables.Rule{
Table: r.workTable,
Chain: chainFwFilter,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
})
}
func (r *family) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: r.routingFwChainName,
},
}
_ = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: chainFwFilter,
Exprs: expressions,
})
}
func (r *family) createChain(name string) *nftables.Chain {
chain := &nftables.Chain{
Name: name,
Table: r.workTable,
}
chain = r.conn.AddChain(chain)
insertReturnTrafficRule(r.conn, r.workTable, chain)
return chain
}
func (r *family) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{
Name: name,
Table: r.workTable,
Hooknum: hookNum,
Priority: nftables.ChainPriorityFilter,
Type: nftables.ChainTypeFilter,
Policy: &polAccept,
}
return r.conn.AddChain(chain)
}
func (r *family) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Verdict{Kind: expr.VerdictDrop},
}
_ = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: chain,
Exprs: expressions,
})
return nil
}
func (r *family) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: to,
},
}
_ = r.conn.AddRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: expressions,
})
}
func (r *family) flushWithBackoff() (err error) {
backoff := 4
backoffTime := 1000 * time.Millisecond
for i := 0; ; i++ {
err = r.conn.Flush()
if err != nil {
log.Debugf("failed to flush nftables: %v", err)
if !strings.Contains(err.Error(), "busy") {
return
}
log.Error("failed to flush nftables, retrying...")
if i == backoff-1 {
return err
}
time.Sleep(backoffTime)
backoffTime *= 2
continue
}
break
}
return
}
func (r *family) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
if r.workTable == nil || chain == nil {
return nil
}
list, err := r.conn.GetRules(r.workTable, chain)
if err != nil {
return err
}
for _, rule := range list {
if len(rule.UserData) == 0 {
continue
}
pr, ok := r.filters[firewall.RuleID(rule.UserData)]
if !ok {
continue
}
if mangle {
if pr.mangleRule != nil {
*pr.mangleRule = *rule
}
} else {
*pr.nftRule = *rule
}
}
return nil
}

View File

@@ -0,0 +1,533 @@
//go:build !android
package nftables
import (
"fmt"
"net/netip"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/google/nftables/xt"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
func (r *family) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
ruleID := rule.ID()
if _, exists := r.rules[ruleID+dnatSuffix]; exists {
return rule, nil
}
protoNum, err := r.af.protoNum(rule.Protocol)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
if err := r.addDnatRedirect(rule, protoNum, ruleID); err != nil {
return nil, err
}
r.addDnatMasq(rule, protoNum, ruleID)
// Unlike iptables, there's no point in adding "out" rules in the forward chain here as our policy is ACCEPT.
// To overcome DROP policies in other chains, we'd have to add rules to the chains there.
// We also cannot just add "oif <iface> accept" there and filter in our own table as we don't know what is supposed to be allowed.
// TODO: find chains with drop policies and add rules there
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush rules: %w", err)
}
return &rule, nil
}
func (r *family) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, ruleID firewall.RuleID) error {
dnatExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
}
dnatExprs = append(dnatExprs, applyPort(&rule.DestinationPort, false)...)
// shifted translated port is not supported in nftables, so we hand this over to xtables
if rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2 {
if rule.TranslatedPort.Values[0] != rule.DestinationPort.Values[0] ||
rule.TranslatedPort.Values[1] != rule.DestinationPort.Values[1] {
return r.addXTablesRedirect(dnatExprs, ruleID, rule)
}
}
additionalExprs, regProtoMin, regProtoMax, err := r.handleTranslatedPort(rule)
if err != nil {
return err
}
dnatExprs = append(dnatExprs, additionalExprs...)
dnatExprs = append(dnatExprs,
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(r.af.tableFamily),
RegAddrMin: 1,
RegProtoMin: regProtoMin,
RegProtoMax: regProtoMax,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingRdr],
Exprs: dnatExprs,
UserData: []byte(ruleID + dnatSuffix),
}
r.conn.AddRule(dnatRule)
r.rules[ruleID+dnatSuffix] = dnatRule
return nil
}
func (r *family) handleTranslatedPort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
switch {
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
return r.handlePortRange(rule)
case len(rule.TranslatedPort.Values) == 0:
return r.handleAddressOnly(rule)
case len(rule.TranslatedPort.Values) == 1:
return r.handleSinglePort(rule)
default:
return nil, 0, 0, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
}
}
func (r *family) handlePortRange(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
},
&expr.Immediate{
Register: 3,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[1]),
},
}
return exprs, 2, 3, nil
}
func (r *family) handleAddressOnly(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
}
return exprs, 0, 0, nil
}
func (r *family) handleSinglePort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
},
}
return exprs, 2, 0, nil
}
func (r *family) addXTablesRedirect(dnatExprs []expr.Any, ruleID firewall.RuleID, rule firewall.ForwardRule) error {
dnatExprs = append(dnatExprs,
&expr.Counter{},
&expr.Target{
Name: "DNAT",
Rev: 2,
Info: &xt.NatRange2{
NatRange: xt.NatRange{
Flags: uint(xt.NatRangeMapIPs | xt.NatRangeProtoSpecified | xt.NatRangeProtoOffset),
MinIP: rule.TranslatedAddress.AsSlice(),
MaxIP: rule.TranslatedAddress.AsSlice(),
MinPort: rule.TranslatedPort.Values[0],
MaxPort: rule.TranslatedPort.Values[1],
},
BasePort: rule.DestinationPort.Values[0],
},
},
)
natTable := &nftables.Table{
Name: tableNat,
Family: r.af.tableFamily,
}
dnatRule := &nftables.Rule{
Table: natTable,
Chain: &nftables.Chain{
Name: chainNameNatPrerouting,
Table: natTable,
Type: nftables.ChainTypeNAT,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
},
Exprs: dnatExprs,
UserData: []byte(ruleID + dnatSuffix),
}
r.conn.AddRule(dnatRule)
r.rules[ruleID+dnatSuffix] = dnatRule
return nil
}
func (r *family) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleID firewall.RuleID) {
masqExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: r.af.dstAddrOffset,
Len: r.af.addrLen,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
}
masqExprs = append(masqExprs, applyPort(&rule.TranslatedPort, false)...)
masqExprs = append(masqExprs, &expr.Masq{})
masqRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: masqExprs,
UserData: []byte(ruleID + snatSuffix),
}
r.conn.AddRule(masqRule)
r.rules[ruleID+snatSuffix] = masqRule
}
func (r *family) DeleteDNATRule(rule firewall.Rule) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
ruleID := rule.ID()
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
var needsFlush bool
if dnatRule, exists := r.rules[ruleID+dnatSuffix]; exists {
if dnatRule.Handle == 0 {
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleID+dnatSuffix)
delete(r.rules, ruleID+dnatSuffix)
} else if err := r.conn.DelRule(dnatRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
} else {
needsFlush = true
}
}
if masqRule, exists := r.rules[ruleID+snatSuffix]; exists {
if masqRule.Handle == 0 {
log.Warnf("snat rule %s has no handle, removing stale entry", ruleID+snatSuffix)
delete(r.rules, ruleID+snatSuffix)
} else if err := r.conn.DelRule(masqRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
} else {
needsFlush = true
}
}
if needsFlush {
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
}
if merr == nil {
delete(r.rules, ruleID+dnatSuffix)
delete(r.rules, ruleID+snatSuffix)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if _, exists := r.rules[ruleID]; exists {
return nil
}
protoNum, err := r.af.protoNum(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 3,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 3,
Data: binaryutil.BigEndian.PutUint16(originalPort),
},
}
bits := 32
if localAddr.Is6() {
bits = 128
}
exprs = append(exprs, prefixMatchExprs(r.af, netip.PrefixFrom(localAddr, bits), false)...)
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: localAddr.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(translatedPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(r.af.tableFamily),
RegAddrMin: 1,
RegProtoMin: 2,
RegProtoMax: 0,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingRdr],
Exprs: exprs,
UserData: []byte(ruleID),
}
r.conn.AddRule(dnatRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *family) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}
// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use.
func (r *family) ensureNATOutputChain() error {
if _, exists := r.chains[chainNameNATOutput]; exists {
return nil
}
r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{
Name: chainNameNATOutput,
Table: r.workTable,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
})
if err := r.conn.Flush(); err != nil {
delete(r.chains, chainNameNATOutput)
return fmt.Errorf("create NAT output chain: %w", err)
}
return nil
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *family) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if _, exists := r.rules[ruleID]; exists {
return nil
}
if err := r.ensureNATOutputChain(); err != nil {
return err
}
protoNum, err := r.af.protoNum(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: binaryutil.BigEndian.PutUint16(originalPort),
},
}
bits := 32
if localAddr.Is6() {
bits = 128
}
exprs = append(exprs, prefixMatchExprs(r.af, netip.PrefixFrom(localAddr, bits), false)...)
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: localAddr.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(translatedPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(r.af.tableFamily),
RegAddrMin: 1,
RegProtoMin: 2,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameNATOutput],
Exprs: exprs,
UserData: []byte(ruleID),
}
r.conn.AddRule(dnatRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("add output DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
return nil
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *family) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID)
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete output DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}

View File

@@ -0,0 +1,249 @@
//go:build !android
package nftables
import (
"fmt"
"net/netip"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
)
const (
tableNat = "nat"
tableMangle = "mangle"
tableRaw = "raw"
tableSecurity = "security"
chainNameNatPrerouting = "PREROUTING"
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameRoutingRdr = "netbird-rt-redirect"
chainNameNATOutput = "netbird-nat-output"
chainNameForward = "FORWARD"
chainNameMangleForward = "netbird-mangle-forward"
// Peer ACL chain names.
chainNameInputRules = "netbird-acl-input-rules"
chainNameInputFilter = "netbird-acl-input-filter"
chainNameForwardFilter = "netbird-acl-forward-filter"
chainNameManglePrerouting = "netbird-mangle-prerouting"
chainNameManglePostrouting = "netbird-mangle-postrouting"
flushError = "flush: %w"
firewalldTableName = "firewalld"
userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif"
userDataAcceptInputRule = "inputaccept"
dnatSuffix firewall.RuleID = "_dnat"
snatSuffix firewall.RuleID = "_snat"
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
ipv4TCPHeaderSize = 40
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
ipv6TCPHeaderSize = 60
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
maxPrefixesSet = 1500
refreshRulesMapError = "refresh rules map: %w"
)
var (
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
)
type setInput struct {
set firewall.Set
prefixes []netip.Prefix
}
// family holds the per-address-family nftables state. One instance
// handles route ACLs, peer ACLs, NAT, DNAT, and MSS clamping for a
// single family; the top-level Manager owns one for v4 and another
// for v6. The name predates the peer-ACL absorption; it's effectively
// the per-family backend now.
type family struct {
conn *nftables.Conn
workTable *nftables.Table
filterTable *nftables.Table
chains map[string]*nftables.Chain
// filters holds peer + route filter rules keyed by content hash.
// AddFilterRule writes here; DeleteFilterRule looks up by id.
filters map[firewall.RuleID]*Rule
// rules holds NAT, DNAT, and external accept rules (auxiliary
// plumbing that isn't a filter rule).
rules map[firewall.RuleID]*nftables.Rule
// Peer ACL chain handles.
chainInputRules *nftables.Chain
chainPrerouting *nftables.Chain
routingFwChainName string
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
af addrFamily
wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool
mtu uint16
}
func newFamily(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*family, error) {
r := &family{
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
filters: make(map[firewall.RuleID]*Rule),
rules: make(map[firewall.RuleID]*nftables.Rule),
routingFwChainName: chainNameRoutingFw,
af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4),
wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(),
mtu: mtu,
}
r.ipsetCounter = refcounter.New(
r.createIpSet,
r.deleteIpSet,
)
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
log.Debugf("ip filter table not found: %v", err)
}
return r, nil
}
func (r *family) init(workTable *nftables.Table) error {
r.workTable = workTable
if err := r.removeAcceptFilterRules(); err != nil {
log.Errorf("failed to clean up rules from filter table: %s", err)
}
if err := r.createContainers(); err != nil {
return fmt.Errorf("create containers: %w", err)
}
if err := r.setupDataPlaneMark(); err != nil {
log.Errorf("failed to set up data plane mark: %v", err)
}
if err := r.createDefaultChains(); err != nil {
return fmt.Errorf("create default acl chains: %w", err)
}
return nil
}
// Reset cleans existing nftables filter table rules from the system
func (r *family) Reset() error {
// clear without deleting the ipsets, the nf table will be deleted by the caller
r.ipsetCounter.Clear()
var merr *multierror.Error
if err := r.removeAcceptFilterRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
}
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.removeNatPreroutingRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(r.af.tableFamily)
if err != nil {
return nil, fmt.Errorf("list tables: %w", err)
}
for _, table := range tables {
if table.Name == "filter" {
return table, nil
}
}
return nil, errFilterTableNotFound
}
func hookName(hook *nftables.ChainHook) string {
if hook == nil {
return "unknown"
}
switch *hook {
case *nftables.ChainHookForward:
return chainNameForward
case *nftables.ChainHookInput:
return chainNameInput
default:
return fmt.Sprintf("hook(%d)", *hook)
}
}
func familyName(family nftables.TableFamily) string {
switch family {
case nftables.TableFamilyIPv4:
return "ip"
case nftables.TableFamilyIPv6:
return "ip6"
case nftables.TableFamilyINet:
return "inet"
default:
return fmt.Sprintf("family(%d)", family)
}
}
func (r *family) iptablesProto() iptables.Protocol {
if r.af.tableFamily == nftables.TableFamilyIPv6 {
return iptables.ProtocolIPv6
}
return iptables.ProtocolIPv4
}
func (r *family) refreshRulesMap() error {
var merr *multierror.Error
newRules := make(map[firewall.RuleID]*nftables.Rule)
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
// preserve existing entries for this chain since we can't verify their state
for k, v := range r.rules {
if v.Chain != nil && v.Chain.Name == chain.Name {
newRules[k] = v
}
}
continue
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
newRules[firewall.RuleID(rule.UserData)] = rule
}
}
}
r.rules = newRules
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -0,0 +1,441 @@
//go:build !android
package nftables
import (
"fmt"
"net"
"net/netip"
"slices"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
)
// AddFilterRule installs one nftables packet-filter rule. With
// destination empty the rule goes to the peer ACL input chain plus a
// paired prerouting mangle rule for the redirect mark. With
// destination set (prefix or named set) it goes to the route ACL
// forward chain. Multi-source rules collapse to one nftables rule
// backed by the shared refcounted hash:net set.
func (r *family) AddFilterRule(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
isRoute := !destination.IsZero()
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
if existing, ok := r.filters[ruleID]; ok {
return existing, nil
}
srcExprs, err := r.applyNetwork(sourceNetwork(sources), sources, true)
if err != nil {
return nil, fmt.Errorf("apply source: %w", err)
}
var exprs []expr.Any
if isRoute {
exprs, err = r.buildRouteFilterExprs(srcExprs, destination, proto, sPort, dPort)
} else {
exprs, err = r.buildPeerFilterExprs(srcExprs, proto, sPort, dPort)
}
if err != nil {
r.dropNetworkMatch(srcExprs)
return nil, err
}
mainExprs := slices.Clone(exprs)
verdict := expr.VerdictAccept
if action == firewall.ActionDrop {
verdict = expr.VerdictDrop
}
mainExprs = append(mainExprs, &expr.Verdict{Kind: verdict})
chain := r.chainInputRules
if isRoute {
chain = r.chains[chainNameRoutingFw]
}
userData := []byte(ruleID)
nftRule := &nftables.Rule{
Table: r.workTable,
Chain: chain,
Exprs: mainExprs,
UserData: userData,
}
if action == firewall.ActionDrop {
nftRule = r.conn.InsertRule(nftRule)
} else {
nftRule = r.conn.AddRule(nftRule)
}
if err := r.conn.Flush(); err != nil {
r.dropNetworkMatch(exprs)
return nil, fmt.Errorf(flushError, err)
}
rule := &Rule{
nftRule: nftRule,
sources: sources,
id: ruleID,
}
if !isRoute {
rule.mangleRule = r.createPreroutingRule(exprs, userData)
}
r.filters[ruleID] = rule
log.Debugf("added filter rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v",
sources, destination, proto, sPort, dPort, action)
return rule, nil
}
// buildPeerFilterExprs assembles the input-chain (peer ACL) match: the
// IP-header protocol byte read via Payload, then source, then ports
// (no counter), matching the historical peer shape so per-rule kernel
// state is identical to pre-unification.
func (r *family) buildPeerFilterExprs(
srcExprs []expr.Any,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
) ([]expr.Any, error) {
var exprs []expr.Any
if proto != firewall.ProtocolALL {
protoNum, err := r.af.protoNum(proto)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: r.af.protoOffset,
Len: 1,
},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protoNum}},
)
}
exprs = append(exprs, srcExprs...)
exprs = append(exprs, applyPort(sPort, true)...)
exprs = append(exprs, applyPort(dPort, false)...)
return exprs, nil
}
// buildRouteFilterExprs assembles the forward-chain (route ACL) match:
// source, then destination, then optional proto/ports, then a counter.
func (r *family) buildRouteFilterExprs(
srcExprs []expr.Any,
destination firewall.Network,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
) ([]expr.Any, error) {
exprs := append([]expr.Any{}, srcExprs...)
destExprs, err := r.applyNetwork(destination, nil, false)
if err != nil {
return nil, fmt.Errorf("apply destination: %w", err)
}
exprs = append(exprs, destExprs...)
if proto != firewall.ProtocolALL {
protoNum, err := r.af.protoNum(proto)
if err != nil {
r.dropNetworkMatch(destExprs)
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protoNum}},
)
exprs = append(exprs, applyPort(sPort, true)...)
exprs = append(exprs, applyPort(dPort, false)...)
}
exprs = append(exprs, &expr.Counter{})
return exprs, nil
}
func (r *family) hasRule(id firewall.RuleID) bool {
_, ok := r.filters[id]
return ok
}
func (r *family) hasDNATRule(id firewall.RuleID) bool {
_, ok := r.rules[id+dnatSuffix]
return ok
}
// DeleteFilterRule removes a previously installed filter rule. Source
// set references are recovered from the stored rule's expressions via
// findSets and dropped from the shared refcounter.
func (r *family) DeleteFilterRule(rule firewall.Rule) error {
ruleID := rule.ID()
pr, ok := r.filters[ruleID]
if !ok {
log.Debugf("filter rule %s not found", ruleID)
return nil
}
// A freshly added rule carries no handle until it is read back from
// the kernel, and Flush only refreshes the peer chains. Pull live
// handles for this rule's chain before deciding it is stale so route
// rules (which Flush never refreshes) can actually be deleted.
if pr.nftRule.Handle == 0 {
if err := r.refreshRuleHandles(pr.nftRule.Chain, false); err != nil {
log.Warnf("refresh handles for chain %s: %v", pr.nftRule.Chain.Name, err)
}
if pr.mangleRule != nil {
if err := r.refreshRuleHandles(r.chainPrerouting, true); err != nil {
log.Warnf("refresh mangle handles: %v", err)
}
}
}
if pr.nftRule.Handle == 0 {
log.Warnf("filter rule %s has no handle, removing stale entry", ruleID)
r.dropNetworkMatch(pr.nftRule.Exprs)
delete(r.filters, ruleID)
return nil
}
if err := r.conn.DelRule(pr.nftRule); err != nil {
log.Errorf("queue rule delete: %v", err)
}
if pr.mangleRule != nil {
if err := r.conn.DelRule(pr.mangleRule); err != nil {
log.Errorf("queue mangle rule delete: %v", err)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete %s: %w", ruleID, err)
}
r.dropNetworkMatch(pr.nftRule.Exprs)
delete(r.filters, ruleID)
return nil
}
func (r *family) decrementSetCounter(rule *nftables.Rule) error {
if r.ipsetCounter == nil {
return nil
}
sets := findSets(rule)
var merr *multierror.Error
for _, setName := range sets {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// dropNetworkMatch undoes whatever the source/destination match
// reserved. Safe to call when the spec is empty or holds only inline
// matchers.
func (r *family) dropNetworkMatch(exprs []expr.Any) {
if r.ipsetCounter == nil {
return
}
for _, e := range exprs {
lookup, ok := e.(*expr.Lookup)
if !ok {
continue
}
if _, err := r.ipsetCounter.Decrement(lookup.SetName); err != nil {
log.Errorf("rollback ipset decrement %s: %v", lookup.SetName, err)
}
}
}
func (r *family) applyNetwork(
network firewall.Network,
setPrefixes []netip.Prefix,
isSource bool,
) ([]expr.Any, error) {
if network.IsSet() {
exprs, err := r.getIpSet(network.Set, setPrefixes, isSource)
if err != nil {
side := "destination"
if isSource {
side = "source"
}
return nil, fmt.Errorf("%s set: %w", side, err)
}
return exprs, nil
}
if network.IsPrefix() {
return prefixMatchExprs(r.af, network.Prefix, isSource), nil
}
return nil, nil
}
// prefixMatchExprs is the family-aware match sequence for a CIDR
// prefix. /0 returns nil; a host prefix (full bit length for the
// family) skips the bitwise step since the mask is all-ones. Shared
// between family and aclManager so both treat single prefixes
// identically.
func prefixMatchExprs(af addrFamily, prefix netip.Prefix, isSource bool) []expr.Any {
offset := af.dstAddrOffset
if isSource {
offset = af.srcAddrOffset
}
ones := prefix.Bits()
if ones == 0 {
return nil
}
payload := &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: af.addrLen,
}
cmp := &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: prefix.Masked().Addr().AsSlice(),
}
if ones == af.totalBits {
return []expr.Any{payload, cmp}
}
mask := net.CIDRMask(ones, af.totalBits)
xor := make([]byte, af.addrLen)
return []expr.Any{
payload,
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: af.addrLen,
Mask: mask,
Xor: xor,
},
cmp,
}
}
func applyPort(port *firewall.Port, isSource bool) []expr.Any {
if port == nil {
return nil
}
var exprs []expr.Any
// src
offset := uint32(2)
if isSource {
// dst
offset = 0
}
exprs = append(exprs, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: offset,
Len: 2,
})
if port.IsRange && len(port.Values) == 2 {
exprs = append(exprs,
&expr.Range{
Op: expr.CmpOpEq,
Register: 1,
FromData: binaryutil.BigEndian.PutUint16(port.Values[0]),
ToData: binaryutil.BigEndian.PutUint16(port.Values[1]),
},
)
} else {
for i, p := range port.Values {
if i > 0 {
exprs = append(exprs, &expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: []byte{0x00, 0x00, 0xff, 0xff},
Xor: []byte{0x00, 0x00, 0x00, 0x00},
})
}
exprs = append(exprs, &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(p),
})
}
}
return exprs
}
func getCtNewExprs() []expr.Any {
return []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
}
}
// sourceNetwork classifies a source-prefix list into the firewall.Network
// shape the rest of the spec-builder consumes: empty for match-any, a
// single prefix inline, or an ipset for multiple sources.
func sourceNetwork(sources []netip.Prefix) firewall.Network {
switch {
case len(sources) == 0:
return firewall.Network{}
case len(sources) == 1 && sources[0].Bits() == 0:
return firewall.Network{}
case len(sources) == 1:
return firewall.Network{Prefix: sources[0]}
default:
return firewall.Network{Set: firewall.NewPrefixSet(sources)}
}
}
func ifname(n string) []byte {
b := make([]byte, 16)
copy(b, n+"\x00")
return b
}
// findSets scans an nftables rule's expressions for expr.Lookup and
// returns the named sets in occurrence order. Used at delete time to
// drop ipsetCounter references; peer and route ACLs go through it.
func findSets(rule *nftables.Rule) []string {
var sets []string
for _, e := range rule.Exprs {
if lookup, ok := e.(*expr.Lookup); ok {
sets = append(sets, lookup.SetName)
}
}
return sets
}

View File

@@ -0,0 +1,196 @@
//go:build !android
package nftables
import (
"encoding/binary"
"fmt"
"net/netip"
"github.com/google/nftables"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
)
func (r *family) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) {
ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{
set: set,
prefixes: prefixes,
})
if err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
return r.getIpSetExprs(ref, isSource)
}
func (r *family) createIpSet(setName string, input setInput) (*nftables.Set, error) {
// overlapping prefixes will result in an error, so we need to merge them
prefixes := firewall.MergeIPRanges(input.prefixes)
nfset := &nftables.Set{
Name: setName,
Comment: input.set.Comment(),
Table: r.workTable,
// required for prefixes
Interval: true,
KeyType: r.af.setKeyType,
}
elements := r.convertPrefixesToSet(prefixes)
nElements := len(elements)
maxElements := maxPrefixesSet * 2
initialElements := elements[:min(maxElements, nElements)]
if err := r.conn.AddSet(nfset, initialElements); err != nil {
return nil, fmt.Errorf("error adding set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Debugf("Created new ipset: %s with %d initial prefixes (total prefixes %d)", setName, len(initialElements)/2, len(prefixes))
// The set is committed now. If a later batch fails, destroy it: the
// refcounter records nothing on a create-callback error, so it would
// otherwise leak, and a partial source set fails-open for deny rules.
if err := r.addRemainingElements(nfset, elements, maxElements); err != nil {
if derr := r.deleteIpSet(setName, nfset); derr != nil {
log.Warnf("rollback ipset %s after add failure: %v", setName, derr)
}
return nil, err
}
log.Infof("Created new ipset: %s with %d prefixes", setName, len(prefixes))
return nfset, nil
}
// addRemainingElements adds element batches beyond the initial one in
// maxElements-sized chunks, flushing each. Called after the set has been
// created with its first batch.
func (r *family) addRemainingElements(nfset *nftables.Set, elements []nftables.SetElement, maxElements int) error {
nElements := len(elements)
for subStart := maxElements; subStart < nElements; subStart += maxElements {
subEnd := min(subStart+maxElements, nElements)
subElement := elements[subStart:subEnd]
nSubPrefixes := len(subElement) / 2
log.Tracef("Adding new prefixes (%d) in ipset: %s", nSubPrefixes, nfset.Name)
if err := r.conn.SetAddElements(nfset, subElement); err != nil {
return fmt.Errorf("error adding prefixes (%d) to set %s: %w", nSubPrefixes, nfset.Name, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush error: %w", err)
}
log.Debugf("Added new prefixes (%d) in ipset: %s", nSubPrefixes, nfset.Name)
}
return nil
}
func (r *family) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
var elements []nftables.SetElement
for _, prefix := range prefixes {
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
firstIP := prefix.Addr()
lastIP := calculateLastIP(prefix).Next()
elements = append(elements,
// the nft tool also adds a zero-address IntervalEnd element, see https://github.com/google/nftables/issues/247
// nftables.SetElement{Key: make([]byte, r.af.addrLen), IntervalEnd: true},
nftables.SetElement{Key: firstIP.AsSlice()},
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
)
}
return elements
}
// calculateLastIP determines the last IP in a given prefix.
func calculateLastIP(prefix netip.Prefix) netip.Addr {
masked := prefix.Masked()
if masked.Addr().Is4() {
hostMask := ^uint32(0) >> masked.Bits()
lastIP := uint32FromNetipAddr(masked.Addr()) | hostMask
return netip.AddrFrom4(uint32ToBytes(lastIP))
}
// IPv6: set host bits to all 1s
b := masked.Addr().As16()
bits := masked.Bits()
for i := bits; i < 128; i++ {
b[i/8] |= 1 << (7 - i%8)
}
return netip.AddrFrom16(b)
}
// Utility function to convert netip.Addr to uint32.
func uint32FromNetipAddr(addr netip.Addr) uint32 {
b := addr.As4()
return binary.BigEndian.Uint32(b[:])
}
// Utility function to convert uint32 to a netip-compatible byte slice.
func uint32ToBytes(ip uint32) [4]byte {
var b [4]byte
binary.BigEndian.PutUint32(b[:], ip)
return b
}
func (r *family) deleteIpSet(setName string, nfset *nftables.Set) error {
r.conn.DelSet(nfset)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
log.Debugf("Deleted unused ipset %s", setName)
return nil
}
func (r *family) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName())
if err != nil {
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
}
// Overlapping prefixes (e.g. duplicate resolved addresses) make the
// interval set reject the batch, so merge them as createIpSet does.
prefixes = firewall.MergeIPRanges(prefixes)
elements := r.convertPrefixesToSet(prefixes)
if err := r.conn.SetAddElements(nfset, elements); err != nil {
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
return nil
}
func (r *family) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
// dst offset by default
offset := r.af.dstAddrOffset
if isSource {
// src offset
offset = r.af.srcAddrOffset
}
return []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: r.af.addrLen,
},
&expr.Lookup{
SourceRegister: 1,
SetName: ref.Out.Name,
SetID: ref.Out.ID,
},
}, nil
}

View File

@@ -1,85 +0,0 @@
package nftables
import (
"net"
)
type ipsetStore struct {
ipsetReference map[string]int
ipsets map[string]map[string]struct{} // ipsetName -> list of ips
}
func newIpsetStore() *ipsetStore {
return &ipsetStore{
ipsetReference: make(map[string]int),
ipsets: make(map[string]map[string]struct{}),
}
}
func (s *ipsetStore) ips(ipsetName string) (map[string]struct{}, bool) {
r, ok := s.ipsets[ipsetName]
return r, ok
}
func (s *ipsetStore) newIpset(ipsetName string) map[string]struct{} {
s.ipsetReference[ipsetName] = 0
ipList := make(map[string]struct{})
s.ipsets[ipsetName] = ipList
return ipList
}
func (s *ipsetStore) deleteIpset(ipsetName string) {
delete(s.ipsetReference, ipsetName)
delete(s.ipsets, ipsetName)
}
func (s *ipsetStore) DeleteIpFromSet(ipsetName string, ip net.IP) {
ipList, ok := s.ipsets[ipsetName]
if !ok {
return
}
delete(ipList, ip.String())
}
func (s *ipsetStore) AddIpToSet(ipsetName string, ip net.IP) {
ipList, ok := s.ipsets[ipsetName]
if !ok {
return
}
ipList[ip.String()] = struct{}{}
}
func (s *ipsetStore) IsIpInSet(ipsetName string, ip net.IP) bool {
ipList, ok := s.ipsets[ipsetName]
if !ok {
return false
}
_, ok = ipList[ip.String()]
return ok
}
func (s *ipsetStore) AddReferenceToIpset(ipsetName string) {
s.ipsetReference[ipsetName]++
}
func (s *ipsetStore) DeleteReferenceFromIpSet(ipsetName string) {
r, ok := s.ipsetReference[ipsetName]
if !ok {
return
}
if r == 0 {
return
}
s.ipsetReference[ipsetName]--
}
func (s *ipsetStore) HasReferenceToSet(ipsetName string) bool {
if _, ok := s.ipsetReference[ipsetName]; !ok {
return false
}
if s.ipsetReference[ipsetName] == 0 {
return false
}
return true
}

View File

@@ -3,7 +3,6 @@ package nftables
import (
"context"
"fmt"
"net"
"net/netip"
"os"
"sync"
@@ -45,18 +44,17 @@ type iFaceMapper interface {
Address() wgaddr.Address
}
// Manager of iptables firewall
// Manager of nftables firewall. Per-family state (peer ACLs, route
// ACLs, NAT, DNAT, MSS clamping) lives on family; Manager dispatches
// by family and provides the public firewall.Manager surface.
type Manager struct {
mutex sync.Mutex
rConn *nftables.Conn
wgIface iFaceMapper
router *router
aclManager *AclManager
// IPv6 counterparts, nil when no v6 overlay
router6 *router
aclManager6 *AclManager
family4 *family
// IPv6 counterpart, nil when no v6 overlay.
family6 *family
notrackOutputChain *nftables.Chain
notrackPreroutingChain *nftables.Chain
@@ -75,14 +73,9 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}
var err error
m.router, err = newRouter(workTable, wgIface, mtu)
m.family4, err = newFamily(workTable, wgIface, mtu)
if err != nil {
return nil, fmt.Errorf("create router: %w", err)
}
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
if err != nil {
return nil, fmt.Errorf("create acl manager: %w", err)
return nil, fmt.Errorf("create family: %w", err)
}
if wgIface.Address().HasIPv6() {
@@ -100,26 +93,21 @@ func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mt
workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}
var err error
m.router6, err = newRouter(workTable6, wgIface, mtu)
m.family6, err = newFamily(workTable6, wgIface, mtu)
if err != nil {
return fmt.Errorf("create v6 router: %w", err)
return fmt.Errorf("create v6 family: %w", err)
}
// Share the same IP forwarding state with the v4 router, since
// EnableIPForwarding controls both v4 and v6 sysctls.
m.router6.ipFwdState = m.router.ipFwdState
m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw)
if err != nil {
return fmt.Errorf("create v6 acl manager: %w", err)
}
m.family6.ipFwdState = m.family4.ipFwdState
return nil
}
// hasIPv6 reports whether the manager has IPv6 components initialized.
func (m *Manager) hasIPv6() bool {
return m.router6 != nil
return m.family6 != nil
}
func (m *Manager) initIPv6() error {
@@ -128,12 +116,8 @@ func (m *Manager) initIPv6() error {
return fmt.Errorf("create v6 work table: %w", err)
}
if err := m.router6.init(workTable6); err != nil {
return fmt.Errorf("v6 router init: %w", err)
}
if err := m.aclManager6.init(workTable6); err != nil {
return fmt.Errorf("v6 acl manager init: %w", err)
if err := m.family6.init(workTable6); err != nil {
return fmt.Errorf("v6 family init: %w", err)
}
return nil
@@ -162,13 +146,13 @@ func (m *Manager) reconcileExternalChains() error {
defer m.mutex.Unlock()
var merr *multierror.Error
if m.router != nil {
if err := m.router.acceptExternalChainsRules(); err != nil {
if m.family4 != nil {
if err := m.family4.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("v4: %w", err))
}
}
if m.hasIPv6() {
if err := m.router6.acceptExternalChainsRules(); err != nil {
if err := m.family6.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("v6: %w", err))
}
}
@@ -187,12 +171,8 @@ func (m *Manager) initFirewall() (err error) {
}
}()
if err := m.router.init(workTable); err != nil {
return fmt.Errorf("router init: %w", err)
}
if err := m.aclManager.init(workTable); err != nil {
return fmt.Errorf("acl manager init: %w", err)
if err := m.family4.init(workTable); err != nil {
return fmt.Errorf("family init: %w", err)
}
if m.hasIPv6() {
@@ -220,7 +200,7 @@ func (m *Manager) persistState(stateManager *statemanager.Manager) {
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
MTU: m.router.mtu,
MTU: m.family4.mtu,
},
}); err != nil {
log.Errorf("failed to update state: %v", err)
@@ -235,12 +215,12 @@ func (m *Manager) persistState(stateManager *statemanager.Manager) {
// rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through.
func (m *Manager) rollbackInit() {
if err := m.router.Reset(); err != nil {
log.Warnf("rollback router: %v", err)
if err := m.family4.Reset(); err != nil {
log.Warnf("rollback family: %v", err)
}
if m.hasIPv6() {
if err := m.router6.Reset(); err != nil {
log.Warnf("rollback v6 router: %v", err)
if err := m.family6.Reset(); err != nil {
log.Warnf("rollback v6 family: %v", err)
}
}
if err := m.cleanupNetbirdTables(); err != nil {
@@ -251,118 +231,77 @@ func (m *Manager) rollbackInit() {
}
}
// AddPeerFiltering rule to the firewall
// AddFilterRule installs a packet-filtering rule.
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *Manager) AddPeerFiltering(
id []byte,
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
if ip.To4() != nil {
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
if !m.hasIPv6() {
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
}
return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
func (m *Manager) AddRouteFiltering(
// Destination semantics: zero Network → input chain (peer ACL);
// set Network → forward chain (route ACL).
//
// Sources are a single address family; the rule is dispatched to the
// matching per-family backend.
func (m *Manager) AddFilterRule(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
if len(sources) == 0 {
return nil, firewall.ErrNoSources
}
m.mutex.Lock()
defer m.mutex.Unlock()
if isIPv6RouteRule(sources, destination) {
fam := m.family4
if isIPv6Rule(sources, destination) {
if !m.hasIPv6() {
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
return nil, fmt.Errorf("add filtering: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
fam = m.family6
}
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
return fam.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
// DeleteFilterRule removes a filtering rule. The owning family is found
// by id, refreshing from the kernel if the in-memory caches miss so a
// stale cache cannot leak the rule. family.DeleteFilterRule is idempotent
// when the id is absent.
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && isIPv6Rule(rule) {
return m.aclManager6.DeletePeerRule(rule)
}
return m.aclManager.DeletePeerRule(rule)
}
func isIPv6Rule(rule firewall.Rule) bool {
r, ok := rule.(*Rule)
return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6
}
// isIPv6RouteRule determines whether a route rule belongs to the v6 table.
// For static routes, the destination prefix determines the family. For dynamic
// routes (DomainSet), the sources determine the family since management
// duplicates dynamic rules per family.
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
return len(sources) > 0 && sources[0].Addr().Is6()
}
// DeleteRouteRule deletes a routing rule. Route rules live in exactly one
// router; the cached maps are normally authoritative, so the kernel is only
// consulted when neither map knows about the rule.
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
id := rule.ID()
r, err := m.routerForRuleID(id, (*router).hasRule)
fam, err := m.familyForRuleID(rule.ID(), (*family).hasRule)
if err != nil {
return err
}
return r.DeleteRouteRule(rule)
return fam.DeleteFilterRule(rule)
}
// routerForRuleID picks the router holding the rule with the given id, using
// familyForRuleID picks the family holding the rule with the given id, using
// the supplied lookup. If the cached maps disagree (or both miss), it refreshes
// from the kernel once and re-checks before falling back to the v4 router.
func (m *Manager) routerForRuleID(id string, has func(*router, string) bool) (*router, error) {
if has(m.router, id) {
return m.router, nil
// from the kernel once and re-checks before falling back to the v4 family.
func (m *Manager) familyForRuleID(id firewall.RuleID, has func(*family, firewall.RuleID) bool) (*family, error) {
if has(m.family4, id) {
return m.family4, nil
}
if m.hasIPv6() && has(m.router6, id) {
return m.router6, nil
if m.hasIPv6() && has(m.family6, id) {
return m.family6, nil
}
if !m.hasIPv6() {
return m.router, nil
return m.family4, nil
}
if err := m.router.refreshRulesMap(); err != nil {
if err := m.family4.refreshRulesMap(); err != nil {
return nil, fmt.Errorf("refresh v4 rules: %w", err)
}
if err := m.router6.refreshRulesMap(); err != nil {
if err := m.family6.refreshRulesMap(); err != nil {
return nil, fmt.Errorf("refresh v6 rules: %w", err)
}
if has(m.router6, id) && !has(m.router, id) {
return m.router6, nil
if has(m.family6, id) && !has(m.family4, id) {
return m.family6, nil
}
return m.router, nil
return m.family4, nil
}
func (m *Manager) IsServerRouteSupported() bool {
@@ -381,10 +320,10 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if !m.hasIPv6() {
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddNatRule(pair)
return m.family6.AddNatRule(pair)
}
if err := m.router.AddNatRule(pair); err != nil {
if err := m.family4.AddNatRule(pair); err != nil {
return err
}
@@ -396,7 +335,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
// so the eventual cleanup still works.
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.router6.AddNatRule(v6Pair); err != nil {
if err := m.family6.AddNatRule(v6Pair); err != nil {
return fmt.Errorf("add v6 NAT rule: %w", err)
}
}
@@ -412,18 +351,18 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
if !m.hasIPv6() {
return nil
}
return m.router6.RemoveNatRule(pair)
return m.family6.RemoveNatRule(pair)
}
var merr *multierror.Error
if err := m.router.RemoveNatRule(pair); err != nil {
if err := m.family4.RemoveNatRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
}
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
if err := m.family6.RemoveNatRule(v6Pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
}
}
@@ -445,11 +384,11 @@ func (m *Manager) AllowNetbird() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.aclManager.createDefaultAllowRules(); err != nil {
if err := m.family4.createDefaultAllowRules(); err != nil {
return fmt.Errorf("create default allow rules: %w", err)
}
if m.hasIPv6() {
if err := m.aclManager6.createDefaultAllowRules(); err != nil {
if err := m.family6.createDefaultAllowRules(); err != nil {
return fmt.Errorf("create v6 default allow rules: %w", err)
}
}
@@ -466,11 +405,11 @@ func (m *Manager) AllowNetbird() error {
// SetLegacyManagement sets the route manager to use legacy management
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
if err := firewall.SetLegacyManagement(m.family4, isLegacy); err != nil {
return err
}
if m.hasIPv6() {
return firewall.SetLegacyManagement(m.router6, isLegacy)
return firewall.SetLegacyManagement(m.family6, isLegacy)
}
return nil
}
@@ -484,13 +423,13 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
var merr *multierror.Error
if err := m.router.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset router: %v", err))
if err := m.family4.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset family: %w", err))
}
if m.hasIPv6() {
if err := m.router6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %v", err))
if err := m.family6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 family: %w", err))
}
}
@@ -530,14 +469,14 @@ func (m *Manager) SetLogLevel(log.Level) {
}
func (m *Manager) EnableRouting() error {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
if err := m.family4.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil
}
func (m *Manager) DisableRouting() error {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
if err := m.family4.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil
@@ -551,12 +490,12 @@ func (m *Manager) Flush() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.aclManager.Flush(); err != nil {
if err := m.family4.Flush(); err != nil {
return err
}
if m.hasIPv6() {
if err := m.aclManager6.Flush(); err != nil {
if err := m.family6.Flush(); err != nil {
return fmt.Errorf("flush v6 acl: %w", err)
}
}
@@ -577,9 +516,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
if !m.hasIPv6() {
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddDNATRule(rule)
return m.family6.AddDNATRule(rule)
}
return m.router.AddDNATRule(rule)
return m.family4.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
@@ -587,7 +526,7 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
r, err := m.routerForRuleID(rule.ID(), (*router).hasDNATRule)
r, err := m.familyForRuleID(rule.ID(), (*family).hasDNATRule)
if err != nil {
return err
}
@@ -608,12 +547,12 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
}
}
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
if err := m.family4.UpdateSet(set, v4Prefixes); err != nil {
return err
}
if m.hasIPv6() && len(v6Prefixes) > 0 {
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
if err := m.family6.UpdateSet(set, v6Prefixes); err != nil {
return fmt.Errorf("update v6 set: %w", err)
}
}
@@ -630,9 +569,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
if !m.hasIPv6() {
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family4.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
@@ -644,9 +583,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
if !m.hasIPv6() {
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family4.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
@@ -658,9 +597,9 @@ func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol
if !m.hasIPv6() {
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family4.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
@@ -672,9 +611,9 @@ func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Proto
if !m.hasIPv6() {
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family4.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
const (
@@ -903,3 +842,14 @@ func getEstablishedExprs(register uint32) []expr.Any {
},
}
}
// isIPv6Rule reports whether the rule belongs to the v6 table. For a
// prefix destination the destination family decides; otherwise the
// (single-family) sources do, since management duplicates rules per
// family.
func isIPv6Rule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
return len(sources) > 0 && sources[0].Addr().Is6()
}

View File

@@ -1,3 +1,5 @@
//go:build integration && !android
package nftables
import (
@@ -70,13 +72,13 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
rule, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop)
require.NoError(t, err, "failed to add rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
rules, err := testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
require.NoError(t, err, "failed to get rules")
require.Len(t, rules, 2, "expected 2 rules")
@@ -147,15 +149,12 @@ func TestNftablesManager(t *testing.T) {
// Compare connection tracking rule at position 1 (pushed down by DROP rule insertion)
compareExprsIgnoringCounters(t, rules[1].Exprs, expectedExprs1)
for _, r := range rule {
err = manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
}
require.NoError(t, manager.DeleteFilterRule(rule), "failed to delete rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
rules, err = testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
require.NoError(t, err, "failed to get rules")
// established rule remains
require.Len(t, rules, 1, "expected 1 rules after deletion")
@@ -180,47 +179,39 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
testClient := &nftables.Conn{}
// Add accept rule first
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "accept-http")
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err, "failed to add accept rule")
// Add deny rule second for the same traffic
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop, "deny-http")
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
require.NoError(t, err, "failed to add deny rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
rules, err := testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
require.NoError(t, err, "failed to get rules")
t.Logf("Found %d rules in nftables chain", len(rules))
// Find the accept and deny rules and verify deny comes before accept
// Single-source rules emit a direct payload+cmp on the source IP
// (no set lookup). Match by source-IP + port + verdict instead of
// the legacy per-(action,port) set names ("deny-http"/"accept-http")
// that this test predates.
wantSrc := ip.AsSlice()
var acceptRuleIndex, denyRuleIndex = -1, -1
for i, rule := range rules {
hasAcceptHTTPSet := false
hasDenyHTTPSet := false
hasPort80 := false
var hasSrc, hasPort80 bool
var action string
for _, e := range rule.Exprs {
// Check for set lookup
if lookup, ok := e.(*expr.Lookup); ok {
switch lookup.SetName {
case "accept-http":
hasAcceptHTTPSet = true
case "deny-http":
hasDenyHTTPSet = true
if cmp, ok := e.(*expr.Cmp); ok && cmp.Op == expr.CmpOpEq {
if bytes.Equal(cmp.Data, wantSrc) {
hasSrc = true
}
}
// Check for port 80
if cmp, ok := e.(*expr.Cmp); ok {
if cmp.Op == expr.CmpOpEq && len(cmp.Data) == 2 && binary.BigEndian.Uint16(cmp.Data) == 80 {
if len(cmp.Data) == 2 && binary.BigEndian.Uint16(cmp.Data) == 80 {
hasPort80 = true
}
}
// Check for verdict
if verdict, ok := e.(*expr.Verdict); ok {
switch verdict.Kind {
case expr.VerdictAccept:
@@ -231,11 +222,15 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
}
}
if hasAcceptHTTPSet && hasPort80 && action == "ACCEPT" {
t.Logf("Rule [%d]: accept-http set + Port 80 + ACCEPT", i)
if !hasSrc || !hasPort80 {
continue
}
switch action {
case "ACCEPT":
t.Logf("Rule [%d]: src=%s port=80 ACCEPT", i, ip)
acceptRuleIndex = i
} else if hasDenyHTTPSet && hasPort80 && action == "DROP" {
t.Logf("Rule [%d]: deny-http set + Port 80 + DROP", i)
case "DROP":
t.Logf("Rule [%d]: src=%s port=80 DROP", i, ip)
denyRuleIndex = i
}
}
@@ -279,7 +274,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
require.NoError(t, err, "failed to add rule")
if i%100 == 0 {
@@ -361,10 +356,10 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
})
ip := netip.MustParseAddr("100.96.0.1")
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering(
_, err = manager.AddFilterRule(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")},
@@ -437,10 +432,10 @@ func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) {
})
ip := netip.MustParseAddr("fd00::2")
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err, "add v6 peer filtering rule")
_, err = manager.AddRouteFiltering(
_, err = manager.AddFilterRule(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00:1::/64")},
fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
@@ -550,7 +545,7 @@ func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
prefixes = append(prefixes, netip.PrefixFrom(addr, 24))
}
}
_, err = manager.AddRouteFiltering(
_, err = manager.AddFilterRule(
nil,
prefixes,
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
@@ -565,7 +560,7 @@ func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
verifyIptablesOutput(t, stdout, stderr)
}
func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T) {
func TestNftablesManagerCompatibilityWithIptablesForWildcardSource(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
@@ -591,9 +586,9 @@ func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T)
verifyIptablesOutput(t, stdout, stderr)
})
_, err = manager.AddRouteFiltering(
_, err = manager.AddFilterRule(
nil,
[]netip.Prefix{},
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
fw.ProtocolTCP,
nil,

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
//go:build !android
//go:build integration && !android
package nftables
@@ -37,7 +37,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
// need fw manager to init both acl mgr and router for all chains to be present
// need fw manager to init both acl mgr and family for all chains to be present
manager, err := Create(ifaceMock, iface.DefaultMTU)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -47,7 +47,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
nftablesTestingClient := &nftables.Conn{}
rtr := manager.router
rtr := manager.family4
err = rtr.AddNatRule(testCase.InputPair)
require.NoError(t, err, "pair should be inserted")
@@ -90,9 +90,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
}
// Build CIDR matching expressions
testRouter := &router{af: afIPv4}
sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true)
destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false)
testRouter := &family{af: afIPv4}
sourceExp := prefixMatchExprs(testRouter.af, testCase.InputPair.Source.Prefix, true)
destExp := prefixMatchExprs(testRouter.af, testCase.InputPair.Destination.Prefix, false)
// Combine all expressions in the correct order
// nolint:gocritic
@@ -100,14 +100,14 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
testingExpression = append(testingExpression, sourceExp...)
testingExpression = append(testingExpression, destExp...)
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
natRuleKey := testCase.InputPair.GenKey(firewall.PreroutingFormat)
found := 0
for _, chain := range rtr.chains {
if chain.Name == chainNameManglePrerouting {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
// Compare expressions up to the mark setting expressions
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
found = 1
@@ -135,19 +135,19 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
rtr := manager.router
rtr := manager.family4
// First add the NAT rule using the router's method
// First add the NAT rule using the family's method
err = rtr.AddNatRule(testCase.InputPair)
require.NoError(t, err, "should add NAT rule")
// Verify the rule was added
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
natRuleKey := testCase.InputPair.GenKey(firewall.PreroutingFormat)
found := false
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
require.NoError(t, err, "should list rules")
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
found = true
break
}
@@ -163,7 +163,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
require.NoError(t, err, "should list rules after removal")
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
found = true
break
}
@@ -200,11 +200,11 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router")
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create family")
require.NoError(t, r.init(workTable))
defer func(r *router) {
defer func(r *family) {
require.NoError(t, r.Reset(), "Failed to reset rules")
}(r)
@@ -314,16 +314,16 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
ruleKey, err := r.AddFilterRule(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddFilterRule failed")
t.Cleanup(func() {
require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule")
require.NoError(t, r.DeleteFilterRule(ruleKey), "Failed to delete rule")
})
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "Rule not found in internal map")
stored, ok := r.filters[id.RuleID(ruleKey.ID())]
require.True(t, ok, "Rule not found in filters map")
rule := stored.nftRule
t.Log("Internal rule expressions:")
for i, expr := range rule.Exprs {
@@ -339,7 +339,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
var nftRule *nftables.Rule
for _, rule := range rules {
if string(rule.UserData) == ruleKey.ID() {
if firewall.RuleID(rule.UserData) == ruleKey.ID() {
nftRule = rule
break
}
@@ -367,12 +367,12 @@ func TestNftablesCreateIpSet(t *testing.T) {
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router")
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create family")
require.NoError(t, r.init(workTable))
defer func() {
require.NoError(t, r.Reset(), "Failed to reset router")
require.NoError(t, r.Reset(), "Failed to reset family")
}()
tests := []struct {
@@ -509,6 +509,42 @@ func TestNftablesCreateIpSet(t *testing.T) {
}
}
// TestNftablesUpdateSetMergesOverlapping verifies that UpdateSet merges
// overlapping prefixes before adding them. An interval set rejects
// overlapping elements, so without the merge a batch holding a /32 already
// covered by a /24, or a duplicate address as DNS resolution can produce,
// would fail.
func TestNftablesUpdateSetMergesOverlapping(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err, "create work table")
defer deleteWorkTable()
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "create family")
require.NoError(t, r.init(workTable))
defer func() {
require.NoError(t, r.Reset(), "reset family")
}()
initial := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}
set := firewall.NewPrefixSet(initial)
created, err := r.createIpSet(set.HashedName(), setInput{prefixes: initial})
require.NoError(t, err, "create ip set")
require.NotNil(t, created)
overlapping := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.1/32"),
netip.MustParsePrefix("192.168.1.1/32"),
}
require.NoError(t, r.UpdateSet(set, overlapping), "UpdateSet must merge overlapping prefixes")
}
func TestNftablesCreateIpSet_IPv6(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
@@ -518,11 +554,11 @@ func TestNftablesCreateIpSet_IPv6(t *testing.T) {
require.NoError(t, err, "Failed to create v6 work table")
defer deleteWorkTableIPv6()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router")
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create family")
require.NoError(t, r.init(workTable))
defer func() {
require.NoError(t, r.Reset(), "Failed to reset router")
require.NoError(t, r.Reset(), "Failed to reset family")
}()
tests := []struct {
@@ -861,13 +897,13 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
require.NoError(t, err)
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Add a real rule to the kernel
ruleKey, err := r.AddRouteFiltering(
ruleKey, err := r.AddFilterRule(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
@@ -878,11 +914,11 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, r.DeleteRouteRule(ruleKey))
require.NoError(t, r.DeleteFilterRule(ruleKey))
})
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
staleKey := "stale-rule-that-does-not-exist"
staleKey := firewall.RuleID("stale-rule-that-does-not-exist")
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
@@ -902,6 +938,55 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
}
// TestRouter_DeleteRouteRule_RemovesKernelRule verifies a route filter
// rule is actually removed from the kernel on delete. The route chain is
// not refreshed by Flush, so the stored rule carries a zero handle;
// DeleteFilterRule must pull live handles itself before issuing the
// delete or the kernel rule leaks. Regression test for that path.
func TestRouter_DeleteRouteRule_RemovesKernelRule(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err)
defer deleteWorkTable()
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
ruleKey, err := r.AddFilterRule(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
firewall.ProtocolTCP,
nil,
&firewall.Port{Values: []uint16{80}},
firewall.ActionAccept,
)
require.NoError(t, err)
countKernelRules := func() int {
list, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw])
require.NoError(t, err)
n := 0
for _, rule := range list {
if string(rule.UserData) == string(ruleKey.ID()) {
n++
}
}
return n
}
require.Equal(t, 1, countKernelRules(), "rule should be present in the kernel after add")
require.NoError(t, r.DeleteFilterRule(ruleKey))
assert.Equal(t, 0, countKernelRules(), "rule must be removed from the kernel after delete")
assert.NotContains(t, r.filters, ruleKey.ID(), "filters map entry should be cleared")
}
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
@@ -911,24 +996,28 @@ func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
require.NoError(t, err)
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Inject a stale entry with Handle=0
staleKey := "stale-route-rule"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
staleKey := id.RuleID("stale-route-rule")
staleRule := &Rule{
nftRule: &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
},
id: staleKey,
}
r.filters[staleKey] = staleRule
// DeleteRouteRule should not return an error for stale handles
err = r.DeleteRouteRule(id.RuleID(staleKey))
// DeleteFilterRule should not return an error for stale handles
err = r.DeleteFilterRule(staleRule)
assert.NoError(t, err, "deleting a stale rule should not error")
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
assert.NotContains(t, r.filters, staleKey, "stale entry should be cleaned up")
}
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
@@ -950,7 +1039,7 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
Masquerade: true,
}
rtr := manager.router
rtr := manager.family4
// First add succeeds
err = rtr.AddNatRule(pair)
@@ -960,11 +1049,11 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
})
// Corrupt the handle to simulate stale state
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
natRuleKey := pair.GenKey(firewall.PreroutingFormat)
if rule, exists := rtr.rules[natRuleKey]; exists {
rule.Handle = 0
}
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
inverseKey := firewall.GetInversePair(pair).GenKey(firewall.PreroutingFormat)
if rule, exists := rtr.rules[inverseKey]; exists {
rule.Handle = 0
}
@@ -979,7 +1068,7 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
found := 0
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
found++
}
}
@@ -1010,7 +1099,7 @@ func TestCalculateLastIP(t *testing.T) {
}
func TestConvertPrefixesToSet_IPv6(t *testing.T) {
r := &router{af: afIPv6}
r := &family{af: afIPv6}
prefixes := []netip.Prefix{
netip.MustParsePrefix("fd00::/64"),
netip.MustParsePrefix("2001:db8::1/128"),

View File

@@ -0,0 +1,490 @@
//go:build !android
package nftables
import (
"fmt"
"strings"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/client/net"
)
func (r *family) AddNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
return fmt.Errorf("add legacy routing rule: %w", err)
}
}
if pair.Masquerade {
if err := r.addNatRule(pair); err != nil {
return fmt.Errorf("add nat rule: %w", err)
}
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("add inverse nat rule: %w", err)
}
}
if err := r.conn.Flush(); err != nil {
r.rollbackRules(pair)
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
}
return nil
}
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
func (r *family) rollbackRules(pair firewall.RouterPair) {
keys := []firewall.RuleID{
pair.GenKey(firewall.ForwardingFormat),
pair.GenKey(firewall.PreroutingFormat),
firewall.GetInversePair(pair).GenKey(firewall.PreroutingFormat),
}
for _, key := range keys {
rule, ok := r.rules[key]
if !ok {
continue
}
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("rollback set counter for %s: %v", key, err)
}
delete(r.rules, key)
}
}
// addNatRule inserts a nftables rule to the conn client flush queue
func (r *family) addNatRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
if err != nil {
return fmt.Errorf("apply source: %w", err)
}
destExp, err := r.applyNetwork(pair.Destination, nil, false)
if err != nil {
return fmt.Errorf("apply destination: %w", err)
}
op := expr.CmpOpEq
if pair.Inverse {
op = expr.CmpOpNeq
}
exprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: op,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
exprs = append(exprs, getCtNewExprs()...)
exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
var markValue uint32 = nbnet.PreroutingFwmarkMasquerade
if pair.Inverse {
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(markValue),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
)
ruleID := pair.GenKey(firewall.PreroutingFormat)
if _, exists := r.rules[ruleID]; exists {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err)
}
}
// Ensure nat rules come first, so the mark can be overwritten.
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
r.rules[ruleID] = r.conn.InsertRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameManglePrerouting],
Exprs: exprs,
UserData: []byte(ruleID),
})
return nil
}
func (r *family) addPostroutingRules() {
// First masquerade rule for traffic coming in from WireGuard interface
exprs := []expr.Any{
// Match on the first fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade),
},
// We need to exclude the loopback interface as this changes the ebpf proxy port
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs,
})
// Second masquerade rule for traffic going out through WireGuard interface
exprs2 := []expr.Any{
// Match on the second fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn),
},
// Match WireGuard interface
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs2,
})
}
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
func (r *family) addMSSClampingRules() error {
overhead := uint16(ipv4TCPHeaderSize)
if r.af.tableFamily == nftables.TableFamilyIPv6 {
overhead = ipv6TCPHeaderSize
}
if r.mtu <= overhead {
log.Debugf("MTU %d too small for MSS clamping (overhead %d), skipping", r.mtu, overhead)
return nil
}
mss := r.mtu - overhead
exprsOut := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 13,
Len: 1,
},
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 1,
Mask: []byte{0x02},
Xor: []byte{0x00},
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0x00},
},
&expr.Counter{},
&expr.Exthdr{
DestRegister: 1,
Type: 2,
Offset: 2,
Len: 2,
Op: expr.ExthdrOpTcpopt,
},
&expr.Cmp{
Op: expr.CmpOpGt,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
},
&expr.Exthdr{
SourceRegister: 1,
Type: 2,
Offset: 2,
Len: 2,
Op: expr.ExthdrOpTcpopt,
},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameMangleForward],
Exprs: exprsOut,
})
return r.conn.Flush()
}
func (r *family) addLegacyRouteRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
if err != nil {
return fmt.Errorf("apply source: %w", err)
}
destExp, err := r.applyNetwork(pair.Destination, nil, false)
if err != nil {
return fmt.Errorf("apply destination: %w", err)
}
var exprs []expr.Any
exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
exprs = append(exprs,
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
)
ruleID := pair.GenKey(firewall.ForwardingFormat)
if _, exists := r.rules[ruleID]; exists {
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
}
r.rules[ruleID] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Exprs: exprs,
UserData: []byte(ruleID),
})
return nil
}
// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
func (r *family) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleID := pair.GenKey(firewall.ForwardingFormat)
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleID)
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleID, err)
}
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleID)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
return nil
}
// GetLegacyManagement returns the route manager's legacy management mode
func (r *family) GetLegacyManagement() bool {
return r.legacyManagement
}
// SetLegacyManagement sets the route manager to use legacy management mode
func (r *family) SetLegacyManagement(isLegacy bool) {
r.legacyManagement = isLegacy
}
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
func (r *family) RemoveAllLegacyRouteRules() error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
for k, rule := range r.rules {
if !strings.HasPrefix(string(k), firewall.ForwardingFormatPrefix) {
continue
}
if err := r.conn.DelRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
} else {
delete(r.rules, k)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) removeNatPreroutingRules() error {
table := &nftables.Table{
Name: tableNat,
Family: r.af.tableFamily,
}
chain := &nftables.Chain{
Name: chainNameNatPrerouting,
Table: table,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
}
rules, err := r.conn.GetRules(table, chain)
if err != nil {
return fmt.Errorf("get rules from nat table: %w", err)
}
var merr *multierror.Error
// Delete rules that have our UserData suffix
for _, rule := range rules {
if len(rule.UserData) == 0 || !strings.HasSuffix(string(rule.UserData), string(dnatSuffix)) {
continue
}
if err := r.conn.DelRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete rule %s: %w", rule.UserData, err))
}
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
}
}
if err := r.removeLegacyRouteRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
}
// Set counters are decremented in the sub-methods above before flush. If flush fails,
// counters will be off until the next successful removal or refresh cycle.
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) removeNatRule(pair firewall.RouterPair) error {
ruleID := pair.GenKey(firewall.PreroutingFormat)
rule, exists := r.rules[ruleID]
if !exists {
log.Debugf("prerouting rule %s not found", ruleID)
return nil
}
if rule.Handle == 0 {
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleID)
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleID, err)
}
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleID)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
return nil
}

View File

@@ -1,21 +1,26 @@
package nftables
import (
"net"
"net/netip"
"github.com/google/nftables"
"github.com/netbirdio/netbird/client/firewall/manager"
)
// Rule to handle management of rules
// Rule wraps an installed filter rule (peer or route). Source set
// membership is encoded in the rule's expressions; DeleteFilterRule
// recovers the set name via findSets so the refcounter can drop the
// right reference. mangleRule is set only for peer rules.
type Rule struct {
nftRule *nftables.Rule
mangleRule *nftables.Rule
nftSet *nftables.Set
ruleID string
ip net.IP
// sources is the canonical source list this rule was created for.
sources []netip.Prefix
id manager.RuleID
}
// GetRuleID returns the rule id
func (r *Rule) ID() string {
return r.ruleID
// ID returns the rule id
func (r *Rule) ID() manager.RuleID {
return r.id
}

View File

@@ -0,0 +1,27 @@
//go:build integration && !android
package nftables
import (
"fmt"
"net"
"net/netip"
)
func pfx(ip net.IP) []netip.Prefix {
if ip == nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
if ip.IsUnspecified() {
if ip.To4() != nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
}
a, ok := netip.AddrFromSlice(ip)
if !ok {
panic(fmt.Sprintf("invalid IP length: %d", len(ip)))
}
a = a.Unmap()
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
}

View File

@@ -5,7 +5,6 @@ import (
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
"os"
"slices"
@@ -72,12 +71,14 @@ const (
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
// RuleSet is a set of rules grouped by a string key
type RuleSet map[string]PeerRule
// peerRules is the canonical list-based storage for peer ACL rules.
// Match order is significant: drop rules come before accept rules so
// callers should consult the slice in order.
type peerRules []*PeerRule
type RouteRules []*RouteRule
type routeRules []*RouteRule
func (r RouteRules) Sort() {
func (r routeRules) Sort() {
slices.SortStableFunc(r, func(a, b *RouteRule) int {
// Deny rules come first
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
@@ -86,22 +87,44 @@ func (r RouteRules) Sort() {
if a.action != firewall.ActionDrop && b.action == firewall.ActionDrop {
return 1
}
return strings.Compare(a.id, b.id)
return strings.Compare(string(a.id), string(b.id))
})
}
// peerRuleSpec carries the parameters that define a peer filter rule,
// threaded together through the build path so the builders take a single
// argument instead of a long parameter list.
type peerRuleSpec struct {
mgmtID []byte
sources []netip.Prefix
ipLayer gopacket.LayerType
matchAny bool
proto firewall.Protocol
sPort *firewall.Port
dPort *firewall.Port
action firewall.Action
}
// Manager userspace firewall manager
type Manager struct {
outgoingRules map[netip.Addr]RuleSet
incomingDenyRules map[netip.Addr]RuleSet
incomingRules map[netip.Addr]RuleSet
routeRules RouteRules
routeRulesMap map[nbid.RuleID]*RouteRule
decoders sync.Pool
wgIface common.IFaceMapper
nativeFirewall firewall.Manager
decoders sync.Pool
wgIface common.IFaceMapper
// nativeFirewall is the kernel firewall (nftables/iptables) used for
// the split case where peer ACLs run in userspace here but routing
// stays in the kernel: when the userspace firewall is forced yet the
// router keeps using the kernel, route NAT/ACLs and DNAT are
// delegated to it. It is nil when no native backend is available.
nativeFirewall firewall.Manager
mutex sync.RWMutex
mutex sync.RWMutex
incomingDenyRules peerRules
incomingAcceptRules peerRules
incomingDenyIndex peerRuleIndex
incomingAcceptIndex peerRuleIndex
peerRulesMap map[nbid.RuleID]*PeerRule
routeRules routeRules
routeRulesMap map[nbid.RuleID]*RouteRule
// indicates whether server routes are disabled
disableServerRoutes bool
@@ -183,24 +206,6 @@ func (d *decoder) decodePacket(data []byte) error {
}
}
// Create userspace firewall manager constructor
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
return create(iface, nil, disableServerRoutes, flowLogger, mtu)
}
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
if nativeFirewall == nil {
return nil, errors.New("native firewall is nil")
}
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger, mtu)
if err != nil {
return nil, err
}
return mgr, nil
}
func parseCreateEnv() (bool, bool, bool) {
var disableConntrack, enableLocalForwarding, disableMSSClamping bool
var err error
@@ -231,7 +236,7 @@ func parseCreateEnv() (bool, bool, bool) {
return disableConntrack, enableLocalForwarding, disableMSSClamping
}
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
func Create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
disableConntrack, enableLocalForwarding, disableMSSClamping := parseCreateEnv()
m := &Manager{
@@ -254,11 +259,8 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
return d
},
},
nativeFirewall: nativeFirewall,
outgoingRules: make(map[netip.Addr]RuleSet),
incomingDenyRules: make(map[netip.Addr]RuleSet),
incomingRules: make(map[netip.Addr]RuleSet),
wgIface: iface,
nativeFirewall: nativeFirewall,
localipmanager: newLocalIPManager(),
disableServerRoutes: disableServerRoutes,
stateful: !disableConntrack,
@@ -266,6 +268,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
flowLogger: flowLogger,
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
peerRulesMap: make(map[nbid.RuleID]*PeerRule),
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{},
@@ -320,7 +323,7 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule,
}
var rules []firewall.Rule
v4Rule, err := m.addRouteFiltering(
v4Rule, err := m.addRouteRule(
nil,
sources,
firewall.Network{Prefix: wgPrefix},
@@ -336,7 +339,7 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule,
if v6Net.IsValid() {
log.Debugf("blocking invalid routed traffic for %s", v6Net)
v6Rule, err := m.addRouteFiltering(
v6Rule, err := m.addRouteRule(
nil,
sources,
firewall.Network{Prefix: v6Net},
@@ -488,64 +491,136 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
return nil
}
// AddPeerFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *Manager) AddPeerFiltering(
// addPeerRule installs an input-chain rule that matches packets
// by source only. Called from AddFilterRule when the caller doesn't
// specify a destination. Mixed-family inputs are split: each family
// gets its own rule with a family-correct ipLayer so packet decoding
// matches what the matcher expects.
func (m *Manager) addPeerRule(
id []byte,
ip net.IP,
sources []netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
_ string,
) ([]firewall.Rule, error) {
// TODO: fix in upper layers
i, ok := netip.AddrFromSlice(ip)
if !ok {
return nil, fmt.Errorf("invalid IP: %s", ip)
}
i = i.Unmap()
r := PeerRule{
id: uuid.New().String(),
mgmtId: id,
ip: i,
ipLayer: layers.LayerTypeIPv6,
matchByIP: true,
drop: action == firewall.ActionDrop,
}
if i.Is4() {
r.ipLayer = layers.LayerTypeIPv4
}
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
r.matchByIP = false
}
r.sPort = sPort
r.dPort = dPort
r.protoLayer = protoToLayer(proto, r.ipLayer)
) (firewall.Rule, error) {
m.mutex.Lock()
var targetMap map[netip.Addr]RuleSet
if r.drop {
targetMap = m.incomingDenyRules
} else {
targetMap = m.incomingRules
defer m.mutex.Unlock()
if sourcesMatchAny(sources) {
spec := peerRuleSpec{
mgmtID: id,
sources: sources,
ipLayer: layerTypeAll,
matchAny: true,
proto: proto,
sPort: sPort,
dPort: dPort,
action: action,
}
return m.addOnePeerRule(spec), nil
}
if _, ok := targetMap[r.ip]; !ok {
targetMap[r.ip] = make(RuleSet)
// Sources are a single family; normalize v4-mapped prefixes to plain
// v4 and pick the matching IP layer.
normalized := make([]netip.Prefix, len(sources))
ipLayer := layers.LayerTypeIPv4
for i, p := range sources {
normalized[i] = firewall.UnmapPrefix(p)
if normalized[i].Addr().Is6() {
ipLayer = layers.LayerTypeIPv6
}
}
targetMap[r.ip][r.id] = r
m.mutex.Unlock()
return []firewall.Rule{&r}, nil
spec := peerRuleSpec{
mgmtID: id,
sources: normalized,
ipLayer: ipLayer,
matchAny: false,
proto: proto,
sPort: sPort,
dPort: dPort,
action: action,
}
return m.addOnePeerRule(spec), nil
}
func (m *Manager) AddRouteFiltering(
// addOnePeerRule builds and registers a single-family peer rule, or
// returns the existing rule when one with the same content key is
// already installed. The caller must hold m.mutex. The content key is
// the shared GenerateRuleID with an empty destination, so peer
// rules dedup the same way route rules and the kernel backends do.
//
// There is no refcount: a content key is installed once and deleted on
// the first DeleteFilterRule for that key. The caller must therefore
// key its own tracking by the returned rule id so add and delete stay
// balanced per content key; the acl manager does this via
// peerRulesPairs. The content key is order-independent, so callers
// passing the same sources in any order dedup to one rule.
func (m *Manager) addOnePeerRule(spec peerRuleSpec) *PeerRule {
ruleID := nbid.GenerateRuleID(spec.sources, firewall.Network{}, spec.proto, spec.sPort, spec.dPort, spec.action)
if existing, ok := m.peerRulesMap[ruleID]; ok {
return existing
}
rule := m.buildPeerRule(ruleID, spec)
m.registerPeerRule(rule)
return rule
}
func (m *Manager) buildPeerRule(ruleID nbid.RuleID, spec peerRuleSpec) *PeerRule {
r := &PeerRule{
id: ruleID,
mgmtId: spec.mgmtID,
sources: spec.sources,
matchAny: spec.matchAny,
action: spec.action,
srcPort: spec.sPort,
dstPort: spec.dPort,
}
if !spec.matchAny {
r.sourceAddrs = make(map[netip.Addr]struct{}, len(spec.sources))
for _, p := range spec.sources {
if p.Bits() == p.Addr().BitLen() {
r.sourceAddrs[p.Addr()] = struct{}{}
}
}
}
r.protoLayer = protoToLayer(spec.proto, spec.ipLayer)
return r
}
// registerPeerRule records a freshly built peer rule in the matching
// slice, index, and dedup map. The caller must hold m.mutex.
func (m *Manager) registerPeerRule(r *PeerRule) {
if r.action == firewall.ActionDrop {
m.incomingDenyRules = append(m.incomingDenyRules, r)
m.incomingDenyIndex.add(r)
} else {
m.incomingAcceptRules = append(m.incomingAcceptRules, r)
m.incomingAcceptIndex.add(r)
}
m.peerRulesMap[r.id] = r
}
// sourcesMatchAny reports whether the source list matches every source,
// i.e. contains an explicit /0 prefix. An empty list does not qualify:
// AddFilterRule rejects it with ErrNoSources, so "match any" is always
// the deliberate /0 case.
func sourcesMatchAny(sources []netip.Prefix) bool {
for _, p := range sources {
if p.Bits() == 0 {
return true
}
}
return false
}
// AddFilterRule is the unified entry point for both peer (input chain)
// and route (forward chain) filtering rules. The destination
// distinguishes the two semantics: a zero Network installs an
// input-side rule that matches by source only; a set Network installs
// a forward-side rule that also matches the destination.
func (m *Manager) AddFilterRule(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
@@ -553,13 +628,37 @@ func (m *Manager) AddRouteFiltering(
sPort, dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
if len(sources) == 0 {
return nil, firewall.ErrNoSources
}
if destination.IsZero() {
return m.addPeerRule(id, sources, proto, sPort, dPort, action)
}
m.mutex.Lock()
defer m.mutex.Unlock()
return m.addRouteRule(id, sources, destination, proto, sPort, dPort, action)
}
// DeleteFilterRule deletes a filtering rule. The rule's underlying type
// is used to route to the correct internal path.
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.addRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
if r, ok := rule.(*PeerRule); ok {
return m.deletePeerRuleLocked(r)
}
// Either our *RouteRule or, under native delegation, a native
// route-rule object that implements firewall.Rule but isn't one of
// our concrete types. The route path forwards the latter to the
// native firewall.
return m.deleteRouteRule(rule)
}
func (m *Manager) addRouteFiltering(
func (m *Manager) addRouteRule(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
@@ -568,18 +667,17 @@ func (m *Manager) addRouteFiltering(
action firewall.Action,
) (firewall.Rule, error) {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
return m.nativeFirewall.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
}
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
if existingRule, ok := m.routeRulesMap[ruleID]; ok {
return existingRule, nil
}
rule := RouteRule{
// TODO: consolidate these IDs
id: string(ruleKey),
id: ruleID,
mgmtId: id,
sources: sources,
dstSet: destination.Set,
@@ -594,72 +692,57 @@ func (m *Manager) addRouteFiltering(
m.routeRules = append(m.routeRules, &rule)
m.routeRules.Sort()
m.routeRulesMap[ruleKey] = &rule
m.routeRulesMap[ruleID] = &rule
return &rule, nil
}
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.deleteRouteRule(rule)
}
func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.DeleteRouteRule(rule)
return m.nativeFirewall.DeleteFilterRule(rule)
}
ruleKey := nbid.RuleID(rule.ID())
if _, ok := m.routeRulesMap[ruleKey]; !ok {
return fmt.Errorf("route rule not found: %s", ruleKey)
ruleID := rule.ID()
trimmed, _, ok := removeRuleByID(m.routeRules, ruleID)
if !ok {
return fmt.Errorf("route rule not found: %s", ruleID)
}
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
return r.id == string(ruleKey)
})
if idx < 0 {
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
}
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
delete(m.routeRulesMap, ruleKey)
m.routeRules = trimmed
delete(m.routeRulesMap, ruleID)
return nil
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
// deletePeerRuleLocked removes a peer rule from the matching slice,
// index, and dedup map. The caller must hold m.mutex.
func (m *Manager) deletePeerRuleLocked(r *PeerRule) error {
target, index := &m.incomingAcceptRules, &m.incomingAcceptIndex
if r.action == firewall.ActionDrop {
target, index = &m.incomingDenyRules, &m.incomingDenyIndex
}
r, ok := rule.(*PeerRule)
trimmed, stored, ok := removeRuleByID(*target, r.id)
if !ok {
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
}
var sourceMap map[netip.Addr]RuleSet
if r.drop {
sourceMap = m.incomingDenyRules
} else {
sourceMap = m.incomingRules
}
if ruleset, ok := sourceMap[r.ip]; ok {
if _, exists := ruleset[r.id]; !exists {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(ruleset, r.id)
if len(ruleset) == 0 {
delete(sourceMap, r.ip)
}
} else {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
*target = trimmed
index.remove(stored)
delete(m.peerRulesMap, r.id)
return nil
}
// removeRuleByID removes the first rule whose id matches ruleID from
// rules, preserving order. It returns the trimmed slice, the removed
// rule, and whether a match was found.
func removeRuleByID[S ~[]T, T firewall.Rule](rules S, ruleID firewall.RuleID) (S, T, bool) {
idx := slices.IndexFunc(rules, func(r T) bool { return r.ID() == ruleID })
var removed T
if idx < 0 {
return rules, removed, false
}
removed = rules[idx]
return slices.Delete(rules, idx, idx+1), removed, true
}
// SetLegacyManagement doesn't need to be implemented for this manager
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if m.nativeFirewall == nil {
@@ -674,9 +757,11 @@ func (m *Manager) Flush() error { return nil }
// resetState clears all firewall rules and closes connection trackers.
// Must be called with m.mutex held.
func (m *Manager) resetState() {
clear(m.outgoingRules)
clear(m.incomingDenyRules)
clear(m.incomingRules)
m.incomingDenyRules = m.incomingDenyRules[:0]
m.incomingAcceptRules = m.incomingAcceptRules[:0]
m.incomingDenyIndex.reset()
m.incomingAcceptIndex.reset()
clear(m.peerRulesMap)
clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0]
m.udpHookOut.Store(nil)
@@ -820,11 +905,11 @@ func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
case layers.LayerTypeIPv4:
src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
return src, dst
return src.Unmap(), dst.Unmap()
case layers.LayerTypeIPv6:
src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
return src, dst
return src.Unmap(), dst.Unmap()
default:
return netip.Addr{}, netip.Addr{}
}
@@ -1404,20 +1489,12 @@ func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte)
return nil, false
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingDenyRules[srcIP], d); ok {
if mgmtId, filter, ok := m.incomingDenyIndex.match(srcIP, d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[srcIP], d); ok {
if mgmtId, filter, ok := m.incomingAcceptIndex.match(srcIP, d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv4Unspecified()], d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv6Unspecified()], d); ok {
return mgmtId, filter
}
return nil, true
}
@@ -1438,39 +1515,6 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
return false
}
func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
payloadLayer := d.decoded[1]
for _, rule := range rules {
if rule.matchByIP && ip.Compare(rule.ip) != 0 {
continue
}
if rule.protoLayer == layerTypeAll {
return rule.mgmtId, rule.drop, true
}
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
continue
}
switch payloadLayer {
case layers.LayerTypeTCP:
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
return rule.mgmtId, rule.drop, true
}
case layers.LayerTypeUDP:
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
return rule.mgmtId, rule.drop, true
}
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return rule.mgmtId, rule.drop, true
}
}
return nil, false, false
}
// routeACLsPass returns true if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) {
m.mutex.RLock()

View File

@@ -94,7 +94,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: false,
setupFunc: func(m *Manager) {
// Single rule allowing all traffic
_, err := m.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
_, err := m.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolALL, nil, nil, fw.ActionAccept)
require.NoError(b, err)
},
desc: "Baseline: Single 'allow all' rule without connection tracking",
@@ -114,15 +114,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Add explicit rules matching return traffic pattern
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
ip := generateRandomIPs(1)[0]
_, err := m.AddPeerFiltering(
_, err := m.AddFilterRule(
nil,
ip,
pfx(ip), fw.Network{},
fw.ProtocolTCP,
&fw.Port{Values: []uint16{uint16(1024 + i)}},
&fw.Port{Values: []uint16{80}},
fw.ActionAccept,
"",
)
fw.ActionAccept)
require.NoError(b, err)
}
},
@@ -133,15 +131,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: true,
setupFunc: func(m *Manager) {
// Add some basic rules but rely on state for established connections
_, err := m.AddPeerFiltering(
_, err := m.AddFilterRule(
nil,
net.ParseIP("0.0.0.0"),
pfx(net.ParseIP("0.0.0.0")), fw.Network{},
fw.ProtocolTCP,
nil,
nil,
fw.ActionDrop,
"",
)
fw.ActionDrop)
require.NoError(b, err)
},
desc: "Connection tracking with established connections",
@@ -170,7 +166,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Create manager and basic setup
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
}, nil, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -210,7 +206,7 @@ func BenchmarkStateScaling(b *testing.B) {
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
}, nil, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -253,7 +249,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
}, nil, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -411,7 +407,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
}, nil, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -538,7 +534,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
}, nil, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -546,7 +542,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
require.NoError(b, err)
}
@@ -621,7 +617,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
}, nil, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -629,7 +625,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
require.NoError(b, err)
}
@@ -732,14 +728,14 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
}, nil, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
// Setup initial state based on scenario
if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
require.NoError(b, err)
}
@@ -812,13 +808,13 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
}, nil, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
require.NoError(b, err)
}
@@ -931,7 +927,7 @@ func BenchmarkRouteACLs(b *testing.B) {
for _, r := range rules {
dst := fw.Network{Prefix: r.dest}
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
_, err := manager.AddFilterRule(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
if err != nil {
b.Fatal(err)
}
@@ -1016,7 +1012,7 @@ func BenchmarkMSSClamping(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
}, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -1081,7 +1077,7 @@ func BenchmarkMSSClampingOverhead(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
}, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -1136,7 +1132,7 @@ func BenchmarkMSSClampingMemory(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
}, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))

View File

@@ -32,7 +32,7 @@ func TestPeerACLFiltering(t *testing.T) {
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
require.NotNil(t, manager)
@@ -496,40 +496,32 @@ func TestPeerACLFiltering(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
if tc.ruleAction == fw.ActionDrop {
// add general accept rule for the same IP to test drop rule precedence
rules, err := manager.AddPeerFiltering(
rules, err := manager.AddFilterRule(
nil,
net.ParseIP(tc.ruleIP),
pfx(net.ParseIP(tc.ruleIP)), fw.Network{},
fw.ProtocolALL,
nil,
nil,
fw.ActionAccept,
"",
)
fw.ActionAccept)
require.NoError(t, err)
require.NotEmpty(t, rules)
require.NotNil(t, rules)
t.Cleanup(func() {
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
require.NoError(t, manager.DeleteFilterRule(rules))
})
}
rules, err := manager.AddPeerFiltering(
rules, err := manager.AddFilterRule(
nil,
net.ParseIP(tc.ruleIP),
pfx(net.ParseIP(tc.ruleIP)), fw.Network{},
tc.ruleProto,
tc.ruleSrcPort,
tc.ruleDstPort,
tc.ruleAction,
"",
)
tc.ruleAction)
require.NoError(t, err)
require.NotEmpty(t, rules)
require.NotNil(t, rules)
t.Cleanup(func() {
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
require.NoError(t, manager.DeleteFilterRule(rules))
})
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
@@ -557,7 +549,7 @@ func TestPeerACLFilteringIPv6(t *testing.T) {
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
@@ -672,22 +664,18 @@ func TestPeerACLFilteringIPv6(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.ruleAction == fw.ActionDrop {
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
rules, err := manager.AddFilterRule(nil, pfx(net.ParseIP(tc.ruleIP)), fw.Network{}, fw.ProtocolALL, nil, nil, fw.ActionAccept)
require.NoError(t, err)
t.Cleanup(func() {
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
require.NoError(t, manager.DeleteFilterRule(rules))
})
}
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction, "")
rules, err := manager.AddFilterRule(nil, pfx(net.ParseIP(tc.ruleIP)), fw.Network{}, tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction)
require.NoError(t, err)
require.NotEmpty(t, rules)
require.NotNil(t, rules)
t.Cleanup(func() {
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
require.NoError(t, manager.DeleteFilterRule(rules))
})
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
@@ -800,7 +788,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(tb, err)
require.NoError(tb, manager.EnableRouting())
require.NotNil(tb, manager)
@@ -1405,7 +1393,7 @@ func TestRouteACLFiltering(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
if tc.rule.action == fw.ActionDrop {
// add general accept rule to test drop rule
rule, err := manager.AddRouteFiltering(
rule, err := manager.AddFilterRule(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
@@ -1415,13 +1403,13 @@ func TestRouteACLFiltering(t *testing.T) {
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
require.NotEmpty(t, rule)
t.Cleanup(func() {
require.NoError(t, manager.DeleteRouteRule(rule))
require.NoError(t, manager.DeleteFilterRule(rule))
})
}
rule, err := manager.AddRouteFiltering(
rule, err := manager.AddFilterRule(
nil,
tc.rule.sources,
tc.rule.dest,
@@ -1431,10 +1419,10 @@ func TestRouteACLFiltering(t *testing.T) {
tc.rule.action,
)
require.NoError(t, err)
require.NotNil(t, rule)
require.NotEmpty(t, rule)
t.Cleanup(func() {
require.NoError(t, manager.DeleteRouteRule(rule))
require.NoError(t, manager.DeleteFilterRule(rule))
})
srcIP := netip.MustParseAddr(tc.srcIP)
@@ -1602,9 +1590,9 @@ func TestRouteACLOrder(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var rules []fw.Rule
var addedRules []fw.Rule
for _, r := range tc.rules {
rule, err := manager.AddRouteFiltering(
rule, err := manager.AddFilterRule(
nil,
r.sources,
r.dest,
@@ -1615,12 +1603,12 @@ func TestRouteACLOrder(t *testing.T) {
)
require.NoError(t, err)
require.NotNil(t, rule)
rules = append(rules, rule)
addedRules = append(addedRules, rule)
}
t.Cleanup(func() {
for _, rule := range rules {
require.NoError(t, manager.DeleteRouteRule(rule))
for _, rule := range addedRules {
require.NoError(t, manager.DeleteFilterRule(rule))
}
})
@@ -1646,7 +1634,7 @@ func TestRouteACLSet(t *testing.T) {
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -1655,7 +1643,7 @@ func TestRouteACLSet(t *testing.T) {
set := fw.NewDomainSet(domain.List{"example.org"})
// Add rule that uses the set (initially empty)
rule, err := manager.AddRouteFiltering(
rule, err := manager.AddFilterRule(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
@@ -1689,7 +1677,7 @@ func TestRouteACLFilteringIPv6(t *testing.T) {
manager := setupRoutedManager(t, "10.10.0.100/16")
v6Dst := netip.MustParsePrefix("fd00:dead:beef::/48")
_, err := manager.AddRouteFiltering(
_, err := manager.AddFilterRule(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
fw.Network{Prefix: v6Dst},
@@ -1700,7 +1688,7 @@ func TestRouteACLFilteringIPv6(t *testing.T) {
)
require.NoError(t, err)
_, err = manager.AddRouteFiltering(
_, err = manager.AddFilterRule(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
fw.Network{Prefix: netip.MustParsePrefix("fd00:dead:beef:1::/64")},

View File

@@ -29,7 +29,7 @@ func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add rule first time
rule1, err := manager.AddRouteFiltering(
rule1, err := manager.AddFilterRule(
[]byte("policy-1"),
sources,
destination,
@@ -42,7 +42,7 @@ func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
require.NotNil(t, rule1)
// Add the same rule again
rule2, err := manager.AddRouteFiltering(
rule2, err := manager.AddFilterRule(
[]byte("policy-1"),
sources,
destination,
@@ -74,7 +74,7 @@ func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
// Add first rule
rule1, err := manager.AddRouteFiltering(
rule1, err := manager.AddFilterRule(
[]byte("policy-1"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
@@ -86,7 +86,7 @@ func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
require.NoError(t, err)
// Add different rule (different destination)
rule2, err := manager.AddRouteFiltering(
rule2, err := manager.AddFilterRule(
[]byte("policy-2"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
@@ -115,7 +115,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
rule1, err := manager.AddRouteFiltering(
rule1, err := manager.AddFilterRule(
[]byte("policy-1"),
sources,
destination,
@@ -132,7 +132,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
require.True(t, pass, "Traffic should pass with rule in place")
// Re-add same rule (simulates network map update)
rule2, err := manager.AddRouteFiltering(
rule2, err := manager.AddFilterRule(
[]byte("policy-1"),
sources,
destination,
@@ -147,7 +147,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
// would remove the only matching rule and cause a traffic gap.
if rule1.ID() != rule2.ID() {
err = manager.DeleteRouteRule(rule1)
err = manager.DeleteFilterRule(rule1)
require.NoError(t, err)
}
@@ -156,6 +156,59 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
"Traffic should still pass after rule update - no gap should occur")
}
// TestBlockInvalidRoutedDualStack verifies that when the interface has an
// IPv6 overlay address, blockInvalidRouted installs a block rule for both
// the v4 and v6 WG prefixes and that routed traffic to the v6 prefix is
// denied. The v4-only soft-skip path is covered by
// TestBlockInvalidRoutedIdempotent, whose mock leaves IPv6Net invalid.
func TestBlockInvalidRoutedDualStack(t *testing.T) {
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
wgNet6 := netip.MustParsePrefix("fd00:1234::1/64")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
IPv6: wgNet6.Addr(),
IPv6Net: wgNet6,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
rules, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.Len(t, rules, 2, "dual-stack interface must produce a v4 and a v6 block rule")
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 2, ruleCount, "should have one block rule per family")
// v6 routed traffic to the WG prefix must be denied.
srcIP := netip.MustParseAddr("2001:db8::1")
dstIP := netip.MustParseAddr("fd00:1234::50")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
assert.False(t, pass, "block rule should deny routed traffic to the v6 WG prefix")
}
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
// exactly one drop rule for the WireGuard network prefix, and calling it again
// returns the same rule without duplicating.
@@ -182,7 +235,7 @@ func TestBlockInvalidRoutedIdempotent(t *testing.T) {
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -245,7 +298,7 @@ func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -274,7 +327,7 @@ func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
// Simulate 5 network map updates with the same route rule
for i := 0; i < 5; i++ {
rule, err := manager.AddRouteFiltering(
rule, err := manager.AddFilterRule(
[]byte("policy-1"),
sources,
destination,
@@ -304,7 +357,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add same rule twice
rule1, err := manager.AddRouteFiltering(
rule1, err := manager.AddFilterRule(
[]byte("policy-1"),
sources,
destination,
@@ -315,7 +368,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
)
require.NoError(t, err)
rule2, err := manager.AddRouteFiltering(
rule2, err := manager.AddFilterRule(
[]byte("policy-1"),
sources,
destination,
@@ -329,7 +382,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
// Delete using first reference
err = manager.DeleteRouteRule(rule1)
err = manager.DeleteFilterRule(rule1)
require.NoError(t, err)
// Verify traffic no longer passes
@@ -364,7 +417,7 @@ func setupTestManager(t *testing.T) *Manager {
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.EnableRouting())

View File

@@ -78,7 +78,7 @@ func TestManagerCreate(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -89,7 +89,7 @@ func TestManagerCreate(t *testing.T) {
}
}
func TestManagerAddPeerFiltering(t *testing.T) {
func TestManagerAddFilterRule(t *testing.T) {
isSetFilterCalled := false
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error {
@@ -98,7 +98,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
},
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -109,7 +109,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
rule, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
rule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -131,7 +131,7 @@ func TestManagerDeleteRule(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -142,63 +142,33 @@ func TestManagerDeleteRule(t *testing.T) {
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
rule2, err := m.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, proto, nil, port, action)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
}
// Check rules exist in appropriate maps
for _, r := range rule2 {
peerRule, ok := r.(*PeerRule)
if !ok {
t.Errorf("rule should be a PeerRule")
continue
}
// Check if rule exists in deny or allow maps based on action
var found bool
if peerRule.drop {
_, found = m.incomingDenyRules[ip][r.ID()]
} else {
_, found = m.incomingRules[ip][r.ID()]
}
if !found {
t.Errorf("rule2 is not in the expected rules map")
peerRule, ok := rule2.(*PeerRule)
require.True(t, ok, "rule should be a PeerRule")
inMap := func() bool {
if peerRule.action == fw.ActionDrop {
return findRuleByID(m.incomingDenyRules, ip, rule2.ID())
}
return findRuleByID(m.incomingAcceptRules, ip, rule2.ID())
}
for _, r := range rule2 {
err = m.DeletePeerRule(r)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
}
}
require.True(t, inMap(), "rule2 should be in the expected rules list")
// Check rules are removed from appropriate maps
for _, r := range rule2 {
peerRule, ok := r.(*PeerRule)
if !ok {
t.Errorf("rule should be a PeerRule")
continue
}
// Check if rule is removed from deny or allow maps based on action
var found bool
if peerRule.drop {
_, found = m.incomingDenyRules[ip][r.ID()]
} else {
_, found = m.incomingRules[ip][r.ID()]
}
if found {
t.Errorf("rule2 should be removed from the rules map")
}
}
require.NoError(t, m.DeleteFilterRule(rule2), "failed to delete rule")
require.False(t, inMap(), "rule2 should be removed from the rules list")
}
func TestSetUDPPacketHook(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
}, nil, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
@@ -222,7 +192,7 @@ func TestSetUDPPacketHook(t *testing.T) {
func TestSetTCPPacketHook(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
}, nil, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
@@ -250,7 +220,7 @@ func TestPeerRuleLifecycleDenyRules(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
@@ -260,36 +230,34 @@ func TestPeerRuleLifecycleDenyRules(t *testing.T) {
addr := netip.MustParseAddr("192.168.1.1")
// Add multiple deny rules for different ports
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
rule1, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
require.NoError(t, err)
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
rule2, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
require.NoError(t, err)
m.mutex.RLock()
denyCount := len(m.incomingDenyRules[addr])
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
m.mutex.RUnlock()
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
// Delete the first deny rule
err = m.DeletePeerRule(rule1[0])
err = m.DeleteFilterRule(rule1)
require.NoError(t, err)
m.mutex.RLock()
denyCount = len(m.incomingDenyRules[addr])
denyCount = countRulesForAddr(m.incomingDenyRules, addr)
m.mutex.RUnlock()
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
// Delete the second deny rule
err = m.DeletePeerRule(rule2[0])
err = m.DeleteFilterRule(rule2)
require.NoError(t, err)
m.mutex.RLock()
_, exists := m.incomingDenyRules[addr]
exists := countRulesForAddr(m.incomingDenyRules, addr) > 0
m.mutex.RUnlock()
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
require.False(t, exists, "Deny rules should be cleaned up when empty")
}
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
@@ -299,7 +267,7 @@ func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
@@ -311,27 +279,21 @@ func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
// Simulate 10 network map updates: add rule, delete old, add new
for i := 0; i < 10; i++ {
// Add a deny rule
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
rules, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
require.NoError(t, err)
// Add an allow rule
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
allowRules, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err)
// Delete them (simulating ACL manager cleanup)
for _, r := range rules {
require.NoError(t, m.DeletePeerRule(r))
}
for _, r := range allowRules {
require.NoError(t, m.DeletePeerRule(r))
}
require.NoError(t, m.DeleteFilterRule(rules))
require.NoError(t, m.DeleteFilterRule(allowRules))
}
m.mutex.RLock()
denyCount := len(m.incomingDenyRules[addr])
allowCount := len(m.incomingRules[addr])
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
allowCount := countRulesForAddr(m.incomingAcceptRules, addr)
m.mutex.RUnlock()
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
@@ -345,7 +307,7 @@ func TestMixedAllowDenyRulesSameIP(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
@@ -354,41 +316,39 @@ func TestMixedAllowDenyRulesSameIP(t *testing.T) {
ip := net.ParseIP("192.168.1.1")
// Add allow rule for port 80
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
allowRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err)
// Add deny rule for port 22
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
denyRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
require.NoError(t, err)
addr := netip.MustParseAddr("192.168.1.1")
m.mutex.RLock()
allowCount := len(m.incomingRules[addr])
denyCount := len(m.incomingDenyRules[addr])
allowCount := countRulesForAddr(m.incomingAcceptRules, addr)
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
m.mutex.RUnlock()
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
// Delete allow rule should not affect deny rule
err = m.DeletePeerRule(allowRule[0])
err = m.DeleteFilterRule(allowRule)
require.NoError(t, err)
m.mutex.RLock()
denyCountAfter := len(m.incomingDenyRules[addr])
denyCountAfter := countRulesForAddr(m.incomingDenyRules, addr)
m.mutex.RUnlock()
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
// Delete deny rule
err = m.DeletePeerRule(denyRule[0])
err = m.DeleteFilterRule(denyRule)
require.NoError(t, err)
m.mutex.RLock()
_, denyExists := m.incomingDenyRules[addr]
_, allowExists := m.incomingRules[addr]
denyExists := countRulesForAddr(m.incomingDenyRules, addr) > 0
allowExists := countRulesForAddr(m.incomingAcceptRules, addr) > 0
m.mutex.RUnlock()
require.False(t, denyExists, "Deny rules should be empty")
@@ -400,7 +360,7 @@ func TestManagerReset(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -411,7 +371,7 @@ func TestManagerReset(t *testing.T) {
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err = m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
_, err = m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -423,7 +383,7 @@ func TestManagerReset(t *testing.T) {
return
}
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 || len(m.incomingDenyRules) != 0 {
if len(m.incomingAcceptRules) != 0 || len(m.incomingDenyRules) != 0 {
t.Errorf("rules are not empty")
}
}
@@ -439,7 +399,7 @@ func TestNotMatchByIP(t *testing.T) {
},
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -449,7 +409,7 @@ func TestNotMatchByIP(t *testing.T) {
proto := fw.ProtocolUDP
action := fw.ActionAccept
_, err = m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
_, err = m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -502,7 +462,7 @@ func TestRemovePacketHook(t *testing.T) {
}
// creating manager instance
manager, err := Create(iface, false, flowLogger, nbiface.DefaultMTU)
manager, err := Create(iface, nil, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Fatalf("Failed to create Manager: %s", err)
}
@@ -521,7 +481,7 @@ func TestRemovePacketHook(t *testing.T) {
func TestProcessOutgoingHooks(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
}, nil, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
manager.udpTracker.Close()
@@ -606,7 +566,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
manager, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
time.Sleep(time.Second)
@@ -621,7 +581,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
_, err = manager.AddFilterRule(nil, pfx(ip), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
require.NoError(t, err, "failed to add rule")
}
@@ -633,7 +593,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
}, nil, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
manager.udpTracker.Close() // Close the existing tracker
@@ -845,7 +805,7 @@ func TestUpdateSetMerge(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
manager, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -858,7 +818,7 @@ func TestUpdateSetMerge(t *testing.T) {
netip.MustParsePrefix("192.168.1.0/24"),
}
rule, err := manager.AddRouteFiltering(
rule, err := manager.AddFilterRule(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
@@ -931,7 +891,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
manager, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -939,7 +899,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
set := fw.NewDomainSet(domain.List{"example.org"})
rule, err := manager.AddRouteFiltering(
rule, err := manager.AddFilterRule(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
@@ -1051,7 +1011,7 @@ func TestMSSClamping(t *testing.T) {
},
}
manager, err := Create(ifaceMock, false, flowLogger, 1280)
manager, err := Create(ifaceMock, nil, false, flowLogger, 1280)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -1243,7 +1203,7 @@ func TestShouldForward(t *testing.T) {
return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)}
}
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
manager, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -1358,7 +1318,7 @@ func TestShouldForward(t *testing.T) {
// Re-create manager to pick up the new address with IPv6
require.NoError(t, manager.Close(nil))
manager, err = Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
manager, err = Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
v6Cases := []struct {

View File

@@ -66,7 +66,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
}, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -126,7 +126,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
func BenchmarkDNATConcurrency(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
}, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -198,7 +198,7 @@ func BenchmarkDNATScaling(b *testing.B) {
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
}, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -310,7 +310,7 @@ func BenchmarkChecksumUpdate(b *testing.B) {
func BenchmarkDNATMemoryAllocations(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
}, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -483,7 +483,7 @@ func BenchmarkPortDNAT(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
}, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))

View File

@@ -15,7 +15,7 @@ import (
func TestPortDNATBasic(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
}, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -51,7 +51,7 @@ func TestPortDNATBasic(t *testing.T) {
func TestPortDNATMultipleRules(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
}, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))

View File

@@ -17,7 +17,7 @@ import (
func TestDNATTranslationCorrectness(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
}, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -106,7 +106,7 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
func TestDNATMappingManagement(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
}, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -154,7 +154,7 @@ func TestDNATMappingManagement(t *testing.T) {
func TestInboundPortDNAT(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
}, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -204,7 +204,7 @@ func TestInboundPortDNAT(t *testing.T) {
func TestInboundPortDNATNegative(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
}, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))

View File

@@ -0,0 +1,327 @@
//go:build uspbench
package uspfilter
import (
"fmt"
"io"
"math/rand"
"net"
"net/netip"
"runtime"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// BenchmarkPeerACLMatch measures the per-packet cost of the peer ACL
// matcher (peerACLsBlock) across realistic shapes: M distinct policy
// rules, each with K source peers in its set.
//
// With the reverse-source index, miss cost is independent of M and
// hit cost grows only with the number of rules touching a single
// srcIP, not with total rule count.
func BenchmarkPeerACLMatch(b *testing.B) {
shapes := []struct{ M, K int }{
{1, 100}, {10, 100}, {50, 100}, {100, 100}, {100, 1000},
}
families := []struct {
name string
v6 bool
}{{"v4", false}, {"v6", true}}
for _, fam := range families {
for _, s := range shapes {
b.Run(fmt.Sprintf("%s/M=%d/K=%d/hit", fam.name, s.M, s.K), func(b *testing.B) {
runPeerACLBench(b, s.M, s.K, true, fam.v6)
})
b.Run(fmt.Sprintf("%s/M=%d/K=%d/miss", fam.name, s.M, s.K), func(b *testing.B) {
runPeerACLBench(b, s.M, s.K, false, fam.v6)
})
}
}
}
func runPeerACLBench(b *testing.B, m, k int, hit, v6 bool) {
log.SetOutput(io.Discard) // keep manager logs out of the benchmark output
// Miss packets are dropped, so they always traverse the full peer
// ACL matcher (every bucket) without short-circuiting and without
// touching conntrack. Disable conntrack for the miss case so it
// measures the matcher, not established-state lookups. The hit case
// keeps conntrack on: an accepted packet reaches trackInbound, which
// needs the trackers conntrack creates.
if !hit {
b.Setenv("NB_DISABLE_CONNTRACK", "1")
}
bits := 32
genPkt := generatePacket
addrs := uniqueAddrs
if v6 {
bits = 128
genPkt = generatePacket6
addrs = uniqueAddrs6
}
// dstIP must be a local IP so filterInbound takes the local-traffic
// path (handleLocalTraffic → peerACLsBlock) we are measuring; an
// address the manager doesn't own would be treated as routed and
// short-circuit before the peer matcher.
dstIP := addrs(1, 2)[0]
mockAddr := wgaddr.Address{IP: dstIP, Network: netip.PrefixFrom(dstIP, bits)}
if v6 {
// The local-IP manager needs a valid v4 address too; expose the v6
// dst as the interface's IPv6 so IsLocalIP recognizes it.
mockAddr = wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
IPv6: dstIP,
IPv6Net: netip.PrefixFrom(dstIP, bits),
}
}
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { return mockAddr },
}, nil, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() { require.NoError(b, manager.Close(nil)) })
// Generate M policies × K source peers, all distinct.
all := addrs(m*k, 1)
for i := 0; i < m; i++ {
sources := make([]netip.Prefix, k)
for j, a := range all[i*k : (i+1)*k] {
sources[j] = netip.PrefixFrom(a, bits)
}
_, err := manager.AddFilterRule(
nil, sources, fw.Network{}, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{uint16(80 + i)}},
fw.ActionAccept)
require.NoError(b, err)
}
// Hit: cycle through real sources, picking the matching policy's port.
// Miss: a source from a disjoint range, port 80 (matches no policy).
var pktFn func(i int) []byte
if hit {
pktFn = func(i int) []byte {
policy := i % m
src := all[policy*k+(i%k)]
return genPkt(b, src.AsSlice(), dstIP.AsSlice(),
uint16(1024+i%60000), uint16(80+policy), layers.IPProtocolTCP)
}
} else {
miss := addrs(4096, 99)
pktFn = func(i int) []byte {
return genPkt(b, miss[i%len(miss)].AsSlice(), dstIP.AsSlice(),
uint16(1024+i%60000), 80, layers.IPProtocolTCP)
}
}
// Pre-build a pool to avoid allocations dominating the measurement.
pool := make([][]byte, 1024)
for i := range pool {
pool[i] = pktFn(i)
}
// Confirm the matcher is actually exercised: a hit packet must be
// allowed and a miss packet dropped. Without this the benchmark
// could silently time the routed early-return instead.
require.Equal(b, !hit, manager.filterInbound(pool[0], 0),
"benchmark must reach the peer ACL matcher")
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterInbound(pool[i%len(pool)], 0)
}
}
// BenchmarkPeerACLIndexMemory reports the resident memory cost of
// the source-keyed index across realistic deployment shapes. Two
// dimensions matter: (M, K), the number of policies × peers-per-policy,
// and overlap, the fraction of peers shared between policies.
//
// The output uses ReportMetric("bytes/rule") so the cost can be
// compared across shapes directly. Total bytes = bytes/rule * M.
func BenchmarkPeerACLIndexMemory(b *testing.B) {
cases := []struct {
name string
M, K int
overlapFrac float64 // 0 = disjoint per-policy sources, 1 = all share the same pool
}{
{"M=10/K=100/disjoint", 10, 100, 0},
{"M=100/K=100/disjoint", 100, 100, 0},
{"M=100/K=1000/disjoint", 100, 1000, 0},
{"M=100/K=1000/overlap=0.5", 100, 1000, 0.5},
{"M=100/K=1000/overlap=1.0", 100, 1000, 1.0},
{"M=1000/K=100/overlap=1.0", 1000, 100, 1.0},
}
for _, c := range cases {
b.Run(c.name, func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
mgr, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
populateIndexedRules(b, mgr, c.M, c.K, c.overlapFrac)
runtime.GC()
var ms runtime.MemStats
runtime.ReadMemStats(&ms)
before := ms.HeapAlloc
// Drop the manager's external roots so we can isolate
// the index cost. We hold the manager itself live; the
// index is what we measure on the second pass.
mgr.incomingAcceptIndex.reset()
mgr.incomingDenyIndex.reset()
mgr.incomingAcceptRules = mgr.incomingAcceptRules[:0]
mgr.incomingDenyRules = mgr.incomingDenyRules[:0]
runtime.GC()
runtime.ReadMemStats(&ms)
after := ms.HeapAlloc
delta := int64(before) - int64(after)
if delta < 0 {
delta = 0
}
b.ReportMetric(float64(delta)/float64(c.M), "bytes/rule")
b.ReportMetric(float64(delta), "bytes/total")
require.NoError(b, mgr.Close(nil))
}
})
}
}
func populateIndexedRules(b *testing.B, mgr *Manager, m, k int, overlapFrac float64) {
b.Helper()
pool := uniqueAddrs(k+m*k, 1) // big enough universe
sharedLen := int(float64(k) * overlapFrac)
if sharedLen > k {
sharedLen = k
}
shared := pool[:sharedLen]
uniquePool := pool[sharedLen:]
for i := 0; i < m; i++ {
sources := make([]netip.Prefix, 0, k)
for _, a := range shared {
sources = append(sources, netip.PrefixFrom(a, 32))
}
// each policy gets (k-sharedLen) addresses unique to it from the unique pool
unique := uniquePool[i*(k-sharedLen) : (i+1)*(k-sharedLen)]
for _, a := range unique {
sources = append(sources, netip.PrefixFrom(a, 32))
}
_, err := mgr.AddFilterRule(
nil, sources, fw.Network{}, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{uint16(80 + i)}},
fw.ActionAccept)
require.NoError(b, err)
}
}
// uniqueAddrs returns n distinct addrs. Seeds 1, 2 are used for
// policy sources / dst; seed 99 puts misses in 10/8.
func uniqueAddrs(n int, seed int64) []netip.Addr {
out := make([]netip.Addr, 0, n)
seen := make(map[netip.Addr]struct{}, n)
r := rand.New(rand.NewSource(seed))
miss := seed == 99
for len(out) < n {
var b [4]byte
if miss {
b[0] = 10
b[1] = byte(r.Intn(256))
} else {
b[0] = 100
b[1] = byte(64 + r.Intn(63))
}
b[2] = byte(r.Intn(256))
b[3] = byte(1 + r.Intn(254))
a := netip.AddrFrom4(b)
if _, ok := seen[a]; ok {
continue
}
seen[a] = struct{}{}
out = append(out, a)
}
return out
}
// uniqueAddrs6 mirrors uniqueAddrs for IPv6: sources come from the ULA
// range fd00::/8, the miss set (seed 99) from 2001:db8::/32 so it is
// disjoint from any source.
func uniqueAddrs6(n int, seed int64) []netip.Addr {
out := make([]netip.Addr, 0, n)
seen := make(map[netip.Addr]struct{}, n)
r := rand.New(rand.NewSource(seed))
miss := seed == 99
for len(out) < n {
var b [16]byte
if miss {
b[0], b[1], b[2], b[3] = 0x20, 0x01, 0x0d, 0xb8
} else {
b[0] = 0xfd
}
for x := 8; x < 16; x++ {
b[x] = byte(r.Intn(256))
}
a := netip.AddrFrom16(b)
if _, ok := seen[a]; ok {
continue
}
seen[a] = struct{}{}
out = append(out, a)
}
return out
}
// generatePacket6 builds an IPv6 TCP/UDP packet, mirroring
// generatePacket for the v4 case.
func generatePacket6(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16, protocol layers.IPProtocol) []byte {
b.Helper()
ipv6 := &layers.IPv6{
Version: 6,
HopLimit: 64,
NextHeader: protocol,
SrcIP: srcIP,
DstIP: dstIP,
}
var transportLayer gopacket.SerializableLayer
switch protocol {
case layers.IPProtocolTCP:
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
}
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv6))
transportLayer = tcp
case layers.IPProtocolUDP:
udp := &layers.UDP{
SrcPort: layers.UDPPort(srcPort),
DstPort: layers.UDPPort(dstPort),
}
require.NoError(b, udp.SetNetworkLayerForChecksum(ipv6))
transportLayer = udp
}
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv6, transportLayer, gopacket.Payload("test")))
return buf.Bytes()
}

View File

@@ -0,0 +1,149 @@
package uspfilter
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
nbiface "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
func newTestManager(t *testing.T) *Manager {
t.Helper()
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err, "create manager")
return m
}
// TestAddPeerFiltering_DeduplicatesIdenticalRules verifies that adding
// the same peer rule twice does not create two backing rules. The acl
// manager keys its own cache, but the firewall backend must be
// idempotent on its own so a double-apply cannot leak rules, matching
// the route path and the kernel backends.
func TestAddPeerFiltering_DeduplicatesIdenticalRules(t *testing.T) {
m := newTestManager(t)
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
first, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err, "first add")
second, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err, "second add")
assert.Equal(t, first.ID(), second.ID(), "duplicate add should return the same rule id")
assert.Len(t, m.incomingDenyRules, 1, "duplicate add must not create a second backing rule")
}
// TestDeletePeerFiltering_NoRefcountSingleDeleteRemoves locks the
// backend's no-refcount contract: a content key installed twice is
// still one rule, and the first DeleteFilterRule removes it. The
// backend does not refcount, so balance is the caller's job (it keys
// its tracking by the returned id and deletes once per key). If this
// ever silently grew a refcount, the acl manager's delete accounting
// would diverge from the kernel.
func TestDeletePeerFiltering_NoRefcountSingleDeleteRemoves(t *testing.T) {
m := newTestManager(t)
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
first, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err, "first add")
second, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err, "second add")
require.Equal(t, first.ID(), second.ID(), "dedup to one rule")
require.Len(t, m.incomingDenyRules, 1, "still one backing rule after duplicate add")
require.NoError(t, m.DeleteFilterRule(first), "delete once")
assert.Empty(t, m.incomingDenyRules, "single delete removes the backing rule (no refcount)")
assert.NotContains(t, m.peerRulesMap, first.ID(), "dedup map entry cleared")
}
// TestAddPeerFiltering_DeterministicID verifies the peer rule id is a
// content hash, not a random UUID: identical inputs produce the same id
// across independent managers. A random id breaks caller-side dedup.
func TestAddPeerFiltering_DeterministicID(t *testing.T) {
ip := net.ParseIP("10.0.0.5")
proto := fw.ProtocolUDP
port := &fw.Port{Values: []uint16{53}}
action := fw.ActionAccept
m1 := newTestManager(t)
r1, err := m1.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
m2 := newTestManager(t)
r2, err := m2.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
assert.Equal(t, r1.ID(), r2.ID(), "same inputs must produce the same rule id")
}
// TestAddPeerFiltering_DistinctRulesNotDeduped verifies that rules
// differing only by port are kept separate.
func TestAddPeerFiltering_DistinctRulesNotDeduped(t *testing.T) {
m := newTestManager(t)
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
action := fw.ActionAccept
r80, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, &fw.Port{Values: []uint16{80}}, action)
require.NoError(t, err)
r443, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, &fw.Port{Values: []uint16{443}}, action)
require.NoError(t, err)
assert.NotEqual(t, r80.ID(), r443.ID(), "different ports must produce different rule ids")
assert.Len(t, m.incomingAcceptRules, 2, "distinct rules must both be stored")
}
// TestAddPeerFiltering_SourceVsDestPortNotDeduped verifies that a rule
// matching on source port and one matching on destination port for the
// same selector do not collide: the port lands in a different slot, so
// the content key must differ.
func TestAddPeerFiltering_SourceVsDestPortNotDeduped(t *testing.T) {
m := newTestManager(t)
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
dPortRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
sPortRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, port, nil, action)
require.NoError(t, err)
assert.NotEqual(t, dPortRule.ID(), sPortRule.ID(), "source-port and dest-port matches must produce different rule ids")
}
// TestAddFilterRule_EmptySourcesRejected verifies that an empty source
// list is rejected rather than treated as "match any". "Match any" must
// be an explicit /0, so a zeroed list can never silently widen a rule to
// every source.
func TestAddFilterRule_EmptySourcesRejected(t *testing.T) {
m := newTestManager(t)
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
_, err := m.AddFilterRule(nil, nil, fw.Network{}, proto, nil, port, fw.ActionAccept)
require.ErrorIs(t, err, fw.ErrNoSources, "empty sources must be rejected")
assert.Empty(t, m.incomingAcceptRules, "no rule should be stored for empty sources")
}

View File

@@ -0,0 +1,105 @@
package uspfilter
import (
"net"
"net/netip"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
nbiface "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func newV6TestManager(t *testing.T, localV6 string) *Manager {
t.Helper()
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"),
IPv6: netip.MustParseAddr(localV6),
IPv6Net: netip.MustParsePrefix("fd00::/64"),
}
},
}
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err, "create manager")
t.Cleanup(func() { require.NoError(t, m.Close(nil)) })
return m
}
func v6UDPPacket(t *testing.T, src, dst string, dstPort uint16) []byte {
t.Helper()
ip6 := &layers.IPv6{
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
SrcIP: net.ParseIP(src),
DstIP: net.ParseIP(dst),
}
udp := &layers.UDP{SrcPort: 51334, DstPort: layers.UDPPort(dstPort)}
require.NoError(t, udp.SetNetworkLayerForChecksum(ip6))
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
require.NoError(t, gopacket.SerializeLayers(buf, opts, ip6, udp, gopacket.Payload("test")))
return buf.Bytes()
}
// TestPeerACL_IPv6HostRule verifies the source index resolves /128 v6
// rules: a matching v6 source is accepted, a non-matching one is
// denied by the default. This is the end-to-end proof that the index
// is not v4-only.
func TestPeerACL_IPv6HostRule(t *testing.T) {
m := newV6TestManager(t, "fd00::100")
src := net.ParseIP("fd00::1")
_, err := m.AddFilterRule(nil, pfx(src), fw.Network{}, fw.ProtocolUDP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionAccept)
require.NoError(t, err, "add v6 accept rule")
require.False(t, m.filterInbound(v6UDPPacket(t, "fd00::1", "fd00::100", 53), 0),
"v6 packet from the allowed /128 source must be accepted")
require.True(t, m.filterInbound(v6UDPPacket(t, "fd00::2", "fd00::100", 53), 0),
"v6 packet from an unlisted source must be denied by default")
}
// TestPeerACL_IPv6IndexBuckets verifies that v6 sources land in the
// right index bucket: a /128 in bySource keyed by its address, and
// coarser prefixes (including ::/0) in the nonHost slice.
func TestPeerACL_IPv6IndexBuckets(t *testing.T) {
m := newV6TestManager(t, "fd00::100")
port := &fw.Port{Values: []uint16{53}}
host := netip.MustParseAddr("fd00::1")
_, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(host, 128)}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
require.NoError(t, err)
assert.Contains(t, m.incomingAcceptIndex.bySource, host, "/128 v6 source must be indexed by address")
_, err = m.AddFilterRule(nil, []netip.Prefix{netip.MustParsePrefix("fd00:dead::/64")}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
require.NoError(t, err)
require.Len(t, m.incomingAcceptIndex.nonHost, 1, "coarser v6 prefix must land in nonHost")
_, err = m.AddFilterRule(nil, []netip.Prefix{netip.MustParsePrefix("::/0")}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
require.NoError(t, err)
require.Len(t, m.incomingAcceptIndex.nonHost, 2, "::/0 source must also land in nonHost")
}
// TestPeerACL_IPv4MappedSourceNormalized verifies a v4-mapped v6
// source prefix is normalized to v4 so a plain v4 packet matches it.
func TestPeerACL_IPv4MappedSourceNormalized(t *testing.T) {
m := newTestManager(t)
mapped := netip.MustParseAddr("::ffff:192.168.1.1")
_, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(mapped, mapped.BitLen())}, fw.Network{}, fw.ProtocolUDP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionAccept)
require.NoError(t, err)
v4 := netip.MustParseAddr("192.168.1.1")
assert.Contains(t, m.incomingAcceptIndex.bySource, v4, "v4-mapped v6 source must be indexed as plain v4")
}

View File

@@ -0,0 +1,139 @@
package uspfilter
import (
"net/netip"
"slices"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
// peerRuleIndex is the source-side dispatcher consulted on the packet
// hot path. It splits rules into two buckets by the shape of their
// source list:
//
// - bySource: every source is a host prefix (/32 for v4, /128 for
// v6). Keyed by the concrete source address, so a hit guarantees
// the source filter passes and the matcher goes straight to
// proto/port checks. This is the common case for peer ACLs.
// - nonHost: any source list with a prefix coarser than a host,
// including a /0 "match any". Walked linearly with a per-rule
// Contains() check. Expected small or empty for typical peer ACLs.
//
// Maintained incrementally by add/remove, never rebuilt.
type peerRuleIndex struct {
bySource map[netip.Addr][]*PeerRule
nonHost []*PeerRule
}
func (i *peerRuleIndex) add(r *PeerRule) {
if hasNonHostSource(r) {
i.nonHost = append(i.nonHost, r)
return
}
if i.bySource == nil {
i.bySource = make(map[netip.Addr][]*PeerRule)
}
for a := range r.sourceAddrs {
i.bySource[a] = append(i.bySource[a], r)
}
}
func (i *peerRuleIndex) remove(r *PeerRule) {
if hasNonHostSource(r) {
i.nonHost = slices.DeleteFunc(i.nonHost, eqRule(r))
return
}
if i.bySource == nil {
return
}
for a := range r.sourceAddrs {
entries := slices.DeleteFunc(i.bySource[a], eqRule(r))
if len(entries) == 0 {
delete(i.bySource, a)
} else {
i.bySource[a] = entries
}
}
}
func (i *peerRuleIndex) reset() {
i.bySource = nil
i.nonHost = i.nonHost[:0]
}
// match returns the first rule matching src and the decoded packet.
// Host rules are found by direct map lookup; nonHost rules need a
// per-rule source Contains() check, except match-any (/0) rules which
// apply to every source regardless of family (a v4 /0 also matches v6).
// Within either bucket the matcher runs the proto/port filter.
func (i *peerRuleIndex) match(src netip.Addr, d *decoder) ([]byte, bool, bool) {
payloadLayer := d.decoded[1]
for _, rule := range i.bySource[src] {
if id, drop, ok := matchProto(rule, d, payloadLayer); ok {
return id, drop, true
}
}
for _, rule := range i.nonHost {
if !rule.matchAny && !prefixesContain(rule.sources, src) {
continue
}
if id, drop, ok := matchProto(rule, d, payloadLayer); ok {
return id, drop, true
}
}
return nil, false, false
}
func eqRule(target *PeerRule) func(*PeerRule) bool {
return func(p *PeerRule) bool { return p == target }
}
// hasNonHostSource reports whether the rule has any source prefix
// that is not a single host address. Called only at add/remove time,
// not on the packet path.
func hasNonHostSource(r *PeerRule) bool {
for _, p := range r.sources {
if p.Bits() != p.Addr().BitLen() {
return true
}
}
return false
}
// matchProto applies the proto/port half of a rule against the
// decoded packet. Source matching is the caller's responsibility.
func matchProto(rule *PeerRule, d *decoder, payloadLayer gopacket.LayerType) ([]byte, bool, bool) {
drop := rule.action == firewall.ActionDrop
if rule.protoLayer == layerTypeAll {
return rule.mgmtId, drop, true
}
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
return nil, false, false
}
switch payloadLayer {
case layers.LayerTypeTCP:
if portsMatch(rule.srcPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dstPort, uint16(d.tcp.DstPort)) {
return rule.mgmtId, drop, true
}
case layers.LayerTypeUDP:
if portsMatch(rule.srcPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dstPort, uint16(d.udp.DstPort)) {
return rule.mgmtId, drop, true
}
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return rule.mgmtId, drop, true
}
return nil, false, false
}
func prefixesContain(sources []netip.Prefix, src netip.Addr) bool {
for _, p := range sources {
if p.Contains(src) {
return true
}
}
return false
}

View File

@@ -10,24 +10,49 @@ import (
// PeerRule to handle management of rules
type PeerRule struct {
id string
mgmtId []byte
ip netip.Addr
ipLayer gopacket.LayerType
matchByIP bool
id firewall.RuleID
mgmtId []byte
// sources is the canonical list of source prefixes this rule
// matches against.
sources []netip.Prefix
// sourceAddrs is a fast-path membership set for host-prefix
// sources (/32 v4, /128 v6). Populated alongside sources;
// consulted before falling back to prefix scan.
sourceAddrs map[netip.Addr]struct{}
// matchAny is true when sources covers everything (0.0.0.0/0,
// ::/0). In that case neither sourceAddrs nor sources need to be
// consulted.
matchAny bool
protoLayer gopacket.LayerType
sPort *firewall.Port
dPort *firewall.Port
drop bool
srcPort *firewall.Port
dstPort *firewall.Port
action firewall.Action
}
// matchesSource reports whether the given source address is covered
// by this rule's source list.
func (r *PeerRule) matchesSource(src netip.Addr) bool {
if r.matchAny {
return true
}
if _, ok := r.sourceAddrs[src]; ok {
return true
}
for _, p := range r.sources {
if p.Contains(src) {
return true
}
}
return false
}
// ID returns the rule id
func (r *PeerRule) ID() string {
func (r *PeerRule) ID() firewall.RuleID {
return r.id
}
type RouteRule struct {
id string
id firewall.RuleID
mgmtId []byte
sources []netip.Prefix
dstSet firewall.Set
@@ -39,6 +64,6 @@ type RouteRule struct {
}
// ID returns the rule id
func (r *RouteRule) ID() string {
func (r *RouteRule) ID() firewall.RuleID {
return r.id
}

View File

@@ -0,0 +1,50 @@
package uspfilter
import (
"net"
"net/netip"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
// countRulesForAddr reports how many rules in the given slice match
// the supplied source address.
func countRulesForAddr(rules peerRules, src netip.Addr) int {
n := 0
for _, r := range rules {
if r.matchesSource(src) {
n++
}
}
return n
}
// findRuleByID returns true if the rules slice contains a rule with
// the given id whose source set covers src.
func findRuleByID(rules peerRules, src netip.Addr, id firewall.RuleID) bool {
for _, r := range rules {
if r.id == id && r.matchesSource(src) {
return true
}
}
return false
}
// pfx converts a single net.IP into the []netip.Prefix form
// AddFilterRule expects. A nil or unspecified address becomes a /0
// ("match any") prefix in the matching family; any other address
// becomes its /32 (or /128) host prefix.
func pfx(ip net.IP) []netip.Prefix {
if ip == nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
if ip.IsUnspecified() {
if ip.To4() != nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
}
a, _ := netip.AddrFromSlice(ip)
a = a.Unmap()
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
}

View File

@@ -285,6 +285,14 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
trace.SourceIP = srcIP
trace.DestinationIP = dstIP
// A fragment or otherwise truncated packet has no transport layer.
// The inbound datapath drops these via isValidPacket; the tracer must
// guard explicitly since every downstream stage reads d.decoded[1].
if len(d.decoded) < 2 {
trace.AddResult(StageReceived, "Packet has no transport layer (fragment or unsupported protocol)", false)
return trace
}
// Determine protocol and ports
switch d.decoded[1] {
case layers.LayerTypeTCP:

View File

@@ -45,7 +45,7 @@ func TestTracePacket(t *testing.T) {
},
}
m, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
m, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
if !statefulMode {
@@ -97,7 +97,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -121,7 +121,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -150,7 +150,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -178,7 +178,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -205,7 +205,7 @@ func TestTracePacket(t *testing.T) {
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
_, err := m.AddFilterRule(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -231,7 +231,7 @@ func TestTracePacket(t *testing.T) {
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
_, err := m.AddFilterRule(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -332,7 +332,7 @@ func TestTracePacket(t *testing.T) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -355,7 +355,7 @@ func TestTracePacket(t *testing.T) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -379,7 +379,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolUDP
port := &fw.Port{Values: []uint16{53}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -423,7 +423,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {

View File

@@ -260,23 +260,15 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
; Drop Run, App Paths and Uninstall entries left in the 32-bit registry view
; or HKCU by legacy installers.
DetailPrint "Cleaning legacy 32-bit / HKCU entries..."
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
SetRegView 32
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DeleteRegKey HKLM "${REG_APP_PATH}"
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
DeleteRegKey HKLM "${UNINSTALL_PATH}"
SetRegView 64
; Create autostart registry entry based on checkbox
DetailPrint "Autostart enabled: $AutostartEnabled"
${If} $AutostartEnabled == "1"
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
${Else}
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DetailPrint "Autostart not enabled by user"
${EndIf}
@@ -307,16 +299,11 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
DetailPrint "Terminating Netbird UI process..."
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart entries from every view a previous installer may have used.
; Remove autostart registry entry
DetailPrint "Removing autostart registry entry if exists..."
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
SetRegView 32
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DeleteRegKey HKLM "${REG_APP_PATH}"
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
DeleteRegKey HKLM "${UNINSTALL_PATH}"
SetRegView 64
; Handle data deletion based on checkbox
DetailPrint "Checking if user requested data deletion..."

View File

@@ -0,0 +1,190 @@
package acl
import (
"net/netip"
"sync"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/firewall"
fwmgr "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// TestNetworkZeroPrefixIsRoute guards the route-vs-peer dispatch
// invariant: the backends classify a rule as a peer rule purely by the
// absence of a destination (neither prefix nor set). A default route
// (0.0.0.0/0 or ::/0) is a valid prefix and must therefore classify as
// a route, not collapse into the peer path.
func TestNetworkZeroPrefixIsRoute(t *testing.T) {
for _, p := range []string{"0.0.0.0/0", "::/0", "10.0.0.0/8"} {
n := fwmgr.Network{Prefix: netip.MustParsePrefix(p)}
assert.True(t, n.IsPrefix(), "%s must report IsPrefix", p)
assert.True(t, n.IsPrefix() || n.IsSet(), "%s must classify as a route", p)
}
// A zero-value Network is the only peer-rule shape.
var empty fwmgr.Network
assert.False(t, empty.IsPrefix(), "zero Network must not be a prefix")
assert.False(t, empty.IsSet(), "zero Network must not be a set")
}
// TestDetermineDestinationAlwaysRoute verifies determineDestination
// never yields an empty Network for a valid route rule: every branch
// (static prefix, default route, dynamic with/without domains, with and
// without a local resolver) produces a destination that classifies as a
// route. If this regresses, a route rule would be dispatched down the
// peer path, which matches on source only.
func TestDetermineDestinationAlwaysRoute(t *testing.T) {
v4 := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}
v6 := []netip.Prefix{netip.MustParsePrefix("2001:db8::/48")}
cases := []struct {
name string
rule *mgmProto.RouteFirewallRule
resolver bool
sources []netip.Prefix
}{
{"static prefix", &mgmProto.RouteFirewallRule{Destination: "192.168.0.0/16"}, false, v4},
{"static default route", &mgmProto.RouteFirewallRule{Destination: "0.0.0.0/0"}, false, v4},
{"dynamic with domains + resolver", &mgmProto.RouteFirewallRule{IsDynamic: true, Domains: []string{"example.com"}}, true, v4},
{"dynamic no domains + resolver (v4)", &mgmProto.RouteFirewallRule{IsDynamic: true}, true, v4},
{"dynamic no domains + resolver (v6)", &mgmProto.RouteFirewallRule{IsDynamic: true}, true, v6},
{"dynamic + no local resolver (v4)", &mgmProto.RouteFirewallRule{IsDynamic: true}, false, v4},
{"dynamic + no local resolver (v6)", &mgmProto.RouteFirewallRule{IsDynamic: true}, false, v6},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
dest, err := determineDestination(tc.rule, tc.resolver, tc.sources)
require.NoError(t, err)
assert.True(t, dest.IsPrefix() || dest.IsSet(),
"destination must classify as a route, got empty Network")
})
}
}
// countingFirewall wraps a real firewall.Manager and counts filter-rule
// add/delete calls so a test can assert how many backing rules the acl
// manager actually creates and tears down.
type countingFirewall struct {
fwmgr.Manager
mu sync.Mutex
addCalls int
dels int
ruleIDs map[fwmgr.RuleID]struct{}
}
// distinctRules returns the number of distinct backing rules the
// backend produced. Because the backend dedups identical content,
// repeated AddFilterRule calls for the same rule resolve to one id.
func (f *countingFirewall) distinctRules() int {
f.mu.Lock()
defer f.mu.Unlock()
return len(f.ruleIDs)
}
func (f *countingFirewall) AddFilterRule(id []byte, sources []netip.Prefix, destination fwmgr.Network, proto fwmgr.Protocol, sPort, dPort *fwmgr.Port, action fwmgr.Action) (fwmgr.Rule, error) {
rule, err := f.Manager.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
if err == nil {
f.mu.Lock()
f.addCalls++
if f.ruleIDs == nil {
f.ruleIDs = make(map[fwmgr.RuleID]struct{})
}
if rule != nil {
f.ruleIDs[rule.ID()] = struct{}{}
}
f.mu.Unlock()
}
return rule, err
}
func (f *countingFirewall) DeleteFilterRule(r fwmgr.Rule) error {
err := f.Manager.DeleteFilterRule(r)
if err == nil {
f.mu.Lock()
f.dels++
delete(f.ruleIDs, r.ID())
f.mu.Unlock()
}
return err
}
func newCountingACL(t *testing.T) (*DefaultManager, *countingFirewall, func()) {
t.Helper()
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{IP: network.Addr(), Network: network}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
realFW, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
fw := &countingFirewall{Manager: realFW}
cleanup := func() {
require.NoError(t, realFW.Close(nil))
ctrl.Finish()
}
return NewDefaultManager(fw), fw, cleanup
}
// TestDuplicateContentPoliciesShareOneRule verifies the dedup contract
// the backends rely on: two policies that authorize an identical flow
// (same selector and sources) collapse to a single backing firewall
// rule, and that rule survives until BOTH policies are gone. This is
// why the backend can dedup on add without refcounting on delete: the
// acl manager's pair key matches the backend's content key, so add and
// delete stay balanced per content key across full-state reapplies.
func TestDuplicateContentPoliciesShareOneRule(t *testing.T) {
acl, fw, cleanup := newCountingACL(t)
defer cleanup()
ruleA := &mgmProto.FirewallRule{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
}
ruleB := &mgmProto.FirewallRule{
PolicyID: []byte("policy-B"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
}
// Both policies present: identical content collapses to one rule.
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: []*mgmProto.FirewallRule{ruleA, ruleB}, FirewallRulesIsEmpty: false}, false)
assert.Equal(t, 1, fw.distinctRules(), "identical-content policies must produce one backing rule")
assert.Equal(t, 1, len(acl.peerRulesPairs), "one content key, one pair")
// Drop policy A only: the shared rule is still authorized by B, so
// nothing is deleted.
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: []*mgmProto.FirewallRule{ruleB}, FirewallRulesIsEmpty: false}, false)
assert.Equal(t, 1, fw.distinctRules(), "no new backing rule on reapply")
assert.Equal(t, 0, fw.dels, "rule must survive while any policy still authorizes it")
assert.Equal(t, 1, len(acl.peerRulesPairs))
// Drop policy B too: now the content key has no authorizer and the
// single backing rule is removed exactly once.
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: nil, FirewallRulesIsEmpty: true}, false)
assert.Equal(t, 1, fw.dels, "rule removed once when last policy is gone")
assert.Equal(t, 0, len(acl.peerRulesPairs))
}

View File

@@ -0,0 +1,318 @@
package acl
import (
"errors"
"net/netip"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/firewall"
fwmgr "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
)
// TestGroupPeerRulesPolicyIDSeparates verifies that two FirewallRules
// with identical selectors but different PolicyIDs do NOT get merged
// into one group, so each policy's sources merge under its own
// attribution id. (Identical-content groups may still dedup to one
// backing rule at the backend; see TestDuplicateContentPoliciesShareOneRule.)
func TestGroupPeerRulesPolicyIDSeparates(t *testing.T) {
rules := []*mgmProto.FirewallRule{
{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PolicyID: []byte("policy-B"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
}
groups, denyErr, err := groupPeerRules(rules)
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 2, "rules with different PolicyIDs must produce separate groups")
}
// TestGroupPeerRulesFamilySeparates verifies that v4 and v6 rules
// belonging to the same policy don't merge.
func TestGroupPeerRulesFamilySeparates(t *testing.T) {
rules := []*mgmProto.FirewallRule{
{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PolicyID: []byte("policy-A"),
PeerIP: "2001:db8::1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
}
groups, denyErr, err := groupPeerRules(rules)
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 2, "rules of different families must produce separate groups")
var sawV4, sawV6 bool
for _, g := range groups {
require.Len(t, g.sources, 1)
if g.sources[0].Addr().Is4() {
sawV4 = true
}
if g.sources[0].Addr().Is6() {
sawV6 = true
}
}
assert.True(t, sawV4 && sawV6)
}
// TestGroupPeerRulesSplitsMixedFamilySingleRule verifies that a single
// FirewallRule carrying both v4 and v6 source prefixes is split into one
// group per family. Each backend keys a rule to a single family, so a
// group whose sources span families would mismatch the other family's
// sources. mgmt normally emits one rule per family; this guards against
// a mixed-family rule slipping through.
func TestGroupPeerRulesSplitsMixedFamilySingleRule(t *testing.T) {
srcs := [][]byte{
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.1")),
netiputil.EncodeAddr(netip.MustParseAddr("2001:db8::1")),
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.2")),
netiputil.EncodeAddr(netip.MustParseAddr("2001:db8::2")),
}
rules := []*mgmProto.FirewallRule{
{
PolicyID: []byte("policy-A"),
SourcePrefixes: srcs,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
}
groups, denyErr, err := groupPeerRules(rules)
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 2, "mixed-family sources in one rule must split into two groups")
for _, g := range groups {
require.Len(t, g.sources, 2)
v6 := prefixIsV6(g.sources[0])
for _, s := range g.sources {
assert.Equal(t, v6, prefixIsV6(s), "every source in a group must share one family")
}
}
}
// TestGroupPeerRulesMergesSameSelector verifies that rules sharing
// every distinguishing field (policy, family, direction, action,
// proto, port) collapse into a single multi-source group.
func TestGroupPeerRulesMergesSameSelector(t *testing.T) {
mk := func(peerIP string) *mgmProto.FirewallRule {
return &mgmProto.FirewallRule{
PolicyID: []byte("policy-A"),
PeerIP: peerIP,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
}
}
rules := []*mgmProto.FirewallRule{mk("10.0.0.1"), mk("10.0.0.2"), mk("10.0.0.3")}
groups, denyErr, err := groupPeerRules(rules)
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 1)
require.Len(t, groups[0].sources, 3)
}
// TestGroupPeerRulesPortSeparates verifies that PortInfo is part of the
// selector key: rules differing only in port must not merge, and a
// single port must not merge with a range. A regression dropping the
// port from the key would collapse rules for different ports into one.
func TestGroupPeerRulesPortSeparates(t *testing.T) {
mkPort := func(peerIP string, port uint32) *mgmProto.FirewallRule {
return &mgmProto.FirewallRule{
PolicyID: []byte("policy-A"),
PeerIP: peerIP,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{PortSelection: &mgmProto.PortInfo_Port{Port: port}},
}
}
groups, denyErr, err := groupPeerRules([]*mgmProto.FirewallRule{
mkPort("10.0.0.1", 80), mkPort("10.0.0.2", 80), mkPort("10.0.0.3", 443),
})
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 2, "rules on different ports must not merge")
rangeRule := &mgmProto.FirewallRule{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{PortSelection: &mgmProto.PortInfo_Range_{Range: &mgmProto.PortInfo_Range{Start: 80, End: 90}}},
}
groups, denyErr, err = groupPeerRules([]*mgmProto.FirewallRule{mkPort("10.0.0.1", 80), rangeRule})
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 2, "a single port and a range must not merge")
}
// TestGroupPeerRulesUsesSourcePrefixesWhenPresent verifies that the
// new sourcePrefixes wire field is consumed and produces a
// multi-source group in one shot (no client-side merging needed).
func TestGroupPeerRulesUsesSourcePrefixesWhenPresent(t *testing.T) {
srcs := [][]byte{
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.1")),
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.2")),
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.3")),
}
rules := []*mgmProto.FirewallRule{
{
PolicyID: []byte("policy-A"),
SourcePrefixes: srcs,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
}
groups, denyErr, err := groupPeerRules(rules)
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 1)
require.Len(t, groups[0].sources, 3)
}
// TestGroupPeerRulesActionSeparates verifies the obvious: accept
// and drop rules with the same selector don't merge.
func TestGroupPeerRulesActionSeparates(t *testing.T) {
rules := []*mgmProto.FirewallRule{
{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
}
groups, denyErr, err := groupPeerRules(rules)
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 2)
}
// failingDeleteFirewall wraps a real firewall.Manager and forces the
// next N DeleteFilterRule calls to fail. Used to verify that the acl
// manager retains rules whose deletion was rejected by the backend,
// so they get retried on the next ApplyFiltering pass instead of
// becoming orphans.
type failingDeleteFirewall struct {
fwmgr.Manager
failCount int
}
func (f *failingDeleteFirewall) DeleteFilterRule(r fwmgr.Rule) error {
if f.failCount > 0 {
f.failCount--
return errors.New("simulated delete failure")
}
return f.Manager.DeleteFilterRule(r)
}
// TestApplyFilteringRetainsRulesOnDeleteFailure verifies that a
// transient DeleteFilterRule error doesn't make the acl manager
// forget about a rule. The rule must remain in peerRulesPairs so the
// next ApplyFiltering pass attempts the delete again.
func TestApplyFilteringRetainsRulesOnDeleteFailure(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{IP: network.Addr(), Network: network}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
realFW, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() { require.NoError(t, realFW.Close(nil)) }()
fw := &failingDeleteFirewall{Manager: realFW}
acl := NewDefaultManager(fw)
// First pass: install a rule.
netmap1 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(netmap1, false)
require.Equal(t, 1, len(acl.peerRulesPairs), "rule should be installed")
// Second pass: remove the rule from the map. The backend will
// fail the delete; the acl manager must retain the rule.
fw.failCount = 1
netmap2 := &mgmProto.NetworkMap{FirewallRules: nil, FirewallRulesIsEmpty: true}
acl.ApplyFiltering(netmap2, false)
require.Equal(t, 1, len(acl.peerRulesPairs),
"rule must be retained when DeleteFilterRule fails so it gets retried")
// Third pass: same map, backend no longer fails. The rule
// should now succeed in being removed.
acl.ApplyFiltering(netmap2, false)
require.Equal(t, 0, len(acl.peerRulesPairs), "retry should succeed")
}

View File

@@ -5,18 +5,18 @@ import (
"encoding/hex"
"fmt"
"net/netip"
"slices"
"strconv"
"github.com/netbirdio/netbird/client/firewall/manager"
)
type RuleID string
// RuleID aliases manager.RuleID so existing nbid.RuleID references
// keep working while the canonical type lives in the firewall package.
type RuleID = manager.RuleID
func (r RuleID) ID() string {
return string(r)
}
func GenerateRouteRuleKey(
// GenerateRuleID returns a deterministic content hash identifying a filter rule.
func GenerateRuleID(
sources []netip.Prefix,
destination manager.Network,
proto manager.Protocol,
@@ -24,6 +24,7 @@ func GenerateRouteRuleKey(
dPort *manager.Port,
action manager.Action,
) RuleID {
sources = slices.Clone(sources)
manager.SortPrefixes(sources)
h := sha256.New()

View File

@@ -1,8 +1,6 @@
package acl
import (
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"net/netip"
@@ -23,6 +21,10 @@ import (
var ErrSourceRangesEmpty = errors.New("sources range is empty")
// ErrNoRuleReturned is returned when the firewall backend reports success
// from AddFilterRule but yields no rule to track.
var ErrNoRuleReturned = errors.New("backend returned no rule")
// Manager is a ACL rules manager
type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
@@ -31,17 +33,46 @@ type Manager interface {
// DefaultManager uses firewall manager to handle
type DefaultManager struct {
firewall firewall.Manager
ipsetCounter int
peerRulesPairs map[id.RuleID][]firewall.Rule
routeRules map[id.RuleID]struct{}
routeRules map[id.RuleID]firewall.Rule
mutex sync.Mutex
}
// peerRuleGroup collapses a set of single-source FirewallRules sharing
// the same selector into one multi-source rule to push to the backend.
type peerRuleGroup struct {
direction mgmProto.RuleDirection
action mgmProto.RuleAction
protocol mgmProto.RuleProtocol
port *mgmProto.PortInfo
// legacyPort is used only when PortInfo is empty (old management).
legacyPort string
policyID []byte
sources []netip.Prefix
}
// peerRuleKey is the comparable selector that decides which single-source
// rules merge into one group. Rules with an equal key collapse into one
// multi-source backend rule. PortInfo is flattened into its scalar fields
// so the key compares by value; policyID keeps policies separate so two
// policies authorizing different peers don't merge under one attribution.
type peerRuleKey struct {
v6 bool
policyID string
direction mgmProto.RuleDirection
action mgmProto.RuleAction
protocol mgmProto.RuleProtocol
legacyPort string
port uint16
rangeStart uint16
rangeEnd uint16
}
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
return &DefaultManager{
firewall: fm,
peerRulesPairs: make(map[id.RuleID][]firewall.Rule),
routeRules: make(map[id.RuleID]struct{}),
routeRules: make(map[id.RuleID]firewall.Rule),
}
}
@@ -68,10 +99,12 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
time.Since(start), total)
}()
d.applyPeerACLs(networkMap)
if err := d.applyPeerACLs(networkMap); err != nil {
log.Errorf("apply peer ACLs: %v", err)
}
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err)
log.Errorf("apply route ACLs: %v", err)
}
if err := d.firewall.Flush(); err != nil {
@@ -79,7 +112,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
}
}
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) error {
rules := networkMap.FirewallRules
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
@@ -102,59 +135,158 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
)
}
newRulePairs := make(map[id.RuleID][]firewall.Rule)
ipsetByRuleSelectors := make(map[string]string)
// Group incoming single-source rules from management by their
// (direction, action, proto, port) selector and merge sources.
// One call to the firewall backend per merged rule.
// A deny we cannot decode would leave its traffic unblocked, so skip
// the whole pass and keep existing rules until the next sync.
groups, denyErr, err := groupPeerRules(rules)
if denyErr != nil {
return fmt.Errorf("decode deny rule sources: %w", denyErr)
}
// TODO: deny rules should be fatal: if a deny rule fails to apply, we must
// roll back all allow rules to avoid a fail-open where allowed traffic bypasses
// the missing deny. Currently we accumulate errors and continue.
newRulePairs := make(map[id.RuleID][]firewall.Rule)
var merr *multierror.Error
for _, r := range rules {
// if this rule is member of rule selection with more than DefaultIPsCountForSet
// it's IP address can be used in the ipset for firewall manager which supports it
selector := d.getRuleGroupingSelector(r)
ipsetName, ok := ipsetByRuleSelectors[selector]
if !ok {
d.ipsetCounter++
ipsetName = fmt.Sprintf("nb%07d", d.ipsetCounter)
ipsetByRuleSelectors[selector] = ipsetName
}
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
if err != nil {
merr = multierror.Append(merr, err)
}
// Apply denies first. A deny that fails to install is a security
// failure (fail-open), so if any deny errors we roll back the
// denies we already installed in this pass and bail out without
// installing any accept. Pre-existing rules stay untouched until
// the next successful pass clears them.
denies, accepts := splitDenyAccept(groups)
if err := d.installPeerGroups(denies, newRulePairs, true); err != nil {
return fmt.Errorf("install deny rules: %w", err)
}
if err := d.installPeerGroups(accepts, newRulePairs, false); err != nil {
merr = multierror.Append(merr, err)
}
// Tear down rules that disappeared from the networkmap. Any rule
// the backend refuses to delete stays in our tracking so it gets
// retried on the next ApplyFiltering. Otherwise a transient
// delete failure would leak the rule in the firewall until the
// process exits.
for pairID, rules := range d.peerRulesPairs {
if _, ok := newRulePairs[pairID]; ok {
continue
}
if len(rulePair) > 0 {
d.peerRulesPairs[pairID] = rulePair
newRulePairs[pairID] = rulePair
}
}
if merr != nil {
log.Errorf("failed to apply %d peer ACL rule(s): %v", merr.Len(), nberrors.FormatErrorOrNil(merr))
}
for pairID, rules := range d.peerRulesPairs {
if _, ok := newRulePairs[pairID]; !ok {
for _, rule := range rules {
if err := d.firewall.DeletePeerRule(rule); err != nil {
log.Errorf("failed to delete peer firewall rule: %v", err)
continue
}
var remaining []firewall.Rule
for _, rule := range rules {
if err := d.firewall.DeleteFilterRule(rule); err != nil {
log.Errorf("failed to delete peer firewall rule, will retry: %v", err)
remaining = append(remaining, rule)
}
delete(d.peerRulesPairs, pairID)
}
if len(remaining) > 0 {
newRulePairs[pairID] = remaining
}
}
d.peerRulesPairs = newRulePairs
return nberrors.FormatErrorOrNil(merr)
}
// installPeerGroups applies each group and records the resulting rule
// pairs in newRulePairs. With atomic set (deny rules), a single failure
// rolls back every rule installed in this call and returns, leaving the
// firewall exactly as before: denies are fail-closed and must be applied
// all-or-nothing. With atomic unset (accept rules), failures are
// accumulated and the remaining groups still install, so one malformed
// allow cannot drop every other legitimate allow in the pass.
func (d *DefaultManager) installPeerGroups(groups []*peerRuleGroup, newRulePairs map[id.RuleID][]firewall.Rule, atomic bool) error {
var freshlyInstalled []id.RuleID
var merr *multierror.Error
for _, g := range groups {
pairID, rulePair, err := d.applyPeerGroup(g)
if err != nil {
if atomic {
d.rollbackInstalled(freshlyInstalled)
return fmt.Errorf("apply firewall rule: %w", err)
}
merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
continue
}
if len(rulePair) == 0 {
continue
}
if _, existed := d.peerRulesPairs[pairID]; !existed {
freshlyInstalled = append(freshlyInstalled, pairID)
}
d.peerRulesPairs[pairID] = rulePair
newRulePairs[pairID] = rulePair
}
return nberrors.FormatErrorOrNil(merr)
}
func (d *DefaultManager) rollbackInstalled(pairIDs []id.RuleID) {
var merr *multierror.Error
for _, pairID := range pairIDs {
for _, rule := range d.peerRulesPairs[pairID] {
if err := d.firewall.DeleteFilterRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("rule %s: %w", pairID, err))
}
}
delete(d.peerRulesPairs, pairID)
}
if err := nberrors.FormatErrorOrNil(merr); err != nil {
log.Errorf("rollback peer rules: %v", err)
}
}
func (d *DefaultManager) applyPeerGroup(g *peerRuleGroup) (id.RuleID, []firewall.Rule, error) {
protocol, err := ConvertToFirewallProtocol(g.protocol)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %w", err)
}
action, err := convertFirewallAction(g.action)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %w", err)
}
port, err := resolveGroupPort(g)
if err != nil {
return "", nil, err
}
var fwRule firewall.Rule
switch g.direction {
case mgmProto.RuleDirection_IN:
fwRule, err = d.firewall.AddFilterRule(g.policyID, g.sources, firewall.Network{}, protocol, nil, port, action)
case mgmProto.RuleDirection_OUT:
if d.firewall.IsStateful() {
return "", nil, nil
}
if shouldSkipInvertedRule(protocol, port) {
return "", nil, nil
}
fwRule, err = d.firewall.AddFilterRule(g.policyID, g.sources, firewall.Network{}, protocol, port, nil, action)
default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
}
if err != nil {
return "", nil, fmt.Errorf("add firewall rule: %w", err)
}
if fwRule == nil {
return "", nil, nil
}
// Derive the pair id from the backend rule, like the route path:
// the backend dedups identical content, so two policies authorizing
// the same flow resolve to the same id and a single backing rule.
return fwRule.ID(), []firewall.Rule{fwRule}, nil
}
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error {
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
newRouteRules := make(map[id.RuleID]firewall.Rule, len(rules))
var merr *multierror.Error
// Apply new rules - firewall manager will return existing rule ID if already present
// Apply new rules - firewall manager will return the existing rule if already present
for _, rule := range rules {
id, err := d.applyRouteACL(rule, dynamicResolver)
addedRule, err := d.applyRouteACL(rule, dynamicResolver)
if err != nil {
if errors.Is(err, ErrSourceRangesEmpty) {
log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err)
@@ -163,16 +295,18 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dyn
}
continue
}
newRouteRules[id] = struct{}{}
newRouteRules[addedRule.ID()] = addedRule
}
// Clean up old firewall rules
for id := range d.routeRules {
if _, exists := newRouteRules[id]; !exists {
if err := d.firewall.DeleteRouteRule(id); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete route rule: %w", err))
}
// implicitly deleted from the map
// Tear down old route rules; retain ones the backend refused so a
// transient failure doesn't leave orphaned rules in the firewall.
for ruleID, rule := range d.routeRules {
if _, exists := newRouteRules[ruleID]; exists {
continue
}
if err := d.firewall.DeleteFilterRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete route rule, will retry: %w", err))
newRouteRules[ruleID] = rule
}
}
@@ -180,102 +314,196 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dyn
return nberrors.FormatErrorOrNil(merr)
}
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (id.RuleID, error) {
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (firewall.Rule, error) {
if len(rule.SourceRanges) == 0 {
return "", ErrSourceRangesEmpty
return nil, ErrSourceRangesEmpty
}
var sources []netip.Prefix
for _, sourceRange := range rule.SourceRanges {
source, err := netip.ParsePrefix(sourceRange)
if err != nil {
return "", fmt.Errorf("parse source range: %w", err)
return nil, fmt.Errorf("parse source range: %w", err)
}
sources = append(sources, source)
sources = append(sources, firewall.UnmapPrefix(source))
}
destination, err := determineDestination(rule, dynamicResolver, sources)
if err != nil {
return "", fmt.Errorf("determine destination: %w", err)
return nil, fmt.Errorf("determine destination: %w", err)
}
protocol, err := convertToFirewallProtocol(rule.Protocol)
protocol, err := ConvertToFirewallProtocol(rule.Protocol)
if err != nil {
return "", fmt.Errorf("invalid protocol: %w", err)
return nil, fmt.Errorf("invalid protocol: %w", err)
}
action, err := convertFirewallAction(rule.Action)
if err != nil {
return "", fmt.Errorf("invalid action: %w", err)
return nil, fmt.Errorf("invalid action: %w", err)
}
dPorts := convertPortInfo(rule.PortInfo)
addedRule, err := d.firewall.AddRouteFiltering(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
addedRule, err := d.firewall.AddFilterRule(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
if err != nil {
return "", fmt.Errorf("add route rule: %w", err)
return nil, fmt.Errorf("add route rule: %w", err)
}
if addedRule == nil {
return nil, fmt.Errorf("add route rule: %w", ErrNoRuleReturned)
}
return id.RuleID(addedRule.ID()), nil
return addedRule, nil
}
func (d *DefaultManager) protoRuleToFirewallRule(
r *mgmProto.FirewallRule,
ipsetName string,
) (id.RuleID, []firewall.Rule, error) {
ip, err := extractRuleIP(r)
if err != nil {
return "", nil, err
// splitDenyAccept partitions groups by action so denies can be
// applied before accepts. Order within each bucket is preserved.
func splitDenyAccept(groups []*peerRuleGroup) (denies, accepts []*peerRuleGroup) {
for _, g := range groups {
if g.action == mgmProto.RuleAction_DROP {
denies = append(denies, g)
} else {
accepts = append(accepts, g)
}
}
return denies, accepts
}
// groupPeerRules merges single-source rules sharing a selector into
// multi-source groups. It splits source-decode failures by action:
// denyErr is non-nil when a deny rule could not be decoded, which is a
// fail-open risk the caller must treat as fatal for the pass; err
// carries the tolerable accept-rule failures the caller can log and
// continue past.
func groupPeerRules(rules []*mgmProto.FirewallRule) (groups []*peerRuleGroup, denyErr error, err error) {
var denyMerr, acceptMerr *multierror.Error
byKey := make(map[peerRuleKey]*peerRuleGroup)
order := make([]peerRuleKey, 0)
for _, r := range rules {
srcs, decErr := extractRuleSources(r)
if decErr != nil {
if r.Action == mgmProto.RuleAction_DROP {
denyMerr = multierror.Append(denyMerr, decErr)
} else {
acceptMerr = multierror.Append(acceptMerr, decErr)
}
continue
}
// A single FirewallRule normally carries one address family, but
// split by family defensively: each backend keys a rule to one
// family and would mismatch sources of the other, so a group's
// sources must never span families.
v4, v6 := splitPrefixesByFamily(srcs)
for _, sub := range []struct {
isV6 bool
sources []netip.Prefix
}{{false, v4}, {true, v6}} {
if len(sub.sources) == 0 {
continue
}
key := ruleGroupKey(r, sub.isV6)
g, ok := byKey[key]
if !ok {
g = &peerRuleGroup{
direction: r.Direction,
action: r.Action,
protocol: r.Protocol,
port: r.PortInfo,
legacyPort: r.Port,
policyID: r.PolicyID,
}
byKey[key] = g
order = append(order, key)
}
g.sources = append(g.sources, sub.sources...)
}
}
protocol, err := convertToFirewallProtocol(r.Protocol)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
out := make([]*peerRuleGroup, 0, len(order))
for _, k := range order {
out = append(out, byKey[k])
}
return out, nberrors.FormatErrorOrNil(denyMerr), nberrors.FormatErrorOrNil(acceptMerr)
}
func prefixIsV6(p netip.Prefix) bool {
return p.Addr().Is6() && !p.Addr().Is4In6()
}
// splitPrefixesByFamily partitions prefixes into IPv4 and IPv6 groups.
func splitPrefixesByFamily(prefixes []netip.Prefix) (v4, v6 []netip.Prefix) {
for _, p := range prefixes {
if prefixIsV6(p) {
v6 = append(v6, p)
} else {
v4 = append(v4, p)
}
}
return v4, v6
}
// ruleGroupKey builds the selector key for a rule. v6 must reflect the
// rule's source family: mgmt emits one rule per family and mixing them
// would break ICMP-variant selection in uspfilter.
func ruleGroupKey(r *mgmProto.FirewallRule, v6 bool) peerRuleKey {
k := peerRuleKey{
v6: v6,
policyID: string(r.PolicyID),
direction: r.Direction,
action: r.Action,
protocol: r.Protocol,
legacyPort: r.Port,
}
if pi := r.PortInfo; pi != nil {
k.port = uint16(pi.GetPort())
if rng := pi.GetRange(); rng != nil {
k.rangeStart = uint16(rng.GetStart())
k.rangeEnd = uint16(rng.GetEnd())
}
}
return k
}
// extractRuleSources returns all source prefixes the rule applies to.
// New management populates sourcePrefixes; older management sets PeerIP.
func extractRuleSources(r *mgmProto.FirewallRule) ([]netip.Prefix, error) {
if len(r.SourcePrefixes) > 0 {
out := make([]netip.Prefix, 0, len(r.SourcePrefixes))
for _, raw := range r.SourcePrefixes {
addr, err := netiputil.DecodeAddr(raw)
if err != nil {
return nil, fmt.Errorf("decode source prefix: %w", err)
}
out = append(out, netip.PrefixFrom(addr.Unmap(), addr.Unmap().BitLen()))
}
return out, nil
}
action, err := convertFirewallAction(r.Action)
//nolint:staticcheck // PeerIP used for backward compatibility with old management
addr, err := netip.ParseAddr(r.PeerIP)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
return nil, fmt.Errorf("invalid IP address, skipping firewall rule")
}
addr = addr.Unmap()
return []netip.Prefix{netip.PrefixFrom(addr, addr.BitLen())}, nil
}
var port *firewall.Port
if !portInfoEmpty(r.PortInfo) {
port = convertPortInfo(r.PortInfo)
} else if r.Port != "" {
// old version of management, single port
value, err := strconv.Atoi(r.Port)
func resolveGroupPort(g *peerRuleGroup) (*firewall.Port, error) {
if !portInfoEmpty(g.port) {
return convertPortInfo(g.port), nil
}
if g.legacyPort != "" {
value, err := strconv.Atoi(g.legacyPort)
if err != nil {
return "", nil, fmt.Errorf("invalid port: %w", err)
return nil, fmt.Errorf("invalid port: %w", err)
}
port = &firewall.Port{
return &firewall.Port{
Values: []uint16{uint16(value)},
}
}, nil
}
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action)
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
return ruleID, rulesPair, nil
}
var rules []firewall.Rule
switch r.Direction {
case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
case mgmProto.RuleDirection_OUT:
if d.firewall.IsStateful() {
return "", nil, nil
}
// return traffic for outbound connections if firewall is stateless
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
}
if err != nil {
return "", nil, err
}
return ruleID, rules, nil
// nolint:nilnil // a nil port legitimately means "no port restriction"
return nil, nil
}
func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
@@ -294,85 +522,9 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
}
}
func (d *DefaultManager) addInRules(
id []byte,
ip netip.Addr,
protocol firewall.Protocol,
port *firewall.Port,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, nil, port, action, ipsetName)
if err != nil {
return nil, fmt.Errorf("add firewall rule: %w", err)
}
return rule, nil
}
func (d *DefaultManager) addOutRules(
id []byte,
ip netip.Addr,
protocol firewall.Protocol,
port *firewall.Port,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
if shouldSkipInvertedRule(protocol, port) {
return nil, nil
}
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, port, nil, action, ipsetName)
if err != nil {
return nil, fmt.Errorf("add firewall rule: %w", err)
}
return rule, nil
}
// getPeerRuleID returns unique ID for the rule based on its parameters.
func (d *DefaultManager) getPeerRuleID(
ip netip.Addr,
proto firewall.Protocol,
direction int,
port *firewall.Port,
action firewall.Action,
) id.RuleID {
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action))
if port != nil {
idStr += port.String()
}
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
}
// getRuleGroupingSelector takes all rule properties except IP address to build selector
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
}
// extractRuleIP extracts the peer IP from a firewall rule.
// If sourcePrefixes is populated (new management), decode the first entry and use its address.
// Otherwise fall back to the deprecated PeerIP string field (old management).
func extractRuleIP(r *mgmProto.FirewallRule) (netip.Addr, error) {
if len(r.SourcePrefixes) > 0 {
addr, err := netiputil.DecodeAddr(r.SourcePrefixes[0])
if err != nil {
return netip.Addr{}, fmt.Errorf("decode source prefix: %w", err)
}
return addr.Unmap(), nil
}
//nolint:staticcheck // PeerIP used for backward compatibility with old management
addr, err := netip.ParseAddr(r.PeerIP)
if err != nil {
return netip.Addr{}, fmt.Errorf("invalid IP address, skipping firewall rule")
}
return addr.Unmap(), nil
}
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
// ConvertToFirewallProtocol maps a management rule protocol to the
// firewall protocol type.
func ConvertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
switch protocol {
case mgmProto.RuleProtocol_TCP:
return firewall.ProtocolTCP, nil

View File

@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/firewall"
fwmanager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
@@ -76,9 +77,9 @@ func TestDefaultManager(t *testing.T) {
})
t.Run("add extra rules", func(t *testing.T) {
existedPairs := map[string]struct{}{}
existedPairs := map[fwmanager.RuleID]struct{}{}
for id := range acl.peerRulesPairs {
existedPairs[id.ID()] = struct{}{}
existedPairs[id] = struct{}{}
}
// remove first rule
@@ -105,7 +106,7 @@ func TestDefaultManager(t *testing.T) {
// check that old rule was removed
previousCount := 0
for id := range acl.peerRulesPairs {
if _, ok := existedPairs[id.ID()]; ok {
if _, ok := existedPairs[id]; ok {
previousCount++
}
}

View File

@@ -360,7 +360,13 @@ func isRedirectURLPortUsed(redirectURL string, excludedRanges []excludedPortRang
return true
}
addr := fmt.Sprintf(":%s", port)
// FreeBSD 15 disables connecting to INADDR_ANY (0.0.0.0) as a localhost
// alias by default, ensure explicit ip for localhost.
host := parsedURL.Hostname()
if host == "" {
host = "127.0.0.1"
}
addr := net.JoinHostPort(host, port)
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
if err != nil {
return false

View File

@@ -339,8 +339,7 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
case entry.Pattern == ".":
return true
case entry.IsWildcard:
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
return strings.HasSuffix(qname, "."+entry.Pattern)
default:
// For non-wildcard patterns:
// If handler wants subdomain matching, allow suffix match

View File

@@ -164,6 +164,54 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
matchSubdomains: true,
shouldMatch: true,
},
{
name: "wildcard label-boundary mismatch (suffix overlap)",
handlerDomain: "*.b.test.",
queryDomain: "x.ab.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "wildcard label-boundary match",
handlerDomain: "*.b.test.",
queryDomain: "x.b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "wildcard multi-label match",
handlerDomain: "*.b.test.",
queryDomain: "x.y.b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "wildcard no match on multi-label apex",
handlerDomain: "*.b.test.",
queryDomain: "b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "wildcard no match on unrelated suffix containment",
handlerDomain: "*.example.com.",
queryDomain: "notexample.com.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "wildcard accepts pattern registered without trailing dot",
handlerDomain: "*.b.test",
queryDomain: "x.b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
}
for _, tt := range tests {
@@ -273,6 +321,19 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
expectedCalls: 1,
expectedHandler: 2, // highest priority matching handler should be called
},
{
name: "overlapping wildcard suffixes route to correct handler",
handlers: []struct {
pattern string
priority int
}{
{pattern: "*.b.test.", priority: nbdns.PriorityDNSRoute},
{pattern: "*.ab.test.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "app.ab.test.",
expectedCalls: 1,
expectedHandler: 1,
},
{
name: "root zone with specific domain",
handlers: []struct {

View File

@@ -26,6 +26,19 @@ type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
// PeerConnectivity reports whether a tunnel IP belongs to a peer the
// client knows about and whether that peer is currently connected. The
// local resolver uses this to suppress A/AAAA answers whose RDATA points
// at a disconnected peer (typical case: a synthesized private-service
// record pointing at an embedded proxy peer that just went offline).
//
// known=false means the IP isn't in the local peerstore at all — the
// record is left alone (it points at something outside our mesh, e.g.
// a non-peer upstream).
type PeerConnectivity interface {
IsConnectedByIP(ip string) (known, connected bool)
}
type Resolver struct {
mu sync.RWMutex
records map[dns.Question][]dns.RR
@@ -33,6 +46,11 @@ type Resolver struct {
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
zones map[domain.Domain]bool
resolver resolver
// peerConn, when non-nil, is consulted on every A/AAAA answer to
// drop records pointing at disconnected peers. nil disables the
// filter and preserves the legacy "return whatever is registered"
// behaviour for callers that never wire a status source.
peerConn PeerConnectivity
ctx context.Context
cancel context.CancelFunc
@@ -49,6 +67,15 @@ func NewResolver() *Resolver {
}
}
// SetPeerConnectivity wires the per-IP connectivity check used to filter
// out A/AAAA answers pointing at disconnected peers. Pass nil to disable.
// Safe to call multiple times; the latest value wins.
func (d *Resolver) SetPeerConnectivity(p PeerConnectivity) {
d.mu.Lock()
defer d.mu.Unlock()
d.peerConn = p
}
func (d *Resolver) MatchSubdomains() bool {
return true
}
@@ -95,6 +122,7 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
replyMessage.RecursionAvailable = true
result := d.lookupRecords(logger, question)
result.records = d.filterDisconnectedPeerAnswers(logger, question, result.records)
replyMessage.Authoritative = !result.hasExternalData
replyMessage.Answer = result.records
replyMessage.Rcode = d.determineRcode(question, result)
@@ -436,6 +464,78 @@ func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16,
}
}
// filterDisconnectedPeerAnswers drops A/AAAA records whose RDATA matches
// a known but disconnected peer. The synthesized private-service zones
// emit one A record per connected proxy peer in a cluster; when a peer
// goes offline, the server-side refresh removes the record from the
// next netmap, but the client may still hold the previous netmap for a
// short window. This filter is the local belt to that braces — even on
// the stale netmap, the resolver hides the offline target.
//
// Records pointing at unknown IPs (outside the local peerstore, e.g.
// non-mesh upstreams) are never dropped. Non-A/AAAA records pass
// through untouched.
//
// Escape hatch: if filtering would leave the answer empty AND at least
// one record was filtered, the original list is returned. Better to
// hand the client a record that may not respond than NXDOMAIN it
// completely when every proxy peer is offline (the upstream may still
// be reachable some other way, or the peerstore may be stale).
func (d *Resolver) filterDisconnectedPeerAnswers(logger *log.Entry, question dns.Question, records []dns.RR) []dns.RR {
if len(records) == 0 {
return records
}
d.mu.RLock()
checker := d.peerConn
d.mu.RUnlock()
if checker == nil {
return records
}
kept := make([]dns.RR, 0, len(records))
var dropped int
for _, rr := range records {
ip := extractRecordIP(rr)
if ip == "" {
kept = append(kept, rr)
continue
}
known, connected := checker.IsConnectedByIP(ip)
if known && !connected {
dropped++
continue
}
kept = append(kept, rr)
}
if dropped == 0 {
return records
}
if len(kept) == 0 {
logger.Debugf("all %d answers for %s point at disconnected peers; returning the original list", dropped, question.Name)
return records
}
logger.Tracef("dropped %d disconnected-peer answer(s) for %s, returning %d", dropped, question.Name, len(kept))
return kept
}
// extractRecordIP returns the dotted-decimal / colon-hex IP carried by
// an A or AAAA record, or "" for any other record type.
func extractRecordIP(rr dns.RR) string {
switch r := rr.(type) {
case *dns.A:
if r.A == nil {
return ""
}
return r.A.String()
case *dns.AAAA:
if r.AAAA == nil {
return ""
}
return r.AAAA.String()
}
return ""
}
// Update replaces all zones and their records
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
d.mu.Lock()

View File

@@ -30,6 +30,21 @@ func (m *mockResolver) LookupNetIP(ctx context.Context, network, host string) ([
return nil, nil
}
// mockPeerConnectivity returns canned (known, connected) results per IP.
// Used by the disconnected-peer filter tests below. IPs not in the map
// are reported as unknown so the filter leaves them alone.
type mockPeerConnectivity struct {
byIP map[string]struct{ known, connected bool }
}
func (m mockPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) {
v, ok := m.byIP[ip]
if !ok {
return false, false
}
return v.known, v.connected
}
func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.",
@@ -2652,3 +2667,114 @@ func BenchmarkIsInManagedZone_ManyZones(b *testing.B) {
resolver.isInManagedZone(qname)
}
}
// TestLocalResolver_FilterDisconnectedPeerAnswers verifies the
// connectivity-aware filtering layered on top of lookupRecords:
// when an A record's IP belongs to a known peer that's disconnected,
// the record is dropped from the answer. Records for unknown IPs pass
// through. If filtering would empty the answer entirely and at least
// one record was dropped, the original list is restored (escape hatch
// for the "all proxies offline" case).
func TestLocalResolver_FilterDisconnectedPeerAnswers(t *testing.T) {
zone := "svc.cluster.netbird."
connectedRec := nbdns.SimpleRecord{
Name: zone,
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 5,
RData: "100.64.0.10",
}
disconnectedRec := nbdns.SimpleRecord{
Name: zone,
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 5,
RData: "100.64.0.11",
}
unknownRec := nbdns.SimpleRecord{
Name: zone,
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 5,
RData: "203.0.113.5",
}
type ipState struct{ known, connected bool }
tests := []struct {
name string
records []nbdns.SimpleRecord
connByIP map[string]ipState
wantInOrder []string
}{
{
name: "drops disconnected peer, keeps connected",
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
connByIP: map[string]ipState{
"100.64.0.10": {known: true, connected: true},
"100.64.0.11": {known: true, connected: false},
},
wantInOrder: []string{"100.64.0.10"},
},
{
name: "unknown IPs pass through untouched",
records: []nbdns.SimpleRecord{unknownRec, disconnectedRec},
connByIP: map[string]ipState{
"100.64.0.11": {known: true, connected: false},
},
wantInOrder: []string{"203.0.113.5"},
},
{
name: "all disconnected falls back to original list",
records: []nbdns.SimpleRecord{disconnectedRec, connectedRec},
connByIP: map[string]ipState{
"100.64.0.10": {known: true, connected: false},
"100.64.0.11": {known: true, connected: false},
},
wantInOrder: []string{"100.64.0.11", "100.64.0.10"},
},
{
name: "no checker wired returns all records",
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
connByIP: nil,
wantInOrder: []string{"100.64.0.10", "100.64.0.11"},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
resolver := NewResolver()
if tc.connByIP != nil {
cm := mockPeerConnectivity{byIP: make(map[string]struct{ known, connected bool }, len(tc.connByIP))}
for ip, st := range tc.connByIP {
cm.byIP[ip] = struct{ known, connected bool }{st.known, st.connected}
}
resolver.SetPeerConnectivity(cm)
}
resolver.Update([]nbdns.CustomZone{{
Domain: strings.TrimSuffix(zone, "."),
Records: tc.records,
NonAuthoritative: true,
}})
var got *dns.Msg
writer := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
got = m
return nil
},
}
req := new(dns.Msg).SetQuestion(zone, dns.TypeA)
resolver.ServeDNS(writer, req)
require.NotNil(t, got, "resolver must produce a response")
require.Len(t, got.Answer, len(tc.wantInOrder),
"answer count must match expected: %v", tc.wantInOrder)
for i, want := range tc.wantInOrder {
a, ok := got.Answer[i].(*dns.A)
require.True(t, ok, "answer[%d] must be an A record", i)
assert.Equal(t, want, a.A.String(),
"answer[%d] expected %s got %s", i, want, a.A.String())
}
})
}
}

View File

@@ -301,6 +301,11 @@ func newDefaultServer(
warningDelayBase: defaultWarningDelayBase,
healthRefresh: make(chan struct{}, 1),
}
// Wire the local resolver against the peer status recorder so it can
// suppress A/AAAA answers that point at disconnected peers (typical
// case: synthesised private-service records pointing at an embedded
// proxy peer that just went offline).
defaultServer.localResolver.SetPeerConnectivity(localPeerConnectivity{statusRecorder})
// register with root zone, handler chain takes care of the routing
dnsService.RegisterMux(".", handlerChain)
@@ -1386,3 +1391,25 @@ func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error {
}
return nil
}
// localPeerConnectivity adapts *peer.Status to local.PeerConnectivity so
// the local resolver can ask "is this IP a known peer and is it
// connected?" without taking on the peer package as a dependency.
// A nil status recorder always reports known=false so the resolver
// short-circuits to the legacy "return everything" path.
type localPeerConnectivity struct {
status *peer.Status
}
// IsConnectedByIP looks the IP up in the peerstore and surfaces both
// the known and connected bits. Used by Resolver.filterDisconnectedPeerAnswers.
func (l localPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) {
if l.status == nil {
return false, false
}
state, ok := l.status.PeerStateByIP(ip)
if !ok {
return false, false
}
return true, state.ConnStatus == peer.StatusConnected
}

View File

@@ -900,7 +900,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
return nil, err
}
pf, err := uspfilter.Create(wgIface, false, flowLogger, iface.DefaultMTU)
pf, err := uspfilter.Create(wgIface, nil, false, flowLogger, iface.DefaultMTU)
if err != nil {
t.Fatalf("failed to create uspfilter: %v", err)
return nil, err

View File

@@ -3,7 +3,6 @@ package dnsfwd
import (
"context"
"fmt"
"net"
"net/netip"
"os"
"strconv"
@@ -160,12 +159,13 @@ func (m *Manager) allowDNSFirewall() error {
return nil
}
dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
anyV4 := []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
dnsRule, err := m.firewall.AddFilterRule(nil, anyV4, firewall.Network{}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept)
if err != nil {
return fmt.Errorf("add udp firewall rule: %w", err)
}
tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "")
tcpRule, err := m.firewall.AddFilterRule(nil, anyV4, firewall.Network{}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept)
if err != nil {
return fmt.Errorf("add tcp firewall rule: %w", err)
}
@@ -174,8 +174,12 @@ func (m *Manager) allowDNSFirewall() error {
return fmt.Errorf("flush: %w", err)
}
m.fwRules = dnsRules
m.tcpRules = tcpRules
if dnsRule != nil {
m.fwRules = []firewall.Rule{dnsRule}
}
if tcpRule != nil {
m.tcpRules = []firewall.Rule{tcpRule}
}
m.registerNetstackServices()
@@ -209,12 +213,12 @@ func (m *Manager) unregisterNetstackServices() {
func (m *Manager) dropDNSFirewall() error {
var mErr *multierror.Error
for _, rule := range m.fwRules {
if err := m.firewall.DeletePeerRule(rule); err != nil {
if err := m.firewall.DeleteFilterRule(rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
}
}
for _, rule := range m.tcpRules {
if err := m.firewall.DeletePeerRule(rule); err != nil {
if err := m.firewall.DeleteFilterRule(rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
}
}

View File

@@ -640,14 +640,14 @@ func (e *Engine) initFirewall() error {
port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}}
// IPv4-only: rosenpass peers connect via AllowedIps[0] which is always v4.
if _, err := e.firewall.AddPeerFiltering(
if _, err := e.firewall.AddFilterRule(
nil,
net.IP{0, 0, 0, 0},
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
firewallManager.Network{},
firewallManager.ProtocolUDP,
nil,
&port,
firewallManager.ActionAccept,
"",
); err != nil {
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
return nil
@@ -697,7 +697,7 @@ func (e *Engine) blockLanAccess() {
if network.Addr().Is6() {
source = v6
}
if _, err := e.firewall.AddRouteFiltering(
if _, err := e.firewall.AddFilterRule(
nil,
[]netip.Prefix{source},
firewallManager.Network{Prefix: network},
@@ -1967,6 +1967,29 @@ func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
return e.clientMetrics
}
// Performance bundles runtime-adjustable tunnel pool knobs.
// See Engine.SetPerformance. Nil fields are ignored.
type Performance struct {
PreallocatedBuffersPerPool *uint32
}
// SetPerformance applies the given tuning to this engine's live Device.
func (e *Engine) SetPerformance(t Performance) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if e.wgInterface == nil {
return fmt.Errorf("wg interface not initialized")
}
dev := e.wgInterface.GetWGDevice()
if dev == nil {
return fmt.Errorf("wg device not initialized")
}
if t.PreallocatedBuffersPerPool != nil {
dev.SetPreallocatedBuffersPerPool(*t.PreallocatedBuffersPerPool)
}
return nil
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {
@@ -2346,7 +2369,7 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewal
var merr *multierror.Error
forwardingRules := make([]firewallManager.ForwardRule, 0, len(rules))
for _, rule := range rules {
proto, err := convertToFirewallProtocol(rule.GetProtocol())
proto, err := acl.ConvertToFirewallProtocol(rule.GetProtocol())
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("failed to convert protocol '%s': %w", rule.GetProtocol(), err))
continue

View File

@@ -27,7 +27,7 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
@@ -66,8 +66,8 @@ import (
"github.com/netbirdio/netbird/route"
mgmt "github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/shared/netiputil"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
@@ -1641,7 +1641,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)

View File

@@ -24,14 +24,14 @@ type RulePair struct {
type Manager struct {
dnatFirewall DNATFirewall
rules map[string]RulePair // keys is the ID of the ForwardRule
rules map[firewall.RuleID]RulePair
rulesMu sync.Mutex
}
func NewManager(dnatFirewall DNATFirewall) *Manager {
return &Manager{
dnatFirewall: dnatFirewall,
rules: make(map[string]RulePair),
rules: make(map[firewall.RuleID]RulePair),
}
}
@@ -41,7 +41,7 @@ func (h *Manager) Update(forwardRules []firewall.ForwardRule) error {
var mErr *multierror.Error
toDelete := make(map[string]RulePair, len(h.rules))
toDelete := make(map[firewall.RuleID]RulePair, len(h.rules))
for id, r := range h.rules {
toDelete[id] = r
}
@@ -59,6 +59,10 @@ func (h *Manager) Update(forwardRules []firewall.ForwardRule) error {
mErr = multierror.Append(mErr, fmt.Errorf("add forward rule '%s': %v", fwdRule.String(), err))
continue
}
if rule == nil {
mErr = multierror.Append(mErr, fmt.Errorf("add forward rule '%s': backend returned no rule", fwdRule.String()))
continue
}
log.Infof("forward rule has been added '%s'", fwdRule)
h.rules[id] = RulePair{
ForwardRule: fwdRule,
@@ -90,7 +94,7 @@ func (h *Manager) Close() error {
}
}
h.rules = make(map[string]RulePair)
h.rules = make(map[firewall.RuleID]RulePair)
return nberrors.FormatErrorOrNil(mErr)
}

View File

@@ -14,11 +14,11 @@ var (
)
type MocFwRule struct {
id string
id firewall.RuleID
}
func (m *MocFwRule) ID() string {
return string(m.id)
func (m *MocFwRule) ID() firewall.RuleID {
return m.id
}
type MockDNATFirewall struct {

View File

@@ -10,21 +10,6 @@ import (
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewallManager.Protocol, error) {
switch protocol {
case mgmProto.RuleProtocol_TCP:
return firewallManager.ProtocolTCP, nil
case mgmProto.RuleProtocol_UDP:
return firewallManager.ProtocolUDP, nil
case mgmProto.RuleProtocol_ICMP:
return firewallManager.ProtocolICMP, nil
case mgmProto.RuleProtocol_ALL:
return firewallManager.ProtocolALL, nil
default:
return "", fmt.Errorf("invalid protocol type: %s", protocol.String())
}
}
func convertPortInfo(portInfo *mgmProto.PortInfo) (*firewallManager.Port, error) {
if portInfo == nil {
return nil, errors.New("portInfo cannot be nil")

View File

@@ -50,7 +50,7 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
switch msg.Type {
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
route, flags, err := parseRouteMessage(buf[:n])
if err != nil {
log.Debugf("Network monitor: error parsing routing message: %v", err)
continue
@@ -66,6 +66,10 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
}
switch msg.Type {
case unix.RTM_ADD:
if systemops.IgnoreAddedDefaultRoute(flags) {
log.Debugf("Network monitor: ignoring added default route via %s, interface %s, flags %#x", route.Gw, intf, flags)
continue
}
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
return nil
case unix.RTM_DELETE:
@@ -78,22 +82,26 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
}
}
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
func parseRouteMessage(buf []byte) (*systemops.Route, int, error) {
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
if err != nil {
return nil, fmt.Errorf("parse RIB: %v", err)
return nil, 0, fmt.Errorf("parse RIB: %v", err)
}
if len(msgs) != 1 {
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
return nil, 0, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
}
msg, ok := msgs[0].(*route.RouteMessage)
if !ok {
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
return nil, 0, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
}
return systemops.MsgToRoute(msg)
r, err := systemops.MsgToRoute(msg)
if err != nil {
return nil, 0, err
}
return r, msg.Flags, nil
}
// waitReadable blocks until fd has data to read, or ctx is cancelled.

View File

@@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
@@ -899,7 +900,7 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
}
// Fallback to deterministic key if no NetBird PSK is configured
determKey, err := conn.rosenpassDetermKey()
determKey, err := rosenpass.DeterministicSeedKey(conn.config.LocalKey, conn.config.Key)
if err != nil {
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
return nil
@@ -908,26 +909,6 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
return determKey
}
// todo: move this logic into Rosenpass package
func (conn *Conn) rosenpassDetermKey() (*wgtypes.Key, error) {
lk := []byte(conn.config.LocalKey)
rk := []byte(conn.config.Key) // remote key
var keyInput []byte
if string(lk) > string(rk) {
//nolint:gocritic
keyInput = append(lk[:16], rk[:16]...)
} else {
//nolint:gocritic
keyInput = append(rk[:16], lk[:16]...)
}
key, err := wgtypes.NewKey(keyInput)
if err != nil {
return nil, err
}
return &key, nil
}
func isController(config ConnConfig) bool {
return config.LocalKey > config.Key
}

View File

@@ -185,9 +185,12 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
return s.eventsChan
}
// Status holds a state of peers, signal, management connections and relays
// Status holds a state of peers, signal, management connections and relays.
// mux is an RWMutex so hot read paths (notably PeerStateByIP, called for
// every private-service request) don't contend against each other.
// Pure read methods take RLock; anything that mutates state takes Lock.
type Status struct {
mux sync.Mutex
mux sync.RWMutex
peers map[string]State
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
signalState bool
@@ -283,8 +286,8 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string)
// GetPeer adds peer to Daemon status map
func (d *Status) GetPeer(peerPubKey string) (State, error) {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
state, ok := d.peers[peerPubKey]
if !ok {
@@ -294,8 +297,8 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
}
func (d *Status) PeerByIP(ip string) (string, bool) {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
for _, state := range d.peers {
if state.IP == ip {
@@ -305,6 +308,25 @@ func (d *Status) PeerByIP(ip string) (string, bool) {
return "", false
}
// PeerStateByIP returns the full peer State for the given tunnel IP.
// Matches against either the IPv4 (State.IP) or IPv6 (State.IPv6) tunnel
// address so dual-stack peers are reachable on either family. Returns the
// zero State and false when no peer matches or the input is empty.
func (d *Status) PeerStateByIP(ip string) (State, bool) {
if ip == "" {
return State{}, false
}
d.mux.RLock()
defer d.mux.RUnlock()
for _, state := range d.peers {
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
return state, true
}
}
return State{}, false
}
// RemovePeer removes peer from Daemon status map
func (d *Status) RemovePeer(peerPubKey string) error {
d.mux.Lock()
@@ -702,8 +724,8 @@ func (d *Status) UnsubscribePeerStateChanges(subscription *StatusChangeSubscript
// GetLocalPeerState returns the local peer state
func (d *Status) GetLocalPeerState() LocalPeerState {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
return d.localPeer.Clone()
}
@@ -909,8 +931,8 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
}
func (d *Status) GetRosenpassState() RosenpassState {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
return RosenpassState{
d.rosenpassEnabled,
d.rosenpassPermissive,
@@ -918,14 +940,14 @@ func (d *Status) GetRosenpassState() RosenpassState {
}
func (d *Status) GetLazyConnection() bool {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
return d.lazyConnectionEnabled
}
func (d *Status) GetManagementState() ManagementState {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
return ManagementState{
d.mgmAddress,
d.managementState,
@@ -951,8 +973,8 @@ func (d *Status) UpdateLatency(pubKey string, latency time.Duration) error {
// IsLoginRequired determines if a peer's login has expired.
func (d *Status) IsLoginRequired() bool {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
// if peer is connected to the management then login is not expired
if d.managementState {
@@ -967,8 +989,8 @@ func (d *Status) IsLoginRequired() bool {
}
func (d *Status) GetSignalState() SignalState {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
return SignalState{
d.signalAddress,
d.signalState,
@@ -978,8 +1000,8 @@ func (d *Status) GetSignalState() SignalState {
// GetRelayStates returns the stun/turn/permanent relay states
func (d *Status) GetRelayStates() []relay.ProbeResult {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
if d.relayMgr == nil {
return d.relayStates
}
@@ -1008,8 +1030,8 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
}
func (d *Status) ForwardingRules() []firewall.ForwardRule {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
if d.ingressGwMgr == nil {
return nil
}
@@ -1018,16 +1040,16 @@ func (d *Status) ForwardingRules() []firewall.ForwardRule {
}
func (d *Status) GetDNSStates() []NSGroupState {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
// shallow copy is good enough, as slices fields are currently not updated
return slices.Clone(d.nsGroupStates)
}
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
return maps.Clone(d.resolvedDomainsStates)
}
@@ -1043,8 +1065,8 @@ func (d *Status) GetFullStatus() FullStatus {
LazyConnectionEnabled: d.GetLazyConnection(),
}
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
fullStatus.LocalPeerState = d.localPeer
@@ -1219,8 +1241,8 @@ func (d *Status) SetWgIface(wgInterface WGIfaceStatus) {
}
func (d *Status) PeersStatus() (*configurer.Stats, error) {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
if d.wgIface == nil {
return nil, fmt.Errorf("wgInterface is nil, cannot retrieve peers status")
}

View File

@@ -63,6 +63,33 @@ func TestUpdatePeerState(t *testing.T) {
assert.Equal(t, ip, state.IP, "ip should be equal")
}
func TestStatus_PeerStateByIP(t *testing.T) {
status := NewRecorder("https://mgm")
req := require.New(t)
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", ""))
req.NoError(status.AddPeer("pk-2", "peer-2.netbird", "100.64.0.11", ""))
state, ok := status.PeerStateByIP("100.64.0.10")
req.True(ok, "known tunnel IP should resolve to a peer state")
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
req.Equal("peer-1.netbird", state.FQDN, "matching state must carry the right FQDN")
_, ok = status.PeerStateByIP("100.64.0.99")
req.False(ok, "unknown IP must report ok=false")
}
func TestStatus_PeerStateByIP_MatchesIPv6(t *testing.T) {
status := NewRecorder("https://mgm")
req := require.New(t)
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", "fd00::1"))
state, ok := status.PeerStateByIP("fd00::1")
req.True(ok, "IPv6-only match must resolve to the peer state")
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
}
func TestStatus_UpdatePeerFQDN(t *testing.T) {
key := "abc"
fqdn := "peer-a.netbird.local"

View File

@@ -179,8 +179,10 @@ func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) {
}
dst := net.IPv4zero
if runtime.GOOS == "linux" {
// go-netroute v0.4.0 rejects unspecified destinations client-side on Linux.
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
// go-netroute v0.4.0 rejects unspecified destinations client-side on Linux/Android.
// TODO: on android/ios, use platform APIs (ConnectivityManager.getLinkProperties /
// NWPathMonitor) when netlink-based lookup is restricted or unavailable.
dst = net.IPv4(0, 0, 0, 1)
}
_, gateway, localIP, err = router.Route(dst)
@@ -203,7 +205,7 @@ func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) {
}
dst := net.IPv6zero
if runtime.GOOS == "linux" {
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
// ::2
dst = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}
}

View File

@@ -28,6 +28,15 @@ func hashRosenpassKey(key []byte) string {
return hex.EncodeToString(hasher.Sum(nil))
}
// rpServer is the subset of rp.Server used by Manager. Defined as an interface
// so tests can substitute a mock without spinning up a real UDP server.
type rpServer interface {
AddPeer(rp.PeerConfig) (rp.PeerID, error)
RemovePeer(rp.PeerID) error
Run() error
Close() error
}
type Manager struct {
ifaceName string
spk []byte
@@ -36,7 +45,7 @@ type Manager struct {
preSharedKey *[32]byte
rpPeerIDs map[string]*rp.PeerID
rpWgHandler *NetbirdHandler
server *rp.Server
server rpServer
lock sync.Mutex
port int
wgIface PresharedKeySetter
@@ -51,7 +60,22 @@ func NewManager(preSharedKey *wgtypes.Key, wgIfaceName string) (*Manager, error)
rpKeyHash := hashRosenpassKey(public)
log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash)
return &Manager{ifaceName: wgIfaceName, rpKeyHash: rpKeyHash, spk: public, ssk: secret, preSharedKey: (*[32]byte)(preSharedKey), rpPeerIDs: make(map[string]*rp.PeerID), lock: sync.Mutex{}}, nil
return &Manager{
ifaceName: wgIfaceName,
rpKeyHash: rpKeyHash,
spk: public,
ssk: secret,
preSharedKey: (*[32]byte)(preSharedKey),
rpPeerIDs: make(map[string]*rp.PeerID),
// rpWgHandler is created here (instead of only in generateConfig) so it
// is never nil between NewManager and Run(). Otherwise an early
// OnConnected call (race observed on Android, issue #4341) panics on
// nil receiver in addPeer -> m.rpWgHandler.AddPeer. generateConfig will
// replace it with a fresh handler on each Run() to clear stale peer
// state from previous engine sessions.
rpWgHandler: NewNetbirdHandler(),
lock: sync.Mutex{},
}, nil
}
func (m *Manager) GetPubKey() []byte {
@@ -65,6 +89,16 @@ func (m *Manager) GetAddress() *net.UDPAddr {
// addPeer adds a new peer to the Rosenpass server
func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuardIP string, wireGuardPubKey string) error {
// Defense in depth against issue #4341 (Android crash): if Run() has not
// completed yet, m.server / m.rpWgHandler may be nil. Return an explicit
// error instead of panicking on nil-receiver dereference.
if m.server == nil {
return fmt.Errorf("rosenpass server not initialized")
}
if m.rpWgHandler == nil {
return fmt.Errorf("rosenpass wg handler not initialized")
}
var err error
pcfg := rp.PeerConfig{PublicKey: rosenpassPubKey}
if m.preSharedKey != nil {
@@ -79,6 +113,16 @@ func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuar
if pcfg.Endpoint, err = net.ResolveUDPAddr("udp", peerAddr); err != nil {
return fmt.Errorf("failed to resolve peer endpoint address: %w", err)
}
// Our local Rosenpass UDP server binds on the IPv6 wildcard ([::]) — see
// GetAddress(). The remote peer's endpoint (pcfg.Endpoint) is the destination
// our server will sendto when initiating handshakes. ResolveUDPAddr returns a
// 4-byte IPv4 for IPv4 hosts, which the kernel rejects (EDESTADDRREQ) when
// sent from an AF_INET6 socket. Normalize the remote endpoint to IPv4-mapped
// IPv6 so its address family matches our listening socket.
// TODO: maybe bind the Rosenpass UDP server to the peer wg IP addr
if v4 := pcfg.Endpoint.IP.To4(); v4 != nil {
pcfg.Endpoint.IP = v4.To16()
}
}
peerID, err := m.server.AddPeer(pcfg)
if err != nil {
@@ -182,24 +226,31 @@ func (m *Manager) Run() error {
return err
}
m.server, err = rp.NewUDPServer(conf)
server, err := rp.NewUDPServer(conf)
if err != nil {
return err
}
m.lock.Lock()
m.server = server
m.lock.Unlock()
log.Infof("starting rosenpass server on port %d", m.port)
return m.server.Run()
return server.Run()
}
// Close closes the Rosenpass server
func (m *Manager) Close() error {
if m.server != nil {
err := m.server.Close()
if err != nil {
log.Errorf("failed closing local rosenpass server")
}
m.server = nil
m.lock.Lock()
server := m.server
m.server = nil
m.lock.Unlock()
if server == nil {
return nil
}
if err := server.Close(); err != nil {
log.Errorf("failed closing local rosenpass server: %v", err)
}
return nil
}

View File

@@ -1,14 +1,412 @@
package rosenpass
import (
"errors"
"os"
"sync"
"testing"
rp "cunicu.li/go-rosenpass"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// --- test doubles -----------------------------------------------------------
type addPeerCall struct {
cfg rp.PeerConfig
}
type removePeerCall struct {
id rp.PeerID
}
type mockServer struct {
mu sync.Mutex
addCalls []addPeerCall
removed []removePeerCall
nextID rp.PeerID
addErr error
removeErr error
closed bool
ran bool
}
func (m *mockServer) AddPeer(cfg rp.PeerConfig) (rp.PeerID, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.addCalls = append(m.addCalls, addPeerCall{cfg: cfg})
if m.addErr != nil {
return rp.PeerID{}, m.addErr
}
// Increment a byte in nextID so distinct peers get distinct IDs.
m.nextID[0]++
return m.nextID, nil
}
func (m *mockServer) RemovePeer(id rp.PeerID) error {
m.mu.Lock()
defer m.mu.Unlock()
m.removed = append(m.removed, removePeerCall{id: id})
return m.removeErr
}
func (m *mockServer) Run() error { m.ran = true; return nil }
func (m *mockServer) Close() error { m.closed = true; return nil }
type setPSKCall struct {
peerKey string
psk wgtypes.Key
updateOnly bool
}
type mockIface struct {
mu sync.Mutex
calls []setPSKCall
err error
}
func (m *mockIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
m.mu.Lock()
defer m.mu.Unlock()
m.calls = append(m.calls, setPSKCall{peerKey: peerKey, psk: psk, updateOnly: updateOnly})
return m.err
}
// newTestManager builds a Manager with deterministic spk so tie-break
// against a peer pubkey is controllable from tests. The provided spk byte
// becomes the first byte; remaining bytes are zero.
func newTestManager(spkFirstByte byte, mock *mockServer) *Manager {
spk := make([]byte, 32)
spk[0] = spkFirstByte
return &Manager{
ifaceName: "wt0",
spk: spk,
ssk: make([]byte, 32),
rpKeyHash: "test-hash",
rpPeerIDs: make(map[string]*rp.PeerID),
rpWgHandler: NewNetbirdHandler(),
server: mock,
}
}
// validWGKey returns a deterministic 32-byte wireguard public key (base64).
func validWGKey(t *testing.T, lastByte byte) string {
t.Helper()
var k wgtypes.Key
k[31] = lastByte
return k.String()
}
// --- pure helpers ----------------------------------------------------------
func TestHashRosenpassKey_Deterministic(t *testing.T) {
key := []byte("hello-rosenpass")
require.Equal(t, hashRosenpassKey(key), hashRosenpassKey(key))
require.Len(t, hashRosenpassKey(key), 64) // sha256 hex
}
func TestHashRosenpassKey_DifferentInputsDifferOutputs(t *testing.T) {
require.NotEqual(t, hashRosenpassKey([]byte("a")), hashRosenpassKey([]byte("b")))
}
func TestGetLogLevel_DefaultWhenUnset(t *testing.T) {
// Snapshot + unset to exercise the LookupEnv ok=false branch. t.Setenv
// can only set, not delete, so do it manually with restore via t.Cleanup.
prev, hadPrev := os.LookupEnv(defaultLogLevelVar)
require.NoError(t, os.Unsetenv(defaultLogLevelVar))
t.Cleanup(func() {
if hadPrev {
_ = os.Setenv(defaultLogLevelVar, prev)
} else {
_ = os.Unsetenv(defaultLogLevelVar)
}
})
require.Equal(t, defaultLog.String(), getLogLevel().String())
}
func TestGetLogLevel_Cases(t *testing.T) {
cases := map[string]string{
"debug": "DEBUG",
"info": "INFO",
"warn": "WARN",
"error": "ERROR",
"unknown": "INFO", // default fallback
}
for input, wantStr := range cases {
input, wantStr := input, wantStr
t.Run(input, func(t *testing.T) {
t.Setenv(defaultLogLevelVar, input)
require.Equal(t, wantStr, getLogLevel().String())
})
}
}
func TestFindRandomAvailableUDPPort(t *testing.T) {
port, err := findRandomAvailableUDPPort()
require.NoError(t, err)
require.Greater(t, port, 0)
require.LessOrEqual(t, port, 65535)
}
// --- addPeer ---------------------------------------------------------------
func TestAddPeer_HigherLocalPubkey_SetsEndpoint(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv) // local spk lexicographically larger
remotePubKey := make([]byte, 32) // remote spk = all zeros (smaller)
err := m.addPeer(remotePubKey, "rosenpass-host:7000", "100.1.1.1", validWGKey(t, 1))
require.NoError(t, err)
require.Len(t, srv.addCalls, 1)
ep := srv.addCalls[0].cfg.Endpoint
require.NotNil(t, ep, "initiator side must set Endpoint")
require.Equal(t, 7000, ep.Port)
require.Equal(t, "100.1.1.1", ep.IP.String())
}
func TestAddPeer_HigherLocalPubkey_EndpointIPIsIPv4Mapped(t *testing.T) {
// Regression guard for the EDESTADDRREQ fix: Endpoint.IP must be 16-byte
// (IPv4-mapped IPv6) so it matches the AF_INET6 listening socket family.
srv := &mockServer{}
m := newTestManager(0xFF, srv)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.NoError(t, err)
ep := srv.addCalls[0].cfg.Endpoint
require.NotNil(t, ep)
require.Len(t, ep.IP, 16, "IPv4 endpoint must be normalized to 16-byte v4-mapped form")
require.True(t, ep.IP.To4() != nil, "Endpoint must still be detected as IPv4")
}
func TestAddPeer_LowerLocalPubkey_LeavesEndpointNil(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0x00, srv) // local spk smaller
remotePubKey := make([]byte, 32)
remotePubKey[0] = 0xFF
err := m.addPeer(remotePubKey, "rp:5000", "100.1.1.1", validWGKey(t, 2))
require.NoError(t, err)
require.Nil(t, srv.addCalls[0].cfg.Endpoint, "responder side must NOT set Endpoint")
}
func TestAddPeer_PresharedKeyPropagated(t *testing.T) {
srv := &mockServer{}
psk := &wgtypes.Key{0x42}
m := newTestManager(0xFF, srv)
m.preSharedKey = (*[32]byte)(psk)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 3))
require.NoError(t, err)
require.Equal(t, [32]byte(*psk), [32]byte(srv.addCalls[0].cfg.PresharedKey))
}
func TestAddPeer_InvalidRosenpassAddr_ReturnsError(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv) // initiator path → parses rosenpassAddr
err := m.addPeer(make([]byte, 32), "not-a-host-port", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
require.Empty(t, srv.addCalls, "server.AddPeer must not run when address parse fails")
}
func TestAddPeer_InvalidWireGuardPubKey_ReturnsError(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", "not-a-valid-key")
require.Error(t, err)
}
func TestAddPeer_ServerError_Propagates(t *testing.T) {
srv := &mockServer{addErr: errors.New("boom")}
m := newTestManager(0xFF, srv)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
}
// Regression guard for issue #4341 (Android crash). If Run() has not completed
// before OnConnected fires, m.rpWgHandler or m.server may be nil. Without the
// nil guards, m.rpWgHandler.AddPeer panics on nil receiver.
func TestAddPeer_NilHandler_ReturnsErrorNoCrash(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
m.rpWgHandler = nil // simulate Run() not yet completed
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
require.Contains(t, err.Error(), "wg handler not initialized")
}
func TestAddPeer_NilServer_ReturnsErrorNoCrash(t *testing.T) {
m := newTestManager(0xFF, nil)
m.server = nil // simulate Run() not yet completed
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
require.Contains(t, err.Error(), "server not initialized")
}
// NewManager must pre-initialize rpWgHandler so the nil-receiver crash from
// issue #4341 cannot occur in the window between NewManager and Run().
func TestNewManager_PreInitializesHandler(t *testing.T) {
psk := wgtypes.Key{}
m, err := NewManager(&psk, "wt0")
require.NoError(t, err)
require.NotNil(t, m.rpWgHandler, "rpWgHandler must be initialized in NewManager")
}
func TestAddPeer_RecordsPeerID(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 5)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey)
require.NoError(t, err)
require.Contains(t, m.rpPeerIDs, wgKey)
}
// --- OnConnected / OnDisconnected ------------------------------------------
func TestOnConnected_NilRemotePubKey_NoAddPeer(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
m.OnConnected(validWGKey(t, 1), nil, "100.1.1.1", "rp:5000")
require.Empty(t, srv.addCalls, "nil remote rosenpass pubkey must skip AddPeer")
require.Empty(t, m.rpPeerIDs)
}
func TestOnConnected_ValidPubKey_CallsAddPeer(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 1)
m.OnConnected(wgKey, make([]byte, 32), "100.1.1.1", "rp:5000")
require.Len(t, srv.addCalls, 1)
require.Contains(t, m.rpPeerIDs, wgKey)
}
func TestOnDisconnected_UnknownPeer_NoOp(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
m.OnDisconnected(validWGKey(t, 99))
require.Empty(t, srv.removed, "unknown peer key must not call RemovePeer")
}
func TestOnDisconnected_KnownPeer_CallsRemoveAndForgets(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 1)
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
require.Contains(t, m.rpPeerIDs, wgKey)
m.OnDisconnected(wgKey)
require.Len(t, srv.removed, 1)
require.NotContains(t, m.rpPeerIDs, wgKey, "peer must be forgotten after disconnect")
}
// --- IsPresharedKeyInitialized ---------------------------------------------
func TestIsPresharedKeyInitialized_UnknownPeer_ReturnsFalse(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
require.False(t, m.IsPresharedKeyInitialized(validWGKey(t, 1)))
}
func TestIsPresharedKeyInitialized_AddedButNotHandshaken_ReturnsFalse(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 2)
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
require.False(t, m.IsPresharedKeyInitialized(wgKey))
}
// --- NetbirdHandler.outputKey ----------------------------------------------
func TestHandler_OutputKey_FirstCallUsesUpdateOnlyFalse(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
pid := rp.PeerID{0x01}
wgKey := wgtypes.Key{0xAA}
h.AddPeer(pid, "wt0", rp.Key(wgKey))
psk := rp.Key{0xBB}
h.HandshakeCompleted(pid, psk)
require.Len(t, iface.calls, 1)
require.False(t, iface.calls[0].updateOnly, "first PSK rotation must use updateOnly=false")
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
}
func TestHandler_OutputKey_SubsequentCallsUseUpdateOnlyTrue(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
pid := rp.PeerID{0x02}
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xCC}))
h.HandshakeCompleted(pid, rp.Key{0x01}) // first
h.HandshakeCompleted(pid, rp.Key{0x02}) // second
require.Len(t, iface.calls, 2)
require.False(t, iface.calls[0].updateOnly)
require.True(t, iface.calls[1].updateOnly, "subsequent rotations must use updateOnly=true")
}
func TestHandler_OutputKey_NilInterface_NoCrashNoCall(t *testing.T) {
h := NewNetbirdHandler()
// no SetInterface — iface remains nil
pid := rp.PeerID{0x03}
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{}))
// Must not panic.
h.HandshakeCompleted(pid, rp.Key{})
}
func TestHandler_OutputKey_UnknownPeer_NoCall(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
h.HandshakeCompleted(rp.PeerID{0xFF}, rp.Key{})
require.Empty(t, iface.calls, "unknown peer id must not trigger SetPresharedKey")
}
func TestHandler_RemovePeer_ClearsInitializedState(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
pid := rp.PeerID{0x04}
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xDD}))
h.HandshakeCompleted(pid, rp.Key{0x01})
require.True(t, h.IsPeerInitialized(pid))
h.RemovePeer(pid)
require.False(t, h.IsPeerInitialized(pid), "RemovePeer must clear initialized flag")
}
func TestHandler_SetInterfaceAfterAddPeer_StillReceivesKey(t *testing.T) {
h := NewNetbirdHandler()
pid := rp.PeerID{0x05}
wgKey := wgtypes.Key{0xEE}
h.AddPeer(pid, "wt0", rp.Key(wgKey))
iface := &mockIface{}
h.SetInterface(iface) // set after AddPeer
h.HandshakeCompleted(pid, rp.Key{0x42})
require.Len(t, iface.calls, 1)
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
}

View File

@@ -0,0 +1,42 @@
package rosenpass
import (
"fmt"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// DeterministicSeedKey derives a 32-byte WireGuard preshared key from a pair
// of peer public keys. Both peers, given the same key pair, produce the same
// output regardless of which side runs the function: the inputs are ordered
// lexicographically before concatenation.
//
// NetBird uses this value as the initial Rosenpass-side preshared key when no
// explicit account-level PSK is configured, so both peers converge on the same
// PSK before the first post-quantum handshake completes.
//
// The resulting key MUST NOT be treated as quantum-safe: it is deterministic
// from public keys and exists only to seed WireGuard until Rosenpass rotates
// in a real post-quantum PSK.
func DeterministicSeedKey(localKey, remoteKey string) (*wgtypes.Key, error) {
lk := []byte(localKey)
rk := []byte(remoteKey)
if len(lk) < 16 || len(rk) < 16 {
return nil, fmt.Errorf("rosenpass: peer keys must be at least 16 bytes (got local=%d, remote=%d)", len(lk), len(rk))
}
var keyInput []byte
if localKey > remoteKey {
keyInput = append(keyInput, lk[:16]...)
keyInput = append(keyInput, rk[:16]...)
} else {
keyInput = append(keyInput, rk[:16]...)
keyInput = append(keyInput, lk[:16]...)
}
key, err := wgtypes.NewKey(keyInput)
if err != nil {
return nil, fmt.Errorf("rosenpass: deterministic seed key: %w", err)
}
return &key, nil
}

View File

@@ -0,0 +1,44 @@
package rosenpass
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestDeterministicSeedKey_SameForBothSides(t *testing.T) {
// Peer A and peer B must derive the same PSK regardless of which side
// computes it: the function orders inputs internally.
a := strings.Repeat("a", 32)
b := strings.Repeat("b", 32)
keyAB, err := DeterministicSeedKey(a, b)
require.NoError(t, err)
keyBA, err := DeterministicSeedKey(b, a)
require.NoError(t, err)
require.Equal(t, keyAB.String(), keyBA.String(), "swapping arguments must yield identical key")
}
func TestDeterministicSeedKey_ChangesWithKeys(t *testing.T) {
a := strings.Repeat("a", 32)
b := strings.Repeat("b", 32)
c := strings.Repeat("c", 32)
keyAB, err := DeterministicSeedKey(a, b)
require.NoError(t, err)
keyAC, err := DeterministicSeedKey(a, c)
require.NoError(t, err)
require.NotEqual(t, keyAB.String(), keyAC.String(), "different peer pair must yield different key")
}
func TestDeterministicSeedKey_TooShortKey_ReturnsError(t *testing.T) {
short := "short" // < 16 bytes
long := strings.Repeat("x", 32)
_, err := DeterministicSeedKey(short, long)
require.Error(t, err)
_, err = DeterministicSeedKey(long, short)
require.Error(t, err)
}

Some files were not shown because too many files have changed in this diff Show More