Compare commits

..

1 Commits

Author SHA1 Message Date
mlsmaycon
5eb28acb11 [management] Account-scoped ephemeral peer cleanup
Replace the per-peer linked list with a per-account map keyed by
accountID. Each entry holds only the latest disconnect timestamp we
have observed for that account and a single timer that fires the next
sweep. Sweeps query the database for the authoritative stale set,
batch the deletes through peers.Manager.DeletePeers, then drop the
account from the tracker when lastDisc + lifeTime <= now (else
re-arm at horizon + cleanupWindow).

The drop rule is the entire termination story: an account stays
tracked only while OnPeerDisconnected keeps refreshing the
timestamp. There is no internal feedback loop that can advance
lastDisc on its own, so once disconnects stop the account drops in
at most one sweep.

A timestamp beats the ref-counter alternative because the counter
drifts positive in three real situations the cleanup loop has no
signal for: peers deleted via the API while offline, peers that
reconnect within the lifetime window, and management restarts. The
timestamp design never claims to know the size of the stale set —
it only knows the latest disconnect we observed and uses that to
bound when it is safe to drop the account.

OnPeerConnected becomes a no-op. The sweep query already filters
reconnected peers at the database level (peer_status_connected =
false in the WHERE clause), so there is nothing the in-memory
tracker needs to do on reconnect. The interface method is preserved
for call-site compatibility.

LoadInitialPeers no longer runs the catch-up query synchronously.
It schedules a deferred load via time.AfterFunc at a random delay
between 8 and 10 minutes. Without the jitter, every management
replica in a fleet-wide deploy would issue the catch-up query
simultaneously. The catch-up itself is one GROUP BY against the
peers table:
```sql
  SELECT account_id, MAX(peer_status_last_seen)
  FROM peers
  WHERE ephemeral = true AND peer_status_connected = false
  GROUP BY account_id
```
For each row the tracker seeds an entry and arms a sweep at
max(now, last_seen + lifeTime) + cleanupWindow — so accounts whose
backlog is already stale get cleaned soon after the delay elapses,
and accounts that disconnected recently wait the remaining window.
OnPeerDisconnected calls that arrive during the delay window seed
the tracker live, and the catch-up query skips accounts that are
already tracked.

Stop() cancels both the deferred initial-load timer and every
per-account sweep timer, and flips a stopped flag so subsequent
OnPeerDisconnected calls are ignored. This makes restarts and test
teardown clean.

Two new store methods:
  GetStaleEphemeralPeerIDsForAccount(ctx, accountID, olderThan)
  GetEphemeralAccountsLastDisconnect(ctx)
Both are scoped, indexable queries that the existing peers table
supports without schema changes.

The pending metric is renamed from
management.ephemeral.peers.pending to
management.ephemeral.accounts.tracked to reflect the new semantics
(it now counts accounts on the cleanup list, not peers). Method
names on the metrics type are unchanged so no production call site
has to move. No new metric labels, no per-account cardinality.

The algorithm was validated against an in-memory SQLite peers
table through an 11-scenario prototype kept under proto/, including
pathological-churn and 4-hour randomized simulations. All scenarios
terminate; max observed per-account sweep rate stays bounded near
the lifeTime + cleanupWindow cadence even under sustained
disconnect churn.

Verification: go build, go vet, race-clean tests across the
ephemeral, store, and telemetry packages, plus a clean
golangci-lint pass on the touched packages.
2026-05-19 09:50:14 +02:00
278 changed files with 9913 additions and 21960 deletions

View File

@@ -1,45 +0,0 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "daily"
open-pull-requests-limit: 15
groups:
actions:
patterns:
- "*"
ignore:
# git-town/action v1.3.x crashes on cyclic PR graphs (self-loop main->main
# fork PRs) via its topological-sort visualization. Pinned to v1.2.1 in
# git-town.yml; block v1.3.x until upstream tolerates cyclic edges.
- dependency-name: "git-town/action"
update-types:
- "version-update:semver-minor"
- "version-update:semver-major"
- package-ecosystem: "gomod"
directories:
- "/"
schedule:
interval: "daily"
open-pull-requests-limit: 15
groups:
aws-sdk:
patterns:
- "github.com/aws/aws-sdk-go-v2/*"
pion:
patterns:
- "github.com/pion/*"
gorm:
patterns:
- "gorm.io/*"
otel:
patterns:
- "go.opentelemetry.io/*"
testcontainers:
patterns:
- "github.com/testcontainers/testcontainers-go/*"
wireguard:
patterns:
- "golang.zx2c4.com/wireguard*"

View File

@@ -12,7 +12,6 @@
- [ ] Is a feature enhancement
- [ ] It is a refactor
- [ ] Created tests that fail without the change (if possible)
- [ ] This change does **not** modify the public API, gRPC protocols, functionality behavior, CLI / service flags, or introduce a new feature — **OR** I have discussed it with the NetBird team beforehand (link the issue / Slack thread in the description). See [CONTRIBUTING.md](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTING.md#discuss-changes-with-the-netbird-team-first).
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).

View File

@@ -2,16 +2,16 @@ name: Check License Dependencies
on:
push:
branches: [main]
branches: [ main ]
paths:
- "go.mod"
- "go.sum"
- ".github/workflows/check-license-dependencies.yml"
- 'go.mod'
- 'go.sum'
- '.github/workflows/check-license-dependencies.yml'
pull_request:
paths:
- "go.mod"
- "go.sum"
- ".github/workflows/check-license-dependencies.yml"
- 'go.mod'
- 'go.sum'
- '.github/workflows/check-license-dependencies.yml'
jobs:
check-internal-dependencies:
@@ -19,10 +19,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- uses: actions/checkout@v4
- name: Check for problematic license dependencies
run: |
@@ -59,57 +56,55 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: true
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version-file: 'go.mod'
cache: true
- name: Install go-licenses
run: go install github.com/google/go-licenses@v1.6.0
- name: Install go-licenses
run: go install github.com/google/go-licenses@v1.6.0
- name: Check for GPL/AGPL licensed dependencies
run: |
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
- name: Check for GPL/AGPL licensed dependencies
run: |
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
echo ""
# Check all Go packages for copyleft licenses, excluding internal netbird packages
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
if [ -n "$COPYLEFT_DEPS" ]; then
echo "Found copyleft licensed dependencies:"
echo "$COPYLEFT_DEPS"
echo ""
# Check all Go packages for copyleft licenses, excluding internal netbird packages
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
# Filter out dependencies that are only pulled in by internal AGPL packages
INCOMPATIBLE=""
while IFS=',' read -r package url license; do
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
# Find ALL packages that import this GPL package using go list
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
if [ -n "$COPYLEFT_DEPS" ]; then
echo "Found copyleft licensed dependencies:"
echo "$COPYLEFT_DEPS"
echo ""
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
# Filter out dependencies that are only pulled in by internal AGPL packages
INCOMPATIBLE=""
while IFS=',' read -r package url license; do
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
# Find ALL packages that import this GPL package using go list
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
else
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
fi
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
else
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
fi
done <<< "$COPYLEFT_DEPS"
if [ -n "$INCOMPATIBLE" ]; then
echo ""
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
echo -e "$INCOMPATIBLE"
exit 1
fi
fi
done <<< "$COPYLEFT_DEPS"
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
if [ -n "$INCOMPATIBLE" ]; then
echo ""
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
echo -e "$INCOMPATIBLE"
exit 1
fi
fi
echo "✅ All external license dependencies are compatible with BSD-3-Clause"

View File

@@ -83,7 +83,7 @@ jobs:
- name: Verify docs PR exists (and is open or merged)
if: steps.validate.outputs.mode == 'added'
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
uses: actions/github-script@v7
id: verify
with:
pr_number: ${{ steps.extract.outputs.pr_number }}

View File

@@ -8,10 +8,11 @@ jobs:
post:
runs-on: ubuntu-latest
steps:
- uses: roots/discourse-topic-github-release-action@557d74ea05b6cc0c47f555c1d5d28a89d904005b # v1.1.0
- uses: roots/discourse-topic-github-release-action@main
with:
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
discourse-base-url: https://forum.netbird.io
discourse-author-username: NetBird
discourse-category: 17
discourse-tags: releases
discourse-tags:
releases

View File

@@ -3,7 +3,7 @@ name: Git Town
on:
pull_request:
branches:
- "**"
- '**'
jobs:
git-town:
@@ -15,9 +15,7 @@ jobs:
pull-requests: write
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- uses: git-town/action@3d8b878379abb1ee393fb49865a28b4a6c2cd3b0 # v1.2.1
- uses: actions/checkout@v4
- uses: git-town/action@v1.2.1
with:
skip-single-stacks: true

View File

@@ -16,18 +16,16 @@ jobs:
runs-on: macos-latest
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: macos-gotest-${{ hashFiles('**/go.sum') }}
@@ -46,3 +44,4 @@ jobs:
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)

View File

@@ -15,31 +15,20 @@ jobs:
name: "Client / Unit"
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Read Go version from go.mod
id: goversion
run: echo "version=$(awk '/^go / {print $2}' go.mod)" >> "$GITHUB_OUTPUT"
- uses: actions/checkout@v4
- name: Test in FreeBSD
id: test
env:
GO_VERSION: ${{ steps.goversion.outputs.version }}
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
uses: vmactions/freebsd-vm@v1
with:
usesh: true
copyback: false
release: "15.0"
envs: "GO_VERSION"
release: "14.2"
prepare: |
pkg install -y curl pkgconf xorg
GO_TARBALL="go${GO_VERSION}.freebsd-amd64.tar.gz"
GO_TARBALL="go1.25.3.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -vLO "$GO_URL"
tar -C /usr/local -vxzf "$GO_TARBALL"
tar -C /usr/local -vxzf "$GO_TARBALL"
# -x - to print all executed commands
# -e - to faile on first error

View File

@@ -18,11 +18,9 @@ jobs:
management: ${{ steps.filter.outputs.management }}
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
- uses: dorny/paths-filter@v3
id: filter
with:
filters: |
@@ -30,7 +28,7 @@ jobs:
- 'management/**'
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
@@ -38,10 +36,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache@v4
id: cache
with:
path: |
@@ -115,16 +113,14 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: ["386", "amd64"]
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
@@ -132,10 +128,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
@@ -158,20 +154,18 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags "devcert integration" -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
test_client_on_docker:
name: "Client (Docker) / Unit"
needs: [build-cache]
needs: [ build-cache ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
@@ -183,7 +177,7 @@ jobs:
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@v4
id: cache-restore
with:
path: |
@@ -220,7 +214,7 @@ jobs:
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags "devcert integration" -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
'
test_relay:
@@ -237,12 +231,10 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
@@ -254,10 +246,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
@@ -285,16 +277,14 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: ["386", "amd64"]
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
@@ -308,7 +298,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
@@ -334,16 +324,14 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: ["386", "amd64"]
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
@@ -355,10 +343,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
@@ -382,21 +370,19 @@ jobs:
test_management:
name: "Management / Unit"
needs: [build-cache]
needs: [ build-cache ]
strategy:
fail-fast: false
matrix:
arch: ["amd64"]
store: ["sqlite", "postgres", "mysql"]
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
@@ -404,10 +390,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
@@ -424,7 +410,7 @@ jobs:
- name: Login to Docker hub
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -441,7 +427,7 @@ jobs:
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
@@ -451,13 +437,13 @@ jobs:
benchmark:
name: "Management / Benchmark"
needs: [build-cache]
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: ["amd64"]
store: ["sqlite", "postgres"]
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Create Docker network
@@ -488,12 +474,10 @@ jobs:
prom/prometheus
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
@@ -501,10 +485,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
@@ -521,7 +505,7 @@ jobs:
- name: Login to Docker hub
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -545,13 +529,13 @@ jobs:
api_benchmark:
name: "Management / Benchmark (API)"
needs: [build-cache]
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: ["amd64"]
store: ["sqlite", "postgres"]
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Create Docker network
@@ -582,12 +566,10 @@ jobs:
prom/prometheus
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
@@ -595,10 +577,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
@@ -615,7 +597,7 @@ jobs:
- name: Login to Docker hub
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -641,22 +623,20 @@ jobs:
api_integration_test:
name: "Management / Integration"
needs: [build-cache]
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: ["amd64"]
store: ["sqlite", "postgres"]
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
@@ -664,10 +644,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}

View File

@@ -18,12 +18,10 @@ jobs:
runs-on: windows-latest
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
id: go
with:
go-version-file: "go.mod"
@@ -35,7 +33,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache@v4
with:
path: |
${{ env.cache }}
@@ -46,15 +44,16 @@ jobs:
${{ runner.os }}-go-
- name: Download wintun
uses: carlosperate/download-file-action@v2
id: download-wintun
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
with:
url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
destination: ${{ env.downloadPath }}\wintun.zip
sha256: 07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
file-name: wintun.zip
location: ${{ env.downloadPath }}
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
- name: Decompressing wintun files
run: tar -xvf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'

View File

@@ -15,11 +15,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: codespell
uses: codespell-project/actions-codespell@8f01853be192eb0f849a5c7d721450e7a467c579 # v2.2
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
skip: go.mod,go.sum,**/proxy/web/**
@@ -40,15 +38,13 @@ jobs:
timeout-minutes: 15
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest'
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
@@ -56,7 +52,7 @@ jobs:
if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: golangci-lint
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
with:
version: latest
skip-cache: true

View File

@@ -22,9 +22,7 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: run install script
env:

View File

@@ -16,25 +16,23 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
- name: Setup Android SDK
uses: android-actions/setup-android@40fd30fb8d7440372e1316f5d1809ec01dcd3699 # v4.0.1
uses: android-actions/setup-android@v3
with:
cmdline-tools-version: 8512546
- name: Setup Java
uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654
uses: actions/setup-java@v4
with:
java-version: "11"
distribution: "adopt"
- name: NDK Cache
id: ndk-cache
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache@v4
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620
@@ -54,11 +52,9 @@ jobs:
runs-on: macos-latest
steps:
- name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
- name: install gomobile

View File

@@ -9,7 +9,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Validate PR title prefix
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
uses: actions/github-script@v7
with:
script: |
const title = context.payload.pull_request.title;

View File

@@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Check for proto tool version changes
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
uses: actions/github-script@v7
with:
script: |
const files = await github.paginate(github.rest.pulls.listFiles, {
@@ -20,66 +20,34 @@ jobs:
per_page: 100,
});
const modifiedPbFiles = files.filter(
f => f.filename.endsWith('.pb.go') && f.status === 'modified'
);
if (modifiedPbFiles.length === 0) {
console.log('No modified .pb.go files to check');
const pbFiles = files.filter(f => f.filename.endsWith('.pb.go'));
const missingPatch = pbFiles.filter(f => !f.patch).map(f => f.filename);
if (missingPatch.length > 0) {
core.setFailed(
`Cannot inspect patch data for:\n` +
missingPatch.map(f => `- ${f}`).join('\n') +
`\nThis can happen with very large PRs. Verify proto versions manually.`
);
return;
}
const versionPattern = /^\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
const baseSha = context.payload.pull_request.base.sha;
const headSha = context.payload.pull_request.head.sha;
async function getVersionHeader(path, ref) {
try {
const res = await github.rest.repos.getContent({
owner: context.repo.owner,
repo: context.repo.repo,
path,
ref,
});
if (!res.data.content) {
return { ok: false, reason: 'no inline content (file too large)' };
}
const content = Buffer.from(res.data.content, 'base64').toString('utf8');
const lines = content
.split('\n')
.slice(0, 20)
.filter(line => versionPattern.test(line));
return { ok: true, lines };
} catch (e) {
return { ok: false, reason: e.message };
}
}
const versionPattern = /^[+-]\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
const violations = [];
for (const file of modifiedPbFiles) {
const [base, head] = await Promise.all([
getVersionHeader(file.filename, baseSha),
getVersionHeader(file.filename, headSha),
]);
if (!base.ok || !head.ok) {
core.warning(
`Skipping ${file.filename}: base=${base.ok ? 'ok' : base.reason}, head=${head.ok ? 'ok' : head.reason}`
);
continue;
}
if (base.lines.join('\n') !== head.lines.join('\n')) {
for (const file of pbFiles) {
const changed = file.patch
.split('\n')
.filter(line => versionPattern.test(line));
if (changed.length > 0) {
violations.push({
file: file.filename,
base: base.lines,
head: head.lines,
lines: changed,
});
}
}
if (violations.length > 0) {
const details = violations.map(v =>
`${v.file}:\n` +
` base:\n${v.base.map(l => ' ' + l).join('\n') || ' (none)'}\n` +
` head:\n${v.head.map(l => ' ' + l).join('\n') || ' (none)'}`
`${v.file}:\n${v.lines.map(l => ' ' + l).join('\n')}`
).join('\n\n');
core.setFailed(

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.5"
SIGN_PIPE_VER: "v0.1.4"
GORELEASER_VER: "v2.14.3"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
@@ -24,9 +24,7 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Generate FreeBSD port diff
run: bash release_files/freebsd-port-diff.sh
@@ -53,26 +51,19 @@ jobs:
echo "Generated files for version: $VERSION"
cat netbird-*.diff
- name: Read Go version from go.mod
id: goversion
run: echo "version=$(awk '/^go / {print $2}' go.mod)" >> "$GITHUB_OUTPUT"
- name: Test FreeBSD port
if: steps.check_diff.outputs.diff_exists == 'true'
env:
GO_VERSION: ${{ steps.goversion.outputs.version }}
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
uses: vmactions/freebsd-vm@v1
with:
usesh: true
copyback: false
release: "15.0"
envs: "GO_VERSION"
prepare: |
# Install required packages
pkg install -y git curl portlint
pkg install -y git curl portlint go
# Install Go for building
GO_TARBALL="go${GO_VERSION}.freebsd-amd64.tar.gz"
GO_TARBALL="go1.25.5.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -LO "$GO_URL"
tar -C /usr/local -xzf "$GO_TARBALL"
@@ -102,19 +93,19 @@ jobs:
# Show patched Makefile
version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
cd /usr/ports/security/netbird
export BATCH=yes
make package
pkg add ./work/pkg/netbird-*.pkg
netbird version | grep "$version"
echo "FreeBSD port test completed successfully!"
- name: Upload FreeBSD port files
if: steps.check_diff.outputs.diff_exists == 'true'
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
uses: actions/upload-artifact@v4
with:
name: freebsd-port-files
path: |
@@ -133,25 +124,26 @@ jobs:
env:
flags: ""
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
persist-credentials: false
- name: Parse semver string
id: semver_parser
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
@@ -164,18 +156,18 @@ jobs:
- name: check git status
run: git --no-pager diff --exit-code
- name: Set up QEMU
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a #v4.0.0
uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
uses: docker/setup-buildx-action@v2
- name: Login to Docker hub
if: github.event_name != 'pull_request'
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Log in to the GitHub container registry
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -199,7 +191,7 @@ jobs:
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
- name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
args: release --clean ${{ env.flags }}
@@ -290,28 +282,28 @@ jobs:
} >> "$GITHUB_OUTPUT"
- name: upload non tags for debug purposes
id: upload_release
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
uses: actions/upload-artifact@v4
with:
name: release
path: dist/
retention-days: 7
- name: upload linux packages
id: upload_linux_packages
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
uses: actions/upload-artifact@v4
with:
name: linux-packages
path: dist/netbird_linux**
retention-days: 7
- name: upload windows packages
id: upload_windows_packages
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
uses: actions/upload-artifact@v4
with:
name: windows-packages
path: dist/netbird_windows**
retention-days: 7
- name: upload macos packages
id: upload_macos_packages
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
uses: actions/upload-artifact@v4
with:
name: macos-packages
path: dist/netbird_darwin**
@@ -322,26 +314,27 @@ jobs:
outputs:
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
persist-credentials: false
- name: Parse semver string
id: semver_parser
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
@@ -382,7 +375,7 @@ jobs:
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
@@ -411,7 +404,7 @@ jobs:
run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: upload non tags for debug purposes
id: upload_release_ui
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
uses: actions/upload-artifact@v4
with:
name: release-ui
path: dist/
@@ -425,17 +418,16 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
persist-credentials: false
- name: Set up Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
@@ -449,7 +441,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
@@ -457,7 +449,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: upload non tags for debug purposes
id: upload_release_ui_darwin
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
uses: actions/upload-artifact@v4
with:
name: release-ui-darwin
path: dist/
@@ -482,26 +474,27 @@ jobs:
PackageWorkdir: netbird_windows_${{ matrix.arch }}
downloadPath: '${{ github.workspace }}\temp'
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Parse semver string
id: semver_parser
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- name: Checkout
uses: actions/checkout@v4
- name: Add 7-Zip to PATH
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- name: Download release artifacts
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
uses: actions/download-artifact@v4
with:
name: release
path: release
- name: Download UI release artifacts
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
uses: actions/download-artifact@v4
with:
name: release-ui
path: release-ui
@@ -521,27 +514,29 @@ jobs:
Get-ChildItem $workdir
- name: Download wintun
uses: carlosperate/download-file-action@v2
id: download-wintun
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
with:
url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
destination: ${{ env.downloadPath }}\wintun.zip
sha256: 07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
file-name: wintun.zip
location: ${{ env.downloadPath }}
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
- name: Decompress wintun files
run: tar -xvf "${{ env.downloadPath }}\wintun.zip" -C ${{ env.downloadPath }}
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
- name: Move wintun.dll into dist
run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
- name: Download Mesa3D (amd64 only)
uses: carlosperate/download-file-action@v2
id: download-mesa3d
if: matrix.arch == 'amd64'
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
with:
url: https://pkgs.netbird.io/mesa3d/MesaForWindows-x64-20.1.8.7z
destination: ${{ env.downloadPath }}\mesa3d.7z
sha256: 71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9
file-url: https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z
file-name: mesa3d.7z
location: ${{ env.downloadPath }}
sha256: '71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9'
- name: Extract Mesa3D driver (amd64 only)
if: matrix.arch == 'amd64'
@@ -552,38 +547,35 @@ jobs:
run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
- name: Download EnVar plugin for NSIS
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
uses: carlosperate/download-file-action@v2
with:
url: https://pkgs.netbird.io/nsis/EnVar_plugin.zip
destination: ${{ github.workspace }}\envar_plugin.zip
sha256: e9aa92de351345ed82795251d838f1ae9041ba35af9d381a5780c7843b01f56a
file-url: https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip
file-name: envar_plugin.zip
location: ${{ github.workspace }}
- name: Extract EnVar plugin
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/envar_plugin.zip"
- name: Download ShellExecAsUser plugin for NSIS (amd64 only)
uses: carlosperate/download-file-action@v2
if: matrix.arch == 'amd64'
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
with:
url: https://pkgs.netbird.io/nsis/ShellExecAsUser_amd64-Unicode.7z
destination: ${{ github.workspace }}\ShellExecAsUser_amd64-Unicode.7z
sha256: 0a55ea25c7330a92cec028eda8afcaf1b1a7092e0dfb77c21c8f654564b4ff9d
file-url: https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z
file-name: ShellExecAsUser_amd64-Unicode.7z
location: ${{ github.workspace }}
- name: Extract ShellExecAsUser plugin (amd64 only)
if: matrix.arch == 'amd64'
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z"
- name: Build NSIS installer
shell: pwsh
uses: joncloud/makensis-action@v3.3
with:
additional-plugin-paths: ${{ github.workspace }}/NSIS_Plugins/Plugins
script-file: client/installer.nsis
arguments: "/V4 /DARCH=${{ matrix.arch }}"
env:
APPVER: ${{ steps.semver_parser.outputs.major }}.${{ steps.semver_parser.outputs.minor }}.${{ steps.semver_parser.outputs.patch }}.${{ github.run_id }}
run: |
$nsisPluginDir = "C:\Program Files (x86)\NSIS\Plugins\x86-unicode"
$srcPlugins = "${{ github.workspace }}\NSIS_Plugins\Plugins"
Get-ChildItem -Path $srcPlugins -Recurse -Filter *.dll |
Copy-Item -Destination $nsisPluginDir -Force
& "C:\Program Files (x86)\NSIS\makensis.exe" /V4 "/DARCH=${{ matrix.arch }}" client\installer.nsis
if ($LASTEXITCODE -ne 0) { throw "makensis failed with exit code $LASTEXITCODE" }
- name: Rename NSIS installer
run: mv netbird-installer.exe netbird_installer_test_windows_${{ matrix.arch }}.exe
@@ -600,7 +592,7 @@ jobs:
- name: Upload installer artifacts
if: always()
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
uses: actions/upload-artifact@v4
with:
name: windows-installer-test-${{ matrix.arch }}
path: |
@@ -619,7 +611,7 @@ jobs:
pull-requests: write
steps:
- name: Create or update PR comment
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
uses: actions/github-script@v7
env:
RELEASE_RESULT: ${{ needs.release.result }}
RELEASE_UI_RESULT: ${{ needs.release_ui.result }}
@@ -711,7 +703,7 @@ jobs:
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Trigger binaries sign pipelines
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
uses: benc-uk/workflow-dispatch@v1
with:
workflow: Sign bin and installer
repo: netbirdio/sign-pipelines

View File

@@ -14,9 +14,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Trigger main branch sync
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
uses: benc-uk/workflow-dispatch@v1
with:
workflow: sync-main.yml
repo: ${{ secrets.UPSTREAM_REPO }}
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "sha": "${{ github.sha }}" }'
inputs: '{ "sha": "${{ github.sha }}" }'

View File

@@ -3,7 +3,7 @@ name: sync tag
on:
push:
tags:
- "v*"
- 'v*'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
@@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Trigger release tag sync
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
uses: benc-uk/workflow-dispatch@v1
with:
workflow: sync-tag.yml
ref: main
@@ -29,7 +29,7 @@ jobs:
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
steps:
- name: Trigger android-client submodule bump
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
with:
workflow: bump-netbird.yml
ref: main
@@ -42,10 +42,10 @@ jobs:
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
steps:
- name: Trigger ios-client submodule bump
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
with:
workflow: bump-netbird.yml
ref: main
repo: netbirdio/ios-client
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }'
inputs: '{ "tag": "${{ github.ref_name }}" }'

View File

@@ -6,10 +6,10 @@ on:
- main
pull_request:
paths:
- "infrastructure_files/**"
- ".github/workflows/test-infrastructure-files.yml"
- "management/cmd/**"
- "signal/cmd/**"
- 'infrastructure_files/**'
- '.github/workflows/test-infrastructure-files.yml'
- 'management/cmd/**'
- 'signal/cmd/**'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
@@ -20,7 +20,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
store: ["sqlite", "postgres", "mysql"]
store: [ 'sqlite', 'postgres', 'mysql' ]
services:
postgres:
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
@@ -68,17 +68,15 @@ jobs:
run: sudo apt-get install -y curl
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
- name: Cache Go modules
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -141,8 +139,8 @@ jobs:
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_SIGNAL_PORT: 12345
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
NETBIRD_STORE_ENGINE_POSTGRES_DSN: "${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$"
NETBIRD_STORE_ENGINE_MYSQL_DSN: "${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$"
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
@@ -256,9 +254,7 @@ jobs:
run: sudo apt-get install -y jq
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: run script with Zitadel PostgreSQL
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh

View File

@@ -3,9 +3,9 @@ name: update docs
on:
push:
tags:
- "v*"
- 'v*'
paths:
- "shared/management/http/api/openapi.yml"
- 'shared/management/http/api/openapi.yml'
jobs:
trigger_docs_api_update:
@@ -13,10 +13,10 @@ jobs:
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Trigger API pages generation
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
uses: benc-uk/workflow-dispatch@v1
with:
workflow: generate api pages
repo: netbirdio/docs
ref: "refs/heads/main"
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}" }'
inputs: '{ "tag": "${{ github.ref }}" }'

View File

@@ -19,17 +19,15 @@ jobs:
GOARCH: wasm
steps:
- name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Install golangci-lint
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
with:
version: latest
install-mode: binary
@@ -44,11 +42,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
- name: Build Wasm client
@@ -69,3 +65,4 @@ jobs:
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
exit 1
fi

View File

@@ -15,7 +15,6 @@ If you haven't already, join our slack workspace [here](https://docs.netbird.io/
- [Contributing to NetBird](#contributing-to-netbird)
- [Contents](#contents)
- [Code of conduct](#code-of-conduct)
- [Discuss changes with the NetBird team first](#discuss-changes-with-the-netbird-team-first)
- [Directory structure](#directory-structure)
- [Development setup](#development-setup)
- [Requirements](#requirements)
@@ -34,14 +33,6 @@ Conduct which can be found in the file [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md).
By participating, you are expected to uphold this code. Please report
unacceptable behavior to community@netbird.io.
## Discuss changes with the NetBird team first
Changes to the **public API**, **gRPC protocols**, **functionality behavior**, **CLI / service flags**, or **new features** should be discussed with the NetBird team before you start the work. These surfaces are part of NetBird's contract with operators, self-hosters, and downstream integrators, and changes to them have compatibility, security, and release-planning implications that benefit from an early conversation.
Open an issue or reach out on [Slack](https://docs.netbird.io/slack-url) to talk through what you have in mind. We'll help shape the change, flag any constraints we know about, and confirm the direction so the PR review can focus on implementation rather than design.
Typical bug fixes, internal refactors, documentation updates, and tests do not need pre-discussion — open the PR directly.
## Directory structure
The NetBird project monorepo is organized to maintain most of its individual dependencies code within their directories, except for a few auxiliary or shared packages.

153
README.md
View File

@@ -1,134 +1,147 @@
<div align="center">
<p align="center">
<img width="234" src="docs/media/logo-full.png" alt="NetBird logo"/>
</p>
<p align="center">
<a href="https://sonarcloud.io/dashboard?id=netbirdio_netbird">
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" alt="SonarCloud alert status"/>
</a>
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
<img src="https://img.shields.io/badge/license-BSD--3-blue" alt="BSD-3 License"/>
</a>
<br/>
<br/>
<p align="center">
<img width="234" src="docs/media/logo-full.png"/>
</p>
<p>
<a href="https://img.shields.io/badge/license-BSD--3-blue)">
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" />
</a>
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a>
<br>
<a href="https://docs.netbird.io/slack-url">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack" alt="NetBird Slack"/>
</a>
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
<a href="https://forum.netbird.io">
<img src="https://img.shields.io/badge/community%20forum-@netbird-red.svg?logo=discourse" alt="Community forum"/>
</a>
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
</a>
<br>
<a href="https://gurubase.io/g/netbird">
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF" alt="Gurubase: Ask NetBird Guru"/>
</a>
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
</a>
</p>
</div>
<p align="center">
<strong>
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
<br/>
See <a href="https://netbird.io/docs/">Documentation</a>
<br/>
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
</strong>
<strong>
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
<br/>
See <a href="https://netbird.io/docs/">Documentation</a>
<br/>
<strong>
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
</strong>
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
<br/>
</strong>
<br>
<strong>
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
</strong>
<br>
<br>
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
New: NetBird terraform provider
</a>
</p>
<br>
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
### Open Source Network Security in a Single Platform
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
### Self-host NetBird (video)
### Self-Host NetBird (Video)
[![Watch the video](https://img.youtube.com/vi/bZAgpT6nzaQ/0.jpg)](https://youtu.be/bZAgpT6nzaQ)
### Key features
| Connectivity | Management | Security | Automation | Platforms |
|---|---|---|---|---|
| ✓ [Kernel WireGuard](https://docs.netbird.io/about-netbird/why-wireguard-with-netbird) | ✓ [Admin Web UI](https://github.com/netbirdio/dashboard) | ✓ [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) | ✓ [Public API](https://docs.netbird.io/api) | ✓ [Linux](https://docs.netbird.io/get-started/install/linux) |
| ✓ [Peer-to-peer connections](https://docs.netbird.io/about-netbird/how-netbird-works) | ✓ Auto peer discovery and configuration | ✓ [Access control: groups & rules](https://docs.netbird.io/how-to/manage-network-access) | ✓ [Setup keys for bulk provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) | ✓ [macOS](https://docs.netbird.io/get-started/install/macos) |
| Connection relay fallback | ✓ [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) | ✓ [Activity logging](https://docs.netbird.io/how-to/audit-events-logging) | ✓ [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) | ✓ [Windows](https://docs.netbird.io/get-started/install/windows) |
| [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) | ✓ [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) | ✓ [Traffic events](https://docs.netbird.io/manage/activity/traffic-events-logging) | ✓ [IdP groups sync with JWT](https://docs.netbird.io/manage/team/idp-sync) | ✓ [Android](https://docs.netbird.io/get-started/install/android) |
| ✓ [Domain-based DNS routes](https://docs.netbird.io/manage/dns/dns-aliases-for-routed-networks) | ✓ [Custom DNS zones](https://docs.netbird.io/manage/dns/custom-zones) | ✓ [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) | ✓ [Terraform provider](https://registry.terraform.io/providers/netbirdio/netbird/latest) | ✓ [Android TV](https://docs.netbird.io/get-started/install/android-tv) |
| ✓ [Exit nodes](https://docs.netbird.io/manage/network-routes/use-cases/exit-nodes) | ✓ [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) | ✓ Peer-to-peer encryption | ✓ [Ansible collection](https://github.com/netbirdio/ansible-netbird) | ✓ [iOS](https://docs.netbird.io/get-started/install/ios) |
| ✓ [IPv6 dual-stack overlay](https://docs.netbird.io/manage/settings/ipv6) | ✓ [Multi-account profile switching](https://docs.netbird.io/client/profiles) | ✓ [SSH with central access policies](https://docs.netbird.io/manage/peers/ssh) | | ✓ [Apple TV](https://docs.netbird.io/get-started/install/tvos) |
| ✓ [Browser SSH & RDP](https://docs.netbird.io/manage/peers/browser-client) | | ✓ [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) | | ✓ FreeBSD |
| ✓ [Reverse proxy with auto-TLS](https://docs.netbird.io/manage/reverse-proxy) | | ✓ [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication) | | ✓ [pfSense](https://docs.netbird.io/get-started/install/pfsense) |
| | | | | ✓ [OPNsense](https://docs.netbird.io/get-started/install/opnsense) |
| | | | | ✓ [MikroTik RouterOS](https://docs.netbird.io/use-cases/homelab/client-on-mikrotik-router) |
| | | | | ✓ OpenWRT |
| | | | | ✓ [Synology](https://docs.netbird.io/get-started/install/synology) |
| | | | | ✓ [TrueNAS](https://docs.netbird.io/get-started/install/truenas) |
| | | | | ✓ [Proxmox](https://docs.netbird.io/get-started/install/proxmox-ve) |
| | | | | ✓ [Raspberry Pi](https://docs.netbird.io/get-started/install/raspberrypi) |
| | | | | ✓ [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) |
| | | | | ✓ [Container](https://docs.netbird.io/get-started/install/docker) |
| Connectivity | Management | Security | Automation| Platforms |
|----|----|----|----|----|
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
| <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
| <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
| <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
||| <ul><li>- \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ui></li> || <ul><li>- \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)</ui></li> |
||||| <ul><li>- \[x] Docker</ui></li> |
### Quickstart with NetBird Cloud
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install).
- Follow the steps to sign up with Google, Microsoft, GitHub or your email address.
- Check the NetBird [admin UI](https://app.netbird.io/).
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
- Follow the steps to sign-up with Google, Microsoft, GitHub or your email address.
- Check NetBird [admin UI](https://app.netbird.io/).
- Add more machines.
### Quickstart with self-hosted NetBird
This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM. Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IdPs.
> This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM.
Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IDPs.
**Infrastructure requirements:**
- A Linux VM with at least **1 CPU** and **2 GB** of memory.
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port **3478**.
- A **public domain** name pointing to the VM.
- A Linux VM with at least **1CPU** and **2GB** of memory.
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port: **3478**.
- **Public domain** name pointing to the VM.
**Software requirements:**
- Docker with the Compose plugin (Compose v2 or higher). See the [Docker installation guide](https://docs.docker.com/engine/install/).
- Docker installed on the VM with the docker-compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
- [curl](https://curl.se/) installed.
Usually available in the official repositories and can be installed with `sudo apt install curl` or `sudo yum install curl`
**Steps**
- Download and run the installation script:
```bash
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started.sh | bash
```
- Once finished, you can manage the resources via `docker-compose`
### A bit on NetBird internals
- Every machine in the network runs the [NetBird agent](client/), which manages WireGuard.
- Every agent connects to the [Management Service](management/), which holds network state, manages peer IPs, and distributes updates to agents.
- Agents use ICE (via [pion/ice](https://github.com/pion/ice)) to discover connection candidates for peer-to-peer connections.
- Candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
- Agents negotiate a connection through the [Signal Service](signal/), exchanging end-to-end encrypted messages with candidates.
- When NAT traversal fails (e.g. mobile carrier-grade NAT) and a direct p2p connection isn't possible, the system falls back to a [Relay Service](relay/) and a secure WireGuard tunnel is established through it.
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
- NetBird agent uses WebRTC ICE implemented in [pion/ice library](https://github.com/pion/ice) to discover connection candidates when establishing a peer-to-peer connection between machines.
- Connection candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
- Agents negotiate a connection through [Signal Service](signal/) passing p2p encrypted messages with candidates.
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and a p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
<p float="left" align="middle">
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700" alt="NetBird high-level architecture diagram"/>
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700"/>
</p>
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
### Community projects
- [NetBird installer script](https://github.com/physk/netbird-installer)
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) - terminal UI for managing NetBird peers, routes, and settings
- [caddy-netbird](https://github.com/lixmal/caddy-netbird) - Caddy plugin that embeds a NetBird client for proxying HTTP and TCP/UDP traffic through NetBird networks
- [NetBird installer script](https://github.com/physk/netbird-installer)
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) — terminal UI for managing NetBird peers, routes, and settings
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
### Support acknowledgement
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by the Federal Ministry of Education and Research of the Federal Republic of Germany. Together with the [CISPA Helmholtz Center for Information Security](https://cispa.de/en), NetBird brings security best practices and simplicity to private networking.
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by The Federal Ministry of Education and Research of The Federal Republic of Germany. Together with [CISPA Helmholtz Center for Information Security](https://cispa.de/en) NetBird brings the security best practices and simplicity to private networking.
![CISPA_Logo_BLACK_EN_RZ_RGB (1)](https://user-images.githubusercontent.com/700848/203091324-c6d311a0-22b5-4b05-a288-91cbc6cdcc46.png)
### Acknowledgements
We build on open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE](https://github.com/pion/ice), and [Rosenpass](https://rosenpass.eu). We greatly appreciate the work these projects are doing, and we'd love it if you could support them too (e.g., by starring or contributing).
### Testimonials
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g., by giving a star or a contribution).
### Legal
This repository is licensed under the BSD-3-Clause license, which applies to all parts of the repository except for the directories management/, signal/ and relay/.
This repository is licensed under BSD-3-Clause license that applies to all parts of the repository except for the directories management/, signal/ and relay/.
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.

View File

@@ -11,7 +11,7 @@ import (
"go.opentelemetry.io/otel"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/management-integrations/integrations"
nbcache "github.com/netbirdio/netbird/management/server/cache"
@@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
t.Fatal(err)
}
iv, _ := validator.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
require.NoError(t, err)

View File

@@ -12,7 +12,6 @@ import (
"sync"
"github.com/sirupsen/logrus"
wgdevice "golang.zx2c4.com/wireguard/device"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface"
@@ -85,12 +84,6 @@ type Options struct {
DisableIPv6 bool
// BlockInbound blocks all inbound connections from peers
BlockInbound bool
// BlockLANAccess blocks the embedded peer from reaching the host's
// LAN (RFC 1918, link-local, loopback) when it's used as a routing
// peer. Mirrors profilemanager.ConfigInput.BlockLANAccess. Useful
// when the embedded client must never act as a stepping stone into
// the host's local network (e.g. the proxy's overlay peer).
BlockLANAccess bool
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
WireguardPort *int
// MTU is the MTU for the tunnel interface.
@@ -101,26 +94,6 @@ type Options struct {
MTU *uint16
// DNSLabels defines additional DNS labels configured in the peer.
DNSLabels []string
// Performance configures the tunnel's buffer pool cap and batch size.
Performance Performance
}
// Performance configures the embedded client's tunnel memory/throughput knobs.
//
// These settings are process-global: any non-nil field also becomes the
// default for Clients constructed by later embed.New calls in the same
// process. Nil fields are ignored.
type Performance struct {
// PreallocatedBuffersPerPool caps the per-tunnel buffer pool. Zero
// leaves the pool unbounded. Lower values trade throughput for a
// tighter memory ceiling. May also be changed on a running Client via
// Client.SetPerformance, provided this field was nonzero at construction.
PreallocatedBuffersPerPool *uint32
// MaxBatchSize overrides the number of packets the tunnel reads or
// writes per syscall, which also bounds eager buffer allocation per
// worker. Zero uses the platform default. Applied at construction
// only; ignored by Client.SetPerformance.
MaxBatchSize *uint32
}
// validateCredentials checks that exactly one credential type is provided
@@ -202,7 +175,6 @@ func New(opts Options) (*Client, error) {
DisableClientRoutes: &opts.DisableClientRoutes,
DisableIPv6: &opts.DisableIPv6,
BlockInbound: &opts.BlockInbound,
BlockLANAccess: &opts.BlockLANAccess,
WireguardPort: opts.WireguardPort,
MTU: opts.MTU,
DNSLabels: parsedLabels,
@@ -220,13 +192,6 @@ func New(opts Options) (*Client, error) {
config.PrivateKey = opts.PrivateKey
}
if opts.Performance.PreallocatedBuffersPerPool != nil {
wgdevice.SetPreallocatedBuffersPerPool(*opts.Performance.PreallocatedBuffersPerPool)
}
if opts.Performance.MaxBatchSize != nil {
wgdevice.SetMaxBatchSizeOverride(*opts.Performance.MaxBatchSize)
}
return &Client{
deviceName: opts.DeviceName,
setupKey: opts.SetupKey,
@@ -440,21 +405,6 @@ func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession,
}, nil
}
// IdentityForIP looks up a remote peer by its tunnel IP using the
// embedded client's status recorder. Returns the peer's WireGuard public
// key and FQDN. ok=false means the IP isn't in this client's peer
// roster — callers should treat that as "unknown peer".
func (c *Client) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
if !ip.IsValid() || c.recorder == nil {
return "", "", false
}
state, found := c.recorder.PeerStateByIP(ip.String())
if !found {
return "", "", false
}
return state.PubKey, state.FQDN, true
}
// Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock()
@@ -523,25 +473,6 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
}
// SetPerformance retunes a running Client. Only PreallocatedBuffersPerPool
// takes effect, and only when it was nonzero at construction;
// MaxBatchSize is construction-only and returns an error if set here.
//
// Returns ErrClientNotStarted / ErrEngineNotStarted if the Client is not
// running yet.
func (c *Client) SetPerformance(t Performance) error {
if t.MaxBatchSize != nil {
return errors.New("MaxBatchSize is construction-only and cannot be changed at runtime")
}
engine, err := c.getEngine()
if err != nil {
return err
}
return engine.SetPerformance(internal.Performance{
PreallocatedBuffersPerPool: t.PreallocatedBuffersPerPool,
})
}
// StartCapture begins capturing packets on this client's tunnel device.
// Only one capture can be active at a time; starting a new one stops the previous.
// Call StopCapture (or CaptureSession.Stop) to end it.

View File

@@ -21,7 +21,7 @@ func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.
}
// use userspace packet filtering firewall
fm, err := uspfilter.Create(iface, nil, disableServerRoutes, flowLogger, mtu)
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
if err != nil {
return nil, err
}

View File

@@ -29,80 +29,47 @@ const (
NFTABLES
)
// SkipNftablesEnv is the environment variable to skip nftables check
const SkipNftablesEnv = "NB_SKIP_NFTABLES_CHECK"
// errNoFirewallManager indicates no kernel firewall backend is present,
// as opposed to a backend that exists but failed to create or initialize.
var errNoFirewallManager = errors.New("no firewall manager found")
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
// We run in userspace mode and force userspace firewall was requested.
// We run in userspace mode and force userspace firewall was requested. We don't attempt native firewall.
if iface.IsUserspaceBind() && forceUserspaceFirewall() {
nativeFw, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding without it", err)
}
log.Info("forcing userspace firewall")
return createUserspaceFirewall(iface, nativeFw, disableServerRoutes, flowLogger, mtu)
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
}
// Use native firewall for either kernel or userspace, the interface appears identical to netfilter
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
switch {
case err == nil && !iface.IsUserspaceBind():
// Nothing to do, fall through
case err == nil && iface.IsUserspaceBind():
// Native firewall handles packet filtering, but the userspace WireGuard bind
// needs a device filter for DNS interception hooks. Install a minimal
// hooks-only filter that passes all traffic through to the kernel firewall.
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
}
case err != nil && !iface.IsUserspaceBind():
// Kernel cannot fall back to anything else, need to return error
return nil, err
case err != nil && iface.IsUserspaceBind():
// Fall back to the userspace packet filter if native is unavailable
logNativeFirewallUnavailable(err)
// Kernel cannot fall back to anything else, need to return error
if !iface.IsUserspaceBind() {
return fm, err
}
// Fall back to the userspace packet filter if native is unavailable
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
}
// Native firewall handles packet filtering, but the userspace WireGuard bind
// needs a device filter for DNS interception hooks. Install a minimal
// hooks-only filter that passes all traffic through to the kernel firewall.
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
}
return fm, nil
}
// createUserspaceFirewall builds the userspace packet filter, optionally
// backed by a native firewall, and allows netbird interface traffic.
func createUserspaceFirewall(iface IFaceMapper, nativeFw firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) {
fm, err := uspfilter.Create(iface, nativeFw, disableServerRoutes, flowLogger, mtu)
if err != nil {
return nil, err
}
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
return fm, nil
}
// logNativeFirewallUnavailable logs the fallback to userspace at info level
// when no kernel firewall backend exists, and at warn level otherwise.
func logNativeFirewallUnavailable(err error) {
if errors.Is(err, errNoFirewallManager) {
log.Infof("no native firewall backend available: %v. Proceeding with userspace", err)
} else {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
}
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
fm, err := createFW(iface, mtu)
if err != nil {
return nil, fmt.Errorf("create firewall: %w", err)
return nil, fmt.Errorf("create firewall: %s", err)
}
if err = fm.Init(stateManager); err != nil {
@@ -121,10 +88,29 @@ func createFW(iface IFaceMapper, mtu uint16) (firewall.Manager, error) {
log.Info("creating an nftables firewall manager")
return nbnftables.Create(iface, mtu)
default:
return nil, errNoFirewallManager
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
return nil, errors.New("no firewall manager found")
}
}
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) {
var errUsp error
if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
} else {
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
}
if errUsp != nil {
return nil, fmt.Errorf("create userspace firewall: %s", errUsp)
}
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
return fm, nil
}
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func check() FWType {
useIPTABLES := false
@@ -146,38 +132,35 @@ func check() FWType {
}
}
// Honor the skip env before probing nftables at all.
if os.Getenv(SkipNftablesEnv) != "true" {
nf := nftables.Conn{}
if chains, err := nf.ListChains(); err == nil {
if !useIPTABLES {
nf := nftables.Conn{}
if chains, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
if !useIPTABLES {
return NFTABLES
}
// search for chains where table is filter
// if we find one, we assume that nftables manager can be used with iptables
for _, chain := range chains {
if chain.Table.Name == "filter" {
return NFTABLES
}
}
// search for chains where table is filter
// if we find one, we assume that nftables manager can be used with iptables
for _, chain := range chains {
if chain.Table.Name == "filter" {
return NFTABLES
}
}
// check tables for the following constraints:
// 1. there is no chain in nftables for the filter table and there is at least one chain in iptables, we assume that nftables manager can not be used
// 2. there is no tables or more than one table, we assume that nftables manager can be used
// 3. there is only one table and its name is filter, we assume that nftables manager can not be used, since there was no chain in it
// 4. if we find an error we log and continue with iptables check
nbTablesList, err := nf.ListTables()
switch {
case err == nil && len(iptablesChains) > 0:
return IPTABLES
case err == nil && len(nbTablesList) != 1:
return NFTABLES
case err == nil && len(nbTablesList) == 1 && nbTablesList[0].Name == "filter":
return IPTABLES
case err != nil:
log.Errorf("failed to list nftables tables on fw manager discovery: %s", err)
}
// check tables for the following constraints:
// 1. there is no chain in nftables for the filter table and there is at least one chain in iptables, we assume that nftables manager can not be used
// 2. there is no tables or more than one table, we assume that nftables manager can be used
// 3. there is only one table and its name is filter, we assume that nftables manager can not be used, since there was no chain in it
// 4. if we find an error we log and continue with iptables check
nbTablesList, err := nf.ListTables()
switch {
case err == nil && len(iptablesChains) > 0:
return IPTABLES
case err == nil && len(nbTablesList) != 1:
return NFTABLES
case err == nil && len(nbTablesList) == 1 && nbTablesList[0].Name == "filter":
return IPTABLES
case err != nil:
log.Errorf("failed to list nftables tables on fw manager discovery: %s", err)
}
}

View File

@@ -0,0 +1,554 @@
package iptables
import (
"errors"
"fmt"
"net"
"slices"
"github.com/coreos/go-iptables/iptables"
"github.com/google/uuid"
ipset "github.com/lrh3321/ipset-go"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (
tableName = "filter"
// rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT"
// mangleFwdKey is the entries map key for mangle FORWARD guard rules that prevent
// external DNAT from bypassing ACL rules.
mangleFwdKey = "MANGLE-FORWARD"
)
type aclEntries map[string][][]string
type entry struct {
spec []string
position int
}
type aclManager struct {
iptablesClient *iptables.IPTables
wgIface iFaceMapper
entries aclEntries
optionalEntries map[string][]entry
ipsetStore *ipsetStore
v6 bool
stateManager *statemanager.Manager
}
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
return &aclManager{
iptablesClient: iptablesClient,
wgIface: wgIface,
entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(),
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
}, nil
}
func (m *aclManager) init(stateManager *statemanager.Manager) error {
m.stateManager = stateManager
m.seedInitialEntries()
m.seedInitialOptionalEntries()
if err := m.cleanChains(); err != nil {
return fmt.Errorf("clean chains: %w", err)
}
if err := m.createDefaultChains(); err != nil {
return fmt.Errorf("create default chains: %w", err)
}
m.updateState()
return nil
}
func (m *aclManager) AddPeerFiltering(
id []byte,
ip net.IP,
protocol firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
chain := chainNameInputRules
ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
if m.v6 && ipsetName != "" {
ipsetName += "-v6"
}
proto := protoForFamily(protocol, m.v6)
specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName)
mangleSpecs := slices.Clone(specs)
mangleSpecs = append(mangleSpecs,
"-i", m.wgIface.Name(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
)
specs = append(specs, "-j", actionToStr(action))
if ipsetName != "" {
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
if err := m.addToIPSet(ipsetName, ip); err != nil {
return nil, fmt.Errorf("add IP to ipset: %w", err)
}
// if ruleset already exists it means we already have the firewall rule
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
ipList.addIP(ip.String())
return []firewall.Rule{&Rule{
ruleID: uuid.New().String(),
ipsetName: ipsetName,
ip: ip.String(),
chain: chain,
specs: specs,
v6: m.v6,
}}, nil
}
if err := m.flushIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("flush ipset %s before use: %v", ipsetName, err)
} else {
log.Errorf("flush ipset %s before use: %v", ipsetName, err)
}
}
if err := m.createIPSet(ipsetName); err != nil {
return nil, fmt.Errorf("create ipset: %w", err)
}
if err := m.addToIPSet(ipsetName, ip); err != nil {
return nil, fmt.Errorf("add IP to ipset: %w", err)
}
ipList := newIpList(ip.String())
m.ipsetStore.addIpList(ipsetName, ipList)
}
ok, err := m.iptablesClient.Exists(tableFilter, chain, specs...)
if err != nil {
return nil, fmt.Errorf("failed to check rule: %w", err)
}
if ok {
return nil, fmt.Errorf("rule already exists")
}
// Insert DROP rules at the beginning, append ACCEPT rules at the end
if action == firewall.ActionDrop {
// Insert at the beginning of the chain (position 1)
err = m.iptablesClient.Insert(tableFilter, chain, 1, specs...)
} else {
err = m.iptablesClient.Append(tableFilter, chain, specs...)
}
if err != nil {
return nil, err
}
if err := m.iptablesClient.Append(tableMangle, chainRTPRE, mangleSpecs...); err != nil {
log.Errorf("failed to add mangle rule: %v", err)
mangleSpecs = nil
}
rule := &Rule{
ruleID: uuid.New().String(),
specs: specs,
mangleSpecs: mangleSpecs,
ipsetName: ipsetName,
ip: ip.String(),
chain: chain,
v6: m.v6,
}
m.updateState()
return []firewall.Rule{rule}, nil
}
// DeletePeerRule from the firewall by rule definition
func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
}
shouldDestroyIpset := false
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
// delete IP from ruleset IPs list and ipset
if _, ok := ipsetList.ips[r.ip]; ok {
ip := net.ParseIP(r.ip)
if ip == nil {
return fmt.Errorf("parse IP %s", r.ip)
}
if err := m.delFromIPSet(r.ipsetName, ip); err != nil {
return fmt.Errorf("delete ip from ipset: %w", err)
}
delete(ipsetList.ips, r.ip)
}
// if after delete, set still contains other IPs,
// no need to delete firewall rule and we should exit here
if len(ipsetList.ips) != 0 {
return nil
}
// we delete last IP from the set, that means we need to delete
// set itself and associated firewall rule too
m.ipsetStore.deleteIpset(r.ipsetName)
shouldDestroyIpset = true
}
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
}
if r.mangleSpecs != nil {
if err := m.iptablesClient.Delete(tableMangle, chainRTPRE, r.mangleSpecs...); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
if shouldDestroyIpset {
if err := m.destroyIPSet(r.ipsetName); err != nil {
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("destroy empty ipset: %v", err)
} else {
log.Errorf("destroy empty ipset: %v", err)
}
}
}
m.updateState()
return nil
}
func (m *aclManager) Reset() error {
if err := m.cleanChains(); err != nil {
return fmt.Errorf("clean chains: %w", err)
}
m.updateState()
return nil
}
// todo write less destructive cleanup mechanism
func (m *aclManager) cleanChains() error {
ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules)
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
}
if ok {
for _, rule := range m.entries["INPUT"] {
err := m.iptablesClient.DeleteIfExists(tableName, "INPUT", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
for _, rule := range m.entries["FORWARD"] {
err := m.iptablesClient.DeleteIfExists(tableName, "FORWARD", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameInputRules)
if err != nil {
log.Debugf("failed to clear and delete %s chain: %s", chainNameInputRules, err)
return err
}
}
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
if err != nil {
return fmt.Errorf("list chains: %w", err)
}
if ok {
for _, rule := range m.entries["PREROUTING"] {
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
}
for _, rule := range m.entries[mangleFwdKey] {
if err := m.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil {
log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err)
}
}
for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := m.flushIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("flush ipset %q during reset: %v", ipsetName, err)
} else {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
}
}
if err := m.destroyIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("destroy ipset %q during reset: %v", ipsetName, err)
} else {
log.Errorf("destroy ipset %q during reset: %v", ipsetName, err)
}
}
m.ipsetStore.deleteIpset(ipsetName)
}
return nil
}
func (m *aclManager) createDefaultChains() error {
// chain netbird-acl-input-rules
if err := m.iptablesClient.NewChain(tableName, chainNameInputRules); err != nil {
log.Debugf("failed to create '%s' chain: %s", chainNameInputRules, err)
return err
}
for chainName, rules := range m.entries {
// mangle FORWARD guard rules are handled separately below
if chainName == mangleFwdKey {
continue
}
for _, rule := range rules {
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
return err
}
}
}
for chainName, entries := range m.optionalEntries {
for _, entry := range entries {
if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil {
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
continue
}
m.entries[chainName] = append(m.entries[chainName], entry.spec)
}
}
clear(m.optionalEntries)
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
for _, rule := range m.entries[mangleFwdKey] {
if err := m.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil {
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
}
}
return nil
}
// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed.
// We want to make sure our traffic is not dropped by existing rules.
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
func (m *aclManager) seedInitialEntries() {
established := getConntrackEstablished()
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
// Inbound is handled by our ACLs, the rest is dropped.
// For outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules.
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
// Mangle FORWARD guard: when external DNAT redirects traffic from the wg interface, it
// traverses FORWARD instead of INPUT, bypassing ACL rules. ACCEPT rules in filter FORWARD
// can be inserted above ours. Mangle runs before filter, so these guard rules enforce the
// ACL mark check where it cannot be overridden.
m.appendToEntries(mangleFwdKey, []string{
"-i", m.wgIface.Name(),
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
"-j", "ACCEPT",
})
m.appendToEntries(mangleFwdKey, []string{
"-i", m.wgIface.Name(),
"-m", "conntrack", "--ctstate", "DNAT",
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
"-j", "DROP",
})
}
func (m *aclManager) seedInitialOptionalEntries() {
m.optionalEntries["FORWARD"] = []entry{
{
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
position: 2,
},
}
}
func (m *aclManager) appendToEntries(chainName string, spec []string) {
m.entries[chainName] = append(m.entries[chainName], spec)
}
func (m *aclManager) updateState() {
if m.stateManager == nil {
return
}
var currentState *ShutdownState
if existing := m.stateManager.GetState(currentState); existing != nil {
if existingState, ok := existing.(*ShutdownState); ok {
currentState = existingState
}
}
if currentState == nil {
currentState = &ShutdownState{}
}
currentState.Lock()
defer currentState.Unlock()
if m.v6 {
currentState.ACLEntries6 = m.entries
currentState.ACLIPsetStore6 = m.ipsetStore
} else {
currentState.ACLEntries = m.entries
currentState.ACLIPsetStore = m.ipsetStore
}
if err := m.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
// filterRuleSpecs returns the specs of a filtering rule
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
if v6 && protocol == firewall.ProtocolICMP {
return "ipv6-icmp"
}
return string(protocol)
}
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
// don't use IP matching if IP is 0.0.0.0
matchByIP := !ip.IsUnspecified()
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--match-set", ipsetName, "src")
} else {
specs = append(specs, "-s", ip.String())
}
}
if protocol != "all" {
specs = append(specs, "-p", protocol)
}
specs = append(specs, applyPort("--sport", sPort)...)
specs = append(specs, applyPort("--dport", dPort)...)
return specs
}
func actionToStr(action firewall.Action) string {
if action == firewall.ActionAccept {
return "ACCEPT"
}
return "DROP"
}
func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action firewall.Action) string {
if ipsetName == "" {
return ""
}
actionSuffix := ""
if action == firewall.ActionDrop {
actionSuffix = "-drop"
}
switch {
case sPort != nil && dPort != nil:
return ipsetName + "-sport-dport" + actionSuffix
case sPort != nil:
return ipsetName + "-sport" + actionSuffix
case dPort != nil:
return ipsetName + "-dport" + actionSuffix
default:
return ipsetName + actionSuffix
}
}
func (m *aclManager) createIPSet(name string) error {
opts := ipset.CreateOptions{
Replace: true,
}
if m.v6 {
opts.Family = ipset.FamilyIPV6
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err)
}
log.Debugf("created ipset %s with type hash:net", name)
return nil
}
func (m *aclManager) addToIPSet(name string, ip net.IP) error {
cidr := uint8(32)
if ip.To4() == nil {
cidr = 128
}
entry := &ipset.Entry{
IP: ip,
CIDR: cidr,
Replace: true,
}
if err := ipset.Add(name, entry); err != nil {
return fmt.Errorf("add IP to ipset %s: %w", name, err)
}
return nil
}
func (m *aclManager) delFromIPSet(name string, ip net.IP) error {
cidr := uint8(32)
if ip.To4() == nil {
cidr = 128
}
entry := &ipset.Entry{
IP: ip,
CIDR: cidr,
}
if err := ipset.Del(name, entry); err != nil {
return fmt.Errorf("delete IP from ipset %s: %w", name, err)
}
return nil
}
func (m *aclManager) flushIPSet(name string) error {
return ipset.Flush(name)
}
func (m *aclManager) destroyIPSet(name string) error {
return ipset.Destroy(name)
}

View File

@@ -1,352 +0,0 @@
//go:build !android
package iptables
import (
"fmt"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/client/net"
)
func (r *family) createContainers() error {
for _, chainInfo := range []struct {
chain string
table string
}{
{chainRTFwdIn, tableFilter},
{chainRTFwdOut, tableFilter},
{chainRTPre, tableMangle},
{chainRTNAT, tableNat},
{chainRTRdr, tableNat},
{chainRTMSSClamp, tableMangle},
} {
// Fallback: clear chains that survived an unclean shutdown.
if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok {
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err)
}
}
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
}
}
if err := r.insertEstablishedRule(chainRTFwdIn); err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
if err := r.insertEstablishedRule(chainRTFwdOut); err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add static nat rules: %w", err)
}
if err := r.addJumpRules(); err != nil {
return fmt.Errorf("add jump rules: %w", err)
}
if err := r.addMSSClampingRules(); err != nil {
log.Errorf("failed to add MSS clamping rules: %s", err)
}
return nil
}
func (r *family) addJumpRules() error {
// Jump to nat chain
natRule := []string{"-j", chainRTNAT}
if err := r.iptablesClient.Insert(tableNat, chainPostrouting, 1, natRule...); err != nil {
return fmt.Errorf("add nat postrouting jump rule: %w", err)
}
r.rules[jumpNATPost] = natRule
// Jump to mangle prerouting chain
preRule := []string{"-j", chainRTPre}
if err := r.iptablesClient.Insert(tableMangle, chainPrerouting, 1, preRule...); err != nil {
return fmt.Errorf("add mangle prerouting jump rule: %w", err)
}
r.rules[jumpManglePre] = preRule
// Jump to nat prerouting chain
rdrRule := []string{"-j", chainRTRdr}
if err := r.iptablesClient.Insert(tableNat, chainPrerouting, 1, rdrRule...); err != nil {
return fmt.Errorf("add nat prerouting jump rule: %w", err)
}
r.rules[jumpNATPre] = rdrRule
return nil
}
func (r *family) setupDataPlaneMark() error {
var merr *multierror.Error
preRule := []string{
"-i", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "NEW",
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkIn),
}
if err := r.iptablesClient.AppendUnique(tableMangle, chainPrerouting, preRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add mangle prerouting rule: %w", err))
} else {
r.rules[markManglePre] = preRule
}
postRule := []string{
"-o", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "NEW",
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkOut),
}
if err := r.iptablesClient.AppendUnique(tableMangle, chainPostrouting, postRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add mangle postrouting rule: %w", err))
} else {
r.rules[markManglePost] = postRule
}
return nberrors.FormatErrorOrNil(merr)
}
// seedInitialEntries adds default rules to the entries map. Rules are
// inserted at position 1, so the order here is reversed.
//
// Existing FORWARD policy decides outbound traffic towards our
// interface. If FORWARD policy is "drop", we add an
// established/related rule to allow return traffic for inbound rules.
func (r *family) seedInitialEntries() {
established := getConntrackEstablished()
r.appendToEntries(chainInput, []string{"-i", r.wgIface.Name(), "-j", "DROP"})
r.appendToEntries(chainInput, []string{"-i", r.wgIface.Name(), "-j", chainACLInput})
r.appendToEntries(chainInput, append([]string{"-i", r.wgIface.Name()}, established...))
r.appendToEntries(chainForward, []string{"-i", r.wgIface.Name(), "-j", "DROP"})
r.appendToEntries(chainForward, []string{"-o", r.wgIface.Name(), "-j", chainRTFwdOut})
r.appendToEntries(chainForward, []string{"-i", r.wgIface.Name(), "-j", chainRTFwdIn})
// Mangle FORWARD guard: when external DNAT redirects traffic from
// the wg interface, it traverses FORWARD instead of INPUT,
// bypassing ACL rules. ACCEPT rules in filter FORWARD can be
// inserted above ours. Mangle runs before filter, so these guard
// rules enforce the ACL mark check where it cannot be overridden.
r.appendToEntries(mangleForwardKey, []string{
"-i", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
"-j", "ACCEPT",
})
r.appendToEntries(mangleForwardKey, []string{
"-i", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "DNAT",
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
"-j", "DROP",
})
}
func (r *family) seedInitialOptionalEntries() {
r.optionalEntries[chainForward] = []entry{
{
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
position: 2,
},
}
}
func (r *family) appendToEntries(chain chainKey, spec ruleSpec) {
r.entries[chain] = append(r.entries[chain], spec)
}
func (r *family) createDefaultChains() error {
if err := r.iptablesClient.NewChain(tableName, chainACLInput); err != nil {
return fmt.Errorf("create %s chain: %w", chainACLInput, err)
}
for chain, rules := range r.entries {
// mangle FORWARD guard rules are handled separately below
if chain == mangleForwardKey {
continue
}
for _, rule := range rules {
if err := r.iptablesClient.InsertUnique(tableName, string(chain), 1, rule...); err != nil {
return fmt.Errorf("insert jump rule into %s: %w", chain, err)
}
}
}
for chain, entries := range r.optionalEntries {
for _, entry := range entries {
if err := r.iptablesClient.InsertUnique(tableName, string(chain), entry.position, entry.spec...); err != nil {
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
continue
}
r.entries[chain] = append(r.entries[chain], entry.spec)
}
}
clear(r.optionalEntries)
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
for _, rule := range r.entries[mangleForwardKey] {
if err := r.iptablesClient.AppendUnique(tableMangle, chainForward, rule...); err != nil {
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
}
}
return nil
}
func (r *family) cleanUpDefaultForwardRules() error {
var merr *multierror.Error
// cleanJumpRules removes the OUTPUT jump to NETBIRD-NAT-OUTPUT among
// the others, so the chain below deletes cleanly instead of failing
// with "device or resource busy".
if err := r.cleanJumpRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("clean jump rules: %w", err))
}
for _, chainInfo := range []struct {
chain string
table string
}{
{chainRTFwdIn, tableFilter},
{chainRTFwdOut, tableFilter},
{chainRTPre, tableMangle},
{chainRTNAT, tableNat},
{chainRTRdr, tableNat},
{chainNATOutput, tableNat},
{chainRTMSSClamp, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err))
continue
}
if ok {
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
merr = multierror.Append(merr, fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err))
}
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) cleanJumpRules() error {
// locations maps each tracked jump rule to the built-in table and
// chain it was inserted into.
locations := map[firewall.RuleID]struct{ table, chain string }{
jumpNATPost: {tableNat, chainPostrouting},
jumpManglePre: {tableMangle, chainPrerouting},
jumpNATPre: {tableNat, chainPrerouting},
jumpMSSClamp: {tableMangle, chainForward},
jumpNATOutput: {tableNat, chainOutput},
}
var merr *multierror.Error
for ruleID, loc := range locations {
rule, exists := r.rules[ruleID]
if !exists {
continue
}
if err := r.iptablesClient.DeleteIfExists(loc.table, loc.chain, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete rule from chain %s in table %s: %w", loc.chain, loc.table, err))
continue
}
delete(r.rules, ruleID)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) cleanAclChains() error {
var merr *multierror.Error
if err := r.cleanInputAclChain(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.cleanPreroutingEntries(); err != nil {
merr = multierror.Append(merr, err)
}
for _, rule := range r.entries[mangleForwardKey] {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainForward, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete mangle %s guard rule %v: %w", chainForward, rule, err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) cleanInputAclChain() error {
ok, err := r.iptablesClient.ChainExists(tableName, chainACLInput)
if err != nil {
return fmt.Errorf("check chain %s: %w", chainACLInput, err)
}
if !ok {
return nil
}
var merr *multierror.Error
for _, rule := range r.entries[chainInput] {
if err := r.iptablesClient.DeleteIfExists(tableName, chainInput, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete %s rule %v: %w", chainInput, rule, err))
}
}
for _, rule := range r.entries[chainForward] {
if err := r.iptablesClient.DeleteIfExists(tableName, chainForward, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete %s rule %v: %w", chainForward, rule, err))
}
}
if err := r.iptablesClient.ClearAndDeleteChain(tableName, chainACLInput); err != nil {
merr = multierror.Append(merr, fmt.Errorf("clear and delete %s chain: %w", chainACLInput, err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) cleanPreroutingEntries() error {
ok, err := r.iptablesClient.ChainExists(tableMangle, chainPrerouting)
if err != nil {
return fmt.Errorf("check chain %s in %s: %w", chainPrerouting, tableMangle, err)
}
if !ok {
return nil
}
var merr *multierror.Error
for _, rule := range r.entries[chainPrerouting] {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPrerouting, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete %s rule %v: %w", chainPrerouting, rule, err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) cleanupDataPlaneMark() error {
var merr *multierror.Error
if preRule, exists := r.rules[markManglePre]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPrerouting, preRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove mangle prerouting rule: %w", err))
} else {
delete(r.rules, markManglePre)
}
}
if postRule, exists := r.rules[markManglePost]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPostrouting, postRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove mangle postrouting rule: %w", err))
} else {
delete(r.rules, markManglePost)
}
}
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -1,269 +0,0 @@
//go:build !android
package iptables
import (
"fmt"
"net/netip"
"strconv"
"strings"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
func (r *family) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
ruleID := rule.ID()
if _, exists := r.rules[ruleID+dnatSuffix]; exists {
return rule, nil
}
toDestination := rule.TranslatedAddress.String()
switch {
case len(rule.TranslatedPort.Values) == 0:
// no translated port, use original port
case len(rule.TranslatedPort.Values) == 1:
toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0])
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
// need the "/originalport" suffix to avoid dnat port randomization
toDestination += fmt.Sprintf(":%d-%d/%d", rule.TranslatedPort.Values[0], rule.TranslatedPort.Values[1], rule.DestinationPort.Values[0])
default:
return nil, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
}
proto := strings.ToLower(string(rule.Protocol))
rules := make(map[firewall.RuleID]ruleInfo, 3)
// DNAT rule
dnatRule := []string{
"!", "-i", r.wgIface.Name(),
"-p", proto,
"-j", "DNAT",
"--to-destination", toDestination,
}
dnatRule = append(dnatRule, applyPort("--dport", &rule.DestinationPort)...)
rules[ruleID+dnatSuffix] = ruleInfo{
table: tableNat,
chain: chainRTRdr,
rule: dnatRule,
}
// SNAT rule
snatRule := []string{
"-o", r.wgIface.Name(),
"-p", proto,
"-d", rule.TranslatedAddress.String(),
"-j", "MASQUERADE",
}
snatRule = append(snatRule, applyPort("--dport", &rule.TranslatedPort)...)
rules[ruleID+snatSuffix] = ruleInfo{
table: tableNat,
chain: chainRTNAT,
rule: snatRule,
}
// Forward filtering rule, if fwd policy is DROP
forwardRule := []string{
"-o", r.wgIface.Name(),
"-p", proto,
"-d", rule.TranslatedAddress.String(),
"-j", "ACCEPT",
}
forwardRule = append(forwardRule, applyPort("--dport", &rule.TranslatedPort)...)
rules[ruleID+fwdSuffix] = ruleInfo{
table: tableFilter,
chain: chainRTFwdOut,
rule: forwardRule,
}
for key, ruleInfo := range rules {
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
log.Errorf("rollback failed: %v", rollbackErr)
}
return nil, fmt.Errorf("add rule %s: %w", key, err)
}
r.rules[key] = ruleInfo.rule
}
r.updateState()
return rule, nil
}
func (r *family) rollbackRules(rules map[firewall.RuleID]ruleInfo) error {
var merr *multierror.Error
for key, ruleInfo := range rules {
if err := r.iptablesClient.DeleteIfExists(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("rollback rule %s: %w", key, err))
// On rollback error, add to rules map for next cleanup
r.rules[key] = ruleInfo.rule
}
}
if merr != nil {
r.updateState()
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) DeleteDNATRule(rule firewall.Rule) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
ruleID := rule.ID()
var merr *multierror.Error
if dnatRule, exists := r.rules[ruleID+dnatSuffix]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRdr, dnatRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete DNAT rule: %w", err))
}
delete(r.rules, ruleID+dnatSuffix)
}
if snatRule, exists := r.rules[ruleID+snatSuffix]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTNAT, snatRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete SNAT rule: %w", err))
}
delete(r.rules, ruleID+snatSuffix)
}
if fwdRule, exists := r.rules[ruleID+fwdSuffix]; exists {
if err := r.iptablesClient.Delete(tableFilter, chainRTFwdOut, fwdRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
}
delete(r.rules, ruleID+fwdSuffix)
}
r.updateState()
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if _, exists := r.rules[ruleID]; exists {
return nil
}
dnatRule := []string{
"-i", r.wgIface.Name(),
"-p", strings.ToLower(protoForFamily(protocol, r.v6)),
"--dport", strconv.Itoa(int(originalPort)),
"-d", localAddr.String(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
}
info := ruleInfo{
table: tableNat,
chain: chainRTRdr,
rule: dnatRule,
}
if err := r.iptablesClient.Append(info.table, info.chain, info.rule...); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = info.rule
r.updateState()
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *family) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRdr, dnatRule...); err != nil {
return fmt.Errorf("delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
r.updateState()
return nil
}
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
func (r *family) ensureNATOutputChain() error {
if _, exists := r.rules[jumpNATOutput]; exists {
return nil
}
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
if err != nil {
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
}
if !chainExists {
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
}
}
jumpRule := []string{"-j", chainNATOutput}
if err := r.iptablesClient.Insert(tableNat, chainOutput, 1, jumpRule...); err != nil {
if !chainExists {
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
}
}
return fmt.Errorf("add OUTPUT jump rule: %w", err)
}
r.rules[jumpNATOutput] = jumpRule
r.updateState()
return nil
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *family) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if _, exists := r.rules[ruleID]; exists {
return nil
}
if err := r.ensureNATOutputChain(); err != nil {
return err
}
dnatRule := []string{
"-p", strings.ToLower(protoForFamily(protocol, localAddr.Is6())),
"--dport", strconv.Itoa(int(originalPort)),
"-d", localAddr.String(),
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
}
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
return fmt.Errorf("add output DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
r.updateState()
return nil
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *family) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
return fmt.Errorf("delete output DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
r.updateState()
return nil
}

View File

@@ -1,246 +0,0 @@
//go:build !android
package iptables
import (
"fmt"
"net/netip"
"github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// constants needed to manage and create iptable rules
const (
tableFilter = "filter"
tableName = tableFilter
tableNat = "nat"
tableMangle = "mangle"
// chainACLInput is the peer ACL chain that holds installed
// peer-filtering rules.
chainACLInput = "NETBIRD-ACL-INPUT"
// mangleForwardKey is the entries map key for mangle FORWARD guard
// rules that prevent external DNAT from bypassing ACL rules.
mangleForwardKey chainKey = "MANGLE-FORWARD"
chainInput = "INPUT"
chainPostrouting = "POSTROUTING"
chainPrerouting = "PREROUTING"
chainForward = "FORWARD"
chainRTNAT = "NETBIRD-RT-NAT"
chainRTFwdIn = "NETBIRD-RT-FWD-IN"
chainRTFwdOut = "NETBIRD-RT-FWD-OUT"
chainRTPre = "NETBIRD-RT-PRE"
chainRTRdr = "NETBIRD-RT-RDR"
chainNATOutput = "NETBIRD-NAT-OUTPUT"
chainRTMSSClamp = "NETBIRD-RT-MSSCLAMP"
jumpManglePre = "jump-mangle-pre"
jumpNATPre = "jump-nat-pre"
jumpNATPost = "jump-nat-post"
jumpNATOutput = "jump-nat-output"
jumpMSSClamp = "jump-mss-clamp"
markManglePre = "mark-mangle-pre"
markManglePost = "mark-mangle-post"
matchSet = "--match-set"
dnatSuffix firewall.RuleID = "_dnat"
snatSuffix firewall.RuleID = "_snat"
fwdSuffix firewall.RuleID = "_fwd"
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
ipv4TCPHeaderSize = 40
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
ipv6TCPHeaderSize = 60
)
type ruleInfo struct {
chain string
table string
rule []string
}
type routeRules map[firewall.RuleID][]string
// ruleSpec is a single iptables rule expressed as its argument list
// (e.g. {"-i", "wg0", "-j", "DROP"}).
type ruleSpec []string
// chainKey identifies the chain a seeded entry belongs to. It holds
// built-in chain names ("INPUT", "FORWARD", "PREROUTING") plus the
// synthetic mangleForwardKey bucket for the mangle FORWARD guard rules.
type chainKey string
// aclEntries maps a chain to the rules seeded into it to jump into or
// guard the netbird ACL chains.
type aclEntries map[chainKey][]ruleSpec
type entry struct {
spec ruleSpec
position int
}
// ipsetCounter is the shared hash:net refcounter used by peer and
// route ACLs alike. The ipset library does not support comments, so
// the key is just the set name (string).
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
// family holds the per-address-family iptables state. One instance
// handles route ACLs, peer ACLs, NAT, DNAT, and MSS clamping for a
// single family; the top-level Manager owns one for v4 and another
// for v6.
type family struct {
iptablesClient *iptables.IPTables
wgIface iFaceMapper
v6 bool
// Peer ACL chain bookkeeping.
entries aclEntries
optionalEntries map[chainKey][]entry
// filters holds peer + route filter rules keyed by content hash.
// AddFilterRule writes here; DeleteFilterRule looks up by id.
filters map[nbid.RuleID]*Rule
ipsetCounter *ipsetCounter
// rules holds NAT, jump, and MSS-clamping rules (auxiliary
// plumbing that isn't a filter rule).
rules routeRules
// Routing / NAT.
legacyManagement bool
mtu uint16
ipFwdState *ipfwdstate.IPForwardingState
stateManager *statemanager.Manager
}
func newFamily(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*family, error) {
r := &family{
iptablesClient: iptablesClient,
wgIface: wgIface,
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
entries: make(aclEntries),
optionalEntries: make(map[chainKey][]entry),
filters: make(map[nbid.RuleID]*Rule),
rules: make(routeRules),
mtu: mtu,
ipFwdState: ipfwdstate.NewIPForwardingState(),
}
r.ipsetCounter = refcounter.New(
func(name string, sources []netip.Prefix) (struct{}, error) {
return struct{}{}, r.createIpSet(name, sources)
},
func(name string, _ struct{}) error {
return r.deleteIpSet(name)
},
)
return r, nil
}
// init wires the family to the state manager and installs both the
// route ACL containers and the peer ACL chain skeleton.
func (r *family) init(stateManager *statemanager.Manager) error {
r.stateManager = stateManager
if err := r.cleanUpDefaultForwardRules(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
if err := r.createContainers(); err != nil {
return fmt.Errorf("create containers: %w", err)
}
if err := r.setupDataPlaneMark(); err != nil {
log.Errorf("failed to set up data plane mark: %v", err)
}
r.seedInitialEntries()
r.seedInitialOptionalEntries()
if err := r.cleanAclChains(); err != nil {
return fmt.Errorf("clean acl chains: %w", err)
}
if err := r.createDefaultChains(); err != nil {
return fmt.Errorf("create default chains: %w", err)
}
r.updateState()
return nil
}
// Reset tears down all firewall state owned by this family. ACL
// chain cleanup runs before route-chain cleanup because the route
// chains are still referenced by FORWARD jumps installed during
// seedInitialEntries; deleting them first would trip EBUSY.
func (r *family) Reset() error {
var merr *multierror.Error
if err := r.cleanAclChains(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.cleanUpDefaultForwardRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.ipsetCounter.Flush(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.cleanupDataPlaneMark(); err != nil {
merr = multierror.Append(merr, err)
}
clear(r.rules)
clear(r.filters)
r.updateState()
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) updateState() {
if r.stateManager == nil {
return
}
var currentState *ShutdownState
if existing := r.stateManager.GetState(currentState); existing != nil {
if existingState, ok := existing.(*ShutdownState); ok {
currentState = existingState
}
}
if currentState == nil {
currentState = &ShutdownState{}
}
currentState.Lock()
defer currentState.Unlock()
if r.v6 {
currentState.RouteRules6 = r.rules
currentState.RouteIPsetCounter6 = r.ipsetCounter
currentState.ACLEntries6 = r.entries
} else {
currentState.RouteRules = r.rules
currentState.RouteIPsetCounter = r.ipsetCounter
currentState.ACLEntries = r.entries
}
if err := r.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
}
}

View File

@@ -1,334 +0,0 @@
//go:build !android
package iptables
import (
"fmt"
"net/netip"
"slices"
"strconv"
"strings"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
nbnet "github.com/netbirdio/netbird/client/net"
)
// AddFilterRule installs a packet-filtering rule. With destination
// empty, the rule goes to the peer ACL input chain plus a paired
// mangle PREROUTING rule for the redirect mark. With destination set
// (prefix or named set), it goes to the route ACL forward chain.
// Multi-source rules collapse to one iptables rule via the shared
// hash:net ipset.
func (r *family) AddFilterRule(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
if existing, ok := r.filters[ruleID]; ok {
return existing, nil
}
srcMatch, err := r.applySourceMatch(sourceNetwork(sources), sources)
if err != nil {
return nil, fmt.Errorf("apply source match: %w", err)
}
rule, err := r.installFilterRule(ruleID, srcMatch, destination, proto, sPort, dPort, action)
if err != nil {
r.dropSourceMatch(srcMatch)
return nil, err
}
r.filters[ruleID] = rule
r.updateState()
return rule, nil
}
func (r *family) hasRule(id nbid.RuleID) bool {
_, ok := r.filters[id]
return ok
}
// hasDNATRule reports whether this family owns the DNAT rule set for
// the given user id. DNAT rules live in r.rules under the well-known
// "<id>_dnat" key; the lookup here is used by Manager.DeleteDNATRule
// to pick the right family.
func (r *family) hasDNATRule(id firewall.RuleID) bool {
_, ok := r.rules[id+dnatSuffix]
return ok
}
// DeleteFilterRule removes a previously installed filter rule. The
// rule's stored chain/table identify where to delete from; source set
// references are recovered from the spec via findSets and dropped
// from the shared ipset counter.
func (r *family) DeleteFilterRule(rule firewall.Rule) error {
ruleID := rule.ID()
pr, ok := r.filters[ruleID]
if !ok {
log.Debugf("filter rule %s not found", ruleID)
return nil
}
if err := r.iptablesClient.Delete(tableFilter, pr.chain, pr.specs...); err != nil {
return fmt.Errorf("delete rule %s: %w", pr.chain, err)
}
if pr.mangleSpecs != nil {
if err := r.iptablesClient.Delete(tableMangle, chainRTPre, pr.mangleSpecs...); err != nil {
log.Errorf("delete mangle rule: %v", err)
}
}
r.dropSourceMatch(pr.specs)
delete(r.filters, ruleID)
r.updateState()
return nil
}
// findSets scans an iptables rule spec for "-m set --match-set <name>
// <dir>" fragments and returns the named sets in occurrence order.
// Used at delete time to drop ipsetCounter references.
func findSets(rule []string) []string {
var sets []string
for i, arg := range rule {
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
sets = append(sets, rule[i+3])
}
}
return sets
}
// sourceNetwork classifies a source-prefix list into the firewall.Network
// shape the rest of the spec-builder consumes: empty for match-any, a
// single prefix inline, or an ipset for multiple sources.
func sourceNetwork(sources []netip.Prefix) firewall.Network {
switch {
case len(sources) == 0:
return firewall.Network{}
case len(sources) == 1 && sources[0].Bits() == 0:
return firewall.Network{}
case len(sources) == 1:
return firewall.Network{Prefix: sources[0]}
default:
return firewall.Network{Set: firewall.NewPrefixSet(sources)}
}
}
// applySourceMatch returns the iptables match fragment for the rule's
// source. For a Set it increments the shared ipset's refcount; for a
// Prefix it emits a direct -s match; for the wildcard it returns nil.
func (r *family) applySourceMatch(network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
switch {
case network.IsSet():
if r.ipsetCounter == nil {
return nil, fmt.Errorf("multi-source peer rule requires shared ipset counter")
}
name := r.ipsetName(network.Set.HashedName())
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
return nil, fmt.Errorf("ipset increment %s: %w", name, err)
}
return []string{"-m", "set", matchSet, name, "src"}, nil
case network.IsPrefix():
return []string{"-s", network.Prefix.String()}, nil
default:
return nil, nil
}
}
// dropSourceMatch undoes whatever applySourceMatch reserved. Safe to
// call when the spec is empty or holds only inline matchers. Decrement
// errors are logged but not returned: the filter rule has already been
// deleted at that point and we don't want to leak the deletion.
func (r *family) dropSourceMatch(srcMatch []string) {
if r.ipsetCounter == nil {
return
}
for _, name := range findSets(srcMatch) {
if _, err := r.ipsetCounter.Decrement(name); err != nil {
log.Errorf("rollback ipset decrement %s: %v", name, err)
}
}
}
// decrementSetCounter drops ipset references owned by a raw rule spec
// stored in r.rules (NAT / legacy route entries). It returns an error
// aggregate so the caller surfaces decrement failures.
func (r *family) decrementSetCounter(rule []string) error {
if r.ipsetCounter == nil {
return nil
}
var merr *multierror.Error
for _, name := range findSets(rule) {
if _, err := r.ipsetCounter.Decrement(name); err != nil {
merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// installFilterRule assembles and writes one iptables filter-chain
// rule. With destination empty the rule lands in the peer ACL input
// chain and a paired mangle PREROUTING rule is added for the redirect
// mark. With destination set the rule lands in the route ACL forward
// chain and there is no mangle pairing.
func (r *family) installFilterRule(
ruleID nbid.RuleID,
srcMatch []string,
destination firewall.Network,
protocol firewall.Protocol,
sPort, dPort *firewall.Port,
action firewall.Action,
) (*Rule, error) {
isRoute := !destination.IsZero()
proto := protoForFamily(protocol, r.v6)
specs := slices.Clone(srcMatch)
var destExp []string
if isRoute {
var err error
destExp, err = r.applyNetwork("-d", destination, nil)
if err != nil {
return nil, fmt.Errorf("apply network -d: %w", err)
}
specs = append(specs, destExp...)
}
specs = append(specs, filterMatchSpecs(proto, sPort, dPort)...)
var mangleSpecs []string
if !isRoute {
mangleSpecs = slices.Clone(specs)
mangleSpecs = append(mangleSpecs,
"-i", r.wgIface.Name(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
)
}
specs = append(specs, "-j", actionToStr(action))
chain := chainACLInput
if isRoute {
chain = chainRTFwdIn
}
// Peer ACL drops are inserted at position 1 so they precede the
// chain's catch-all; route ACL drops are inserted at position 2
// to sit immediately after the established/related accept rule.
var err error
if action == firewall.ActionDrop {
pos := 1
if isRoute {
pos = 2
}
err = r.iptablesClient.Insert(tableFilter, chain, pos, specs...)
} else {
err = r.iptablesClient.Append(tableFilter, chain, specs...)
}
if err != nil {
r.dropSourceMatch(destExp)
return nil, fmt.Errorf("install filter rule on %s: %w", chain, err)
}
// The mangle redirect-mark rule is best effort: the filter rule itself
// is what enforces the ACL, so a mangle failure must not undo it. Drop
// the spec so teardown does not try to remove a rule that was not added.
if mangleSpecs != nil {
if err := r.iptablesClient.Append(tableMangle, chainRTPre, mangleSpecs...); err != nil {
log.Errorf("add mangle rule: %v", err)
mangleSpecs = nil
}
}
return &Rule{
id: ruleID,
specs: specs,
mangleSpecs: mangleSpecs,
chain: chain,
v6: r.v6,
}, nil
}
// applyNetwork resolves a firewall.Network into the iptables match
// fragment for the given direction flag (-s or -d). Set networks
// increment the shared ipset refcount; prefixes emit a direct match;
// an empty network returns no spec ("match any").
func (r *family) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
direction := "src"
if flag == "-d" {
direction = "dst"
}
if network.IsSet() {
name := r.ipsetName(network.Set.HashedName())
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
return []string{"-m", "set", matchSet, name, direction}, nil
}
if network.IsPrefix() {
return []string{flag, network.Prefix.String()}, nil
}
// nolint:nilnil
return nil, nil
}
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
if v6 && protocol == firewall.ProtocolICMP {
return "ipv6-icmp"
}
return string(protocol)
}
// filterMatchSpecs returns the proto/port match fragment for a
// filtering rule. The source match (-s or -m set) is built by the
// caller and prepended.
func filterMatchSpecs(protocol string, sPort, dPort *firewall.Port) (specs []string) {
if protocol != "all" {
specs = append(specs, "-p", protocol)
}
specs = append(specs, applyPort("--sport", sPort)...)
specs = append(specs, applyPort("--dport", dPort)...)
return specs
}
func actionToStr(action firewall.Action) string {
if action == firewall.ActionAccept {
return "ACCEPT"
}
return "DROP"
}
func applyPort(flag string, port *firewall.Port) []string {
if port == nil {
return nil
}
if port.IsRange && len(port.Values) == 2 {
return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])}
}
if len(port.Values) > 1 {
portList := make([]string, len(port.Values))
for i, p := range port.Values {
portList[i] = strconv.Itoa(int(p))
}
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
}
return []string{flag, strconv.Itoa(int(port.Values[0]))}
}

View File

@@ -1,104 +0,0 @@
//go:build !android
package iptables
import (
"fmt"
"net/netip"
"github.com/hashicorp/go-multierror"
"github.com/lrh3321/ipset-go"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
func (r *family) createIpSet(setName string, sources []netip.Prefix) error {
if err := r.createIPSet(setName); err != nil {
return fmt.Errorf("create set %s: %w", setName, err)
}
for _, prefix := range sources {
if err := r.addPrefixToIPSet(setName, prefix); err != nil {
// The refcounter records nothing when this callback errors,
// so destroy the set or it leaks in the kernel. A partial
// source set would also fail-open for deny rules, so the
// rule must fail rather than install with a missing source.
if derr := r.destroyIPSet(setName); derr != nil {
log.Warnf("rollback ipset %s after add failure: %v", setName, derr)
}
return fmt.Errorf("add element to set %s: %w", setName, err)
}
}
return nil
}
func (r *family) deleteIpSet(setName string) error {
if err := r.destroyIPSet(setName); err != nil {
return fmt.Errorf("destroy set %s: %w", setName, err)
}
log.Debugf("deleted unused ipset %s", setName)
return nil
}
func (r *family) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
name := r.ipsetName(set.HashedName())
var merr *multierror.Error
for _, prefix := range prefixes {
if err := r.addPrefixToIPSet(name, prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
}
}
if merr == nil {
log.Debugf("updated set %s with prefixes %v", name, prefixes)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) ipsetName(name string) string {
if r.v6 {
return name + "-v6"
}
return name
}
func (r *family) createIPSet(name string) error {
opts := ipset.CreateOptions{
Replace: true,
}
if r.v6 {
opts.Family = ipset.FamilyIPV6
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err)
}
log.Debugf("created ipset %s with type hash:net", name)
return nil
}
func (r *family) addPrefixToIPSet(name string, prefix netip.Prefix) error {
addr := prefix.Addr()
ip := addr.AsSlice()
entry := &ipset.Entry{
IP: ip,
CIDR: uint8(prefix.Bits()),
Replace: true,
}
if err := ipset.Add(name, entry); err != nil {
return fmt.Errorf("add prefix to ipset %s: %w", name, err)
}
return nil
}
func (r *family) destroyIPSet(name string) error {
return ipset.Destroy(name)
}

View File

@@ -3,6 +3,7 @@ package iptables
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
@@ -17,21 +18,25 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// Manager of iptables firewall. Per-family state (peer ACLs, route
// ACLs, NAT, DNAT, MSS clamping) lives on family; Manager dispatches
// by family and provides the public firewall.Manager surface.
type resetter interface {
Reset() error
}
// Manager of iptables firewall
type Manager struct {
mutex sync.Mutex
wgIface iFaceMapper
ipv4Client *iptables.IPTables
family4 *family
aclMgr *aclManager
router *router
rawSupported bool
// IPv6 counterparts, nil when no v6 overlay
ipv6Client *iptables.IPTables
family6 *family
aclMgr6 *aclManager
router6 *router
}
// iFaceMapper defines subset methods of interface required for manager
@@ -52,9 +57,14 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
ipv4Client: iptablesClient,
}
m.family4, err = newFamily(iptablesClient, wgIface, mtu)
m.router, err = newRouter(iptablesClient, wgIface, mtu)
if err != nil {
return nil, fmt.Errorf("create family: %w", err)
return nil, fmt.Errorf("create router: %w", err)
}
m.aclMgr, err = newAclManager(iptablesClient, wgIface)
if err != nil {
return nil, fmt.Errorf("create acl manager: %w", err)
}
if wgIface.Address().HasIPv6() {
@@ -71,18 +81,21 @@ func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error {
if err != nil {
return fmt.Errorf("init ip6tables: %w", err)
}
m.ipv6Client = ip6Client
family6, err := newFamily(ip6Client, wgIface, mtu)
m.router6, err = newRouter(ip6Client, wgIface, mtu)
if err != nil {
return fmt.Errorf("create v6 family: %w", err)
return fmt.Errorf("create v6 router: %w", err)
}
// Share the same IP forwarding state with the v4 family, since
// Share the same IP forwarding state with the v4 router, since
// EnableIPForwarding controls both v4 and v6 sysctls.
family6.ipFwdState = m.family4.ipFwdState
m.router6.ipFwdState = m.router.ipFwdState
m.ipv6Client = ip6Client
m.family6 = family6
m.aclMgr6, err = newAclManager(ip6Client, wgIface)
if err != nil {
return fmt.Errorf("create v6 acl manager: %w", err)
}
return nil
}
@@ -96,7 +109,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
MTU: m.family4.mtu,
MTU: m.router.mtu,
},
}
stateManager.RegisterState(state)
@@ -128,24 +141,31 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
return nil
}
// initChains initializes the per-family firewall state for both
// address families, rolling back on failure.
// initChains initializes router and ACL chains for both address families,
// rolling back on failure.
func (m *Manager) initChains(stateManager *statemanager.Manager) error {
type initStep struct {
name string
r *family
init func(*statemanager.Manager) error
mgr resetter
}
steps := []initStep{{"v4", m.family4}}
steps := []initStep{
{"router", m.router.init, m.router},
{"acl manager", m.aclMgr.init, m.aclMgr},
}
if m.hasIPv6() {
steps = append(steps, initStep{"v6", m.family6})
steps = append(steps,
initStep{"v6 router", m.router6.init, m.router6},
initStep{"v6 acl manager", m.aclMgr6.init, m.aclMgr6},
)
}
var initialized []initStep
for _, s := range steps {
if err := s.r.init(stateManager); err != nil {
if err := s.init(stateManager); err != nil {
for i := len(initialized) - 1; i >= 0; i-- {
if rerr := initialized[i].r.Reset(); rerr != nil {
if rerr := initialized[i].mgr.Reset(); rerr != nil {
log.Warnf("rollback %s: %v", initialized[i].name, rerr)
}
}
@@ -156,50 +176,84 @@ func (m *Manager) initChains(stateManager *statemanager.Manager) error {
return nil
}
// AddFilterRule installs a packet-filtering rule. See firewall.Manager
// docs for destination semantics. Sources are a single address family;
// the rule is dispatched to the matching v4 / v6 backend.
func (m *Manager) AddFilterRule(
// AddPeerFiltering adds a rule to the firewall
//
// Comment will be ignored because some system this feature is not supported
func (m *Manager) AddPeerFiltering(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
if len(sources) == 0 {
return nil, firewall.ErrNoSources
}
ipsetName string,
) ([]firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
fam := m.family4
if isIPv6Rule(sources, destination) {
if !m.hasIPv6() {
return nil, fmt.Errorf("add filtering: %w", firewall.ErrIPv6NotInitialized)
}
fam = m.family6
if ip.To4() != nil {
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
return fam.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
if !m.hasIPv6() {
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
}
return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
// DeleteFilterRule removes a rule previously added via AddFilterRule.
// The rule is looked up by id in each family's filter cache.
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
id := rule.ID()
if m.family4.hasRule(id) {
return m.family4.DeleteFilterRule(rule)
if isIPv6RouteRule(sources, destination) {
if !m.hasIPv6() {
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
if m.hasIPv6() && m.family6.hasRule(id) {
return m.family6.DeleteFilterRule(rule)
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
log.Debugf("filter rule %s not found in any family", id)
return nil
return len(sources) > 0 && sources[0].Addr().Is6()
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && isIPv6IptRule(rule) {
return m.aclMgr6.DeletePeerRule(rule)
}
return m.aclMgr.DeletePeerRule(rule)
}
func isIPv6IptRule(rule firewall.Rule) bool {
r, ok := rule.(*Rule)
return ok && r.v6
}
// DeleteRouteRule deletes a routing rule.
// Route rules are keyed by content hash. Check v4 first, try v6 if not found.
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
return m.router6.DeleteRouteRule(rule)
}
return m.router.DeleteRouteRule(rule)
}
func (m *Manager) IsServerRouteSupported() bool {
@@ -218,10 +272,10 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if !m.hasIPv6() {
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.AddNatRule(pair)
return m.router6.AddNatRule(pair)
}
if err := m.family4.AddNatRule(pair); err != nil {
if err := m.router.AddNatRule(pair); err != nil {
return err
}
@@ -230,7 +284,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
// wildcard 0.0.0.0/0 destination where the client resolves DNS.
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.family6.AddNatRule(v6Pair); err != nil {
if err := m.router6.AddNatRule(v6Pair); err != nil {
return fmt.Errorf("add v6 NAT rule: %w", err)
}
}
@@ -246,18 +300,18 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
if !m.hasIPv6() {
return nil
}
return m.family6.RemoveNatRule(pair)
return m.router6.RemoveNatRule(pair)
}
var merr *multierror.Error
if err := m.family4.RemoveNatRule(pair); err != nil {
if err := m.router.RemoveNatRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
}
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.family6.RemoveNatRule(v6Pair); err != nil {
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
}
}
@@ -266,11 +320,11 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
}
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if err := firewall.SetLegacyManagement(m.family4, isLegacy); err != nil {
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
return err
}
if m.hasIPv6() {
return firewall.SetLegacyManagement(m.family6, isLegacy)
return firewall.SetLegacyManagement(m.router6, isLegacy)
}
return nil
}
@@ -287,13 +341,19 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
}
if m.hasIPv6() {
if err := m.family6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 family: %w", err))
if err := m.aclMgr6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err))
}
if err := m.router6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err))
}
}
if err := m.family4.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset family: %w", err))
if err := m.aclMgr.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
}
if err := m.router.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
}
// Appending to merr intentionally blocks DeleteState below so ShutdownState
@@ -317,11 +377,11 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
// rules so that packet filtering is handled in userspace instead of by netfilter.
func (m *Manager) AllowNetbird() error {
var merr *multierror.Error
if _, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}, firewall.Network{}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept); err != nil {
if _, err := m.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
merr = multierror.Append(merr, fmt.Errorf("allow netbird v4 interface traffic: %w", err))
}
if m.hasIPv6() {
if _, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}, firewall.Network{}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept); err != nil {
if _, err := m.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
merr = multierror.Append(merr, fmt.Errorf("allow netbird v6 interface traffic: %w", err))
}
}
@@ -342,14 +402,14 @@ func (m *Manager) SetLogLevel(log.Level) {
}
func (m *Manager) EnableRouting() error {
if err := m.family4.ipFwdState.RequestForwarding(); err != nil {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil
}
func (m *Manager) DisableRouting() error {
if err := m.family4.ipFwdState.ReleaseForwarding(); err != nil {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil
@@ -364,9 +424,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
if !m.hasIPv6() {
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.AddDNATRule(rule)
return m.router6.AddDNATRule(rule)
}
return m.family4.AddDNATRule(rule)
return m.router.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
@@ -374,10 +434,10 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && !m.family4.hasDNATRule(rule.ID()) {
return m.family6.DeleteDNATRule(rule)
if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) {
return m.router6.DeleteDNATRule(rule)
}
return m.family4.DeleteDNATRule(rule)
return m.router.DeleteDNATRule(rule)
}
// UpdateSet updates the set with the given prefixes
@@ -394,12 +454,12 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
}
}
if err := m.family4.UpdateSet(set, v4Prefixes); err != nil {
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
return err
}
if m.hasIPv6() && len(v6Prefixes) > 0 {
if err := m.family6.UpdateSet(set, v6Prefixes); err != nil {
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
return fmt.Errorf("update v6 set: %w", err)
}
}
@@ -416,9 +476,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
if !m.hasIPv6() {
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.family4.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
@@ -430,9 +490,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
if !m.hasIPv6() {
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.family4.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
@@ -444,9 +504,9 @@ func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol
if !m.hasIPv6() {
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.family4.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
@@ -458,14 +518,14 @@ func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Proto
if !m.hasIPv6() {
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.family4.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
const (
chainNameRaw = "NETBIRD-RAW"
chainOutput = "OUTPUT"
chainOUTPUT = "OUTPUT"
tableRaw = "raw"
)
@@ -540,15 +600,15 @@ func (m *Manager) initNoTrackChain() error {
jumpRule := []string{"-j", chainNameRaw}
if err := m.ipv4Client.InsertUnique(tableRaw, chainOutput, 1, jumpRule...); err != nil {
if err := m.ipv4Client.InsertUnique(tableRaw, chainOUTPUT, 1, jumpRule...); err != nil {
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
log.Debugf("delete orphan chain: %v", delErr)
}
return fmt.Errorf("add output jump rule: %w", err)
}
if err := m.ipv4Client.InsertUnique(tableRaw, chainPrerouting, 1, jumpRule...); err != nil {
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOutput, jumpRule...); delErr != nil {
if err := m.ipv4Client.InsertUnique(tableRaw, chainPREROUTING, 1, jumpRule...); err != nil {
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); delErr != nil {
log.Debugf("delete output jump rule: %v", delErr)
}
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
@@ -575,11 +635,11 @@ func (m *Manager) cleanupNoTrackChain() error {
jumpRule := []string{"-j", chainNameRaw}
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOutput, jumpRule...); err != nil {
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); err != nil {
return fmt.Errorf("remove output jump rule: %w", err)
}
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPrerouting, jumpRule...); err != nil {
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPREROUTING, jumpRule...); err != nil {
return fmt.Errorf("remove prerouting jump rule: %w", err)
}
@@ -594,13 +654,3 @@ func (m *Manager) cleanupNoTrackChain() error {
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}
// isIPv6Rule reports whether the rule belongs to the IPv6 family, from
// the destination prefix when set, otherwise from the (single-family)
// sources.
func isIPv6Rule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
return len(sources) > 0 && sources[0].Addr().Is6()
}

View File

@@ -1,5 +1,3 @@
//go:build integration && !android
package iptables
import (
@@ -67,39 +65,46 @@ func TestIptablesManager(t *testing.T) {
time.Sleep(time.Second)
}()
var rule2 fw.Rule
var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{
IsRange: true,
Values: []uint16{8043, 8046},
}
rule2, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", port, nil, fw.ActionAccept)
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
rr := rule2.(*Rule)
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
for _, r := range rule2 {
rr := r.(*Rule)
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
}
})
t.Run("delete second rule", func(t *testing.T) {
require.NoError(t, manager.DeleteFilterRule(rule2), "failed to delete rule")
for _, r := range rule2 {
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
}
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
})
t.Run("reset check", func(t *testing.T) {
// add second rule
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{Values: []uint16{5353}}
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "udp", nil, port, fw.ActionAccept)
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
err = manager.Close(nil)
require.NoError(t, err, "failed to reset")
ok, err := ipv4Client.ChainExists("filter", chainACLInput)
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
require.NoError(t, err, "failed check chain exists")
if ok {
require.NoErrorf(t, err, "chain '%v' still exists after Close", chainACLInput)
require.NoErrorf(t, err, "chain '%v' still exists after Close", chainNameInputRules)
}
})
}
@@ -121,13 +126,15 @@ func TestIptablesManagerDenyRules(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{Values: []uint16{22}}
rule, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionDrop)
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-ssh")
require.NoError(t, err, "failed to add deny rule")
require.NotNil(t, rule, "deny rule should not be nil")
require.NotEmpty(t, rule, "deny rule should not be empty")
// Verify the rule was added by checking iptables
rr := rule.(*Rule)
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
for _, r := range rule {
rr := r.(*Rule)
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
}
})
t.Run("deny rule precedence test", func(t *testing.T) {
@@ -135,40 +142,36 @@ func TestIptablesManagerDenyRules(t *testing.T) {
port := &fw.Port{Values: []uint16{80}}
// Add accept rule first
_, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
_, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "accept-http")
require.NoError(t, err, "failed to add accept rule")
// Add deny rule second for same IP/port - this should take precedence
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionDrop)
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-http")
require.NoError(t, err, "failed to add deny rule")
// Inspect the actual iptables rules to verify deny rule comes before accept rule
rules, err := ipv4Client.List("filter", chainACLInput)
rules, err := ipv4Client.List("filter", chainNameInputRules)
require.NoError(t, err, "failed to list iptables rules")
// Debug: print all rules
t.Logf("All iptables rules in chain %s:", chainACLInput)
t.Logf("All iptables rules in chain %s:", chainNameInputRules)
for i, rule := range rules {
t.Logf(" [%d] %s", i, rule)
}
// Single-source rules emit a direct `-s <ip>/32 ... --dport 80`
// match. Match on that shape instead of the legacy
// per-(action,port) ipset names ("deny-http"/"accept-http")
// that this test predates.
srcMatch := fmt.Sprintf("-s %s/32", ip)
var denyRuleIndex, acceptRuleIndex = -1, -1
for i, rule := range rules {
if !strings.Contains(rule, srcMatch) || !strings.Contains(rule, "--dport 80") {
continue
}
if strings.Contains(rule, "-j DROP") {
if strings.Contains(rule, "DROP") {
t.Logf("Found DROP rule at index %d: %s", i, rule)
denyRuleIndex = i
if strings.Contains(rule, "deny-http") && strings.Contains(rule, "80") {
denyRuleIndex = i
}
}
if strings.Contains(rule, "-j ACCEPT") {
if strings.Contains(rule, "ACCEPT") {
t.Logf("Found ACCEPT rule at index %d: %s", i, rule)
acceptRuleIndex = i
if strings.Contains(rule, "accept-http") && strings.Contains(rule, "80") {
acceptRuleIndex = i
}
}
}
@@ -193,6 +196,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
},
}
// just check on the local interface
manager, err := Create(mock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
@@ -206,39 +210,27 @@ func TestIptablesManagerIPSet(t *testing.T) {
time.Sleep(time.Second)
}()
var rule2 fw.Rule
t.Run("single source uses direct -s match (no ipset)", func(t *testing.T) {
var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{
Values: []uint16{443},
}
rule2, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", port, nil, fw.ActionAccept)
require.NoError(t, err, "failed to add rule")
require.NotNil(t, rule2)
require.Contains(t, rule2.(*Rule).specs, "-s",
"single-source rule should use direct -s match, not an ipset")
require.Empty(t, findSets(rule2.(*Rule).specs),
"single-source rule should not allocate a shared ipset")
})
t.Run("delete single-source rule", func(t *testing.T) {
require.NoError(t, manager.DeleteFilterRule(rule2), "failed to delete rule")
})
t.Run("multi-source uses shared ipset", func(t *testing.T) {
sources := []netip.Prefix{
netip.PrefixFrom(netip.MustParseAddr("10.20.0.3"), 32),
netip.PrefixFrom(netip.MustParseAddr("10.20.0.4"), 32),
netip.PrefixFrom(netip.MustParseAddr("10.20.0.5"), 32),
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
for _, r := range rule2 {
require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
require.Equal(t, r.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
}
port := &fw.Port{Values: []uint16{8080}}
multi, err := manager.AddFilterRule(nil, sources, fw.Network{}, "tcp", nil, port, fw.ActionAccept)
require.NoError(t, err, "failed to add multi-source rule")
require.NotNil(t, multi, "multi-source rule must produce one iptables rule")
sets := findSets(multi.(*Rule).specs)
require.Len(t, sets, 1, "multi-source rule must reference exactly one ipset")
})
require.NoError(t, manager.DeleteFilterRule(multi))
t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 {
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
}
})
t.Run("reset check", func(t *testing.T) {
@@ -289,7 +281,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
//go:build integration && !android
//go:build !android
package iptables
@@ -31,7 +31,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "should return a valid iptables manager")
require.NoError(t, manager.init(nil))
@@ -52,12 +52,12 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
// 11. MSS clamping rule for outbound traffic
require.Len(t, manager.rules, 11, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPostrouting, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPostrouting)
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
require.True(t, exists, "postrouting jump rule should exist")
exists, err = manager.iptablesClient.Exists(tableMangle, chainPrerouting, "-j", chainRTPre)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPrerouting)
exists, err = manager.iptablesClient.Exists(tableMangle, chainPREROUTING, "-j", chainRTPRE)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
require.True(t, exists, "prerouting jump rule should exist")
pair := firewall.RouterPair{
@@ -84,7 +84,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
@@ -95,7 +95,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "marking rule should be inserted")
natRuleKey := testCase.InputPair.GenKey(firewall.NatFormat)
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
markingRule := []string{
"-i", ifaceMock.Name(),
"-m", "conntrack",
@@ -106,8 +106,8 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
}
exists, err := iptablesClient.Exists(tableMangle, chainRTPre, markingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
if testCase.InputPair.Masquerade {
require.True(t, exists, "marking rule should be created")
foundRule, found := manager.rules[natRuleKey]
@@ -121,7 +121,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
// Check inverse rule
inversePair := firewall.GetInversePair(testCase.InputPair)
inverseRuleKey := inversePair.GenKey(firewall.NatFormat)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
@@ -132,8 +132,8 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
exists, err = iptablesClient.Exists(tableMangle, chainRTPre, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
if testCase.InputPair.Masquerade {
require.True(t, exists, "inverse marking rule should be created")
foundRule, found := manager.rules[inverseRuleKey]
@@ -157,7 +157,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() {
@@ -170,7 +170,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
natRuleKey := testCase.InputPair.GenKey(firewall.NatFormat)
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
markingRule := []string{
"-i", ifaceMock.Name(),
"-m", "conntrack",
@@ -181,8 +181,8 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
}
exists, err := iptablesClient.Exists(tableMangle, chainRTPre, markingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
require.False(t, exists, "marking rule should not exist")
_, found := manager.rules[natRuleKey]
@@ -190,7 +190,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
// Check inverse rule removal
inversePair := firewall.GetInversePair(testCase.InputPair)
inverseRuleKey := inversePair.GenKey(firewall.NatFormat)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
@@ -201,8 +201,8 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
exists, err = iptablesClient.Exists(tableMangle, chainRTPre, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
require.False(t, exists, "inverse marking rule should not exist")
_, found = manager.rules[inverseRuleKey]
@@ -219,13 +219,13 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "Failed to create iptables client")
r, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create family manager")
r, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router manager")
require.NoError(t, r.init(nil))
defer func() {
err := r.Reset()
require.NoError(t, err, "Failed to reset family")
require.NoError(t, err, "Failed to reset router")
}()
tests := []struct {
@@ -334,30 +334,62 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddFilterRule(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddFilterRule failed")
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
stored, ok := r.filters[ruleKey.ID()]
require.True(t, ok, "rule not stored in filters")
t.Logf("Internal rule: %v", stored.specs)
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "Rule not found in internal map")
exists, err := iptablesClient.Exists(tableFilter, chainRTFwdIn, stored.specs...)
// Log the internal rule
t.Logf("Internal rule: %v", rule)
// Check if the rule exists in iptables
exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, rule...)
assert.NoError(t, err, "Failed to check rule existence")
assert.True(t, exists, "Rule not found in iptables")
if tt.expectSet {
setName := firewall.NewPrefixSet(tt.sources).HashedName()
_, exists := r.ipsetCounter.Get(setName)
assert.True(t, exists, "IPSet not created")
assert.NotEmpty(t, findSets(stored.specs), "Rule should reference an ipset")
var source firewall.Network
if len(tt.sources) > 1 {
source.Set = firewall.NewPrefixSet(tt.sources)
} else if len(tt.sources) > 0 {
source.Prefix = tt.sources[0]
}
// Verify rule content
params := routeFilteringRuleParams{
Source: source,
Destination: firewall.Network{Prefix: tt.destination},
Proto: tt.proto,
SPort: tt.sPort,
DPort: tt.dPort,
Action: tt.action,
}
require.NoError(t, r.DeleteFilterRule(ruleKey), "Failed to delete rule")
expectedRule, err := r.genRouteRuleSpec(params, nil)
require.NoError(t, err, "Failed to generate expected rule spec")
if tt.expectSet {
setName := firewall.NewPrefixSet(tt.sources).HashedName()
expectedRule, err = r.genRouteRuleSpec(params, nil)
require.NoError(t, err, "Failed to generate expected rule spec with set")
// Check if the set was created
_, exists := r.ipsetCounter.Get(setName)
assert.True(t, exists, "IPSet not created")
}
assert.Equal(t, expectedRule, rule, "Rule content mismatch")
// Clean up
err = r.DeleteRouteRule(ruleKey)
require.NoError(t, err, "Failed to delete rule")
})
}
}
func TestFindSetNameInRule(t *testing.T) {
r := &router{}
testCases := []struct {
name string
rule []string
@@ -398,7 +430,7 @@ func TestFindSetNameInRule(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := findSets(tc.rule)
result := r.findSets(tc.rule)
if len(result) != len(tc.expected) {
t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result)

View File

@@ -1,269 +0,0 @@
//go:build !android
package iptables
import (
"fmt"
"strings"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/client/net"
)
func (r *family) AddNatRule(pair firewall.RouterPair) error {
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
return fmt.Errorf("add legacy routing rule: %w", err)
}
}
if !pair.Masquerade {
return nil
}
if err := r.addNatRule(pair); err != nil {
return fmt.Errorf("add nat rule: %w", err)
}
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("add inverse nat rule: %w", err)
}
r.updateState()
return nil
}
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *family) RemoveNatRule(pair firewall.RouterPair) error {
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err)
}
}
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
r.updateState()
return nil
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *family) addLegacyRouteRule(pair firewall.RouterPair) error {
ruleID := pair.GenKey(firewall.ForwardingFormat)
if err := r.removeLegacyRouteRule(pair); err != nil {
return err
}
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", "ACCEPT"}
if err := r.iptablesClient.Append(tableFilter, chainRTFwdIn, rule...); err != nil {
return fmt.Errorf("add legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
r.rules[ruleID] = rule
return nil
}
func (r *family) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleID := pair.GenKey(firewall.ForwardingFormat)
if rule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFwdIn, rule...); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
delete(r.rules, ruleID)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement ipset counter: %w", err)
}
}
return nil
}
// GetLegacyManagement returns the current legacy management mode
func (r *family) GetLegacyManagement() bool {
return r.legacyManagement
}
// SetLegacyManagement sets the route manager to use legacy management mode
func (r *family) SetLegacyManagement(isLegacy bool) {
r.legacyManagement = isLegacy
}
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
func (r *family) RemoveAllLegacyRouteRules() error {
var merr *multierror.Error
for k, rule := range r.rules {
if !strings.HasPrefix(string(k), firewall.ForwardingFormatPrefix) {
continue
}
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFwdIn, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %w", err))
} else {
delete(r.rules, k)
}
}
r.updateState()
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) addPostroutingRules() error {
// First rule for outbound masquerade
rule1 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
"!", "-o", "lo",
"-j", "MASQUERADE",
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
return fmt.Errorf("add outbound masquerade rule: %w", err)
}
r.rules["static-nat-outbound"] = rule1
// Second rule for return traffic masquerade
rule2 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
"-o", r.wgIface.Name(),
"-j", "MASQUERADE",
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
return fmt.Errorf("add return masquerade rule: %w", err)
}
r.rules["static-nat-return"] = rule2
return nil
}
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
func (r *family) addMSSClampingRules() error {
overhead := uint16(ipv4TCPHeaderSize)
if r.v6 {
overhead = ipv6TCPHeaderSize
}
mss := r.mtu - overhead
// Add jump rule from FORWARD chain in mangle table to our custom chain
jumpRule := []string{
"-j", chainRTMSSClamp,
}
if err := r.iptablesClient.Insert(tableMangle, chainForward, 1, jumpRule...); err != nil {
return fmt.Errorf("add jump to MSS clamp chain: %w", err)
}
r.rules[jumpMSSClamp] = jumpRule
ruleOut := []string{
"-o", r.wgIface.Name(),
"-p", "tcp",
"--tcp-flags", "SYN,RST", "SYN",
"-j", "TCPMSS",
"--set-mss", fmt.Sprintf("%d", mss),
}
if err := r.iptablesClient.Append(tableMangle, chainRTMSSClamp, ruleOut...); err != nil {
return fmt.Errorf("add outbound MSS clamp rule: %w", err)
}
r.rules["mss-clamp-out"] = ruleOut
return nil
}
func (r *family) insertEstablishedRule(chain string) error {
establishedRule := getConntrackEstablished()
err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...)
if err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
ruleID := firewall.RuleID("established-" + chain)
r.rules[ruleID] = establishedRule
return nil
}
func (r *family) addNatRule(pair firewall.RouterPair) error {
ruleID := pair.GenKey(firewall.NatFormat)
if rule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPre, rule...); err != nil {
return fmt.Errorf("remove existing marking rule for %s: %w", pair.Destination, err)
}
delete(r.rules, ruleID)
}
markValue := nbnet.PreroutingFwmarkMasquerade
if pair.Inverse {
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
rule := []string{"-i", r.wgIface.Name()}
if pair.Inverse {
rule = []string{"!", "-i", r.wgIface.Name()}
}
rule = append(rule,
"-m", "conntrack",
"--ctstate", "NEW",
)
sourceExp, err := r.applyNetwork("-s", pair.Source, nil)
if err != nil {
return fmt.Errorf("apply network -s: %w", err)
}
destExp, err := r.applyNetwork("-d", pair.Destination, nil)
if err != nil {
return fmt.Errorf("apply network -d: %w", err)
}
rule = append(rule, sourceExp...)
rule = append(rule, destExp...)
rule = append(rule,
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
)
// Ensure nat rules come first, so the mark can be overwritten.
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
if err := r.iptablesClient.Insert(tableMangle, chainRTPre, 1, rule...); err != nil {
r.dropSourceMatch(rule)
return fmt.Errorf("add marking rule for %s: %w", pair.Destination, err)
}
r.rules[ruleID] = rule
r.updateState()
return nil
}
func (r *family) removeNatRule(pair firewall.RouterPair) error {
ruleID := pair.GenKey(firewall.NatFormat)
if rule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPre, rule...); err != nil {
return fmt.Errorf("remove marking rule for %s: %w", pair.Destination, err)
}
delete(r.rules, ruleID)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement ipset counter: %w", err)
}
} else {
log.Debugf("marking rule %s not found", ruleID)
}
r.updateState()
return nil
}

View File

@@ -1,20 +1,18 @@
package iptables
import "github.com/netbirdio/netbird/client/firewall/manager"
// Rule to handle management of rules. Source set membership (when the
// rule was built against a shared hash:net ipset) is encoded in specs;
// DeleteFilterRule recovers it via findSets so the refcounter can drop
// the right reference.
// Rule to handle management of rules
type Rule struct {
id manager.RuleID
ruleID string
ipsetName string
specs []string
mangleSpecs []string
ip string
chain string
v6 bool
}
// ID returns the rule id
func (r *Rule) ID() manager.RuleID {
return r.id
// GetRuleID returns the rule id
func (r *Rule) ID() string {
return r.ruleID
}

View File

@@ -0,0 +1,103 @@
package iptables
import "encoding/json"
type ipList struct {
ips map[string]struct{}
}
func newIpList(ip string) *ipList {
ips := make(map[string]struct{})
ips[ip] = struct{}{}
return &ipList{
ips: ips,
}
}
func (s *ipList) addIP(ip string) {
s.ips[ip] = struct{}{}
}
// MarshalJSON implements json.Marshaler
func (s *ipList) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPs map[string]struct{} `json:"ips"`
}{
IPs: s.ips,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipList) UnmarshalJSON(data []byte) error {
temp := struct {
IPs map[string]struct{} `json:"ips"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ips = temp.IPs
if temp.IPs == nil {
temp.IPs = make(map[string]struct{})
}
return nil
}
type ipsetStore struct {
ipsets map[string]*ipList
}
func newIpsetStore() *ipsetStore {
return &ipsetStore{
ipsets: make(map[string]*ipList),
}
}
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
r, ok := s.ipsets[ipsetName]
return r, ok
}
func (s *ipsetStore) addIpList(ipsetName string, list *ipList) {
s.ipsets[ipsetName] = list
}
func (s *ipsetStore) deleteIpset(ipsetName string) {
delete(s.ipsets, ipsetName)
}
func (s *ipsetStore) ipsetNames() []string {
names := make([]string, 0, len(s.ipsets))
for name := range s.ipsets {
names = append(names, name)
}
return names
}
// MarshalJSON implements json.Marshaler
func (s *ipsetStore) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPSets map[string]*ipList `json:"ipsets"`
}{
IPSets: s.ipsets,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipsetStore) UnmarshalJSON(data []byte) error {
temp := struct {
IPSets map[string]*ipList `json:"ipsets"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ipsets = temp.IPSets
if temp.IPSets == nil {
temp.IPSets = make(map[string]*ipList)
}
return nil
}

View File

@@ -29,13 +29,17 @@ type ShutdownState struct {
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
RouteRules routeRules `json:"route_rules,omitempty"`
RouteRules6 routeRules `json:"route_rules_v6,omitempty"`
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"`
RouteRules routeRules `json:"route_rules,omitempty"`
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
ACLEntries aclEntries `json:"acl_entries,omitempty"`
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
ACLEntries aclEntries `json:"acl_entries,omitempty"`
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
// IPv6 counterparts
RouteRules6 routeRules `json:"route_rules_v6,omitempty"`
RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"`
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"`
}
func (s *ShutdownState) Name() string {
@@ -53,14 +57,17 @@ func (s *ShutdownState) Cleanup() error {
}
if s.RouteRules != nil {
ipt.family4.rules = s.RouteRules
ipt.router.rules = s.RouteRules
}
if s.RouteIPsetCounter != nil {
ipt.family4.ipsetCounter.LoadData(s.RouteIPsetCounter)
ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter)
}
if s.ACLEntries != nil {
ipt.family4.entries = s.ACLEntries
ipt.aclMgr.entries = s.ACLEntries
}
if s.ACLIPsetStore != nil {
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
}
// Clean up v6 state even if the current run has no IPv6.
@@ -72,13 +79,16 @@ func (s *ShutdownState) Cleanup() error {
}
if ipt.hasIPv6() {
if s.RouteRules6 != nil {
ipt.family6.rules = s.RouteRules6
ipt.router6.rules = s.RouteRules6
}
if s.RouteIPsetCounter6 != nil {
ipt.family6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
}
if s.ACLEntries6 != nil {
ipt.family6.entries = s.ACLEntries6
ipt.aclMgr6.entries = s.ACLEntries6
}
if s.ACLIPsetStore6 != nil {
ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6
}
}

View File

@@ -1,27 +0,0 @@
//go:build integration && !android
package iptables
import (
"fmt"
"net"
"net/netip"
)
func pfx(ip net.IP) []netip.Prefix {
if ip == nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
if ip.IsUnspecified() {
if ip.To4() != nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
}
a, ok := netip.AddrFromSlice(ip)
if !ok {
panic(fmt.Sprintf("invalid IP length: %d", len(ip)))
}
a = a.Unmap()
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
}

View File

@@ -3,6 +3,7 @@ package manager
import (
"errors"
"fmt"
"net"
"net/netip"
"sort"
@@ -15,12 +16,6 @@ import (
// method but the IPv6 firewall components were not initialized.
var ErrIPv6NotInitialized = errors.New("IPv6 firewall not initialized")
// ErrNoSources is returned when AddFilterRule is called with an empty
// source list. "Match any source" must be expressed explicitly with a
// /0 prefix; an empty list is a caller error and is rejected rather
// than silently widening the rule to every source.
var ErrNoSources = errors.New("rule has no sources")
const (
ForwardingFormatPrefix = "netbird-fwd-"
ForwardingFormat = "netbird-fwd-%s-%t"
@@ -28,18 +23,13 @@ const (
NatFormat = "netbird-nat-%s-%t"
)
// RuleID identifies a firewall rule. It is a typed string so the
// compiler catches accidental mixing with arbitrary string keys. It is
// only an identifier and does not implement Rule.
type RuleID string
// Rule abstraction should be implemented by each firewall manager
//
// Each firewall type for different OS can use different type
// of the properties to hold data of the created rule
type Rule interface {
// ID returns the rule id
ID() RuleID
ID() string
}
// RuleDirection is the traffic direction which a rule is applied
@@ -101,13 +91,6 @@ func (d Network) IsPrefix() bool {
return d.Prefix.IsValid()
}
// IsZero returns true if the network designates no destination, i.e. it
// is the zero value. A zero Network is the peer-rule sentinel; a non-zero
// one carries a prefix or set destination.
func (d Network) IsZero() bool {
return !d.IsPrefix() && !d.IsSet()
}
// Manager is the high level abstraction of a firewall manager
//
// It declares methods which handle actions required by the
@@ -118,42 +101,43 @@ type Manager interface {
// AllowNetbird allows netbird interface traffic
AllowNetbird() error
// AddFilterRule adds a packet-filtering rule to the firewall.
// AddPeerFiltering adds a rule to the firewall
//
// If destination is the zero Network, the rule applies to traffic
// inbound to this node, i.e. peer ACL semantics, installed in
// the kernel's input chain. If destination is set (prefix or
// set), the rule applies to forwarded traffic with that
// destination, route ACL semantics, installed in the forward
// chain.
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
//
// sources must be a single address family; the caller splits mixed
// families and calls once per family. "Match any source" must be
// expressed with an explicit /0 prefix; an empty sources list is
// rejected with ErrNoSources so a zeroed list can never widen a
// rule to every source.
//
// Note: callers should call Flush() after adding rules.
AddFilterRule(
// Note: Callers should call Flush() after adding rules to ensure
// they are applied to the kernel and rule handles are refreshed.
AddPeerFiltering(
id []byte,
sources []netip.Prefix,
destination Network,
ip net.IP,
proto Protocol,
sPort *Port,
dPort *Port,
action Action,
) (Rule, error)
ipsetName string,
) ([]Rule, error)
// DeleteFilterRule removes a filtering rule previously added via
// AddFilterRule. The rule's own type identifies whether it lives
// in the peer (input) or route (forward) path.
DeleteFilterRule(rule Rule) error
// DeletePeerRule from the firewall by rule definition
DeletePeerRule(rule Rule) error
// IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool
IsStateful() bool
AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination Network,
proto Protocol,
sPort, dPort *Port,
action Action,
) (Rule, error)
// DeleteRouteRule deletes a routing rule
DeleteRouteRule(rule Rule) error
// AddNatRule inserts a routing NAT rule
AddNatRule(pair RouterPair) error
@@ -201,9 +185,8 @@ type Manager interface {
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
}
// GenKey builds the rule id for this pair from the given format.
func (p RouterPair) GenKey(format string) RuleID {
return RuleID(fmt.Sprintf(format, p.ID, p.Inverse))
func GenKey(format string, pair RouterPair) string {
return fmt.Sprintf(format, pair.ID, pair.Inverse)
}
// LegacyManager defines the interface for legacy management operations
@@ -259,20 +242,6 @@ func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
return merged
}
// UnmapPrefix normalizes a v4-mapped v6 prefix (::ffff:a.b.c.d) to its
// plain v4 form, shifting the prefix length out of the 96-bit mapped
// range. Other prefixes are returned unchanged. Keeping prefixes
// unmapped ensures v4 rules match consistently and the match builders
// read the correct address length.
func UnmapPrefix(p netip.Prefix) netip.Prefix {
addr := p.Addr()
if !addr.Is4In6() {
return p
}
bits := max(p.Bits()-96, 0)
return netip.PrefixFrom(addr.Unmap(), bits)
}
// SortPrefixes sorts the given slice of netip.Prefix in place.
// It sorts first by IP address, then by prefix length (most specific to least specific).
func SortPrefixes(prefixes []netip.Prefix) {

View File

@@ -13,13 +13,13 @@ type ForwardRule struct {
TranslatedPort Port
}
func (r ForwardRule) ID() RuleID {
func (r ForwardRule) ID() string {
id := fmt.Sprintf("%s;%s;%s;%s",
r.Protocol,
r.DestinationPort.String(),
r.TranslatedAddress.String(),
r.TranslatedPort.String())
return RuleID(id)
return id
}
func (r ForwardRule) String() string {

View File

@@ -40,7 +40,7 @@ func (h Set) Comment() string {
// NewPrefixSet generates a unique name for an ipset based on the given prefixes.
func NewPrefixSet(prefixes []netip.Prefix) Set {
prefixes = slices.Clone(prefixes)
// sort for consistent naming
SortPrefixes(prefixes)
hash := sha256.New()

View File

@@ -0,0 +1,713 @@
package nftables
import (
"bytes"
"fmt"
"net"
"slices"
"strconv"
"strings"
"time"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (
// rules chains contains the effective ACL rules
chainNameInputRules = "netbird-acl-input-rules"
// filter chains contains the rules that jump to the rules chains
chainNameInputFilter = "netbird-acl-input-filter"
chainNameForwardFilter = "netbird-acl-forward-filter"
chainNameManglePrerouting = "netbird-mangle-prerouting"
chainNameManglePostrouting = "netbird-mangle-postrouting"
)
const flushError = "flush: %w"
type AclManager struct {
rConn *nftables.Conn
sConn *nftables.Conn
wgIface iFaceMapper
routingFwChainName string
af addrFamily
workTable *nftables.Table
chainInputRules *nftables.Chain
chainPrerouting *nftables.Chain
ipsetStore *ipsetStore
rules map[string]*Rule
}
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
// sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation)
// and is permanent. Using same connection for both type of operations
// overloads netlink with high amount of rules ( > 10000)
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return nil, fmt.Errorf("create nf conn: %w", err)
}
return &AclManager{
rConn: &nftables.Conn{},
sConn: sConn,
wgIface: wgIface,
workTable: table,
routingFwChainName: routingFwChainName,
af: familyForAddr(table.Family == nftables.TableFamilyIPv4),
ipsetStore: newIpsetStore(),
rules: make(map[string]*Rule),
}, nil
}
func (m *AclManager) init(workTable *nftables.Table) error {
m.workTable = workTable
return m.createDefaultChains()
}
// AddPeerFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *AclManager) AddPeerFiltering(
id []byte,
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
var ipset *nftables.Set
if ipsetName != "" {
var err error
ipset, err = m.addIpToSet(ipsetName, ip)
if err != nil {
return nil, err
}
}
newRules := make([]firewall.Rule, 0, 2)
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset)
if err != nil {
return nil, err
}
newRules = append(newRules, ioRule)
return newRules, nil
}
// DeletePeerRule from the firewall by rule definition
func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
}
if r.nftSet == nil {
if err := m.rConn.DelRule(r.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
if r.mangleRule != nil {
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
delete(m.rules, r.ID())
return m.rConn.Flush()
}
ips, ok := m.ipsetStore.ips(r.nftSet.Name)
if !ok {
if err := m.rConn.DelRule(r.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
if r.mangleRule != nil {
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
delete(m.rules, r.ID())
return m.rConn.Flush()
}
if _, ok := ips[r.ip.String()]; ok {
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}})
if err != nil {
log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err)
}
if err := m.sConn.Flush(); err != nil {
log.Debugf("flush error of set delete element, %s", r.nftSet.Name)
return err
}
m.ipsetStore.DeleteIpFromSet(r.nftSet.Name, r.ip)
}
// if after delete, set still contains other IPs,
// no need to delete firewall rule and we should exit here
if len(ips) > 0 {
return nil
}
if err := m.rConn.DelRule(r.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
if r.mangleRule != nil {
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
if err := m.rConn.Flush(); err != nil {
return err
}
delete(m.rules, r.ID())
m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name)
if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) {
return nil
}
// we delete last IP from the set, that means we need to delete
// set itself and associated firewall rule too
m.rConn.FlushSet(r.nftSet)
m.rConn.DelSet(r.nftSet)
m.ipsetStore.deleteIpset(r.nftSet.Name)
return nil
}
// createDefaultAllowRules creates default allow rules for the input and output chains
func (m *AclManager) createDefaultAllowRules() error {
expIn := []expr.Any{
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainInputRules,
Position: 0,
Exprs: expIn,
})
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
// Flush rule/chain/set operations from the buffer
//
// Method also get all rules after flush and refreshes handle values in the rulesets
func (m *AclManager) Flush() error {
if err := m.flushWithBackoff(); err != nil {
return err
}
if err := m.refreshRuleHandles(m.chainInputRules, false); err != nil {
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
}
if err := m.refreshRuleHandles(m.chainPrerouting, true); err != nil {
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
}
return nil
}
func (m *AclManager) addIOFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipset *nftables.Set,
) (*Rule, error) {
ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset)
if r, ok := m.rules[ruleId]; ok {
return &Rule{
nftRule: r.nftRule,
mangleRule: r.mangleRule,
nftSet: r.nftSet,
ruleID: r.ruleID,
ip: ip,
}, nil
}
var expressions []expr.Any
if proto != firewall.ProtocolALL {
expressions = append(expressions, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: m.af.protoOffset,
Len: uint32(1),
})
protoData, err := m.af.protoNum(proto)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %v", err)
}
expressions = append(expressions, &expr.Cmp{
Register: 1,
Op: expr.CmpOpEq,
Data: []byte{protoData},
})
}
rawIP := ipToBytes(ip, m.af)
// check if rawIP contains zeroed IPv4 0.0.0.0 value
// in that case not add IP match expression into the rule definition
if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: m.af.srcAddrOffset,
Len: m.af.addrLen,
},
)
// add individual IP for match if no ipset defined
if ipset == nil {
expressions = append(expressions,
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rawIP,
},
)
} else {
expressions = append(expressions,
&expr.Lookup{
SourceRegister: 1,
SetName: ipset.Name,
SetID: ipset.ID,
},
)
}
}
expressions = append(expressions, applyPort(sPort, true)...)
expressions = append(expressions, applyPort(dPort, false)...)
mainExpressions := slices.Clone(expressions)
switch action {
case firewall.ActionAccept:
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictAccept})
case firewall.ActionDrop:
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
}
userData := []byte(ruleId)
chain := m.chainInputRules
rule := &nftables.Rule{
Table: m.workTable,
Chain: chain,
Exprs: mainExpressions,
UserData: userData,
}
// Insert DROP rules at the beginning, append ACCEPT rules at the end
var nftRule *nftables.Rule
if action == firewall.ActionDrop {
nftRule = m.rConn.InsertRule(rule)
} else {
nftRule = m.rConn.AddRule(rule)
}
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err)
}
ruleStruct := &Rule{
nftRule: nftRule,
// best effort mangle rule
mangleRule: m.createPreroutingRule(expressions, userData),
nftSet: ipset,
ruleID: ruleId,
ip: ip,
}
m.rules[ruleId] = ruleStruct
if ipset != nil {
m.ipsetStore.AddReferenceToIpset(ipset.Name)
}
return ruleStruct, nil
}
func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
if m.chainPrerouting == nil {
log.Warn("prerouting chain is not created")
return nil
}
preroutingExprs := slices.Clone(expressions)
// interface
preroutingExprs = append([]expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
}, preroutingExprs...)
// local destination and mark
preroutingExprs = append(preroutingExprs,
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
)
nfRule := m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainPrerouting,
Exprs: preroutingExprs,
UserData: userData,
})
if err := m.rConn.Flush(); err != nil {
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
return nil
}
return nfRule
}
func (m *AclManager) createDefaultChains() (err error) {
// chainNameInputRules
chain := m.createChain(chainNameInputRules)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
return fmt.Errorf(flushError, err)
}
m.chainInputRules = chain
// netbird-acl-input-filter
// type filter hook input priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
m.addDropExpressions(chain, expr.MetaKeyIIFNAME)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
return err
}
// netbird-acl-forward-filter
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
return fmt.Errorf(flushError, err)
}
if err := m.allowRedirectedTraffic(chainFwFilter); err != nil {
log.Errorf("failed to allow redirected traffic: %s", err)
}
return nil
}
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP.
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
// Chain is created by route manager
// TODO: move creation to a common place
m.chainPrerouting = &nftables.Chain{
Name: chainNameManglePrerouting,
Table: m.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
}
m.addFwmarkToForward(chainFwFilter)
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: chainFwFilter,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
})
}
func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.routingFwChainName,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: chainFwFilter,
Exprs: expressions,
})
}
func (m *AclManager) createChain(name string) *nftables.Chain {
chain := &nftables.Chain{
Name: name,
Table: m.workTable,
}
chain = m.rConn.AddChain(chain)
insertReturnTrafficRule(m.rConn, m.workTable, chain)
return chain
}
func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{
Name: name,
Table: m.workTable,
Hooknum: hookNum,
Priority: nftables.ChainPriorityFilter,
Type: nftables.ChainTypeFilter,
Policy: &polAccept,
}
return m.rConn.AddChain(chain)
}
func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{Kind: expr.VerdictDrop},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: chain,
Exprs: expressions,
})
return nil
}
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: to,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: expressions,
})
}
func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) {
ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName)
rawIP := ipToBytes(ip, m.af)
if err != nil {
if ipset, err = m.createSet(m.workTable, ipsetName); err != nil {
return nil, fmt.Errorf("get set name: %v", err)
}
m.ipsetStore.newIpset(ipset.Name)
}
if m.ipsetStore.IsIpInSet(ipset.Name, ip) {
return ipset, nil
}
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
return nil, fmt.Errorf("add set element for the first time: %v", err)
}
m.ipsetStore.AddIpToSet(ipset.Name, ip)
if err := m.sConn.Flush(); err != nil {
return nil, fmt.Errorf("flush add elements: %v", err)
}
return ipset, nil
}
// createSet in given table by name
func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Set, error) {
ipset := &nftables.Set{
Name: name,
Table: table,
Dynamic: true,
KeyType: m.af.setKeyType,
}
if err := m.rConn.AddSet(ipset, nil); err != nil {
return nil, fmt.Errorf("create set: %v", err)
}
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush created set: %v", err)
}
return ipset, nil
}
func (m *AclManager) flushWithBackoff() (err error) {
backoff := 4
backoffTime := 1000 * time.Millisecond
for i := 0; ; i++ {
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to flush nftables: %v", err)
if !strings.Contains(err.Error(), "busy") {
return
}
log.Error("failed to flush nftables, retrying...")
if i == backoff-1 {
return err
}
time.Sleep(backoffTime)
backoffTime *= 2
continue
}
break
}
return
}
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
if m.workTable == nil || chain == nil {
return nil
}
list, err := m.rConn.GetRules(m.workTable, chain)
if err != nil {
return err
}
for _, rule := range list {
if len(rule.UserData) == 0 {
continue
}
split := bytes.Split(rule.UserData, []byte(" "))
r, ok := m.rules[string(split[0])]
if ok {
if mangle {
*r.mangleRule = *rule
} else {
*r.nftRule = *rule
}
}
}
return nil
}
func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
rulesetID := ":" + string(proto) + ":"
if sPort != nil {
rulesetID += sPort.String()
}
rulesetID += ":"
if dPort != nil {
rulesetID += dPort.String()
}
rulesetID += ":"
rulesetID += strconv.Itoa(int(action))
if ipset == nil {
return "ip:" + ip.String() + rulesetID
}
return "set:" + ipset.Name + rulesetID
}
func ifname(n string) []byte {
b := make([]byte, 16)
copy(b, n+"\x00")
return b
}
// ipToBytes converts net.IP to the correct byte length for the address family.
func ipToBytes(ip net.IP, af addrFamily) []byte {
if af.addrLen == 4 {
return ip.To4()
}
return ip.To16()
}

View File

@@ -1,885 +0,0 @@
//go:build !android
package nftables
import (
"bytes"
"errors"
"fmt"
"slices"
"strings"
"time"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/client/net"
)
func (r *family) createContainers() error {
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingFw,
Table: r.workTable,
})
prio := *nftables.ChainPriorityNATSource - 1
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: &prio,
Type: nftables.ChainTypeNAT,
})
r.chains[chainNameRoutingRdr] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingRdr,
Table: r.workTable,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
})
r.chains[chainNameManglePostrouting] = r.conn.AddChain(&nftables.Chain{
Name: chainNameManglePostrouting,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
})
r.chains[chainNameManglePrerouting] = r.conn.AddChain(&nftables.Chain{
Name: chainNameManglePrerouting,
Table: r.workTable,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
})
r.chains[chainNameMangleForward] = r.conn.AddChain(&nftables.Chain{
Name: chainNameMangleForward,
Table: r.workTable,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
})
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
r.addPostroutingRules()
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("initialize tables: %v", err)
}
if err := r.addMSSClampingRules(); err != nil {
log.Errorf("failed to add MSS clamping rules: %s", err)
}
if err := r.acceptForwardRules(); err != nil {
log.Errorf("failed to add accept rules for the forward chain: %s", err)
}
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
if err := r.refreshRulesMap(); err != nil {
log.Errorf("failed to refresh rules: %s", err)
}
return nil
}
// setupDataPlaneMark configures the fwmark for the data plane
func (r *family) setupDataPlaneMark() error {
if r.chains[chainNameManglePrerouting] == nil || r.chains[chainNameManglePostrouting] == nil {
return errors.New("no mangle chains found")
}
ctNew := getCtNewExprs()
preExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
preExprs = append(preExprs, ctNew...)
preExprs = append(preExprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkIn),
},
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
SourceRegister: true,
},
)
preNftRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameManglePrerouting],
Exprs: preExprs,
}
r.conn.AddRule(preNftRule)
postExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
postExprs = append(postExprs, ctNew...)
postExprs = append(postExprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkOut),
},
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
SourceRegister: true,
},
)
postNftRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameManglePostrouting],
Exprs: postExprs,
}
r.conn.AddRule(postNftRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
}
return nil
}
func (r *family) acceptForwardRules() error {
var merr *multierror.Error
if err := r.acceptFilterTableRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) acceptFilterTableRules() error {
if r.filterTable == nil {
return nil
}
fw := "iptables"
defer func() {
log.Debugf("Used %s to add accept forward and input rules", fw)
}()
// Try iptables first and fallback to nftables if iptables is not available.
// Use the correct protocol (iptables vs ip6tables) for the address family.
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
if err != nil {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables"
return r.acceptFilterRulesNftables(r.filterTable)
}
if err := r.acceptFilterRulesIptables(ipt); err != nil {
log.Warnf("iptables failed (table may be incompatible), falling back to nftables: %v", err)
fw = "nftables"
return r.acceptFilterRulesNftables(r.filterTable)
}
return nil
}
func (r *family) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err))
} else {
log.Debugf("added iptables forward rule: %v", rule)
}
}
inputRule := r.getAcceptInputRule()
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err))
} else {
log.Debugf("added iptables input rule: %v", inputRule)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) getAcceptForwardRules() [][]string {
intf := r.wgIface.Name()
return [][]string{
{"-i", intf, "-j", "ACCEPT"},
{"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"},
}
}
func (r *family) getAcceptInputRule() []string {
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
}
// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables.
// This is used when iptables is not available.
func (r *family) acceptFilterRulesNftables(table *nftables.Table) error {
intf := ifname(r.wgIface.Name())
forwardChain := &nftables.Chain{
Name: chainNameForward,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
}
r.insertForwardAcceptRules(forwardChain, intf)
inputChain := &nftables.Chain{
Name: chainNameInput,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
}
r.insertInputAcceptRule(inputChain, intf)
return r.conn.Flush()
}
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
// It dynamically finds chains at call time to handle chains that may have been created after startup.
func (r *family) acceptExternalChainsRules() error {
chains := r.findExternalChains()
if len(chains) == 0 {
return nil
}
intf := ifname(r.wgIface.Name())
for _, chain := range chains {
r.applyExternalChainAccept(chain, intf)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush external chain rules: %w", err)
}
return nil
}
func (r *family) applyExternalChainAccept(chain *nftables.Chain, intf []byte) {
if chain.Hooknum == nil {
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
return
}
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
switch *chain.Hooknum {
case *nftables.ChainHookForward:
r.insertForwardAcceptRules(chain, intf)
case *nftables.ChainHookInput:
r.insertInputAcceptRule(chain, intf)
}
}
func (r *family) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
existing, err := r.existingNetbirdRulesInChain(chain)
if err != nil {
log.Warnf("skip forward accept rules in %s/%s: %v", chain.Table.Name, chain.Name, err)
return
}
r.insertForwardIifRule(chain, intf, existing)
r.insertForwardOifEstablishedRule(chain, intf, existing)
}
func (r *family) insertForwardIifRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
if existing[userDataAcceptForwardRuleIif] {
return
}
r.conn.InsertRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(userDataAcceptForwardRuleIif),
})
}
func (r *family) insertForwardOifEstablishedRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
if existing[userDataAcceptForwardRuleOif] {
return
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
}
r.conn.InsertRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: append(exprs, getEstablishedExprs(2)...),
UserData: []byte(userDataAcceptForwardRuleOif),
})
}
func (r *family) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
existing, err := r.existingNetbirdRulesInChain(chain)
if err != nil {
log.Warnf("skip input accept rule in %s/%s: %v", chain.Table.Name, chain.Name, err)
return
}
if existing[userDataAcceptInputRule] {
return
}
r.conn.InsertRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(userDataAcceptInputRule),
})
}
// existingNetbirdRulesInChain returns the set of netbird-owned UserData tags present in a chain; callers must bail on error since InsertRule is additive.
func (r *family) existingNetbirdRulesInChain(chain *nftables.Chain) (map[string]bool, error) {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return nil, fmt.Errorf("list rules: %w", err)
}
present := map[string]bool{}
for _, rule := range rules {
if !isNetbirdAcceptRuleTag(rule.UserData) {
continue
}
present[string(rule.UserData)] = true
}
return present, nil
}
func isNetbirdAcceptRuleTag(userData []byte) bool {
switch string(userData) {
case userDataAcceptForwardRuleIif,
userDataAcceptForwardRuleOif,
userDataAcceptInputRule:
return true
}
return false
}
func (r *family) removeAcceptFilterRules() error {
var merr *multierror.Error
if err := r.removeFilterTableRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.removeExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) removeFilterTableRules() error {
if r.filterTable == nil {
return nil
}
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
if err != nil {
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
return r.removeAcceptRulesFromTable(r.filterTable)
}
if err := r.removeAcceptFilterRulesIptables(ipt); err != nil {
log.Debugf("iptables removal failed (table may be incompatible), falling back to nftables: %v", err)
return r.removeAcceptRulesFromTable(r.filterTable)
}
return nil
}
func (r *family) removeAcceptRulesFromTable(table *nftables.Table) error {
chains, err := r.conn.ListChainsOfTableFamily(table.Family)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != table.Name {
continue
}
if chain.Name != chainNameForward && chain.Name != chainNameInput {
continue
}
if err := r.removeAcceptRulesFromChain(table, chain); err != nil {
return err
}
}
return r.conn.Flush()
}
func (r *family) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
rules, err := r.conn.GetRules(table, chain)
if err != nil {
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
}
}
}
return nil
}
// removeExternalChainsRules removes our accept rules from all external chains.
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
// ensuring cleanup works even after a crash or if chains changed.
func (r *family) removeExternalChainsRules() error {
chains := r.findExternalChains()
if len(chains) == 0 {
return nil
}
for _, chain := range chains {
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err)
}
}
return r.conn.Flush()
}
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
func (r *family) findExternalChains() []*nftables.Chain {
var chains []*nftables.Chain
families := []nftables.TableFamily{r.af.tableFamily, nftables.TableFamilyINet}
for _, family := range families {
allChains, err := r.conn.ListChainsOfTableFamily(family)
if err != nil {
log.Debugf("list chains for family %d: %v", family, err)
continue
}
for _, chain := range allChains {
if r.isExternalChain(chain) {
chains = append(chains, chain)
}
}
}
return chains
}
func (r *family) isExternalChain(chain *nftables.Chain) bool {
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
return false
}
// Skip firewalld-owned chains. Firewalld creates its chains with the
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
// We delegate acceptance to firewalld by trusting the interface instead.
if chain.Table.Name == firewalldTableName {
return false
}
// Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat)
if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) {
return false
}
if chain.Type != nftables.ChainTypeFilter {
return false
}
if chain.Hooknum == nil {
return false
}
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
}
func isIptablesTable(name string) bool {
switch name {
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
return true
}
return false
}
func (r *family) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err))
}
}
inputRule := r.getAcceptInputRule()
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) createDefaultAllowRules() error {
expIn := []expr.Any{
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = r.conn.InsertRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chainInputRules,
Position: 0,
Exprs: expIn,
})
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
// Flush rule/chain/set operations from the buffer
//
// Method also get all rules after flush and refreshes handle values in the rulesets
func (r *family) Flush() error {
if err := r.flushWithBackoff(); err != nil {
return err
}
if err := r.refreshRuleHandles(r.chainInputRules, false); err != nil {
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
}
if err := r.refreshRuleHandles(r.chainPrerouting, true); err != nil {
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
}
return nil
}
func (r *family) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
if r.chainPrerouting == nil {
log.Warn("prerouting chain is not created")
return nil
}
preroutingExprs := slices.Clone(expressions)
// interface
preroutingExprs = append([]expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}, preroutingExprs...)
// local destination and mark
preroutingExprs = append(preroutingExprs,
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
)
nfRule := r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chainPrerouting,
Exprs: preroutingExprs,
UserData: userData,
})
if err := r.conn.Flush(); err != nil {
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
return nil
}
return nfRule
}
func (r *family) createDefaultChains() (err error) {
// chainNameInputRules
chain := r.createChain(chainNameInputRules)
err = r.conn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
return fmt.Errorf(flushError, err)
}
r.chainInputRules = chain
// netbird-acl-input-filter
// type filter hook input priority filter; policy accept;
chain = r.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
r.addJumpRule(chain, r.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
r.addDropExpressions(chain, expr.MetaKeyIIFNAME)
err = r.conn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
return err
}
// netbird-acl-forward-filter
chainFwFilter := r.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
r.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
r.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
err = r.conn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
return fmt.Errorf(flushError, err)
}
if err := r.allowRedirectedTraffic(chainFwFilter); err != nil {
log.Errorf("failed to allow redirected traffic: %s", err)
}
return nil
}
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP.
func (r *family) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
r.chainPrerouting = r.chains[chainNameManglePrerouting]
r.addFwmarkToForward(chainFwFilter)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (r *family) addFwmarkToForward(chainFwFilter *nftables.Chain) {
r.conn.InsertRule(&nftables.Rule{
Table: r.workTable,
Chain: chainFwFilter,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
})
}
func (r *family) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: r.routingFwChainName,
},
}
_ = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: chainFwFilter,
Exprs: expressions,
})
}
func (r *family) createChain(name string) *nftables.Chain {
chain := &nftables.Chain{
Name: name,
Table: r.workTable,
}
chain = r.conn.AddChain(chain)
insertReturnTrafficRule(r.conn, r.workTable, chain)
return chain
}
func (r *family) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{
Name: name,
Table: r.workTable,
Hooknum: hookNum,
Priority: nftables.ChainPriorityFilter,
Type: nftables.ChainTypeFilter,
Policy: &polAccept,
}
return r.conn.AddChain(chain)
}
func (r *family) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Verdict{Kind: expr.VerdictDrop},
}
_ = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: chain,
Exprs: expressions,
})
return nil
}
func (r *family) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: to,
},
}
_ = r.conn.AddRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: expressions,
})
}
func (r *family) flushWithBackoff() (err error) {
backoff := 4
backoffTime := 1000 * time.Millisecond
for i := 0; ; i++ {
err = r.conn.Flush()
if err != nil {
log.Debugf("failed to flush nftables: %v", err)
if !strings.Contains(err.Error(), "busy") {
return
}
log.Error("failed to flush nftables, retrying...")
if i == backoff-1 {
return err
}
time.Sleep(backoffTime)
backoffTime *= 2
continue
}
break
}
return
}
func (r *family) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
if r.workTable == nil || chain == nil {
return nil
}
list, err := r.conn.GetRules(r.workTable, chain)
if err != nil {
return err
}
for _, rule := range list {
if len(rule.UserData) == 0 {
continue
}
pr, ok := r.filters[firewall.RuleID(rule.UserData)]
if !ok {
continue
}
if mangle {
if pr.mangleRule != nil {
*pr.mangleRule = *rule
}
} else {
*pr.nftRule = *rule
}
}
return nil
}

View File

@@ -1,533 +0,0 @@
//go:build !android
package nftables
import (
"fmt"
"net/netip"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/google/nftables/xt"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
func (r *family) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
ruleID := rule.ID()
if _, exists := r.rules[ruleID+dnatSuffix]; exists {
return rule, nil
}
protoNum, err := r.af.protoNum(rule.Protocol)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
if err := r.addDnatRedirect(rule, protoNum, ruleID); err != nil {
return nil, err
}
r.addDnatMasq(rule, protoNum, ruleID)
// Unlike iptables, there's no point in adding "out" rules in the forward chain here as our policy is ACCEPT.
// To overcome DROP policies in other chains, we'd have to add rules to the chains there.
// We also cannot just add "oif <iface> accept" there and filter in our own table as we don't know what is supposed to be allowed.
// TODO: find chains with drop policies and add rules there
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush rules: %w", err)
}
return &rule, nil
}
func (r *family) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, ruleID firewall.RuleID) error {
dnatExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
}
dnatExprs = append(dnatExprs, applyPort(&rule.DestinationPort, false)...)
// shifted translated port is not supported in nftables, so we hand this over to xtables
if rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2 {
if rule.TranslatedPort.Values[0] != rule.DestinationPort.Values[0] ||
rule.TranslatedPort.Values[1] != rule.DestinationPort.Values[1] {
return r.addXTablesRedirect(dnatExprs, ruleID, rule)
}
}
additionalExprs, regProtoMin, regProtoMax, err := r.handleTranslatedPort(rule)
if err != nil {
return err
}
dnatExprs = append(dnatExprs, additionalExprs...)
dnatExprs = append(dnatExprs,
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(r.af.tableFamily),
RegAddrMin: 1,
RegProtoMin: regProtoMin,
RegProtoMax: regProtoMax,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingRdr],
Exprs: dnatExprs,
UserData: []byte(ruleID + dnatSuffix),
}
r.conn.AddRule(dnatRule)
r.rules[ruleID+dnatSuffix] = dnatRule
return nil
}
func (r *family) handleTranslatedPort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
switch {
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
return r.handlePortRange(rule)
case len(rule.TranslatedPort.Values) == 0:
return r.handleAddressOnly(rule)
case len(rule.TranslatedPort.Values) == 1:
return r.handleSinglePort(rule)
default:
return nil, 0, 0, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
}
}
func (r *family) handlePortRange(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
},
&expr.Immediate{
Register: 3,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[1]),
},
}
return exprs, 2, 3, nil
}
func (r *family) handleAddressOnly(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
}
return exprs, 0, 0, nil
}
func (r *family) handleSinglePort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
},
}
return exprs, 2, 0, nil
}
func (r *family) addXTablesRedirect(dnatExprs []expr.Any, ruleID firewall.RuleID, rule firewall.ForwardRule) error {
dnatExprs = append(dnatExprs,
&expr.Counter{},
&expr.Target{
Name: "DNAT",
Rev: 2,
Info: &xt.NatRange2{
NatRange: xt.NatRange{
Flags: uint(xt.NatRangeMapIPs | xt.NatRangeProtoSpecified | xt.NatRangeProtoOffset),
MinIP: rule.TranslatedAddress.AsSlice(),
MaxIP: rule.TranslatedAddress.AsSlice(),
MinPort: rule.TranslatedPort.Values[0],
MaxPort: rule.TranslatedPort.Values[1],
},
BasePort: rule.DestinationPort.Values[0],
},
},
)
natTable := &nftables.Table{
Name: tableNat,
Family: r.af.tableFamily,
}
dnatRule := &nftables.Rule{
Table: natTable,
Chain: &nftables.Chain{
Name: chainNameNatPrerouting,
Table: natTable,
Type: nftables.ChainTypeNAT,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
},
Exprs: dnatExprs,
UserData: []byte(ruleID + dnatSuffix),
}
r.conn.AddRule(dnatRule)
r.rules[ruleID+dnatSuffix] = dnatRule
return nil
}
func (r *family) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleID firewall.RuleID) {
masqExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: r.af.dstAddrOffset,
Len: r.af.addrLen,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
}
masqExprs = append(masqExprs, applyPort(&rule.TranslatedPort, false)...)
masqExprs = append(masqExprs, &expr.Masq{})
masqRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: masqExprs,
UserData: []byte(ruleID + snatSuffix),
}
r.conn.AddRule(masqRule)
r.rules[ruleID+snatSuffix] = masqRule
}
func (r *family) DeleteDNATRule(rule firewall.Rule) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
ruleID := rule.ID()
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
var needsFlush bool
if dnatRule, exists := r.rules[ruleID+dnatSuffix]; exists {
if dnatRule.Handle == 0 {
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleID+dnatSuffix)
delete(r.rules, ruleID+dnatSuffix)
} else if err := r.conn.DelRule(dnatRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
} else {
needsFlush = true
}
}
if masqRule, exists := r.rules[ruleID+snatSuffix]; exists {
if masqRule.Handle == 0 {
log.Warnf("snat rule %s has no handle, removing stale entry", ruleID+snatSuffix)
delete(r.rules, ruleID+snatSuffix)
} else if err := r.conn.DelRule(masqRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
} else {
needsFlush = true
}
}
if needsFlush {
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
}
if merr == nil {
delete(r.rules, ruleID+dnatSuffix)
delete(r.rules, ruleID+snatSuffix)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if _, exists := r.rules[ruleID]; exists {
return nil
}
protoNum, err := r.af.protoNum(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 3,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 3,
Data: binaryutil.BigEndian.PutUint16(originalPort),
},
}
bits := 32
if localAddr.Is6() {
bits = 128
}
exprs = append(exprs, prefixMatchExprs(r.af, netip.PrefixFrom(localAddr, bits), false)...)
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: localAddr.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(translatedPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(r.af.tableFamily),
RegAddrMin: 1,
RegProtoMin: 2,
RegProtoMax: 0,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingRdr],
Exprs: exprs,
UserData: []byte(ruleID),
}
r.conn.AddRule(dnatRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *family) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}
// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use.
func (r *family) ensureNATOutputChain() error {
if _, exists := r.chains[chainNameNATOutput]; exists {
return nil
}
r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{
Name: chainNameNATOutput,
Table: r.workTable,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
})
if err := r.conn.Flush(); err != nil {
delete(r.chains, chainNameNATOutput)
return fmt.Errorf("create NAT output chain: %w", err)
}
return nil
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *family) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if _, exists := r.rules[ruleID]; exists {
return nil
}
if err := r.ensureNATOutputChain(); err != nil {
return err
}
protoNum, err := r.af.protoNum(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: binaryutil.BigEndian.PutUint16(originalPort),
},
}
bits := 32
if localAddr.Is6() {
bits = 128
}
exprs = append(exprs, prefixMatchExprs(r.af, netip.PrefixFrom(localAddr, bits), false)...)
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: localAddr.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(translatedPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(r.af.tableFamily),
RegAddrMin: 1,
RegProtoMin: 2,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameNATOutput],
Exprs: exprs,
UserData: []byte(ruleID),
}
r.conn.AddRule(dnatRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("add output DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
return nil
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *family) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID)
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete output DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}

View File

@@ -52,10 +52,9 @@ func (m *externalChainMonitor) start() {
ctx, cancel := context.WithCancel(context.Background())
m.cancel = cancel
done := make(chan struct{})
m.done = done
m.done = make(chan struct{})
go m.run(ctx, done)
go m.run(ctx)
}
func (m *externalChainMonitor) stop() {
@@ -73,8 +72,8 @@ func (m *externalChainMonitor) stop() {
<-done
}
func (m *externalChainMonitor) run(ctx context.Context, done chan struct{}) {
defer close(done)
func (m *externalChainMonitor) run(ctx context.Context) {
defer close(m.done)
bo := &backoff.ExponentialBackOff{
InitialInterval: externalMonitorInitInterval,

View File

@@ -1,249 +0,0 @@
//go:build !android
package nftables
import (
"fmt"
"net/netip"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
)
const (
tableNat = "nat"
tableMangle = "mangle"
tableRaw = "raw"
tableSecurity = "security"
chainNameNatPrerouting = "PREROUTING"
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameRoutingRdr = "netbird-rt-redirect"
chainNameNATOutput = "netbird-nat-output"
chainNameForward = "FORWARD"
chainNameMangleForward = "netbird-mangle-forward"
// Peer ACL chain names.
chainNameInputRules = "netbird-acl-input-rules"
chainNameInputFilter = "netbird-acl-input-filter"
chainNameForwardFilter = "netbird-acl-forward-filter"
chainNameManglePrerouting = "netbird-mangle-prerouting"
chainNameManglePostrouting = "netbird-mangle-postrouting"
flushError = "flush: %w"
firewalldTableName = "firewalld"
userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif"
userDataAcceptInputRule = "inputaccept"
dnatSuffix firewall.RuleID = "_dnat"
snatSuffix firewall.RuleID = "_snat"
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
ipv4TCPHeaderSize = 40
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
ipv6TCPHeaderSize = 60
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
maxPrefixesSet = 1500
refreshRulesMapError = "refresh rules map: %w"
)
var (
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
)
type setInput struct {
set firewall.Set
prefixes []netip.Prefix
}
// family holds the per-address-family nftables state. One instance
// handles route ACLs, peer ACLs, NAT, DNAT, and MSS clamping for a
// single family; the top-level Manager owns one for v4 and another
// for v6. The name predates the peer-ACL absorption; it's effectively
// the per-family backend now.
type family struct {
conn *nftables.Conn
workTable *nftables.Table
filterTable *nftables.Table
chains map[string]*nftables.Chain
// filters holds peer + route filter rules keyed by content hash.
// AddFilterRule writes here; DeleteFilterRule looks up by id.
filters map[firewall.RuleID]*Rule
// rules holds NAT, DNAT, and external accept rules (auxiliary
// plumbing that isn't a filter rule).
rules map[firewall.RuleID]*nftables.Rule
// Peer ACL chain handles.
chainInputRules *nftables.Chain
chainPrerouting *nftables.Chain
routingFwChainName string
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
af addrFamily
wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool
mtu uint16
}
func newFamily(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*family, error) {
r := &family{
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
filters: make(map[firewall.RuleID]*Rule),
rules: make(map[firewall.RuleID]*nftables.Rule),
routingFwChainName: chainNameRoutingFw,
af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4),
wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(),
mtu: mtu,
}
r.ipsetCounter = refcounter.New(
r.createIpSet,
r.deleteIpSet,
)
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
log.Debugf("ip filter table not found: %v", err)
}
return r, nil
}
func (r *family) init(workTable *nftables.Table) error {
r.workTable = workTable
if err := r.removeAcceptFilterRules(); err != nil {
log.Errorf("failed to clean up rules from filter table: %s", err)
}
if err := r.createContainers(); err != nil {
return fmt.Errorf("create containers: %w", err)
}
if err := r.setupDataPlaneMark(); err != nil {
log.Errorf("failed to set up data plane mark: %v", err)
}
if err := r.createDefaultChains(); err != nil {
return fmt.Errorf("create default acl chains: %w", err)
}
return nil
}
// Reset cleans existing nftables filter table rules from the system
func (r *family) Reset() error {
// clear without deleting the ipsets, the nf table will be deleted by the caller
r.ipsetCounter.Clear()
var merr *multierror.Error
if err := r.removeAcceptFilterRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
}
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.removeNatPreroutingRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(r.af.tableFamily)
if err != nil {
return nil, fmt.Errorf("list tables: %w", err)
}
for _, table := range tables {
if table.Name == "filter" {
return table, nil
}
}
return nil, errFilterTableNotFound
}
func hookName(hook *nftables.ChainHook) string {
if hook == nil {
return "unknown"
}
switch *hook {
case *nftables.ChainHookForward:
return chainNameForward
case *nftables.ChainHookInput:
return chainNameInput
default:
return fmt.Sprintf("hook(%d)", *hook)
}
}
func familyName(family nftables.TableFamily) string {
switch family {
case nftables.TableFamilyIPv4:
return "ip"
case nftables.TableFamilyIPv6:
return "ip6"
case nftables.TableFamilyINet:
return "inet"
default:
return fmt.Sprintf("family(%d)", family)
}
}
func (r *family) iptablesProto() iptables.Protocol {
if r.af.tableFamily == nftables.TableFamilyIPv6 {
return iptables.ProtocolIPv6
}
return iptables.ProtocolIPv4
}
func (r *family) refreshRulesMap() error {
var merr *multierror.Error
newRules := make(map[firewall.RuleID]*nftables.Rule)
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
// preserve existing entries for this chain since we can't verify their state
for k, v := range r.rules {
if v.Chain != nil && v.Chain.Name == chain.Name {
newRules[k] = v
}
}
continue
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
newRules[firewall.RuleID(rule.UserData)] = rule
}
}
}
r.rules = newRules
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -1,441 +0,0 @@
//go:build !android
package nftables
import (
"fmt"
"net"
"net/netip"
"slices"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
)
// AddFilterRule installs one nftables packet-filter rule. With
// destination empty the rule goes to the peer ACL input chain plus a
// paired prerouting mangle rule for the redirect mark. With
// destination set (prefix or named set) it goes to the route ACL
// forward chain. Multi-source rules collapse to one nftables rule
// backed by the shared refcounted hash:net set.
func (r *family) AddFilterRule(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
isRoute := !destination.IsZero()
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
if existing, ok := r.filters[ruleID]; ok {
return existing, nil
}
srcExprs, err := r.applyNetwork(sourceNetwork(sources), sources, true)
if err != nil {
return nil, fmt.Errorf("apply source: %w", err)
}
var exprs []expr.Any
if isRoute {
exprs, err = r.buildRouteFilterExprs(srcExprs, destination, proto, sPort, dPort)
} else {
exprs, err = r.buildPeerFilterExprs(srcExprs, proto, sPort, dPort)
}
if err != nil {
r.dropNetworkMatch(srcExprs)
return nil, err
}
mainExprs := slices.Clone(exprs)
verdict := expr.VerdictAccept
if action == firewall.ActionDrop {
verdict = expr.VerdictDrop
}
mainExprs = append(mainExprs, &expr.Verdict{Kind: verdict})
chain := r.chainInputRules
if isRoute {
chain = r.chains[chainNameRoutingFw]
}
userData := []byte(ruleID)
nftRule := &nftables.Rule{
Table: r.workTable,
Chain: chain,
Exprs: mainExprs,
UserData: userData,
}
if action == firewall.ActionDrop {
nftRule = r.conn.InsertRule(nftRule)
} else {
nftRule = r.conn.AddRule(nftRule)
}
if err := r.conn.Flush(); err != nil {
r.dropNetworkMatch(exprs)
return nil, fmt.Errorf(flushError, err)
}
rule := &Rule{
nftRule: nftRule,
sources: sources,
id: ruleID,
}
if !isRoute {
rule.mangleRule = r.createPreroutingRule(exprs, userData)
}
r.filters[ruleID] = rule
log.Debugf("added filter rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v",
sources, destination, proto, sPort, dPort, action)
return rule, nil
}
// buildPeerFilterExprs assembles the input-chain (peer ACL) match: the
// IP-header protocol byte read via Payload, then source, then ports
// (no counter), matching the historical peer shape so per-rule kernel
// state is identical to pre-unification.
func (r *family) buildPeerFilterExprs(
srcExprs []expr.Any,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
) ([]expr.Any, error) {
var exprs []expr.Any
if proto != firewall.ProtocolALL {
protoNum, err := r.af.protoNum(proto)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: r.af.protoOffset,
Len: 1,
},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protoNum}},
)
}
exprs = append(exprs, srcExprs...)
exprs = append(exprs, applyPort(sPort, true)...)
exprs = append(exprs, applyPort(dPort, false)...)
return exprs, nil
}
// buildRouteFilterExprs assembles the forward-chain (route ACL) match:
// source, then destination, then optional proto/ports, then a counter.
func (r *family) buildRouteFilterExprs(
srcExprs []expr.Any,
destination firewall.Network,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
) ([]expr.Any, error) {
exprs := append([]expr.Any{}, srcExprs...)
destExprs, err := r.applyNetwork(destination, nil, false)
if err != nil {
return nil, fmt.Errorf("apply destination: %w", err)
}
exprs = append(exprs, destExprs...)
if proto != firewall.ProtocolALL {
protoNum, err := r.af.protoNum(proto)
if err != nil {
r.dropNetworkMatch(destExprs)
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protoNum}},
)
exprs = append(exprs, applyPort(sPort, true)...)
exprs = append(exprs, applyPort(dPort, false)...)
}
exprs = append(exprs, &expr.Counter{})
return exprs, nil
}
func (r *family) hasRule(id firewall.RuleID) bool {
_, ok := r.filters[id]
return ok
}
func (r *family) hasDNATRule(id firewall.RuleID) bool {
_, ok := r.rules[id+dnatSuffix]
return ok
}
// DeleteFilterRule removes a previously installed filter rule. Source
// set references are recovered from the stored rule's expressions via
// findSets and dropped from the shared refcounter.
func (r *family) DeleteFilterRule(rule firewall.Rule) error {
ruleID := rule.ID()
pr, ok := r.filters[ruleID]
if !ok {
log.Debugf("filter rule %s not found", ruleID)
return nil
}
// A freshly added rule carries no handle until it is read back from
// the kernel, and Flush only refreshes the peer chains. Pull live
// handles for this rule's chain before deciding it is stale so route
// rules (which Flush never refreshes) can actually be deleted.
if pr.nftRule.Handle == 0 {
if err := r.refreshRuleHandles(pr.nftRule.Chain, false); err != nil {
log.Warnf("refresh handles for chain %s: %v", pr.nftRule.Chain.Name, err)
}
if pr.mangleRule != nil {
if err := r.refreshRuleHandles(r.chainPrerouting, true); err != nil {
log.Warnf("refresh mangle handles: %v", err)
}
}
}
if pr.nftRule.Handle == 0 {
log.Warnf("filter rule %s has no handle, removing stale entry", ruleID)
r.dropNetworkMatch(pr.nftRule.Exprs)
delete(r.filters, ruleID)
return nil
}
if err := r.conn.DelRule(pr.nftRule); err != nil {
log.Errorf("queue rule delete: %v", err)
}
if pr.mangleRule != nil {
if err := r.conn.DelRule(pr.mangleRule); err != nil {
log.Errorf("queue mangle rule delete: %v", err)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete %s: %w", ruleID, err)
}
r.dropNetworkMatch(pr.nftRule.Exprs)
delete(r.filters, ruleID)
return nil
}
func (r *family) decrementSetCounter(rule *nftables.Rule) error {
if r.ipsetCounter == nil {
return nil
}
sets := findSets(rule)
var merr *multierror.Error
for _, setName := range sets {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// dropNetworkMatch undoes whatever the source/destination match
// reserved. Safe to call when the spec is empty or holds only inline
// matchers.
func (r *family) dropNetworkMatch(exprs []expr.Any) {
if r.ipsetCounter == nil {
return
}
for _, e := range exprs {
lookup, ok := e.(*expr.Lookup)
if !ok {
continue
}
if _, err := r.ipsetCounter.Decrement(lookup.SetName); err != nil {
log.Errorf("rollback ipset decrement %s: %v", lookup.SetName, err)
}
}
}
func (r *family) applyNetwork(
network firewall.Network,
setPrefixes []netip.Prefix,
isSource bool,
) ([]expr.Any, error) {
if network.IsSet() {
exprs, err := r.getIpSet(network.Set, setPrefixes, isSource)
if err != nil {
side := "destination"
if isSource {
side = "source"
}
return nil, fmt.Errorf("%s set: %w", side, err)
}
return exprs, nil
}
if network.IsPrefix() {
return prefixMatchExprs(r.af, network.Prefix, isSource), nil
}
return nil, nil
}
// prefixMatchExprs is the family-aware match sequence for a CIDR
// prefix. /0 returns nil; a host prefix (full bit length for the
// family) skips the bitwise step since the mask is all-ones. Shared
// between family and aclManager so both treat single prefixes
// identically.
func prefixMatchExprs(af addrFamily, prefix netip.Prefix, isSource bool) []expr.Any {
offset := af.dstAddrOffset
if isSource {
offset = af.srcAddrOffset
}
ones := prefix.Bits()
if ones == 0 {
return nil
}
payload := &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: af.addrLen,
}
cmp := &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: prefix.Masked().Addr().AsSlice(),
}
if ones == af.totalBits {
return []expr.Any{payload, cmp}
}
mask := net.CIDRMask(ones, af.totalBits)
xor := make([]byte, af.addrLen)
return []expr.Any{
payload,
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: af.addrLen,
Mask: mask,
Xor: xor,
},
cmp,
}
}
func applyPort(port *firewall.Port, isSource bool) []expr.Any {
if port == nil {
return nil
}
var exprs []expr.Any
// src
offset := uint32(2)
if isSource {
// dst
offset = 0
}
exprs = append(exprs, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: offset,
Len: 2,
})
if port.IsRange && len(port.Values) == 2 {
exprs = append(exprs,
&expr.Range{
Op: expr.CmpOpEq,
Register: 1,
FromData: binaryutil.BigEndian.PutUint16(port.Values[0]),
ToData: binaryutil.BigEndian.PutUint16(port.Values[1]),
},
)
} else {
for i, p := range port.Values {
if i > 0 {
exprs = append(exprs, &expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: []byte{0x00, 0x00, 0xff, 0xff},
Xor: []byte{0x00, 0x00, 0x00, 0x00},
})
}
exprs = append(exprs, &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(p),
})
}
}
return exprs
}
func getCtNewExprs() []expr.Any {
return []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
}
}
// sourceNetwork classifies a source-prefix list into the firewall.Network
// shape the rest of the spec-builder consumes: empty for match-any, a
// single prefix inline, or an ipset for multiple sources.
func sourceNetwork(sources []netip.Prefix) firewall.Network {
switch {
case len(sources) == 0:
return firewall.Network{}
case len(sources) == 1 && sources[0].Bits() == 0:
return firewall.Network{}
case len(sources) == 1:
return firewall.Network{Prefix: sources[0]}
default:
return firewall.Network{Set: firewall.NewPrefixSet(sources)}
}
}
func ifname(n string) []byte {
b := make([]byte, 16)
copy(b, n+"\x00")
return b
}
// findSets scans an nftables rule's expressions for expr.Lookup and
// returns the named sets in occurrence order. Used at delete time to
// drop ipsetCounter references; peer and route ACLs go through it.
func findSets(rule *nftables.Rule) []string {
var sets []string
for _, e := range rule.Exprs {
if lookup, ok := e.(*expr.Lookup); ok {
sets = append(sets, lookup.SetName)
}
}
return sets
}

View File

@@ -1,196 +0,0 @@
//go:build !android
package nftables
import (
"encoding/binary"
"fmt"
"net/netip"
"github.com/google/nftables"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
)
func (r *family) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) {
ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{
set: set,
prefixes: prefixes,
})
if err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
return r.getIpSetExprs(ref, isSource)
}
func (r *family) createIpSet(setName string, input setInput) (*nftables.Set, error) {
// overlapping prefixes will result in an error, so we need to merge them
prefixes := firewall.MergeIPRanges(input.prefixes)
nfset := &nftables.Set{
Name: setName,
Comment: input.set.Comment(),
Table: r.workTable,
// required for prefixes
Interval: true,
KeyType: r.af.setKeyType,
}
elements := r.convertPrefixesToSet(prefixes)
nElements := len(elements)
maxElements := maxPrefixesSet * 2
initialElements := elements[:min(maxElements, nElements)]
if err := r.conn.AddSet(nfset, initialElements); err != nil {
return nil, fmt.Errorf("error adding set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Debugf("Created new ipset: %s with %d initial prefixes (total prefixes %d)", setName, len(initialElements)/2, len(prefixes))
// The set is committed now. If a later batch fails, destroy it: the
// refcounter records nothing on a create-callback error, so it would
// otherwise leak, and a partial source set fails-open for deny rules.
if err := r.addRemainingElements(nfset, elements, maxElements); err != nil {
if derr := r.deleteIpSet(setName, nfset); derr != nil {
log.Warnf("rollback ipset %s after add failure: %v", setName, derr)
}
return nil, err
}
log.Infof("Created new ipset: %s with %d prefixes", setName, len(prefixes))
return nfset, nil
}
// addRemainingElements adds element batches beyond the initial one in
// maxElements-sized chunks, flushing each. Called after the set has been
// created with its first batch.
func (r *family) addRemainingElements(nfset *nftables.Set, elements []nftables.SetElement, maxElements int) error {
nElements := len(elements)
for subStart := maxElements; subStart < nElements; subStart += maxElements {
subEnd := min(subStart+maxElements, nElements)
subElement := elements[subStart:subEnd]
nSubPrefixes := len(subElement) / 2
log.Tracef("Adding new prefixes (%d) in ipset: %s", nSubPrefixes, nfset.Name)
if err := r.conn.SetAddElements(nfset, subElement); err != nil {
return fmt.Errorf("error adding prefixes (%d) to set %s: %w", nSubPrefixes, nfset.Name, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush error: %w", err)
}
log.Debugf("Added new prefixes (%d) in ipset: %s", nSubPrefixes, nfset.Name)
}
return nil
}
func (r *family) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
var elements []nftables.SetElement
for _, prefix := range prefixes {
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
firstIP := prefix.Addr()
lastIP := calculateLastIP(prefix).Next()
elements = append(elements,
// the nft tool also adds a zero-address IntervalEnd element, see https://github.com/google/nftables/issues/247
// nftables.SetElement{Key: make([]byte, r.af.addrLen), IntervalEnd: true},
nftables.SetElement{Key: firstIP.AsSlice()},
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
)
}
return elements
}
// calculateLastIP determines the last IP in a given prefix.
func calculateLastIP(prefix netip.Prefix) netip.Addr {
masked := prefix.Masked()
if masked.Addr().Is4() {
hostMask := ^uint32(0) >> masked.Bits()
lastIP := uint32FromNetipAddr(masked.Addr()) | hostMask
return netip.AddrFrom4(uint32ToBytes(lastIP))
}
// IPv6: set host bits to all 1s
b := masked.Addr().As16()
bits := masked.Bits()
for i := bits; i < 128; i++ {
b[i/8] |= 1 << (7 - i%8)
}
return netip.AddrFrom16(b)
}
// Utility function to convert netip.Addr to uint32.
func uint32FromNetipAddr(addr netip.Addr) uint32 {
b := addr.As4()
return binary.BigEndian.Uint32(b[:])
}
// Utility function to convert uint32 to a netip-compatible byte slice.
func uint32ToBytes(ip uint32) [4]byte {
var b [4]byte
binary.BigEndian.PutUint32(b[:], ip)
return b
}
func (r *family) deleteIpSet(setName string, nfset *nftables.Set) error {
r.conn.DelSet(nfset)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
log.Debugf("Deleted unused ipset %s", setName)
return nil
}
func (r *family) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName())
if err != nil {
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
}
// Overlapping prefixes (e.g. duplicate resolved addresses) make the
// interval set reject the batch, so merge them as createIpSet does.
prefixes = firewall.MergeIPRanges(prefixes)
elements := r.convertPrefixesToSet(prefixes)
if err := r.conn.SetAddElements(nfset, elements); err != nil {
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
return nil
}
func (r *family) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
// dst offset by default
offset := r.af.dstAddrOffset
if isSource {
// src offset
offset = r.af.srcAddrOffset
}
return []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: r.af.addrLen,
},
&expr.Lookup{
SourceRegister: 1,
SetName: ref.Out.Name,
SetID: ref.Out.ID,
},
}, nil
}

View File

@@ -0,0 +1,85 @@
package nftables
import (
"net"
)
type ipsetStore struct {
ipsetReference map[string]int
ipsets map[string]map[string]struct{} // ipsetName -> list of ips
}
func newIpsetStore() *ipsetStore {
return &ipsetStore{
ipsetReference: make(map[string]int),
ipsets: make(map[string]map[string]struct{}),
}
}
func (s *ipsetStore) ips(ipsetName string) (map[string]struct{}, bool) {
r, ok := s.ipsets[ipsetName]
return r, ok
}
func (s *ipsetStore) newIpset(ipsetName string) map[string]struct{} {
s.ipsetReference[ipsetName] = 0
ipList := make(map[string]struct{})
s.ipsets[ipsetName] = ipList
return ipList
}
func (s *ipsetStore) deleteIpset(ipsetName string) {
delete(s.ipsetReference, ipsetName)
delete(s.ipsets, ipsetName)
}
func (s *ipsetStore) DeleteIpFromSet(ipsetName string, ip net.IP) {
ipList, ok := s.ipsets[ipsetName]
if !ok {
return
}
delete(ipList, ip.String())
}
func (s *ipsetStore) AddIpToSet(ipsetName string, ip net.IP) {
ipList, ok := s.ipsets[ipsetName]
if !ok {
return
}
ipList[ip.String()] = struct{}{}
}
func (s *ipsetStore) IsIpInSet(ipsetName string, ip net.IP) bool {
ipList, ok := s.ipsets[ipsetName]
if !ok {
return false
}
_, ok = ipList[ip.String()]
return ok
}
func (s *ipsetStore) AddReferenceToIpset(ipsetName string) {
s.ipsetReference[ipsetName]++
}
func (s *ipsetStore) DeleteReferenceFromIpSet(ipsetName string) {
r, ok := s.ipsetReference[ipsetName]
if !ok {
return
}
if r == 0 {
return
}
s.ipsetReference[ipsetName]--
}
func (s *ipsetStore) HasReferenceToSet(ipsetName string) bool {
if _, ok := s.ipsetReference[ipsetName]; !ok {
return false
}
if s.ipsetReference[ipsetName] == 0 {
return false
}
return true
}

View File

@@ -3,6 +3,7 @@ package nftables
import (
"context"
"fmt"
"net"
"net/netip"
"os"
"sync"
@@ -44,17 +45,18 @@ type iFaceMapper interface {
Address() wgaddr.Address
}
// Manager of nftables firewall. Per-family state (peer ACLs, route
// ACLs, NAT, DNAT, MSS clamping) lives on family; Manager dispatches
// by family and provides the public firewall.Manager surface.
// Manager of iptables firewall
type Manager struct {
mutex sync.Mutex
rConn *nftables.Conn
wgIface iFaceMapper
family4 *family
// IPv6 counterpart, nil when no v6 overlay.
family6 *family
router *router
aclManager *AclManager
// IPv6 counterparts, nil when no v6 overlay
router6 *router
aclManager6 *AclManager
notrackOutputChain *nftables.Chain
notrackPreroutingChain *nftables.Chain
@@ -73,9 +75,14 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}
var err error
m.family4, err = newFamily(workTable, wgIface, mtu)
m.router, err = newRouter(workTable, wgIface, mtu)
if err != nil {
return nil, fmt.Errorf("create family: %w", err)
return nil, fmt.Errorf("create router: %w", err)
}
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
if err != nil {
return nil, fmt.Errorf("create acl manager: %w", err)
}
if wgIface.Address().HasIPv6() {
@@ -93,21 +100,26 @@ func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mt
workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}
var err error
m.family6, err = newFamily(workTable6, wgIface, mtu)
m.router6, err = newRouter(workTable6, wgIface, mtu)
if err != nil {
return fmt.Errorf("create v6 family: %w", err)
return fmt.Errorf("create v6 router: %w", err)
}
// Share the same IP forwarding state with the v4 router, since
// EnableIPForwarding controls both v4 and v6 sysctls.
m.family6.ipFwdState = m.family4.ipFwdState
m.router6.ipFwdState = m.router.ipFwdState
m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw)
if err != nil {
return fmt.Errorf("create v6 acl manager: %w", err)
}
return nil
}
// hasIPv6 reports whether the manager has IPv6 components initialized.
func (m *Manager) hasIPv6() bool {
return m.family6 != nil
return m.router6 != nil
}
func (m *Manager) initIPv6() error {
@@ -116,8 +128,12 @@ func (m *Manager) initIPv6() error {
return fmt.Errorf("create v6 work table: %w", err)
}
if err := m.family6.init(workTable6); err != nil {
return fmt.Errorf("v6 family init: %w", err)
if err := m.router6.init(workTable6); err != nil {
return fmt.Errorf("v6 router init: %w", err)
}
if err := m.aclManager6.init(workTable6); err != nil {
return fmt.Errorf("v6 acl manager init: %w", err)
}
return nil
@@ -146,13 +162,13 @@ func (m *Manager) reconcileExternalChains() error {
defer m.mutex.Unlock()
var merr *multierror.Error
if m.family4 != nil {
if err := m.family4.acceptExternalChainsRules(); err != nil {
if m.router != nil {
if err := m.router.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("v4: %w", err))
}
}
if m.hasIPv6() {
if err := m.family6.acceptExternalChainsRules(); err != nil {
if err := m.router6.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("v6: %w", err))
}
}
@@ -171,8 +187,12 @@ func (m *Manager) initFirewall() (err error) {
}
}()
if err := m.family4.init(workTable); err != nil {
return fmt.Errorf("family init: %w", err)
if err := m.router.init(workTable); err != nil {
return fmt.Errorf("router init: %w", err)
}
if err := m.aclManager.init(workTable); err != nil {
return fmt.Errorf("acl manager init: %w", err)
}
if m.hasIPv6() {
@@ -200,7 +220,7 @@ func (m *Manager) persistState(stateManager *statemanager.Manager) {
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
MTU: m.family4.mtu,
MTU: m.router.mtu,
},
}); err != nil {
log.Errorf("failed to update state: %v", err)
@@ -215,12 +235,12 @@ func (m *Manager) persistState(stateManager *statemanager.Manager) {
// rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through.
func (m *Manager) rollbackInit() {
if err := m.family4.Reset(); err != nil {
log.Warnf("rollback family: %v", err)
if err := m.router.Reset(); err != nil {
log.Warnf("rollback router: %v", err)
}
if m.hasIPv6() {
if err := m.family6.Reset(); err != nil {
log.Warnf("rollback v6 family: %v", err)
if err := m.router6.Reset(); err != nil {
log.Warnf("rollback v6 router: %v", err)
}
}
if err := m.cleanupNetbirdTables(); err != nil {
@@ -231,77 +251,118 @@ func (m *Manager) rollbackInit() {
}
}
// AddFilterRule installs a packet-filtering rule.
// AddPeerFiltering rule to the firewall
//
// Destination semantics: zero Network → input chain (peer ACL);
// set Network → forward chain (route ACL).
//
// Sources are a single address family; the rule is dispatched to the
// matching per-family backend.
func (m *Manager) AddFilterRule(
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *Manager) AddPeerFiltering(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
if len(sources) == 0 {
return nil, firewall.ErrNoSources
}
ipsetName string,
) ([]firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
fam := m.family4
if isIPv6Rule(sources, destination) {
if !m.hasIPv6() {
return nil, fmt.Errorf("add filtering: %w", firewall.ErrIPv6NotInitialized)
}
fam = m.family6
if ip.To4() != nil {
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
return fam.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
if !m.hasIPv6() {
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
}
return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
// DeleteFilterRule removes a filtering rule. The owning family is found
// by id, refreshing from the kernel if the in-memory caches miss so a
// stale cache cannot leak the rule. family.DeleteFilterRule is idempotent
// when the id is absent.
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
fam, err := m.familyForRuleID(rule.ID(), (*family).hasRule)
if isIPv6RouteRule(sources, destination) {
if !m.hasIPv6() {
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && isIPv6Rule(rule) {
return m.aclManager6.DeletePeerRule(rule)
}
return m.aclManager.DeletePeerRule(rule)
}
func isIPv6Rule(rule firewall.Rule) bool {
r, ok := rule.(*Rule)
return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6
}
// isIPv6RouteRule determines whether a route rule belongs to the v6 table.
// For static routes, the destination prefix determines the family. For dynamic
// routes (DomainSet), the sources determine the family since management
// duplicates dynamic rules per family.
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
return len(sources) > 0 && sources[0].Addr().Is6()
}
// DeleteRouteRule deletes a routing rule. Route rules live in exactly one
// router; the cached maps are normally authoritative, so the kernel is only
// consulted when neither map knows about the rule.
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
id := rule.ID()
r, err := m.routerForRuleID(id, (*router).hasRule)
if err != nil {
return err
}
return fam.DeleteFilterRule(rule)
return r.DeleteRouteRule(rule)
}
// familyForRuleID picks the family holding the rule with the given id, using
// routerForRuleID picks the router holding the rule with the given id, using
// the supplied lookup. If the cached maps disagree (or both miss), it refreshes
// from the kernel once and re-checks before falling back to the v4 family.
func (m *Manager) familyForRuleID(id firewall.RuleID, has func(*family, firewall.RuleID) bool) (*family, error) {
if has(m.family4, id) {
return m.family4, nil
// from the kernel once and re-checks before falling back to the v4 router.
func (m *Manager) routerForRuleID(id string, has func(*router, string) bool) (*router, error) {
if has(m.router, id) {
return m.router, nil
}
if m.hasIPv6() && has(m.family6, id) {
return m.family6, nil
if m.hasIPv6() && has(m.router6, id) {
return m.router6, nil
}
if !m.hasIPv6() {
return m.family4, nil
return m.router, nil
}
if err := m.family4.refreshRulesMap(); err != nil {
if err := m.router.refreshRulesMap(); err != nil {
return nil, fmt.Errorf("refresh v4 rules: %w", err)
}
if err := m.family6.refreshRulesMap(); err != nil {
if err := m.router6.refreshRulesMap(); err != nil {
return nil, fmt.Errorf("refresh v6 rules: %w", err)
}
if has(m.family6, id) && !has(m.family4, id) {
return m.family6, nil
if has(m.router6, id) && !has(m.router, id) {
return m.router6, nil
}
return m.family4, nil
return m.router, nil
}
func (m *Manager) IsServerRouteSupported() bool {
@@ -320,10 +381,10 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if !m.hasIPv6() {
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.AddNatRule(pair)
return m.router6.AddNatRule(pair)
}
if err := m.family4.AddNatRule(pair); err != nil {
if err := m.router.AddNatRule(pair); err != nil {
return err
}
@@ -335,7 +396,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
// so the eventual cleanup still works.
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.family6.AddNatRule(v6Pair); err != nil {
if err := m.router6.AddNatRule(v6Pair); err != nil {
return fmt.Errorf("add v6 NAT rule: %w", err)
}
}
@@ -351,18 +412,18 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
if !m.hasIPv6() {
return nil
}
return m.family6.RemoveNatRule(pair)
return m.router6.RemoveNatRule(pair)
}
var merr *multierror.Error
if err := m.family4.RemoveNatRule(pair); err != nil {
if err := m.router.RemoveNatRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
}
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.family6.RemoveNatRule(v6Pair); err != nil {
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
}
}
@@ -384,11 +445,11 @@ func (m *Manager) AllowNetbird() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.family4.createDefaultAllowRules(); err != nil {
if err := m.aclManager.createDefaultAllowRules(); err != nil {
return fmt.Errorf("create default allow rules: %w", err)
}
if m.hasIPv6() {
if err := m.family6.createDefaultAllowRules(); err != nil {
if err := m.aclManager6.createDefaultAllowRules(); err != nil {
return fmt.Errorf("create v6 default allow rules: %w", err)
}
}
@@ -405,11 +466,11 @@ func (m *Manager) AllowNetbird() error {
// SetLegacyManagement sets the route manager to use legacy management
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if err := firewall.SetLegacyManagement(m.family4, isLegacy); err != nil {
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
return err
}
if m.hasIPv6() {
return firewall.SetLegacyManagement(m.family6, isLegacy)
return firewall.SetLegacyManagement(m.router6, isLegacy)
}
return nil
}
@@ -423,13 +484,13 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
var merr *multierror.Error
if err := m.family4.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset family: %w", err))
if err := m.router.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset router: %v", err))
}
if m.hasIPv6() {
if err := m.family6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 family: %w", err))
if err := m.router6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %v", err))
}
}
@@ -469,14 +530,14 @@ func (m *Manager) SetLogLevel(log.Level) {
}
func (m *Manager) EnableRouting() error {
if err := m.family4.ipFwdState.RequestForwarding(); err != nil {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil
}
func (m *Manager) DisableRouting() error {
if err := m.family4.ipFwdState.ReleaseForwarding(); err != nil {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil
@@ -490,12 +551,12 @@ func (m *Manager) Flush() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.family4.Flush(); err != nil {
if err := m.aclManager.Flush(); err != nil {
return err
}
if m.hasIPv6() {
if err := m.family6.Flush(); err != nil {
if err := m.aclManager6.Flush(); err != nil {
return fmt.Errorf("flush v6 acl: %w", err)
}
}
@@ -516,9 +577,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
if !m.hasIPv6() {
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.AddDNATRule(rule)
return m.router6.AddDNATRule(rule)
}
return m.family4.AddDNATRule(rule)
return m.router.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
@@ -526,7 +587,7 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
r, err := m.familyForRuleID(rule.ID(), (*family).hasDNATRule)
r, err := m.routerForRuleID(rule.ID(), (*router).hasDNATRule)
if err != nil {
return err
}
@@ -547,12 +608,12 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
}
}
if err := m.family4.UpdateSet(set, v4Prefixes); err != nil {
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
return err
}
if m.hasIPv6() && len(v6Prefixes) > 0 {
if err := m.family6.UpdateSet(set, v6Prefixes); err != nil {
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
return fmt.Errorf("update v6 set: %w", err)
}
}
@@ -569,9 +630,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
if !m.hasIPv6() {
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.family4.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
@@ -583,9 +644,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
if !m.hasIPv6() {
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.family4.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
@@ -597,9 +658,9 @@ func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol
if !m.hasIPv6() {
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.family4.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
@@ -611,9 +672,9 @@ func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Proto
if !m.hasIPv6() {
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.family4.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
const (
@@ -842,14 +903,3 @@ func getEstablishedExprs(register uint32) []expr.Any {
},
}
}
// isIPv6Rule reports whether the rule belongs to the v6 table. For a
// prefix destination the destination family decides; otherwise the
// (single-family) sources do, since management duplicates rules per
// family.
func isIPv6Rule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
return len(sources) > 0 && sources[0].Addr().Is6()
}

View File

@@ -1,5 +1,3 @@
//go:build integration && !android
package nftables
import (
@@ -72,13 +70,13 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{}
rule, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop)
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
require.NoError(t, err, "failed to add rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err := testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules")
require.Len(t, rules, 2, "expected 2 rules")
@@ -149,12 +147,15 @@ func TestNftablesManager(t *testing.T) {
// Compare connection tracking rule at position 1 (pushed down by DROP rule insertion)
compareExprsIgnoringCounters(t, rules[1].Exprs, expectedExprs1)
require.NoError(t, manager.DeleteFilterRule(rule), "failed to delete rule")
for _, r := range rule {
err = manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
}
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err = testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules")
// established rule remains
require.Len(t, rules, 1, "expected 1 rules after deletion")
@@ -179,39 +180,47 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
testClient := &nftables.Conn{}
// Add accept rule first
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "accept-http")
require.NoError(t, err, "failed to add accept rule")
// Add deny rule second for the same traffic
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop, "deny-http")
require.NoError(t, err, "failed to add deny rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err := testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules")
t.Logf("Found %d rules in nftables chain", len(rules))
// Single-source rules emit a direct payload+cmp on the source IP
// (no set lookup). Match by source-IP + port + verdict instead of
// the legacy per-(action,port) set names ("deny-http"/"accept-http")
// that this test predates.
wantSrc := ip.AsSlice()
// Find the accept and deny rules and verify deny comes before accept
var acceptRuleIndex, denyRuleIndex = -1, -1
for i, rule := range rules {
var hasSrc, hasPort80 bool
hasAcceptHTTPSet := false
hasDenyHTTPSet := false
hasPort80 := false
var action string
for _, e := range rule.Exprs {
if cmp, ok := e.(*expr.Cmp); ok && cmp.Op == expr.CmpOpEq {
if bytes.Equal(cmp.Data, wantSrc) {
hasSrc = true
// Check for set lookup
if lookup, ok := e.(*expr.Lookup); ok {
switch lookup.SetName {
case "accept-http":
hasAcceptHTTPSet = true
case "deny-http":
hasDenyHTTPSet = true
}
if len(cmp.Data) == 2 && binary.BigEndian.Uint16(cmp.Data) == 80 {
}
// Check for port 80
if cmp, ok := e.(*expr.Cmp); ok {
if cmp.Op == expr.CmpOpEq && len(cmp.Data) == 2 && binary.BigEndian.Uint16(cmp.Data) == 80 {
hasPort80 = true
}
}
// Check for verdict
if verdict, ok := e.(*expr.Verdict); ok {
switch verdict.Kind {
case expr.VerdictAccept:
@@ -222,15 +231,11 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
}
}
if !hasSrc || !hasPort80 {
continue
}
switch action {
case "ACCEPT":
t.Logf("Rule [%d]: src=%s port=80 ACCEPT", i, ip)
if hasAcceptHTTPSet && hasPort80 && action == "ACCEPT" {
t.Logf("Rule [%d]: accept-http set + Port 80 + ACCEPT", i)
acceptRuleIndex = i
case "DROP":
t.Logf("Rule [%d]: src=%s port=80 DROP", i, ip)
} else if hasDenyHTTPSet && hasPort80 && action == "DROP" {
t.Logf("Rule [%d]: deny-http set + Port 80 + DROP", i)
denyRuleIndex = i
}
}
@@ -274,7 +279,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
if i%100 == 0 {
@@ -356,10 +361,10 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
})
ip := netip.MustParseAddr("100.96.0.1")
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddFilterRule(
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")},
@@ -432,10 +437,10 @@ func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) {
})
ip := netip.MustParseAddr("fd00::2")
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "add v6 peer filtering rule")
_, err = manager.AddFilterRule(
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00:1::/64")},
fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
@@ -545,7 +550,7 @@ func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
prefixes = append(prefixes, netip.PrefixFrom(addr, 24))
}
}
_, err = manager.AddFilterRule(
_, err = manager.AddRouteFiltering(
nil,
prefixes,
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
@@ -560,7 +565,7 @@ func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
verifyIptablesOutput(t, stdout, stderr)
}
func TestNftablesManagerCompatibilityWithIptablesForWildcardSource(t *testing.T) {
func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
@@ -586,9 +591,9 @@ func TestNftablesManagerCompatibilityWithIptablesForWildcardSource(t *testing.T)
verifyIptablesOutput(t, stdout, stderr)
})
_, err = manager.AddFilterRule(
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
[]netip.Prefix{},
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
fw.ProtocolTCP,
nil,

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
//go:build integration && !android
//go:build !android
package nftables
@@ -37,7 +37,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
// need fw manager to init both acl mgr and family for all chains to be present
// need fw manager to init both acl mgr and router for all chains to be present
manager, err := Create(ifaceMock, iface.DefaultMTU)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -47,7 +47,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
nftablesTestingClient := &nftables.Conn{}
rtr := manager.family4
rtr := manager.router
err = rtr.AddNatRule(testCase.InputPair)
require.NoError(t, err, "pair should be inserted")
@@ -90,9 +90,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
}
// Build CIDR matching expressions
testRouter := &family{af: afIPv4}
sourceExp := prefixMatchExprs(testRouter.af, testCase.InputPair.Source.Prefix, true)
destExp := prefixMatchExprs(testRouter.af, testCase.InputPair.Destination.Prefix, false)
testRouter := &router{af: afIPv4}
sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true)
destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false)
// Combine all expressions in the correct order
// nolint:gocritic
@@ -100,14 +100,14 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
testingExpression = append(testingExpression, sourceExp...)
testingExpression = append(testingExpression, destExp...)
natRuleKey := testCase.InputPair.GenKey(firewall.PreroutingFormat)
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := 0
for _, chain := range rtr.chains {
if chain.Name == chainNameManglePrerouting {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
// Compare expressions up to the mark setting expressions
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
found = 1
@@ -135,19 +135,19 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
rtr := manager.family4
rtr := manager.router
// First add the NAT rule using the family's method
// First add the NAT rule using the router's method
err = rtr.AddNatRule(testCase.InputPair)
require.NoError(t, err, "should add NAT rule")
// Verify the rule was added
natRuleKey := testCase.InputPair.GenKey(firewall.PreroutingFormat)
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := false
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
require.NoError(t, err, "should list rules")
for _, rule := range rules {
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found = true
break
}
@@ -163,7 +163,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
require.NoError(t, err, "should list rules after removal")
for _, rule := range rules {
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found = true
break
}
@@ -200,11 +200,11 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
defer deleteWorkTable()
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create family")
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func(r *family) {
defer func(r *router) {
require.NoError(t, r.Reset(), "Failed to reset rules")
}(r)
@@ -314,16 +314,16 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddFilterRule(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddFilterRule failed")
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
t.Cleanup(func() {
require.NoError(t, r.DeleteFilterRule(ruleKey), "Failed to delete rule")
require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule")
})
stored, ok := r.filters[id.RuleID(ruleKey.ID())]
require.True(t, ok, "Rule not found in filters map")
rule := stored.nftRule
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "Rule not found in internal map")
t.Log("Internal rule expressions:")
for i, expr := range rule.Exprs {
@@ -339,7 +339,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
var nftRule *nftables.Rule
for _, rule := range rules {
if firewall.RuleID(rule.UserData) == ruleKey.ID() {
if string(rule.UserData) == ruleKey.ID() {
nftRule = rule
break
}
@@ -367,12 +367,12 @@ func TestNftablesCreateIpSet(t *testing.T) {
defer deleteWorkTable()
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create family")
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func() {
require.NoError(t, r.Reset(), "Failed to reset family")
require.NoError(t, r.Reset(), "Failed to reset router")
}()
tests := []struct {
@@ -509,42 +509,6 @@ func TestNftablesCreateIpSet(t *testing.T) {
}
}
// TestNftablesUpdateSetMergesOverlapping verifies that UpdateSet merges
// overlapping prefixes before adding them. An interval set rejects
// overlapping elements, so without the merge a batch holding a /32 already
// covered by a /24, or a duplicate address as DNS resolution can produce,
// would fail.
func TestNftablesUpdateSetMergesOverlapping(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err, "create work table")
defer deleteWorkTable()
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "create family")
require.NoError(t, r.init(workTable))
defer func() {
require.NoError(t, r.Reset(), "reset family")
}()
initial := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}
set := firewall.NewPrefixSet(initial)
created, err := r.createIpSet(set.HashedName(), setInput{prefixes: initial})
require.NoError(t, err, "create ip set")
require.NotNil(t, created)
overlapping := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.1/32"),
netip.MustParsePrefix("192.168.1.1/32"),
}
require.NoError(t, r.UpdateSet(set, overlapping), "UpdateSet must merge overlapping prefixes")
}
func TestNftablesCreateIpSet_IPv6(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
@@ -554,11 +518,11 @@ func TestNftablesCreateIpSet_IPv6(t *testing.T) {
require.NoError(t, err, "Failed to create v6 work table")
defer deleteWorkTableIPv6()
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create family")
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func() {
require.NoError(t, r.Reset(), "Failed to reset family")
require.NoError(t, r.Reset(), "Failed to reset router")
}()
tests := []struct {
@@ -897,13 +861,13 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
require.NoError(t, err)
defer deleteWorkTable()
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Add a real rule to the kernel
ruleKey, err := r.AddFilterRule(
ruleKey, err := r.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
@@ -914,11 +878,11 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, r.DeleteFilterRule(ruleKey))
require.NoError(t, r.DeleteRouteRule(ruleKey))
})
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
staleKey := firewall.RuleID("stale-rule-that-does-not-exist")
staleKey := "stale-rule-that-does-not-exist"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
@@ -938,55 +902,6 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
}
// TestRouter_DeleteRouteRule_RemovesKernelRule verifies a route filter
// rule is actually removed from the kernel on delete. The route chain is
// not refreshed by Flush, so the stored rule carries a zero handle;
// DeleteFilterRule must pull live handles itself before issuing the
// delete or the kernel rule leaks. Regression test for that path.
func TestRouter_DeleteRouteRule_RemovesKernelRule(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err)
defer deleteWorkTable()
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
ruleKey, err := r.AddFilterRule(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
firewall.ProtocolTCP,
nil,
&firewall.Port{Values: []uint16{80}},
firewall.ActionAccept,
)
require.NoError(t, err)
countKernelRules := func() int {
list, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw])
require.NoError(t, err)
n := 0
for _, rule := range list {
if string(rule.UserData) == string(ruleKey.ID()) {
n++
}
}
return n
}
require.Equal(t, 1, countKernelRules(), "rule should be present in the kernel after add")
require.NoError(t, r.DeleteFilterRule(ruleKey))
assert.Equal(t, 0, countKernelRules(), "rule must be removed from the kernel after delete")
assert.NotContains(t, r.filters, ruleKey.ID(), "filters map entry should be cleared")
}
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
@@ -996,28 +911,24 @@ func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
require.NoError(t, err)
defer deleteWorkTable()
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Inject a stale entry with Handle=0
staleKey := id.RuleID("stale-route-rule")
staleRule := &Rule{
nftRule: &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
},
id: staleKey,
staleKey := "stale-route-rule"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
}
r.filters[staleKey] = staleRule
// DeleteFilterRule should not return an error for stale handles
err = r.DeleteFilterRule(staleRule)
// DeleteRouteRule should not return an error for stale handles
err = r.DeleteRouteRule(id.RuleID(staleKey))
assert.NoError(t, err, "deleting a stale rule should not error")
assert.NotContains(t, r.filters, staleKey, "stale entry should be cleaned up")
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
}
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
@@ -1039,7 +950,7 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
Masquerade: true,
}
rtr := manager.family4
rtr := manager.router
// First add succeeds
err = rtr.AddNatRule(pair)
@@ -1049,11 +960,11 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
})
// Corrupt the handle to simulate stale state
natRuleKey := pair.GenKey(firewall.PreroutingFormat)
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := rtr.rules[natRuleKey]; exists {
rule.Handle = 0
}
inverseKey := firewall.GetInversePair(pair).GenKey(firewall.PreroutingFormat)
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
if rule, exists := rtr.rules[inverseKey]; exists {
rule.Handle = 0
}
@@ -1068,7 +979,7 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
found := 0
for _, rule := range rules {
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found++
}
}
@@ -1099,7 +1010,7 @@ func TestCalculateLastIP(t *testing.T) {
}
func TestConvertPrefixesToSet_IPv6(t *testing.T) {
r := &family{af: afIPv6}
r := &router{af: afIPv6}
prefixes := []netip.Prefix{
netip.MustParsePrefix("fd00::/64"),
netip.MustParsePrefix("2001:db8::1/128"),

View File

@@ -1,490 +0,0 @@
//go:build !android
package nftables
import (
"fmt"
"strings"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/client/net"
)
func (r *family) AddNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
return fmt.Errorf("add legacy routing rule: %w", err)
}
}
if pair.Masquerade {
if err := r.addNatRule(pair); err != nil {
return fmt.Errorf("add nat rule: %w", err)
}
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("add inverse nat rule: %w", err)
}
}
if err := r.conn.Flush(); err != nil {
r.rollbackRules(pair)
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
}
return nil
}
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
func (r *family) rollbackRules(pair firewall.RouterPair) {
keys := []firewall.RuleID{
pair.GenKey(firewall.ForwardingFormat),
pair.GenKey(firewall.PreroutingFormat),
firewall.GetInversePair(pair).GenKey(firewall.PreroutingFormat),
}
for _, key := range keys {
rule, ok := r.rules[key]
if !ok {
continue
}
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("rollback set counter for %s: %v", key, err)
}
delete(r.rules, key)
}
}
// addNatRule inserts a nftables rule to the conn client flush queue
func (r *family) addNatRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
if err != nil {
return fmt.Errorf("apply source: %w", err)
}
destExp, err := r.applyNetwork(pair.Destination, nil, false)
if err != nil {
return fmt.Errorf("apply destination: %w", err)
}
op := expr.CmpOpEq
if pair.Inverse {
op = expr.CmpOpNeq
}
exprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: op,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
exprs = append(exprs, getCtNewExprs()...)
exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
var markValue uint32 = nbnet.PreroutingFwmarkMasquerade
if pair.Inverse {
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(markValue),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
)
ruleID := pair.GenKey(firewall.PreroutingFormat)
if _, exists := r.rules[ruleID]; exists {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err)
}
}
// Ensure nat rules come first, so the mark can be overwritten.
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
r.rules[ruleID] = r.conn.InsertRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameManglePrerouting],
Exprs: exprs,
UserData: []byte(ruleID),
})
return nil
}
func (r *family) addPostroutingRules() {
// First masquerade rule for traffic coming in from WireGuard interface
exprs := []expr.Any{
// Match on the first fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade),
},
// We need to exclude the loopback interface as this changes the ebpf proxy port
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs,
})
// Second masquerade rule for traffic going out through WireGuard interface
exprs2 := []expr.Any{
// Match on the second fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn),
},
// Match WireGuard interface
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs2,
})
}
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
func (r *family) addMSSClampingRules() error {
overhead := uint16(ipv4TCPHeaderSize)
if r.af.tableFamily == nftables.TableFamilyIPv6 {
overhead = ipv6TCPHeaderSize
}
if r.mtu <= overhead {
log.Debugf("MTU %d too small for MSS clamping (overhead %d), skipping", r.mtu, overhead)
return nil
}
mss := r.mtu - overhead
exprsOut := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 13,
Len: 1,
},
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 1,
Mask: []byte{0x02},
Xor: []byte{0x00},
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0x00},
},
&expr.Counter{},
&expr.Exthdr{
DestRegister: 1,
Type: 2,
Offset: 2,
Len: 2,
Op: expr.ExthdrOpTcpopt,
},
&expr.Cmp{
Op: expr.CmpOpGt,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
},
&expr.Exthdr{
SourceRegister: 1,
Type: 2,
Offset: 2,
Len: 2,
Op: expr.ExthdrOpTcpopt,
},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameMangleForward],
Exprs: exprsOut,
})
return r.conn.Flush()
}
func (r *family) addLegacyRouteRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
if err != nil {
return fmt.Errorf("apply source: %w", err)
}
destExp, err := r.applyNetwork(pair.Destination, nil, false)
if err != nil {
return fmt.Errorf("apply destination: %w", err)
}
var exprs []expr.Any
exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
exprs = append(exprs,
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
)
ruleID := pair.GenKey(firewall.ForwardingFormat)
if _, exists := r.rules[ruleID]; exists {
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
}
r.rules[ruleID] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Exprs: exprs,
UserData: []byte(ruleID),
})
return nil
}
// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
func (r *family) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleID := pair.GenKey(firewall.ForwardingFormat)
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleID)
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleID, err)
}
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleID)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
return nil
}
// GetLegacyManagement returns the route manager's legacy management mode
func (r *family) GetLegacyManagement() bool {
return r.legacyManagement
}
// SetLegacyManagement sets the route manager to use legacy management mode
func (r *family) SetLegacyManagement(isLegacy bool) {
r.legacyManagement = isLegacy
}
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
func (r *family) RemoveAllLegacyRouteRules() error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
for k, rule := range r.rules {
if !strings.HasPrefix(string(k), firewall.ForwardingFormatPrefix) {
continue
}
if err := r.conn.DelRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
} else {
delete(r.rules, k)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) removeNatPreroutingRules() error {
table := &nftables.Table{
Name: tableNat,
Family: r.af.tableFamily,
}
chain := &nftables.Chain{
Name: chainNameNatPrerouting,
Table: table,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
}
rules, err := r.conn.GetRules(table, chain)
if err != nil {
return fmt.Errorf("get rules from nat table: %w", err)
}
var merr *multierror.Error
// Delete rules that have our UserData suffix
for _, rule := range rules {
if len(rule.UserData) == 0 || !strings.HasSuffix(string(rule.UserData), string(dnatSuffix)) {
continue
}
if err := r.conn.DelRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete rule %s: %w", rule.UserData, err))
}
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
}
}
if err := r.removeLegacyRouteRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
}
// Set counters are decremented in the sub-methods above before flush. If flush fails,
// counters will be off until the next successful removal or refresh cycle.
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) removeNatRule(pair firewall.RouterPair) error {
ruleID := pair.GenKey(firewall.PreroutingFormat)
rule, exists := r.rules[ruleID]
if !exists {
log.Debugf("prerouting rule %s not found", ruleID)
return nil
}
if rule.Handle == 0 {
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleID)
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleID, err)
}
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleID)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
return nil
}

View File

@@ -1,26 +1,21 @@
package nftables
import (
"net/netip"
"net"
"github.com/google/nftables"
"github.com/netbirdio/netbird/client/firewall/manager"
)
// Rule wraps an installed filter rule (peer or route). Source set
// membership is encoded in the rule's expressions; DeleteFilterRule
// recovers the set name via findSets so the refcounter can drop the
// right reference. mangleRule is set only for peer rules.
// Rule to handle management of rules
type Rule struct {
nftRule *nftables.Rule
mangleRule *nftables.Rule
// sources is the canonical source list this rule was created for.
sources []netip.Prefix
id manager.RuleID
nftSet *nftables.Set
ruleID string
ip net.IP
}
// ID returns the rule id
func (r *Rule) ID() manager.RuleID {
return r.id
// GetRuleID returns the rule id
func (r *Rule) ID() string {
return r.ruleID
}

View File

@@ -1,27 +0,0 @@
//go:build integration && !android
package nftables
import (
"fmt"
"net"
"net/netip"
)
func pfx(ip net.IP) []netip.Prefix {
if ip == nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
if ip.IsUnspecified() {
if ip.To4() != nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
}
a, ok := netip.AddrFromSlice(ip)
if !ok {
panic(fmt.Sprintf("invalid IP length: %d", len(ip)))
}
a = a.Unmap()
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
"os"
"slices"
@@ -71,14 +72,12 @@ const (
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
// peerRules is the canonical list-based storage for peer ACL rules.
// Match order is significant: drop rules come before accept rules so
// callers should consult the slice in order.
type peerRules []*PeerRule
// RuleSet is a set of rules grouped by a string key
type RuleSet map[string]PeerRule
type routeRules []*RouteRule
type RouteRules []*RouteRule
func (r routeRules) Sort() {
func (r RouteRules) Sort() {
slices.SortStableFunc(r, func(a, b *RouteRule) int {
// Deny rules come first
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
@@ -87,44 +86,22 @@ func (r routeRules) Sort() {
if a.action != firewall.ActionDrop && b.action == firewall.ActionDrop {
return 1
}
return strings.Compare(string(a.id), string(b.id))
return strings.Compare(a.id, b.id)
})
}
// peerRuleSpec carries the parameters that define a peer filter rule,
// threaded together through the build path so the builders take a single
// argument instead of a long parameter list.
type peerRuleSpec struct {
mgmtID []byte
sources []netip.Prefix
ipLayer gopacket.LayerType
matchAny bool
proto firewall.Protocol
sPort *firewall.Port
dPort *firewall.Port
action firewall.Action
}
// Manager userspace firewall manager
type Manager struct {
decoders sync.Pool
wgIface common.IFaceMapper
// nativeFirewall is the kernel firewall (nftables/iptables) used for
// the split case where peer ACLs run in userspace here but routing
// stays in the kernel: when the userspace firewall is forced yet the
// router keeps using the kernel, route NAT/ACLs and DNAT are
// delegated to it. It is nil when no native backend is available.
nativeFirewall firewall.Manager
mutex sync.RWMutex
outgoingRules map[netip.Addr]RuleSet
incomingDenyRules map[netip.Addr]RuleSet
incomingRules map[netip.Addr]RuleSet
routeRules RouteRules
routeRulesMap map[nbid.RuleID]*RouteRule
decoders sync.Pool
wgIface common.IFaceMapper
nativeFirewall firewall.Manager
incomingDenyRules peerRules
incomingAcceptRules peerRules
incomingDenyIndex peerRuleIndex
incomingAcceptIndex peerRuleIndex
peerRulesMap map[nbid.RuleID]*PeerRule
routeRules routeRules
routeRulesMap map[nbid.RuleID]*RouteRule
mutex sync.RWMutex
// indicates whether server routes are disabled
disableServerRoutes bool
@@ -206,6 +183,24 @@ func (d *decoder) decodePacket(data []byte) error {
}
}
// Create userspace firewall manager constructor
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
return create(iface, nil, disableServerRoutes, flowLogger, mtu)
}
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
if nativeFirewall == nil {
return nil, errors.New("native firewall is nil")
}
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger, mtu)
if err != nil {
return nil, err
}
return mgr, nil
}
func parseCreateEnv() (bool, bool, bool) {
var disableConntrack, enableLocalForwarding, disableMSSClamping bool
var err error
@@ -236,7 +231,7 @@ func parseCreateEnv() (bool, bool, bool) {
return disableConntrack, enableLocalForwarding, disableMSSClamping
}
func Create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
disableConntrack, enableLocalForwarding, disableMSSClamping := parseCreateEnv()
m := &Manager{
@@ -259,8 +254,11 @@ func Create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
return d
},
},
wgIface: iface,
nativeFirewall: nativeFirewall,
outgoingRules: make(map[netip.Addr]RuleSet),
incomingDenyRules: make(map[netip.Addr]RuleSet),
incomingRules: make(map[netip.Addr]RuleSet),
wgIface: iface,
localipmanager: newLocalIPManager(),
disableServerRoutes: disableServerRoutes,
stateful: !disableConntrack,
@@ -268,7 +266,6 @@ func Create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
flowLogger: flowLogger,
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
peerRulesMap: make(map[nbid.RuleID]*PeerRule),
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{},
@@ -323,7 +320,7 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule,
}
var rules []firewall.Rule
v4Rule, err := m.addRouteRule(
v4Rule, err := m.addRouteFiltering(
nil,
sources,
firewall.Network{Prefix: wgPrefix},
@@ -339,7 +336,7 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule,
if v6Net.IsValid() {
log.Debugf("blocking invalid routed traffic for %s", v6Net)
v6Rule, err := m.addRouteRule(
v6Rule, err := m.addRouteFiltering(
nil,
sources,
firewall.Network{Prefix: v6Net},
@@ -491,136 +488,64 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
return nil
}
// addPeerRule installs an input-chain rule that matches packets
// by source only. Called from AddFilterRule when the caller doesn't
// specify a destination. Mixed-family inputs are split: each family
// gets its own rule with a family-correct ipLayer so packet decoding
// matches what the matcher expects.
func (m *Manager) addPeerRule(
// AddPeerFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *Manager) AddPeerFiltering(
id []byte,
sources []netip.Prefix,
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
_ string,
) ([]firewall.Rule, error) {
// TODO: fix in upper layers
i, ok := netip.AddrFromSlice(ip)
if !ok {
return nil, fmt.Errorf("invalid IP: %s", ip)
}
i = i.Unmap()
r := PeerRule{
id: uuid.New().String(),
mgmtId: id,
ip: i,
ipLayer: layers.LayerTypeIPv6,
matchByIP: true,
drop: action == firewall.ActionDrop,
}
if i.Is4() {
r.ipLayer = layers.LayerTypeIPv4
}
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
r.matchByIP = false
}
r.sPort = sPort
r.dPort = dPort
r.protoLayer = protoToLayer(proto, r.ipLayer)
m.mutex.Lock()
defer m.mutex.Unlock()
if sourcesMatchAny(sources) {
spec := peerRuleSpec{
mgmtID: id,
sources: sources,
ipLayer: layerTypeAll,
matchAny: true,
proto: proto,
sPort: sPort,
dPort: dPort,
action: action,
}
return m.addOnePeerRule(spec), nil
}
// Sources are a single family; normalize v4-mapped prefixes to plain
// v4 and pick the matching IP layer.
normalized := make([]netip.Prefix, len(sources))
ipLayer := layers.LayerTypeIPv4
for i, p := range sources {
normalized[i] = firewall.UnmapPrefix(p)
if normalized[i].Addr().Is6() {
ipLayer = layers.LayerTypeIPv6
}
}
spec := peerRuleSpec{
mgmtID: id,
sources: normalized,
ipLayer: ipLayer,
matchAny: false,
proto: proto,
sPort: sPort,
dPort: dPort,
action: action,
}
return m.addOnePeerRule(spec), nil
}
// addOnePeerRule builds and registers a single-family peer rule, or
// returns the existing rule when one with the same content key is
// already installed. The caller must hold m.mutex. The content key is
// the shared GenerateRuleID with an empty destination, so peer
// rules dedup the same way route rules and the kernel backends do.
//
// There is no refcount: a content key is installed once and deleted on
// the first DeleteFilterRule for that key. The caller must therefore
// key its own tracking by the returned rule id so add and delete stay
// balanced per content key; the acl manager does this via
// peerRulesPairs. The content key is order-independent, so callers
// passing the same sources in any order dedup to one rule.
func (m *Manager) addOnePeerRule(spec peerRuleSpec) *PeerRule {
ruleID := nbid.GenerateRuleID(spec.sources, firewall.Network{}, spec.proto, spec.sPort, spec.dPort, spec.action)
if existing, ok := m.peerRulesMap[ruleID]; ok {
return existing
}
rule := m.buildPeerRule(ruleID, spec)
m.registerPeerRule(rule)
return rule
}
func (m *Manager) buildPeerRule(ruleID nbid.RuleID, spec peerRuleSpec) *PeerRule {
r := &PeerRule{
id: ruleID,
mgmtId: spec.mgmtID,
sources: spec.sources,
matchAny: spec.matchAny,
action: spec.action,
srcPort: spec.sPort,
dstPort: spec.dPort,
}
if !spec.matchAny {
r.sourceAddrs = make(map[netip.Addr]struct{}, len(spec.sources))
for _, p := range spec.sources {
if p.Bits() == p.Addr().BitLen() {
r.sourceAddrs[p.Addr()] = struct{}{}
}
}
}
r.protoLayer = protoToLayer(spec.proto, spec.ipLayer)
return r
}
// registerPeerRule records a freshly built peer rule in the matching
// slice, index, and dedup map. The caller must hold m.mutex.
func (m *Manager) registerPeerRule(r *PeerRule) {
if r.action == firewall.ActionDrop {
m.incomingDenyRules = append(m.incomingDenyRules, r)
m.incomingDenyIndex.add(r)
var targetMap map[netip.Addr]RuleSet
if r.drop {
targetMap = m.incomingDenyRules
} else {
m.incomingAcceptRules = append(m.incomingAcceptRules, r)
m.incomingAcceptIndex.add(r)
targetMap = m.incomingRules
}
m.peerRulesMap[r.id] = r
if _, ok := targetMap[r.ip]; !ok {
targetMap[r.ip] = make(RuleSet)
}
targetMap[r.ip][r.id] = r
m.mutex.Unlock()
return []firewall.Rule{&r}, nil
}
// sourcesMatchAny reports whether the source list matches every source,
// i.e. contains an explicit /0 prefix. An empty list does not qualify:
// AddFilterRule rejects it with ErrNoSources, so "match any" is always
// the deliberate /0 case.
func sourcesMatchAny(sources []netip.Prefix) bool {
for _, p := range sources {
if p.Bits() == 0 {
return true
}
}
return false
}
// AddFilterRule is the unified entry point for both peer (input chain)
// and route (forward chain) filtering rules. The destination
// distinguishes the two semantics: a zero Network installs an
// input-side rule that matches by source only; a set Network installs
// a forward-side rule that also matches the destination.
func (m *Manager) AddFilterRule(
func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
@@ -628,37 +553,13 @@ func (m *Manager) AddFilterRule(
sPort, dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
if len(sources) == 0 {
return nil, firewall.ErrNoSources
}
if destination.IsZero() {
return m.addPeerRule(id, sources, proto, sPort, dPort, action)
}
m.mutex.Lock()
defer m.mutex.Unlock()
return m.addRouteRule(id, sources, destination, proto, sPort, dPort, action)
}
// DeleteFilterRule deletes a filtering rule. The rule's underlying type
// is used to route to the correct internal path.
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if r, ok := rule.(*PeerRule); ok {
return m.deletePeerRuleLocked(r)
}
// Either our *RouteRule or, under native delegation, a native
// route-rule object that implements firewall.Rule but isn't one of
// our concrete types. The route path forwards the latter to the
// native firewall.
return m.deleteRouteRule(rule)
return m.addRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
func (m *Manager) addRouteRule(
func (m *Manager) addRouteFiltering(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
@@ -667,17 +568,18 @@ func (m *Manager) addRouteRule(
action firewall.Action,
) (firewall.Rule, error) {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if existingRule, ok := m.routeRulesMap[ruleID]; ok {
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
return existingRule, nil
}
rule := RouteRule{
id: ruleID,
// TODO: consolidate these IDs
id: string(ruleKey),
mgmtId: id,
sources: sources,
dstSet: destination.Set,
@@ -692,55 +594,70 @@ func (m *Manager) addRouteRule(
m.routeRules = append(m.routeRules, &rule)
m.routeRules.Sort()
m.routeRulesMap[ruleID] = &rule
m.routeRulesMap[ruleKey] = &rule
return &rule, nil
}
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.deleteRouteRule(rule)
}
func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.DeleteFilterRule(rule)
return m.nativeFirewall.DeleteRouteRule(rule)
}
ruleID := rule.ID()
trimmed, _, ok := removeRuleByID(m.routeRules, ruleID)
if !ok {
return fmt.Errorf("route rule not found: %s", ruleID)
ruleKey := nbid.RuleID(rule.ID())
if _, ok := m.routeRulesMap[ruleKey]; !ok {
return fmt.Errorf("route rule not found: %s", ruleKey)
}
m.routeRules = trimmed
delete(m.routeRulesMap, ruleID)
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
return r.id == string(ruleKey)
})
if idx < 0 {
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
}
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
delete(m.routeRulesMap, ruleKey)
return nil
}
// deletePeerRuleLocked removes a peer rule from the matching slice,
// index, and dedup map. The caller must hold m.mutex.
func (m *Manager) deletePeerRuleLocked(r *PeerRule) error {
target, index := &m.incomingAcceptRules, &m.incomingAcceptIndex
if r.action == firewall.ActionDrop {
target, index = &m.incomingDenyRules, &m.incomingDenyIndex
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
r, ok := rule.(*PeerRule)
if !ok {
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
}
trimmed, stored, ok := removeRuleByID(*target, r.id)
if !ok {
var sourceMap map[netip.Addr]RuleSet
if r.drop {
sourceMap = m.incomingDenyRules
} else {
sourceMap = m.incomingRules
}
if ruleset, ok := sourceMap[r.ip]; ok {
if _, exists := ruleset[r.id]; !exists {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(ruleset, r.id)
if len(ruleset) == 0 {
delete(sourceMap, r.ip)
}
} else {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
*target = trimmed
index.remove(stored)
delete(m.peerRulesMap, r.id)
return nil
}
// removeRuleByID removes the first rule whose id matches ruleID from
// rules, preserving order. It returns the trimmed slice, the removed
// rule, and whether a match was found.
func removeRuleByID[S ~[]T, T firewall.Rule](rules S, ruleID firewall.RuleID) (S, T, bool) {
idx := slices.IndexFunc(rules, func(r T) bool { return r.ID() == ruleID })
var removed T
if idx < 0 {
return rules, removed, false
}
removed = rules[idx]
return slices.Delete(rules, idx, idx+1), removed, true
return nil
}
// SetLegacyManagement doesn't need to be implemented for this manager
@@ -757,11 +674,9 @@ func (m *Manager) Flush() error { return nil }
// resetState clears all firewall rules and closes connection trackers.
// Must be called with m.mutex held.
func (m *Manager) resetState() {
m.incomingDenyRules = m.incomingDenyRules[:0]
m.incomingAcceptRules = m.incomingAcceptRules[:0]
m.incomingDenyIndex.reset()
m.incomingAcceptIndex.reset()
clear(m.peerRulesMap)
clear(m.outgoingRules)
clear(m.incomingDenyRules)
clear(m.incomingRules)
clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0]
m.udpHookOut.Store(nil)
@@ -905,11 +820,11 @@ func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
case layers.LayerTypeIPv4:
src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
return src.Unmap(), dst.Unmap()
return src, dst
case layers.LayerTypeIPv6:
src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
return src.Unmap(), dst.Unmap()
return src, dst
default:
return netip.Addr{}, netip.Addr{}
}
@@ -1489,12 +1404,20 @@ func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte)
return nil, false
}
if mgmtId, filter, ok := m.incomingDenyIndex.match(srcIP, d); ok {
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingDenyRules[srcIP], d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := m.incomingAcceptIndex.match(srcIP, d); ok {
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[srcIP], d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv4Unspecified()], d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv6Unspecified()], d); ok {
return mgmtId, filter
}
return nil, true
}
@@ -1515,6 +1438,39 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
return false
}
func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
payloadLayer := d.decoded[1]
for _, rule := range rules {
if rule.matchByIP && ip.Compare(rule.ip) != 0 {
continue
}
if rule.protoLayer == layerTypeAll {
return rule.mgmtId, rule.drop, true
}
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
continue
}
switch payloadLayer {
case layers.LayerTypeTCP:
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
return rule.mgmtId, rule.drop, true
}
case layers.LayerTypeUDP:
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
return rule.mgmtId, rule.drop, true
}
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return rule.mgmtId, rule.drop, true
}
}
return nil, false, false
}
// routeACLsPass returns true if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) {
m.mutex.RLock()

View File

@@ -94,7 +94,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: false,
setupFunc: func(m *Manager) {
// Single rule allowing all traffic
_, err := m.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolALL, nil, nil, fw.ActionAccept)
_, err := m.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
require.NoError(b, err)
},
desc: "Baseline: Single 'allow all' rule without connection tracking",
@@ -114,13 +114,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Add explicit rules matching return traffic pattern
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
ip := generateRandomIPs(1)[0]
_, err := m.AddFilterRule(
_, err := m.AddPeerFiltering(
nil,
pfx(ip), fw.Network{},
ip,
fw.ProtocolTCP,
&fw.Port{Values: []uint16{uint16(1024 + i)}},
&fw.Port{Values: []uint16{80}},
fw.ActionAccept)
fw.ActionAccept,
"",
)
require.NoError(b, err)
}
},
@@ -131,13 +133,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: true,
setupFunc: func(m *Manager) {
// Add some basic rules but rely on state for established connections
_, err := m.AddFilterRule(
_, err := m.AddPeerFiltering(
nil,
pfx(net.ParseIP("0.0.0.0")), fw.Network{},
net.ParseIP("0.0.0.0"),
fw.ProtocolTCP,
nil,
nil,
fw.ActionDrop)
fw.ActionDrop,
"",
)
require.NoError(b, err)
},
desc: "Connection tracking with established connections",
@@ -166,7 +170,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Create manager and basic setup
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -206,7 +210,7 @@ func BenchmarkStateScaling(b *testing.B) {
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
}, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -249,7 +253,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
}, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -407,7 +411,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
}, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -534,7 +538,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -542,7 +546,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
require.NoError(b, err)
}
@@ -617,7 +621,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -625,7 +629,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
require.NoError(b, err)
}
@@ -728,14 +732,14 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
// Setup initial state based on scenario
if sc.rules {
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
require.NoError(b, err)
}
@@ -808,13 +812,13 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
if sc.rules {
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
require.NoError(b, err)
}
@@ -927,7 +931,7 @@ func BenchmarkRouteACLs(b *testing.B) {
for _, r := range rules {
dst := fw.Network{Prefix: r.dest}
_, err := manager.AddFilterRule(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
if err != nil {
b.Fatal(err)
}
@@ -1012,7 +1016,7 @@ func BenchmarkMSSClamping(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -1077,7 +1081,7 @@ func BenchmarkMSSClampingOverhead(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -1132,7 +1136,7 @@ func BenchmarkMSSClampingMemory(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))

View File

@@ -32,7 +32,7 @@ func TestPeerACLFiltering(t *testing.T) {
},
}
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
require.NotNil(t, manager)
@@ -496,32 +496,40 @@ func TestPeerACLFiltering(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
if tc.ruleAction == fw.ActionDrop {
// add general accept rule for the same IP to test drop rule precedence
rules, err := manager.AddFilterRule(
rules, err := manager.AddPeerFiltering(
nil,
pfx(net.ParseIP(tc.ruleIP)), fw.Network{},
net.ParseIP(tc.ruleIP),
fw.ProtocolALL,
nil,
nil,
fw.ActionAccept)
fw.ActionAccept,
"",
)
require.NoError(t, err)
require.NotNil(t, rules)
require.NotEmpty(t, rules)
t.Cleanup(func() {
require.NoError(t, manager.DeleteFilterRule(rules))
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
})
}
rules, err := manager.AddFilterRule(
rules, err := manager.AddPeerFiltering(
nil,
pfx(net.ParseIP(tc.ruleIP)), fw.Network{},
net.ParseIP(tc.ruleIP),
tc.ruleProto,
tc.ruleSrcPort,
tc.ruleDstPort,
tc.ruleAction)
tc.ruleAction,
"",
)
require.NoError(t, err)
require.NotNil(t, rules)
require.NotEmpty(t, rules)
t.Cleanup(func() {
require.NoError(t, manager.DeleteFilterRule(rules))
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
})
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
@@ -549,7 +557,7 @@ func TestPeerACLFilteringIPv6(t *testing.T) {
},
}
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
@@ -664,18 +672,22 @@ func TestPeerACLFilteringIPv6(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.ruleAction == fw.ActionDrop {
rules, err := manager.AddFilterRule(nil, pfx(net.ParseIP(tc.ruleIP)), fw.Network{}, fw.ProtocolALL, nil, nil, fw.ActionAccept)
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.DeleteFilterRule(rules))
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
})
}
rules, err := manager.AddFilterRule(nil, pfx(net.ParseIP(tc.ruleIP)), fw.Network{}, tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction)
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction, "")
require.NoError(t, err)
require.NotNil(t, rules)
require.NotEmpty(t, rules)
t.Cleanup(func() {
require.NoError(t, manager.DeleteFilterRule(rules))
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
})
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
@@ -788,7 +800,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
},
}
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(tb, err)
require.NoError(tb, manager.EnableRouting())
require.NotNil(tb, manager)
@@ -1393,7 +1405,7 @@ func TestRouteACLFiltering(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
if tc.rule.action == fw.ActionDrop {
// add general accept rule to test drop rule
rule, err := manager.AddFilterRule(
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
@@ -1403,13 +1415,13 @@ func TestRouteACLFiltering(t *testing.T) {
fw.ActionAccept,
)
require.NoError(t, err)
require.NotEmpty(t, rule)
require.NotNil(t, rule)
t.Cleanup(func() {
require.NoError(t, manager.DeleteFilterRule(rule))
require.NoError(t, manager.DeleteRouteRule(rule))
})
}
rule, err := manager.AddFilterRule(
rule, err := manager.AddRouteFiltering(
nil,
tc.rule.sources,
tc.rule.dest,
@@ -1419,10 +1431,10 @@ func TestRouteACLFiltering(t *testing.T) {
tc.rule.action,
)
require.NoError(t, err)
require.NotEmpty(t, rule)
require.NotNil(t, rule)
t.Cleanup(func() {
require.NoError(t, manager.DeleteFilterRule(rule))
require.NoError(t, manager.DeleteRouteRule(rule))
})
srcIP := netip.MustParseAddr(tc.srcIP)
@@ -1590,9 +1602,9 @@ func TestRouteACLOrder(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var addedRules []fw.Rule
var rules []fw.Rule
for _, r := range tc.rules {
rule, err := manager.AddFilterRule(
rule, err := manager.AddRouteFiltering(
nil,
r.sources,
r.dest,
@@ -1603,12 +1615,12 @@ func TestRouteACLOrder(t *testing.T) {
)
require.NoError(t, err)
require.NotNil(t, rule)
addedRules = append(addedRules, rule)
rules = append(rules, rule)
}
t.Cleanup(func() {
for _, rule := range addedRules {
require.NoError(t, manager.DeleteFilterRule(rule))
for _, rule := range rules {
require.NoError(t, manager.DeleteRouteRule(rule))
}
})
@@ -1634,7 +1646,7 @@ func TestRouteACLSet(t *testing.T) {
},
}
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -1643,7 +1655,7 @@ func TestRouteACLSet(t *testing.T) {
set := fw.NewDomainSet(domain.List{"example.org"})
// Add rule that uses the set (initially empty)
rule, err := manager.AddFilterRule(
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
@@ -1677,7 +1689,7 @@ func TestRouteACLFilteringIPv6(t *testing.T) {
manager := setupRoutedManager(t, "10.10.0.100/16")
v6Dst := netip.MustParsePrefix("fd00:dead:beef::/48")
_, err := manager.AddFilterRule(
_, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
fw.Network{Prefix: v6Dst},
@@ -1688,7 +1700,7 @@ func TestRouteACLFilteringIPv6(t *testing.T) {
)
require.NoError(t, err)
_, err = manager.AddFilterRule(
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
fw.Network{Prefix: netip.MustParsePrefix("fd00:dead:beef:1::/64")},

View File

@@ -29,7 +29,7 @@ func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add rule first time
rule1, err := manager.AddFilterRule(
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
@@ -42,7 +42,7 @@ func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
require.NotNil(t, rule1)
// Add the same rule again
rule2, err := manager.AddFilterRule(
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
@@ -74,7 +74,7 @@ func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
// Add first rule
rule1, err := manager.AddFilterRule(
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
@@ -86,7 +86,7 @@ func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
require.NoError(t, err)
// Add different rule (different destination)
rule2, err := manager.AddFilterRule(
rule2, err := manager.AddRouteFiltering(
[]byte("policy-2"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
@@ -115,7 +115,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
rule1, err := manager.AddFilterRule(
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
@@ -132,7 +132,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
require.True(t, pass, "Traffic should pass with rule in place")
// Re-add same rule (simulates network map update)
rule2, err := manager.AddFilterRule(
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
@@ -147,7 +147,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
// would remove the only matching rule and cause a traffic gap.
if rule1.ID() != rule2.ID() {
err = manager.DeleteFilterRule(rule1)
err = manager.DeleteRouteRule(rule1)
require.NoError(t, err)
}
@@ -156,59 +156,6 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
"Traffic should still pass after rule update - no gap should occur")
}
// TestBlockInvalidRoutedDualStack verifies that when the interface has an
// IPv6 overlay address, blockInvalidRouted installs a block rule for both
// the v4 and v6 WG prefixes and that routed traffic to the v6 prefix is
// denied. The v4-only soft-skip path is covered by
// TestBlockInvalidRoutedIdempotent, whose mock leaves IPv6Net invalid.
func TestBlockInvalidRoutedDualStack(t *testing.T) {
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
wgNet6 := netip.MustParsePrefix("fd00:1234::1/64")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
IPv6: wgNet6.Addr(),
IPv6Net: wgNet6,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
rules, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.Len(t, rules, 2, "dual-stack interface must produce a v4 and a v6 block rule")
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 2, ruleCount, "should have one block rule per family")
// v6 routed traffic to the WG prefix must be denied.
srcIP := netip.MustParseAddr("2001:db8::1")
dstIP := netip.MustParseAddr("fd00:1234::50")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
assert.False(t, pass, "block rule should deny routed traffic to the v6 WG prefix")
}
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
// exactly one drop rule for the WireGuard network prefix, and calling it again
// returns the same rule without duplicating.
@@ -235,7 +182,7 @@ func TestBlockInvalidRoutedIdempotent(t *testing.T) {
},
}
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -298,7 +245,7 @@ func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
},
}
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -327,7 +274,7 @@ func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
// Simulate 5 network map updates with the same route rule
for i := 0; i < 5; i++ {
rule, err := manager.AddFilterRule(
rule, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
@@ -357,7 +304,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add same rule twice
rule1, err := manager.AddFilterRule(
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
@@ -368,7 +315,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
)
require.NoError(t, err)
rule2, err := manager.AddFilterRule(
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
@@ -382,7 +329,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
// Delete using first reference
err = manager.DeleteFilterRule(rule1)
err = manager.DeleteRouteRule(rule1)
require.NoError(t, err)
// Verify traffic no longer passes
@@ -417,7 +364,7 @@ func setupTestManager(t *testing.T) *Manager {
},
}
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.EnableRouting())

View File

@@ -78,7 +78,7 @@ func TestManagerCreate(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -89,7 +89,7 @@ func TestManagerCreate(t *testing.T) {
}
}
func TestManagerAddFilterRule(t *testing.T) {
func TestManagerAddPeerFiltering(t *testing.T) {
isSetFilterCalled := false
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error {
@@ -98,7 +98,7 @@ func TestManagerAddFilterRule(t *testing.T) {
},
}
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -109,7 +109,7 @@ func TestManagerAddFilterRule(t *testing.T) {
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
rule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
rule, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -131,7 +131,7 @@ func TestManagerDeleteRule(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -142,33 +142,63 @@ func TestManagerDeleteRule(t *testing.T) {
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
rule2, err := m.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, proto, nil, port, action)
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
}
peerRule, ok := rule2.(*PeerRule)
require.True(t, ok, "rule should be a PeerRule")
inMap := func() bool {
if peerRule.action == fw.ActionDrop {
return findRuleByID(m.incomingDenyRules, ip, rule2.ID())
// Check rules exist in appropriate maps
for _, r := range rule2 {
peerRule, ok := r.(*PeerRule)
if !ok {
t.Errorf("rule should be a PeerRule")
continue
}
// Check if rule exists in deny or allow maps based on action
var found bool
if peerRule.drop {
_, found = m.incomingDenyRules[ip][r.ID()]
} else {
_, found = m.incomingRules[ip][r.ID()]
}
if !found {
t.Errorf("rule2 is not in the expected rules map")
}
return findRuleByID(m.incomingAcceptRules, ip, rule2.ID())
}
require.True(t, inMap(), "rule2 should be in the expected rules list")
for _, r := range rule2 {
err = m.DeletePeerRule(r)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
}
}
require.NoError(t, m.DeleteFilterRule(rule2), "failed to delete rule")
require.False(t, inMap(), "rule2 should be removed from the rules list")
// Check rules are removed from appropriate maps
for _, r := range rule2 {
peerRule, ok := r.(*PeerRule)
if !ok {
t.Errorf("rule should be a PeerRule")
continue
}
// Check if rule is removed from deny or allow maps based on action
var found bool
if peerRule.drop {
_, found = m.incomingDenyRules[ip][r.ID()]
} else {
_, found = m.incomingRules[ip][r.ID()]
}
if found {
t.Errorf("rule2 should be removed from the rules map")
}
}
}
func TestSetUDPPacketHook(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, nbiface.DefaultMTU)
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
@@ -192,7 +222,7 @@ func TestSetUDPPacketHook(t *testing.T) {
func TestSetTCPPacketHook(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, nbiface.DefaultMTU)
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
@@ -220,7 +250,7 @@ func TestPeerRuleLifecycleDenyRules(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
@@ -230,34 +260,36 @@ func TestPeerRuleLifecycleDenyRules(t *testing.T) {
addr := netip.MustParseAddr("192.168.1.1")
// Add multiple deny rules for different ports
rule1, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
rule2, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
require.NoError(t, err)
m.mutex.RLock()
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
denyCount := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
// Delete the first deny rule
err = m.DeleteFilterRule(rule1)
err = m.DeletePeerRule(rule1[0])
require.NoError(t, err)
m.mutex.RLock()
denyCount = countRulesForAddr(m.incomingDenyRules, addr)
denyCount = len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
// Delete the second deny rule
err = m.DeleteFilterRule(rule2)
err = m.DeletePeerRule(rule2[0])
require.NoError(t, err)
m.mutex.RLock()
exists := countRulesForAddr(m.incomingDenyRules, addr) > 0
_, exists := m.incomingDenyRules[addr]
m.mutex.RUnlock()
require.False(t, exists, "Deny rules should be cleaned up when empty")
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
}
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
@@ -267,7 +299,7 @@ func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
@@ -279,21 +311,27 @@ func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
// Simulate 10 network map updates: add rule, delete old, add new
for i := 0; i < 10; i++ {
// Add a deny rule
rules, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
// Add an allow rule
allowRules, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err)
// Delete them (simulating ACL manager cleanup)
require.NoError(t, m.DeleteFilterRule(rules))
require.NoError(t, m.DeleteFilterRule(allowRules))
for _, r := range rules {
require.NoError(t, m.DeletePeerRule(r))
}
for _, r := range allowRules {
require.NoError(t, m.DeletePeerRule(r))
}
}
m.mutex.RLock()
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
allowCount := countRulesForAddr(m.incomingAcceptRules, addr)
denyCount := len(m.incomingDenyRules[addr])
allowCount := len(m.incomingRules[addr])
m.mutex.RUnlock()
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
@@ -307,7 +345,7 @@ func TestMixedAllowDenyRulesSameIP(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
@@ -316,39 +354,41 @@ func TestMixedAllowDenyRulesSameIP(t *testing.T) {
ip := net.ParseIP("192.168.1.1")
// Add allow rule for port 80
allowRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err)
// Add deny rule for port 22
denyRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
addr := netip.MustParseAddr("192.168.1.1")
m.mutex.RLock()
allowCount := countRulesForAddr(m.incomingAcceptRules, addr)
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
allowCount := len(m.incomingRules[addr])
denyCount := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
// Delete allow rule should not affect deny rule
err = m.DeleteFilterRule(allowRule)
err = m.DeletePeerRule(allowRule[0])
require.NoError(t, err)
m.mutex.RLock()
denyCountAfter := countRulesForAddr(m.incomingDenyRules, addr)
denyCountAfter := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
// Delete deny rule
err = m.DeleteFilterRule(denyRule)
err = m.DeletePeerRule(denyRule[0])
require.NoError(t, err)
m.mutex.RLock()
denyExists := countRulesForAddr(m.incomingDenyRules, addr) > 0
allowExists := countRulesForAddr(m.incomingAcceptRules, addr) > 0
_, denyExists := m.incomingDenyRules[addr]
_, allowExists := m.incomingRules[addr]
m.mutex.RUnlock()
require.False(t, denyExists, "Deny rules should be empty")
@@ -360,7 +400,7 @@ func TestManagerReset(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -371,7 +411,7 @@ func TestManagerReset(t *testing.T) {
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err = m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
_, err = m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -383,7 +423,7 @@ func TestManagerReset(t *testing.T) {
return
}
if len(m.incomingAcceptRules) != 0 || len(m.incomingDenyRules) != 0 {
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 || len(m.incomingDenyRules) != 0 {
t.Errorf("rules are not empty")
}
}
@@ -399,7 +439,7 @@ func TestNotMatchByIP(t *testing.T) {
},
}
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -409,7 +449,7 @@ func TestNotMatchByIP(t *testing.T) {
proto := fw.ProtocolUDP
action := fw.ActionAccept
_, err = m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
_, err = m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -462,7 +502,7 @@ func TestRemovePacketHook(t *testing.T) {
}
// creating manager instance
manager, err := Create(iface, nil, false, flowLogger, nbiface.DefaultMTU)
manager, err := Create(iface, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Fatalf("Failed to create Manager: %s", err)
}
@@ -481,7 +521,7 @@ func TestRemovePacketHook(t *testing.T) {
func TestProcessOutgoingHooks(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, nbiface.DefaultMTU)
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
manager.udpTracker.Close()
@@ -566,7 +606,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
time.Sleep(time.Second)
@@ -581,7 +621,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddFilterRule(nil, pfx(ip), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
}
@@ -593,7 +633,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, nbiface.DefaultMTU)
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
manager.udpTracker.Close() // Close the existing tracker
@@ -805,7 +845,7 @@ func TestUpdateSetMerge(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -818,7 +858,7 @@ func TestUpdateSetMerge(t *testing.T) {
netip.MustParsePrefix("192.168.1.0/24"),
}
rule, err := manager.AddFilterRule(
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
@@ -891,7 +931,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -899,7 +939,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
set := fw.NewDomainSet(domain.List{"example.org"})
rule, err := manager.AddFilterRule(
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
@@ -1011,7 +1051,7 @@ func TestMSSClamping(t *testing.T) {
},
}
manager, err := Create(ifaceMock, nil, false, flowLogger, 1280)
manager, err := Create(ifaceMock, false, flowLogger, 1280)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -1203,7 +1243,7 @@ func TestShouldForward(t *testing.T) {
return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)}
}
manager, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -1318,7 +1358,7 @@ func TestShouldForward(t *testing.T) {
// Re-create manager to pick up the new address with IPv6
require.NoError(t, manager.Close(nil))
manager, err = Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
manager, err = Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
v6Cases := []struct {

View File

@@ -66,7 +66,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -126,7 +126,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
func BenchmarkDNATConcurrency(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -198,7 +198,7 @@ func BenchmarkDNATScaling(b *testing.B) {
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -310,7 +310,7 @@ func BenchmarkChecksumUpdate(b *testing.B) {
func BenchmarkDNATMemoryAllocations(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -483,7 +483,7 @@ func BenchmarkPortDNAT(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))

View File

@@ -15,7 +15,7 @@ import (
func TestPortDNATBasic(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -51,7 +51,7 @@ func TestPortDNATBasic(t *testing.T) {
func TestPortDNATMultipleRules(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))

View File

@@ -17,7 +17,7 @@ import (
func TestDNATTranslationCorrectness(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -106,7 +106,7 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
func TestDNATMappingManagement(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -154,7 +154,7 @@ func TestDNATMappingManagement(t *testing.T) {
func TestInboundPortDNAT(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -204,7 +204,7 @@ func TestInboundPortDNAT(t *testing.T) {
func TestInboundPortDNATNegative(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))

View File

@@ -1,327 +0,0 @@
//go:build uspbench
package uspfilter
import (
"fmt"
"io"
"math/rand"
"net"
"net/netip"
"runtime"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// BenchmarkPeerACLMatch measures the per-packet cost of the peer ACL
// matcher (peerACLsBlock) across realistic shapes: M distinct policy
// rules, each with K source peers in its set.
//
// With the reverse-source index, miss cost is independent of M and
// hit cost grows only with the number of rules touching a single
// srcIP, not with total rule count.
func BenchmarkPeerACLMatch(b *testing.B) {
shapes := []struct{ M, K int }{
{1, 100}, {10, 100}, {50, 100}, {100, 100}, {100, 1000},
}
families := []struct {
name string
v6 bool
}{{"v4", false}, {"v6", true}}
for _, fam := range families {
for _, s := range shapes {
b.Run(fmt.Sprintf("%s/M=%d/K=%d/hit", fam.name, s.M, s.K), func(b *testing.B) {
runPeerACLBench(b, s.M, s.K, true, fam.v6)
})
b.Run(fmt.Sprintf("%s/M=%d/K=%d/miss", fam.name, s.M, s.K), func(b *testing.B) {
runPeerACLBench(b, s.M, s.K, false, fam.v6)
})
}
}
}
func runPeerACLBench(b *testing.B, m, k int, hit, v6 bool) {
log.SetOutput(io.Discard) // keep manager logs out of the benchmark output
// Miss packets are dropped, so they always traverse the full peer
// ACL matcher (every bucket) without short-circuiting and without
// touching conntrack. Disable conntrack for the miss case so it
// measures the matcher, not established-state lookups. The hit case
// keeps conntrack on: an accepted packet reaches trackInbound, which
// needs the trackers conntrack creates.
if !hit {
b.Setenv("NB_DISABLE_CONNTRACK", "1")
}
bits := 32
genPkt := generatePacket
addrs := uniqueAddrs
if v6 {
bits = 128
genPkt = generatePacket6
addrs = uniqueAddrs6
}
// dstIP must be a local IP so filterInbound takes the local-traffic
// path (handleLocalTraffic → peerACLsBlock) we are measuring; an
// address the manager doesn't own would be treated as routed and
// short-circuit before the peer matcher.
dstIP := addrs(1, 2)[0]
mockAddr := wgaddr.Address{IP: dstIP, Network: netip.PrefixFrom(dstIP, bits)}
if v6 {
// The local-IP manager needs a valid v4 address too; expose the v6
// dst as the interface's IPv6 so IsLocalIP recognizes it.
mockAddr = wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
IPv6: dstIP,
IPv6Net: netip.PrefixFrom(dstIP, bits),
}
}
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { return mockAddr },
}, nil, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() { require.NoError(b, manager.Close(nil)) })
// Generate M policies × K source peers, all distinct.
all := addrs(m*k, 1)
for i := 0; i < m; i++ {
sources := make([]netip.Prefix, k)
for j, a := range all[i*k : (i+1)*k] {
sources[j] = netip.PrefixFrom(a, bits)
}
_, err := manager.AddFilterRule(
nil, sources, fw.Network{}, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{uint16(80 + i)}},
fw.ActionAccept)
require.NoError(b, err)
}
// Hit: cycle through real sources, picking the matching policy's port.
// Miss: a source from a disjoint range, port 80 (matches no policy).
var pktFn func(i int) []byte
if hit {
pktFn = func(i int) []byte {
policy := i % m
src := all[policy*k+(i%k)]
return genPkt(b, src.AsSlice(), dstIP.AsSlice(),
uint16(1024+i%60000), uint16(80+policy), layers.IPProtocolTCP)
}
} else {
miss := addrs(4096, 99)
pktFn = func(i int) []byte {
return genPkt(b, miss[i%len(miss)].AsSlice(), dstIP.AsSlice(),
uint16(1024+i%60000), 80, layers.IPProtocolTCP)
}
}
// Pre-build a pool to avoid allocations dominating the measurement.
pool := make([][]byte, 1024)
for i := range pool {
pool[i] = pktFn(i)
}
// Confirm the matcher is actually exercised: a hit packet must be
// allowed and a miss packet dropped. Without this the benchmark
// could silently time the routed early-return instead.
require.Equal(b, !hit, manager.filterInbound(pool[0], 0),
"benchmark must reach the peer ACL matcher")
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterInbound(pool[i%len(pool)], 0)
}
}
// BenchmarkPeerACLIndexMemory reports the resident memory cost of
// the source-keyed index across realistic deployment shapes. Two
// dimensions matter: (M, K), the number of policies × peers-per-policy,
// and overlap, the fraction of peers shared between policies.
//
// The output uses ReportMetric("bytes/rule") so the cost can be
// compared across shapes directly. Total bytes = bytes/rule * M.
func BenchmarkPeerACLIndexMemory(b *testing.B) {
cases := []struct {
name string
M, K int
overlapFrac float64 // 0 = disjoint per-policy sources, 1 = all share the same pool
}{
{"M=10/K=100/disjoint", 10, 100, 0},
{"M=100/K=100/disjoint", 100, 100, 0},
{"M=100/K=1000/disjoint", 100, 1000, 0},
{"M=100/K=1000/overlap=0.5", 100, 1000, 0.5},
{"M=100/K=1000/overlap=1.0", 100, 1000, 1.0},
{"M=1000/K=100/overlap=1.0", 1000, 100, 1.0},
}
for _, c := range cases {
b.Run(c.name, func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
mgr, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, nil, false, flowLogger, iface.DefaultMTU)
populateIndexedRules(b, mgr, c.M, c.K, c.overlapFrac)
runtime.GC()
var ms runtime.MemStats
runtime.ReadMemStats(&ms)
before := ms.HeapAlloc
// Drop the manager's external roots so we can isolate
// the index cost. We hold the manager itself live; the
// index is what we measure on the second pass.
mgr.incomingAcceptIndex.reset()
mgr.incomingDenyIndex.reset()
mgr.incomingAcceptRules = mgr.incomingAcceptRules[:0]
mgr.incomingDenyRules = mgr.incomingDenyRules[:0]
runtime.GC()
runtime.ReadMemStats(&ms)
after := ms.HeapAlloc
delta := int64(before) - int64(after)
if delta < 0 {
delta = 0
}
b.ReportMetric(float64(delta)/float64(c.M), "bytes/rule")
b.ReportMetric(float64(delta), "bytes/total")
require.NoError(b, mgr.Close(nil))
}
})
}
}
func populateIndexedRules(b *testing.B, mgr *Manager, m, k int, overlapFrac float64) {
b.Helper()
pool := uniqueAddrs(k+m*k, 1) // big enough universe
sharedLen := int(float64(k) * overlapFrac)
if sharedLen > k {
sharedLen = k
}
shared := pool[:sharedLen]
uniquePool := pool[sharedLen:]
for i := 0; i < m; i++ {
sources := make([]netip.Prefix, 0, k)
for _, a := range shared {
sources = append(sources, netip.PrefixFrom(a, 32))
}
// each policy gets (k-sharedLen) addresses unique to it from the unique pool
unique := uniquePool[i*(k-sharedLen) : (i+1)*(k-sharedLen)]
for _, a := range unique {
sources = append(sources, netip.PrefixFrom(a, 32))
}
_, err := mgr.AddFilterRule(
nil, sources, fw.Network{}, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{uint16(80 + i)}},
fw.ActionAccept)
require.NoError(b, err)
}
}
// uniqueAddrs returns n distinct addrs. Seeds 1, 2 are used for
// policy sources / dst; seed 99 puts misses in 10/8.
func uniqueAddrs(n int, seed int64) []netip.Addr {
out := make([]netip.Addr, 0, n)
seen := make(map[netip.Addr]struct{}, n)
r := rand.New(rand.NewSource(seed))
miss := seed == 99
for len(out) < n {
var b [4]byte
if miss {
b[0] = 10
b[1] = byte(r.Intn(256))
} else {
b[0] = 100
b[1] = byte(64 + r.Intn(63))
}
b[2] = byte(r.Intn(256))
b[3] = byte(1 + r.Intn(254))
a := netip.AddrFrom4(b)
if _, ok := seen[a]; ok {
continue
}
seen[a] = struct{}{}
out = append(out, a)
}
return out
}
// uniqueAddrs6 mirrors uniqueAddrs for IPv6: sources come from the ULA
// range fd00::/8, the miss set (seed 99) from 2001:db8::/32 so it is
// disjoint from any source.
func uniqueAddrs6(n int, seed int64) []netip.Addr {
out := make([]netip.Addr, 0, n)
seen := make(map[netip.Addr]struct{}, n)
r := rand.New(rand.NewSource(seed))
miss := seed == 99
for len(out) < n {
var b [16]byte
if miss {
b[0], b[1], b[2], b[3] = 0x20, 0x01, 0x0d, 0xb8
} else {
b[0] = 0xfd
}
for x := 8; x < 16; x++ {
b[x] = byte(r.Intn(256))
}
a := netip.AddrFrom16(b)
if _, ok := seen[a]; ok {
continue
}
seen[a] = struct{}{}
out = append(out, a)
}
return out
}
// generatePacket6 builds an IPv6 TCP/UDP packet, mirroring
// generatePacket for the v4 case.
func generatePacket6(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16, protocol layers.IPProtocol) []byte {
b.Helper()
ipv6 := &layers.IPv6{
Version: 6,
HopLimit: 64,
NextHeader: protocol,
SrcIP: srcIP,
DstIP: dstIP,
}
var transportLayer gopacket.SerializableLayer
switch protocol {
case layers.IPProtocolTCP:
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
}
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv6))
transportLayer = tcp
case layers.IPProtocolUDP:
udp := &layers.UDP{
SrcPort: layers.UDPPort(srcPort),
DstPort: layers.UDPPort(dstPort),
}
require.NoError(b, udp.SetNetworkLayerForChecksum(ipv6))
transportLayer = udp
}
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv6, transportLayer, gopacket.Payload("test")))
return buf.Bytes()
}

View File

@@ -1,149 +0,0 @@
package uspfilter
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
nbiface "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
func newTestManager(t *testing.T) *Manager {
t.Helper()
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err, "create manager")
return m
}
// TestAddPeerFiltering_DeduplicatesIdenticalRules verifies that adding
// the same peer rule twice does not create two backing rules. The acl
// manager keys its own cache, but the firewall backend must be
// idempotent on its own so a double-apply cannot leak rules, matching
// the route path and the kernel backends.
func TestAddPeerFiltering_DeduplicatesIdenticalRules(t *testing.T) {
m := newTestManager(t)
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
first, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err, "first add")
second, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err, "second add")
assert.Equal(t, first.ID(), second.ID(), "duplicate add should return the same rule id")
assert.Len(t, m.incomingDenyRules, 1, "duplicate add must not create a second backing rule")
}
// TestDeletePeerFiltering_NoRefcountSingleDeleteRemoves locks the
// backend's no-refcount contract: a content key installed twice is
// still one rule, and the first DeleteFilterRule removes it. The
// backend does not refcount, so balance is the caller's job (it keys
// its tracking by the returned id and deletes once per key). If this
// ever silently grew a refcount, the acl manager's delete accounting
// would diverge from the kernel.
func TestDeletePeerFiltering_NoRefcountSingleDeleteRemoves(t *testing.T) {
m := newTestManager(t)
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
first, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err, "first add")
second, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err, "second add")
require.Equal(t, first.ID(), second.ID(), "dedup to one rule")
require.Len(t, m.incomingDenyRules, 1, "still one backing rule after duplicate add")
require.NoError(t, m.DeleteFilterRule(first), "delete once")
assert.Empty(t, m.incomingDenyRules, "single delete removes the backing rule (no refcount)")
assert.NotContains(t, m.peerRulesMap, first.ID(), "dedup map entry cleared")
}
// TestAddPeerFiltering_DeterministicID verifies the peer rule id is a
// content hash, not a random UUID: identical inputs produce the same id
// across independent managers. A random id breaks caller-side dedup.
func TestAddPeerFiltering_DeterministicID(t *testing.T) {
ip := net.ParseIP("10.0.0.5")
proto := fw.ProtocolUDP
port := &fw.Port{Values: []uint16{53}}
action := fw.ActionAccept
m1 := newTestManager(t)
r1, err := m1.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
m2 := newTestManager(t)
r2, err := m2.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
assert.Equal(t, r1.ID(), r2.ID(), "same inputs must produce the same rule id")
}
// TestAddPeerFiltering_DistinctRulesNotDeduped verifies that rules
// differing only by port are kept separate.
func TestAddPeerFiltering_DistinctRulesNotDeduped(t *testing.T) {
m := newTestManager(t)
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
action := fw.ActionAccept
r80, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, &fw.Port{Values: []uint16{80}}, action)
require.NoError(t, err)
r443, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, &fw.Port{Values: []uint16{443}}, action)
require.NoError(t, err)
assert.NotEqual(t, r80.ID(), r443.ID(), "different ports must produce different rule ids")
assert.Len(t, m.incomingAcceptRules, 2, "distinct rules must both be stored")
}
// TestAddPeerFiltering_SourceVsDestPortNotDeduped verifies that a rule
// matching on source port and one matching on destination port for the
// same selector do not collide: the port lands in a different slot, so
// the content key must differ.
func TestAddPeerFiltering_SourceVsDestPortNotDeduped(t *testing.T) {
m := newTestManager(t)
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
dPortRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
sPortRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, port, nil, action)
require.NoError(t, err)
assert.NotEqual(t, dPortRule.ID(), sPortRule.ID(), "source-port and dest-port matches must produce different rule ids")
}
// TestAddFilterRule_EmptySourcesRejected verifies that an empty source
// list is rejected rather than treated as "match any". "Match any" must
// be an explicit /0, so a zeroed list can never silently widen a rule to
// every source.
func TestAddFilterRule_EmptySourcesRejected(t *testing.T) {
m := newTestManager(t)
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
_, err := m.AddFilterRule(nil, nil, fw.Network{}, proto, nil, port, fw.ActionAccept)
require.ErrorIs(t, err, fw.ErrNoSources, "empty sources must be rejected")
assert.Empty(t, m.incomingAcceptRules, "no rule should be stored for empty sources")
}

View File

@@ -1,105 +0,0 @@
package uspfilter
import (
"net"
"net/netip"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
nbiface "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func newV6TestManager(t *testing.T, localV6 string) *Manager {
t.Helper()
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"),
IPv6: netip.MustParseAddr(localV6),
IPv6Net: netip.MustParsePrefix("fd00::/64"),
}
},
}
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err, "create manager")
t.Cleanup(func() { require.NoError(t, m.Close(nil)) })
return m
}
func v6UDPPacket(t *testing.T, src, dst string, dstPort uint16) []byte {
t.Helper()
ip6 := &layers.IPv6{
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
SrcIP: net.ParseIP(src),
DstIP: net.ParseIP(dst),
}
udp := &layers.UDP{SrcPort: 51334, DstPort: layers.UDPPort(dstPort)}
require.NoError(t, udp.SetNetworkLayerForChecksum(ip6))
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
require.NoError(t, gopacket.SerializeLayers(buf, opts, ip6, udp, gopacket.Payload("test")))
return buf.Bytes()
}
// TestPeerACL_IPv6HostRule verifies the source index resolves /128 v6
// rules: a matching v6 source is accepted, a non-matching one is
// denied by the default. This is the end-to-end proof that the index
// is not v4-only.
func TestPeerACL_IPv6HostRule(t *testing.T) {
m := newV6TestManager(t, "fd00::100")
src := net.ParseIP("fd00::1")
_, err := m.AddFilterRule(nil, pfx(src), fw.Network{}, fw.ProtocolUDP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionAccept)
require.NoError(t, err, "add v6 accept rule")
require.False(t, m.filterInbound(v6UDPPacket(t, "fd00::1", "fd00::100", 53), 0),
"v6 packet from the allowed /128 source must be accepted")
require.True(t, m.filterInbound(v6UDPPacket(t, "fd00::2", "fd00::100", 53), 0),
"v6 packet from an unlisted source must be denied by default")
}
// TestPeerACL_IPv6IndexBuckets verifies that v6 sources land in the
// right index bucket: a /128 in bySource keyed by its address, and
// coarser prefixes (including ::/0) in the nonHost slice.
func TestPeerACL_IPv6IndexBuckets(t *testing.T) {
m := newV6TestManager(t, "fd00::100")
port := &fw.Port{Values: []uint16{53}}
host := netip.MustParseAddr("fd00::1")
_, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(host, 128)}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
require.NoError(t, err)
assert.Contains(t, m.incomingAcceptIndex.bySource, host, "/128 v6 source must be indexed by address")
_, err = m.AddFilterRule(nil, []netip.Prefix{netip.MustParsePrefix("fd00:dead::/64")}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
require.NoError(t, err)
require.Len(t, m.incomingAcceptIndex.nonHost, 1, "coarser v6 prefix must land in nonHost")
_, err = m.AddFilterRule(nil, []netip.Prefix{netip.MustParsePrefix("::/0")}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
require.NoError(t, err)
require.Len(t, m.incomingAcceptIndex.nonHost, 2, "::/0 source must also land in nonHost")
}
// TestPeerACL_IPv4MappedSourceNormalized verifies a v4-mapped v6
// source prefix is normalized to v4 so a plain v4 packet matches it.
func TestPeerACL_IPv4MappedSourceNormalized(t *testing.T) {
m := newTestManager(t)
mapped := netip.MustParseAddr("::ffff:192.168.1.1")
_, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(mapped, mapped.BitLen())}, fw.Network{}, fw.ProtocolUDP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionAccept)
require.NoError(t, err)
v4 := netip.MustParseAddr("192.168.1.1")
assert.Contains(t, m.incomingAcceptIndex.bySource, v4, "v4-mapped v6 source must be indexed as plain v4")
}

View File

@@ -1,139 +0,0 @@
package uspfilter
import (
"net/netip"
"slices"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
// peerRuleIndex is the source-side dispatcher consulted on the packet
// hot path. It splits rules into two buckets by the shape of their
// source list:
//
// - bySource: every source is a host prefix (/32 for v4, /128 for
// v6). Keyed by the concrete source address, so a hit guarantees
// the source filter passes and the matcher goes straight to
// proto/port checks. This is the common case for peer ACLs.
// - nonHost: any source list with a prefix coarser than a host,
// including a /0 "match any". Walked linearly with a per-rule
// Contains() check. Expected small or empty for typical peer ACLs.
//
// Maintained incrementally by add/remove, never rebuilt.
type peerRuleIndex struct {
bySource map[netip.Addr][]*PeerRule
nonHost []*PeerRule
}
func (i *peerRuleIndex) add(r *PeerRule) {
if hasNonHostSource(r) {
i.nonHost = append(i.nonHost, r)
return
}
if i.bySource == nil {
i.bySource = make(map[netip.Addr][]*PeerRule)
}
for a := range r.sourceAddrs {
i.bySource[a] = append(i.bySource[a], r)
}
}
func (i *peerRuleIndex) remove(r *PeerRule) {
if hasNonHostSource(r) {
i.nonHost = slices.DeleteFunc(i.nonHost, eqRule(r))
return
}
if i.bySource == nil {
return
}
for a := range r.sourceAddrs {
entries := slices.DeleteFunc(i.bySource[a], eqRule(r))
if len(entries) == 0 {
delete(i.bySource, a)
} else {
i.bySource[a] = entries
}
}
}
func (i *peerRuleIndex) reset() {
i.bySource = nil
i.nonHost = i.nonHost[:0]
}
// match returns the first rule matching src and the decoded packet.
// Host rules are found by direct map lookup; nonHost rules need a
// per-rule source Contains() check, except match-any (/0) rules which
// apply to every source regardless of family (a v4 /0 also matches v6).
// Within either bucket the matcher runs the proto/port filter.
func (i *peerRuleIndex) match(src netip.Addr, d *decoder) ([]byte, bool, bool) {
payloadLayer := d.decoded[1]
for _, rule := range i.bySource[src] {
if id, drop, ok := matchProto(rule, d, payloadLayer); ok {
return id, drop, true
}
}
for _, rule := range i.nonHost {
if !rule.matchAny && !prefixesContain(rule.sources, src) {
continue
}
if id, drop, ok := matchProto(rule, d, payloadLayer); ok {
return id, drop, true
}
}
return nil, false, false
}
func eqRule(target *PeerRule) func(*PeerRule) bool {
return func(p *PeerRule) bool { return p == target }
}
// hasNonHostSource reports whether the rule has any source prefix
// that is not a single host address. Called only at add/remove time,
// not on the packet path.
func hasNonHostSource(r *PeerRule) bool {
for _, p := range r.sources {
if p.Bits() != p.Addr().BitLen() {
return true
}
}
return false
}
// matchProto applies the proto/port half of a rule against the
// decoded packet. Source matching is the caller's responsibility.
func matchProto(rule *PeerRule, d *decoder, payloadLayer gopacket.LayerType) ([]byte, bool, bool) {
drop := rule.action == firewall.ActionDrop
if rule.protoLayer == layerTypeAll {
return rule.mgmtId, drop, true
}
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
return nil, false, false
}
switch payloadLayer {
case layers.LayerTypeTCP:
if portsMatch(rule.srcPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dstPort, uint16(d.tcp.DstPort)) {
return rule.mgmtId, drop, true
}
case layers.LayerTypeUDP:
if portsMatch(rule.srcPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dstPort, uint16(d.udp.DstPort)) {
return rule.mgmtId, drop, true
}
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return rule.mgmtId, drop, true
}
return nil, false, false
}
func prefixesContain(sources []netip.Prefix, src netip.Addr) bool {
for _, p := range sources {
if p.Contains(src) {
return true
}
}
return false
}

View File

@@ -10,49 +10,24 @@ import (
// PeerRule to handle management of rules
type PeerRule struct {
id firewall.RuleID
mgmtId []byte
// sources is the canonical list of source prefixes this rule
// matches against.
sources []netip.Prefix
// sourceAddrs is a fast-path membership set for host-prefix
// sources (/32 v4, /128 v6). Populated alongside sources;
// consulted before falling back to prefix scan.
sourceAddrs map[netip.Addr]struct{}
// matchAny is true when sources covers everything (0.0.0.0/0,
// ::/0). In that case neither sourceAddrs nor sources need to be
// consulted.
matchAny bool
id string
mgmtId []byte
ip netip.Addr
ipLayer gopacket.LayerType
matchByIP bool
protoLayer gopacket.LayerType
srcPort *firewall.Port
dstPort *firewall.Port
action firewall.Action
}
// matchesSource reports whether the given source address is covered
// by this rule's source list.
func (r *PeerRule) matchesSource(src netip.Addr) bool {
if r.matchAny {
return true
}
if _, ok := r.sourceAddrs[src]; ok {
return true
}
for _, p := range r.sources {
if p.Contains(src) {
return true
}
}
return false
sPort *firewall.Port
dPort *firewall.Port
drop bool
}
// ID returns the rule id
func (r *PeerRule) ID() firewall.RuleID {
func (r *PeerRule) ID() string {
return r.id
}
type RouteRule struct {
id firewall.RuleID
id string
mgmtId []byte
sources []netip.Prefix
dstSet firewall.Set
@@ -64,6 +39,6 @@ type RouteRule struct {
}
// ID returns the rule id
func (r *RouteRule) ID() firewall.RuleID {
func (r *RouteRule) ID() string {
return r.id
}

View File

@@ -1,50 +0,0 @@
package uspfilter
import (
"net"
"net/netip"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
// countRulesForAddr reports how many rules in the given slice match
// the supplied source address.
func countRulesForAddr(rules peerRules, src netip.Addr) int {
n := 0
for _, r := range rules {
if r.matchesSource(src) {
n++
}
}
return n
}
// findRuleByID returns true if the rules slice contains a rule with
// the given id whose source set covers src.
func findRuleByID(rules peerRules, src netip.Addr, id firewall.RuleID) bool {
for _, r := range rules {
if r.id == id && r.matchesSource(src) {
return true
}
}
return false
}
// pfx converts a single net.IP into the []netip.Prefix form
// AddFilterRule expects. A nil or unspecified address becomes a /0
// ("match any") prefix in the matching family; any other address
// becomes its /32 (or /128) host prefix.
func pfx(ip net.IP) []netip.Prefix {
if ip == nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
if ip.IsUnspecified() {
if ip.To4() != nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
}
a, _ := netip.AddrFromSlice(ip)
a = a.Unmap()
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
}

View File

@@ -285,14 +285,6 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
trace.SourceIP = srcIP
trace.DestinationIP = dstIP
// A fragment or otherwise truncated packet has no transport layer.
// The inbound datapath drops these via isValidPacket; the tracer must
// guard explicitly since every downstream stage reads d.decoded[1].
if len(d.decoded) < 2 {
trace.AddResult(StageReceived, "Packet has no transport layer (fragment or unsupported protocol)", false)
return trace
}
// Determine protocol and ports
switch d.decoded[1] {
case layers.LayerTypeTCP:

View File

@@ -45,7 +45,7 @@ func TestTracePacket(t *testing.T) {
},
}
m, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
m, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
if !statefulMode {
@@ -97,7 +97,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -121,7 +121,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -150,7 +150,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -178,7 +178,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -205,7 +205,7 @@ func TestTracePacket(t *testing.T) {
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
_, err := m.AddFilterRule(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -231,7 +231,7 @@ func TestTracePacket(t *testing.T) {
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
_, err := m.AddFilterRule(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -332,7 +332,7 @@ func TestTracePacket(t *testing.T) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionAccept
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -355,7 +355,7 @@ func TestTracePacket(t *testing.T) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionDrop
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -379,7 +379,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolUDP
port := &fw.Port{Values: []uint16{53}}
action := fw.ActionAccept
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -423,7 +423,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {

View File

@@ -260,15 +260,23 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
; Create autostart registry entry based on checkbox
; Drop Run, App Paths and Uninstall entries left in the 32-bit registry view
; or HKCU by legacy installers.
DetailPrint "Cleaning legacy 32-bit / HKCU entries..."
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
SetRegView 32
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DeleteRegKey HKLM "${REG_APP_PATH}"
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
DeleteRegKey HKLM "${UNINSTALL_PATH}"
SetRegView 64
DetailPrint "Autostart enabled: $AutostartEnabled"
${If} $AutostartEnabled == "1"
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
${Else}
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DetailPrint "Autostart not enabled by user"
${EndIf}
@@ -299,11 +307,16 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
DetailPrint "Terminating Netbird UI process..."
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry
; Remove autostart entries from every view a previous installer may have used.
DetailPrint "Removing autostart registry entry if exists..."
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
SetRegView 32
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DeleteRegKey HKLM "${REG_APP_PATH}"
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
DeleteRegKey HKLM "${UNINSTALL_PATH}"
SetRegView 64
; Handle data deletion based on checkbox
DetailPrint "Checking if user requested data deletion..."

View File

@@ -1,190 +0,0 @@
package acl
import (
"net/netip"
"sync"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/firewall"
fwmgr "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// TestNetworkZeroPrefixIsRoute guards the route-vs-peer dispatch
// invariant: the backends classify a rule as a peer rule purely by the
// absence of a destination (neither prefix nor set). A default route
// (0.0.0.0/0 or ::/0) is a valid prefix and must therefore classify as
// a route, not collapse into the peer path.
func TestNetworkZeroPrefixIsRoute(t *testing.T) {
for _, p := range []string{"0.0.0.0/0", "::/0", "10.0.0.0/8"} {
n := fwmgr.Network{Prefix: netip.MustParsePrefix(p)}
assert.True(t, n.IsPrefix(), "%s must report IsPrefix", p)
assert.True(t, n.IsPrefix() || n.IsSet(), "%s must classify as a route", p)
}
// A zero-value Network is the only peer-rule shape.
var empty fwmgr.Network
assert.False(t, empty.IsPrefix(), "zero Network must not be a prefix")
assert.False(t, empty.IsSet(), "zero Network must not be a set")
}
// TestDetermineDestinationAlwaysRoute verifies determineDestination
// never yields an empty Network for a valid route rule: every branch
// (static prefix, default route, dynamic with/without domains, with and
// without a local resolver) produces a destination that classifies as a
// route. If this regresses, a route rule would be dispatched down the
// peer path, which matches on source only.
func TestDetermineDestinationAlwaysRoute(t *testing.T) {
v4 := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}
v6 := []netip.Prefix{netip.MustParsePrefix("2001:db8::/48")}
cases := []struct {
name string
rule *mgmProto.RouteFirewallRule
resolver bool
sources []netip.Prefix
}{
{"static prefix", &mgmProto.RouteFirewallRule{Destination: "192.168.0.0/16"}, false, v4},
{"static default route", &mgmProto.RouteFirewallRule{Destination: "0.0.0.0/0"}, false, v4},
{"dynamic with domains + resolver", &mgmProto.RouteFirewallRule{IsDynamic: true, Domains: []string{"example.com"}}, true, v4},
{"dynamic no domains + resolver (v4)", &mgmProto.RouteFirewallRule{IsDynamic: true}, true, v4},
{"dynamic no domains + resolver (v6)", &mgmProto.RouteFirewallRule{IsDynamic: true}, true, v6},
{"dynamic + no local resolver (v4)", &mgmProto.RouteFirewallRule{IsDynamic: true}, false, v4},
{"dynamic + no local resolver (v6)", &mgmProto.RouteFirewallRule{IsDynamic: true}, false, v6},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
dest, err := determineDestination(tc.rule, tc.resolver, tc.sources)
require.NoError(t, err)
assert.True(t, dest.IsPrefix() || dest.IsSet(),
"destination must classify as a route, got empty Network")
})
}
}
// countingFirewall wraps a real firewall.Manager and counts filter-rule
// add/delete calls so a test can assert how many backing rules the acl
// manager actually creates and tears down.
type countingFirewall struct {
fwmgr.Manager
mu sync.Mutex
addCalls int
dels int
ruleIDs map[fwmgr.RuleID]struct{}
}
// distinctRules returns the number of distinct backing rules the
// backend produced. Because the backend dedups identical content,
// repeated AddFilterRule calls for the same rule resolve to one id.
func (f *countingFirewall) distinctRules() int {
f.mu.Lock()
defer f.mu.Unlock()
return len(f.ruleIDs)
}
func (f *countingFirewall) AddFilterRule(id []byte, sources []netip.Prefix, destination fwmgr.Network, proto fwmgr.Protocol, sPort, dPort *fwmgr.Port, action fwmgr.Action) (fwmgr.Rule, error) {
rule, err := f.Manager.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
if err == nil {
f.mu.Lock()
f.addCalls++
if f.ruleIDs == nil {
f.ruleIDs = make(map[fwmgr.RuleID]struct{})
}
if rule != nil {
f.ruleIDs[rule.ID()] = struct{}{}
}
f.mu.Unlock()
}
return rule, err
}
func (f *countingFirewall) DeleteFilterRule(r fwmgr.Rule) error {
err := f.Manager.DeleteFilterRule(r)
if err == nil {
f.mu.Lock()
f.dels++
delete(f.ruleIDs, r.ID())
f.mu.Unlock()
}
return err
}
func newCountingACL(t *testing.T) (*DefaultManager, *countingFirewall, func()) {
t.Helper()
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{IP: network.Addr(), Network: network}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
realFW, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
fw := &countingFirewall{Manager: realFW}
cleanup := func() {
require.NoError(t, realFW.Close(nil))
ctrl.Finish()
}
return NewDefaultManager(fw), fw, cleanup
}
// TestDuplicateContentPoliciesShareOneRule verifies the dedup contract
// the backends rely on: two policies that authorize an identical flow
// (same selector and sources) collapse to a single backing firewall
// rule, and that rule survives until BOTH policies are gone. This is
// why the backend can dedup on add without refcounting on delete: the
// acl manager's pair key matches the backend's content key, so add and
// delete stay balanced per content key across full-state reapplies.
func TestDuplicateContentPoliciesShareOneRule(t *testing.T) {
acl, fw, cleanup := newCountingACL(t)
defer cleanup()
ruleA := &mgmProto.FirewallRule{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
}
ruleB := &mgmProto.FirewallRule{
PolicyID: []byte("policy-B"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
}
// Both policies present: identical content collapses to one rule.
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: []*mgmProto.FirewallRule{ruleA, ruleB}, FirewallRulesIsEmpty: false}, false)
assert.Equal(t, 1, fw.distinctRules(), "identical-content policies must produce one backing rule")
assert.Equal(t, 1, len(acl.peerRulesPairs), "one content key, one pair")
// Drop policy A only: the shared rule is still authorized by B, so
// nothing is deleted.
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: []*mgmProto.FirewallRule{ruleB}, FirewallRulesIsEmpty: false}, false)
assert.Equal(t, 1, fw.distinctRules(), "no new backing rule on reapply")
assert.Equal(t, 0, fw.dels, "rule must survive while any policy still authorizes it")
assert.Equal(t, 1, len(acl.peerRulesPairs))
// Drop policy B too: now the content key has no authorizer and the
// single backing rule is removed exactly once.
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: nil, FirewallRulesIsEmpty: true}, false)
assert.Equal(t, 1, fw.dels, "rule removed once when last policy is gone")
assert.Equal(t, 0, len(acl.peerRulesPairs))
}

View File

@@ -1,318 +0,0 @@
package acl
import (
"errors"
"net/netip"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/firewall"
fwmgr "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
)
// TestGroupPeerRulesPolicyIDSeparates verifies that two FirewallRules
// with identical selectors but different PolicyIDs do NOT get merged
// into one group, so each policy's sources merge under its own
// attribution id. (Identical-content groups may still dedup to one
// backing rule at the backend; see TestDuplicateContentPoliciesShareOneRule.)
func TestGroupPeerRulesPolicyIDSeparates(t *testing.T) {
rules := []*mgmProto.FirewallRule{
{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PolicyID: []byte("policy-B"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
}
groups, denyErr, err := groupPeerRules(rules)
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 2, "rules with different PolicyIDs must produce separate groups")
}
// TestGroupPeerRulesFamilySeparates verifies that v4 and v6 rules
// belonging to the same policy don't merge.
func TestGroupPeerRulesFamilySeparates(t *testing.T) {
rules := []*mgmProto.FirewallRule{
{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PolicyID: []byte("policy-A"),
PeerIP: "2001:db8::1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
}
groups, denyErr, err := groupPeerRules(rules)
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 2, "rules of different families must produce separate groups")
var sawV4, sawV6 bool
for _, g := range groups {
require.Len(t, g.sources, 1)
if g.sources[0].Addr().Is4() {
sawV4 = true
}
if g.sources[0].Addr().Is6() {
sawV6 = true
}
}
assert.True(t, sawV4 && sawV6)
}
// TestGroupPeerRulesSplitsMixedFamilySingleRule verifies that a single
// FirewallRule carrying both v4 and v6 source prefixes is split into one
// group per family. Each backend keys a rule to a single family, so a
// group whose sources span families would mismatch the other family's
// sources. mgmt normally emits one rule per family; this guards against
// a mixed-family rule slipping through.
func TestGroupPeerRulesSplitsMixedFamilySingleRule(t *testing.T) {
srcs := [][]byte{
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.1")),
netiputil.EncodeAddr(netip.MustParseAddr("2001:db8::1")),
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.2")),
netiputil.EncodeAddr(netip.MustParseAddr("2001:db8::2")),
}
rules := []*mgmProto.FirewallRule{
{
PolicyID: []byte("policy-A"),
SourcePrefixes: srcs,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
}
groups, denyErr, err := groupPeerRules(rules)
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 2, "mixed-family sources in one rule must split into two groups")
for _, g := range groups {
require.Len(t, g.sources, 2)
v6 := prefixIsV6(g.sources[0])
for _, s := range g.sources {
assert.Equal(t, v6, prefixIsV6(s), "every source in a group must share one family")
}
}
}
// TestGroupPeerRulesMergesSameSelector verifies that rules sharing
// every distinguishing field (policy, family, direction, action,
// proto, port) collapse into a single multi-source group.
func TestGroupPeerRulesMergesSameSelector(t *testing.T) {
mk := func(peerIP string) *mgmProto.FirewallRule {
return &mgmProto.FirewallRule{
PolicyID: []byte("policy-A"),
PeerIP: peerIP,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
}
}
rules := []*mgmProto.FirewallRule{mk("10.0.0.1"), mk("10.0.0.2"), mk("10.0.0.3")}
groups, denyErr, err := groupPeerRules(rules)
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 1)
require.Len(t, groups[0].sources, 3)
}
// TestGroupPeerRulesPortSeparates verifies that PortInfo is part of the
// selector key: rules differing only in port must not merge, and a
// single port must not merge with a range. A regression dropping the
// port from the key would collapse rules for different ports into one.
func TestGroupPeerRulesPortSeparates(t *testing.T) {
mkPort := func(peerIP string, port uint32) *mgmProto.FirewallRule {
return &mgmProto.FirewallRule{
PolicyID: []byte("policy-A"),
PeerIP: peerIP,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{PortSelection: &mgmProto.PortInfo_Port{Port: port}},
}
}
groups, denyErr, err := groupPeerRules([]*mgmProto.FirewallRule{
mkPort("10.0.0.1", 80), mkPort("10.0.0.2", 80), mkPort("10.0.0.3", 443),
})
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 2, "rules on different ports must not merge")
rangeRule := &mgmProto.FirewallRule{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{PortSelection: &mgmProto.PortInfo_Range_{Range: &mgmProto.PortInfo_Range{Start: 80, End: 90}}},
}
groups, denyErr, err = groupPeerRules([]*mgmProto.FirewallRule{mkPort("10.0.0.1", 80), rangeRule})
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 2, "a single port and a range must not merge")
}
// TestGroupPeerRulesUsesSourcePrefixesWhenPresent verifies that the
// new sourcePrefixes wire field is consumed and produces a
// multi-source group in one shot (no client-side merging needed).
func TestGroupPeerRulesUsesSourcePrefixesWhenPresent(t *testing.T) {
srcs := [][]byte{
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.1")),
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.2")),
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.3")),
}
rules := []*mgmProto.FirewallRule{
{
PolicyID: []byte("policy-A"),
SourcePrefixes: srcs,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
}
groups, denyErr, err := groupPeerRules(rules)
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 1)
require.Len(t, groups[0].sources, 3)
}
// TestGroupPeerRulesActionSeparates verifies the obvious: accept
// and drop rules with the same selector don't merge.
func TestGroupPeerRulesActionSeparates(t *testing.T) {
rules := []*mgmProto.FirewallRule{
{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
}
groups, denyErr, err := groupPeerRules(rules)
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 2)
}
// failingDeleteFirewall wraps a real firewall.Manager and forces the
// next N DeleteFilterRule calls to fail. Used to verify that the acl
// manager retains rules whose deletion was rejected by the backend,
// so they get retried on the next ApplyFiltering pass instead of
// becoming orphans.
type failingDeleteFirewall struct {
fwmgr.Manager
failCount int
}
func (f *failingDeleteFirewall) DeleteFilterRule(r fwmgr.Rule) error {
if f.failCount > 0 {
f.failCount--
return errors.New("simulated delete failure")
}
return f.Manager.DeleteFilterRule(r)
}
// TestApplyFilteringRetainsRulesOnDeleteFailure verifies that a
// transient DeleteFilterRule error doesn't make the acl manager
// forget about a rule. The rule must remain in peerRulesPairs so the
// next ApplyFiltering pass attempts the delete again.
func TestApplyFilteringRetainsRulesOnDeleteFailure(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{IP: network.Addr(), Network: network}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
realFW, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() { require.NoError(t, realFW.Close(nil)) }()
fw := &failingDeleteFirewall{Manager: realFW}
acl := NewDefaultManager(fw)
// First pass: install a rule.
netmap1 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(netmap1, false)
require.Equal(t, 1, len(acl.peerRulesPairs), "rule should be installed")
// Second pass: remove the rule from the map. The backend will
// fail the delete; the acl manager must retain the rule.
fw.failCount = 1
netmap2 := &mgmProto.NetworkMap{FirewallRules: nil, FirewallRulesIsEmpty: true}
acl.ApplyFiltering(netmap2, false)
require.Equal(t, 1, len(acl.peerRulesPairs),
"rule must be retained when DeleteFilterRule fails so it gets retried")
// Third pass: same map, backend no longer fails. The rule
// should now succeed in being removed.
acl.ApplyFiltering(netmap2, false)
require.Equal(t, 0, len(acl.peerRulesPairs), "retry should succeed")
}

View File

@@ -5,18 +5,18 @@ import (
"encoding/hex"
"fmt"
"net/netip"
"slices"
"strconv"
"github.com/netbirdio/netbird/client/firewall/manager"
)
// RuleID aliases manager.RuleID so existing nbid.RuleID references
// keep working while the canonical type lives in the firewall package.
type RuleID = manager.RuleID
type RuleID string
// GenerateRuleID returns a deterministic content hash identifying a filter rule.
func GenerateRuleID(
func (r RuleID) ID() string {
return string(r)
}
func GenerateRouteRuleKey(
sources []netip.Prefix,
destination manager.Network,
proto manager.Protocol,
@@ -24,7 +24,6 @@ func GenerateRuleID(
dPort *manager.Port,
action manager.Action,
) RuleID {
sources = slices.Clone(sources)
manager.SortPrefixes(sources)
h := sha256.New()

View File

@@ -1,6 +1,8 @@
package acl
import (
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"net/netip"
@@ -21,10 +23,6 @@ import (
var ErrSourceRangesEmpty = errors.New("sources range is empty")
// ErrNoRuleReturned is returned when the firewall backend reports success
// from AddFilterRule but yields no rule to track.
var ErrNoRuleReturned = errors.New("backend returned no rule")
// Manager is a ACL rules manager
type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
@@ -33,46 +31,17 @@ type Manager interface {
// DefaultManager uses firewall manager to handle
type DefaultManager struct {
firewall firewall.Manager
ipsetCounter int
peerRulesPairs map[id.RuleID][]firewall.Rule
routeRules map[id.RuleID]firewall.Rule
routeRules map[id.RuleID]struct{}
mutex sync.Mutex
}
// peerRuleGroup collapses a set of single-source FirewallRules sharing
// the same selector into one multi-source rule to push to the backend.
type peerRuleGroup struct {
direction mgmProto.RuleDirection
action mgmProto.RuleAction
protocol mgmProto.RuleProtocol
port *mgmProto.PortInfo
// legacyPort is used only when PortInfo is empty (old management).
legacyPort string
policyID []byte
sources []netip.Prefix
}
// peerRuleKey is the comparable selector that decides which single-source
// rules merge into one group. Rules with an equal key collapse into one
// multi-source backend rule. PortInfo is flattened into its scalar fields
// so the key compares by value; policyID keeps policies separate so two
// policies authorizing different peers don't merge under one attribution.
type peerRuleKey struct {
v6 bool
policyID string
direction mgmProto.RuleDirection
action mgmProto.RuleAction
protocol mgmProto.RuleProtocol
legacyPort string
port uint16
rangeStart uint16
rangeEnd uint16
}
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
return &DefaultManager{
firewall: fm,
peerRulesPairs: make(map[id.RuleID][]firewall.Rule),
routeRules: make(map[id.RuleID]firewall.Rule),
routeRules: make(map[id.RuleID]struct{}),
}
}
@@ -99,12 +68,10 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
time.Since(start), total)
}()
if err := d.applyPeerACLs(networkMap); err != nil {
log.Errorf("apply peer ACLs: %v", err)
}
d.applyPeerACLs(networkMap)
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
log.Errorf("apply route ACLs: %v", err)
log.Errorf("Failed to apply route ACLs: %v", err)
}
if err := d.firewall.Flush(); err != nil {
@@ -112,7 +79,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
}
}
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) error {
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
rules := networkMap.FirewallRules
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
@@ -135,158 +102,59 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) error {
)
}
// Group incoming single-source rules from management by their
// (direction, action, proto, port) selector and merge sources.
// One call to the firewall backend per merged rule.
// A deny we cannot decode would leave its traffic unblocked, so skip
// the whole pass and keep existing rules until the next sync.
groups, denyErr, err := groupPeerRules(rules)
if denyErr != nil {
return fmt.Errorf("decode deny rule sources: %w", denyErr)
}
newRulePairs := make(map[id.RuleID][]firewall.Rule)
ipsetByRuleSelectors := make(map[string]string)
// TODO: deny rules should be fatal: if a deny rule fails to apply, we must
// roll back all allow rules to avoid a fail-open where allowed traffic bypasses
// the missing deny. Currently we accumulate errors and continue.
var merr *multierror.Error
if err != nil {
merr = multierror.Append(merr, err)
}
// Apply denies first. A deny that fails to install is a security
// failure (fail-open), so if any deny errors we roll back the
// denies we already installed in this pass and bail out without
// installing any accept. Pre-existing rules stay untouched until
// the next successful pass clears them.
denies, accepts := splitDenyAccept(groups)
if err := d.installPeerGroups(denies, newRulePairs, true); err != nil {
return fmt.Errorf("install deny rules: %w", err)
}
if err := d.installPeerGroups(accepts, newRulePairs, false); err != nil {
merr = multierror.Append(merr, err)
}
// Tear down rules that disappeared from the networkmap. Any rule
// the backend refuses to delete stays in our tracking so it gets
// retried on the next ApplyFiltering. Otherwise a transient
// delete failure would leak the rule in the firewall until the
// process exits.
for pairID, rules := range d.peerRulesPairs {
if _, ok := newRulePairs[pairID]; ok {
continue
for _, r := range rules {
// if this rule is member of rule selection with more than DefaultIPsCountForSet
// it's IP address can be used in the ipset for firewall manager which supports it
selector := d.getRuleGroupingSelector(r)
ipsetName, ok := ipsetByRuleSelectors[selector]
if !ok {
d.ipsetCounter++
ipsetName = fmt.Sprintf("nb%07d", d.ipsetCounter)
ipsetByRuleSelectors[selector] = ipsetName
}
var remaining []firewall.Rule
for _, rule := range rules {
if err := d.firewall.DeleteFilterRule(rule); err != nil {
log.Errorf("failed to delete peer firewall rule, will retry: %v", err)
remaining = append(remaining, rule)
}
}
if len(remaining) > 0 {
newRulePairs[pairID] = remaining
}
}
d.peerRulesPairs = newRulePairs
return nberrors.FormatErrorOrNil(merr)
}
// installPeerGroups applies each group and records the resulting rule
// pairs in newRulePairs. With atomic set (deny rules), a single failure
// rolls back every rule installed in this call and returns, leaving the
// firewall exactly as before: denies are fail-closed and must be applied
// all-or-nothing. With atomic unset (accept rules), failures are
// accumulated and the remaining groups still install, so one malformed
// allow cannot drop every other legitimate allow in the pass.
func (d *DefaultManager) installPeerGroups(groups []*peerRuleGroup, newRulePairs map[id.RuleID][]firewall.Rule, atomic bool) error {
var freshlyInstalled []id.RuleID
var merr *multierror.Error
for _, g := range groups {
pairID, rulePair, err := d.applyPeerGroup(g)
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
if err != nil {
if atomic {
d.rollbackInstalled(freshlyInstalled)
return fmt.Errorf("apply firewall rule: %w", err)
}
merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
continue
}
if len(rulePair) == 0 {
continue
if len(rulePair) > 0 {
d.peerRulesPairs[pairID] = rulePair
newRulePairs[pairID] = rulePair
}
if _, existed := d.peerRulesPairs[pairID]; !existed {
freshlyInstalled = append(freshlyInstalled, pairID)
}
d.peerRulesPairs[pairID] = rulePair
newRulePairs[pairID] = rulePair
}
return nberrors.FormatErrorOrNil(merr)
}
func (d *DefaultManager) rollbackInstalled(pairIDs []id.RuleID) {
var merr *multierror.Error
for _, pairID := range pairIDs {
for _, rule := range d.peerRulesPairs[pairID] {
if err := d.firewall.DeleteFilterRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("rule %s: %w", pairID, err))
if merr != nil {
log.Errorf("failed to apply %d peer ACL rule(s): %v", merr.Len(), nberrors.FormatErrorOrNil(merr))
}
for pairID, rules := range d.peerRulesPairs {
if _, ok := newRulePairs[pairID]; !ok {
for _, rule := range rules {
if err := d.firewall.DeletePeerRule(rule); err != nil {
log.Errorf("failed to delete peer firewall rule: %v", err)
continue
}
}
delete(d.peerRulesPairs, pairID)
}
delete(d.peerRulesPairs, pairID)
}
if err := nberrors.FormatErrorOrNil(merr); err != nil {
log.Errorf("rollback peer rules: %v", err)
}
}
func (d *DefaultManager) applyPeerGroup(g *peerRuleGroup) (id.RuleID, []firewall.Rule, error) {
protocol, err := ConvertToFirewallProtocol(g.protocol)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %w", err)
}
action, err := convertFirewallAction(g.action)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %w", err)
}
port, err := resolveGroupPort(g)
if err != nil {
return "", nil, err
}
var fwRule firewall.Rule
switch g.direction {
case mgmProto.RuleDirection_IN:
fwRule, err = d.firewall.AddFilterRule(g.policyID, g.sources, firewall.Network{}, protocol, nil, port, action)
case mgmProto.RuleDirection_OUT:
if d.firewall.IsStateful() {
return "", nil, nil
}
if shouldSkipInvertedRule(protocol, port) {
return "", nil, nil
}
fwRule, err = d.firewall.AddFilterRule(g.policyID, g.sources, firewall.Network{}, protocol, port, nil, action)
default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
}
if err != nil {
return "", nil, fmt.Errorf("add firewall rule: %w", err)
}
if fwRule == nil {
return "", nil, nil
}
// Derive the pair id from the backend rule, like the route path:
// the backend dedups identical content, so two policies authorizing
// the same flow resolve to the same id and a single backing rule.
return fwRule.ID(), []firewall.Rule{fwRule}, nil
d.peerRulesPairs = newRulePairs
}
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error {
newRouteRules := make(map[id.RuleID]firewall.Rule, len(rules))
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
var merr *multierror.Error
// Apply new rules - firewall manager will return the existing rule if already present
// Apply new rules - firewall manager will return existing rule ID if already present
for _, rule := range rules {
addedRule, err := d.applyRouteACL(rule, dynamicResolver)
id, err := d.applyRouteACL(rule, dynamicResolver)
if err != nil {
if errors.Is(err, ErrSourceRangesEmpty) {
log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err)
@@ -295,18 +163,16 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dyn
}
continue
}
newRouteRules[addedRule.ID()] = addedRule
newRouteRules[id] = struct{}{}
}
// Tear down old route rules; retain ones the backend refused so a
// transient failure doesn't leave orphaned rules in the firewall.
for ruleID, rule := range d.routeRules {
if _, exists := newRouteRules[ruleID]; exists {
continue
}
if err := d.firewall.DeleteFilterRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete route rule, will retry: %w", err))
newRouteRules[ruleID] = rule
// Clean up old firewall rules
for id := range d.routeRules {
if _, exists := newRouteRules[id]; !exists {
if err := d.firewall.DeleteRouteRule(id); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete route rule: %w", err))
}
// implicitly deleted from the map
}
}
@@ -314,196 +180,102 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dyn
return nberrors.FormatErrorOrNil(merr)
}
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (firewall.Rule, error) {
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (id.RuleID, error) {
if len(rule.SourceRanges) == 0 {
return nil, ErrSourceRangesEmpty
return "", ErrSourceRangesEmpty
}
var sources []netip.Prefix
for _, sourceRange := range rule.SourceRanges {
source, err := netip.ParsePrefix(sourceRange)
if err != nil {
return nil, fmt.Errorf("parse source range: %w", err)
return "", fmt.Errorf("parse source range: %w", err)
}
sources = append(sources, firewall.UnmapPrefix(source))
sources = append(sources, source)
}
destination, err := determineDestination(rule, dynamicResolver, sources)
if err != nil {
return nil, fmt.Errorf("determine destination: %w", err)
return "", fmt.Errorf("determine destination: %w", err)
}
protocol, err := ConvertToFirewallProtocol(rule.Protocol)
protocol, err := convertToFirewallProtocol(rule.Protocol)
if err != nil {
return nil, fmt.Errorf("invalid protocol: %w", err)
return "", fmt.Errorf("invalid protocol: %w", err)
}
action, err := convertFirewallAction(rule.Action)
if err != nil {
return nil, fmt.Errorf("invalid action: %w", err)
return "", fmt.Errorf("invalid action: %w", err)
}
dPorts := convertPortInfo(rule.PortInfo)
addedRule, err := d.firewall.AddFilterRule(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
addedRule, err := d.firewall.AddRouteFiltering(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
if err != nil {
return nil, fmt.Errorf("add route rule: %w", err)
}
if addedRule == nil {
return nil, fmt.Errorf("add route rule: %w", ErrNoRuleReturned)
return "", fmt.Errorf("add route rule: %w", err)
}
return addedRule, nil
return id.RuleID(addedRule.ID()), nil
}
// splitDenyAccept partitions groups by action so denies can be
// applied before accepts. Order within each bucket is preserved.
func splitDenyAccept(groups []*peerRuleGroup) (denies, accepts []*peerRuleGroup) {
for _, g := range groups {
if g.action == mgmProto.RuleAction_DROP {
denies = append(denies, g)
} else {
accepts = append(accepts, g)
}
}
return denies, accepts
}
// groupPeerRules merges single-source rules sharing a selector into
// multi-source groups. It splits source-decode failures by action:
// denyErr is non-nil when a deny rule could not be decoded, which is a
// fail-open risk the caller must treat as fatal for the pass; err
// carries the tolerable accept-rule failures the caller can log and
// continue past.
func groupPeerRules(rules []*mgmProto.FirewallRule) (groups []*peerRuleGroup, denyErr error, err error) {
var denyMerr, acceptMerr *multierror.Error
byKey := make(map[peerRuleKey]*peerRuleGroup)
order := make([]peerRuleKey, 0)
for _, r := range rules {
srcs, decErr := extractRuleSources(r)
if decErr != nil {
if r.Action == mgmProto.RuleAction_DROP {
denyMerr = multierror.Append(denyMerr, decErr)
} else {
acceptMerr = multierror.Append(acceptMerr, decErr)
}
continue
}
// A single FirewallRule normally carries one address family, but
// split by family defensively: each backend keys a rule to one
// family and would mismatch sources of the other, so a group's
// sources must never span families.
v4, v6 := splitPrefixesByFamily(srcs)
for _, sub := range []struct {
isV6 bool
sources []netip.Prefix
}{{false, v4}, {true, v6}} {
if len(sub.sources) == 0 {
continue
}
key := ruleGroupKey(r, sub.isV6)
g, ok := byKey[key]
if !ok {
g = &peerRuleGroup{
direction: r.Direction,
action: r.Action,
protocol: r.Protocol,
port: r.PortInfo,
legacyPort: r.Port,
policyID: r.PolicyID,
}
byKey[key] = g
order = append(order, key)
}
g.sources = append(g.sources, sub.sources...)
}
}
out := make([]*peerRuleGroup, 0, len(order))
for _, k := range order {
out = append(out, byKey[k])
}
return out, nberrors.FormatErrorOrNil(denyMerr), nberrors.FormatErrorOrNil(acceptMerr)
}
func prefixIsV6(p netip.Prefix) bool {
return p.Addr().Is6() && !p.Addr().Is4In6()
}
// splitPrefixesByFamily partitions prefixes into IPv4 and IPv6 groups.
func splitPrefixesByFamily(prefixes []netip.Prefix) (v4, v6 []netip.Prefix) {
for _, p := range prefixes {
if prefixIsV6(p) {
v6 = append(v6, p)
} else {
v4 = append(v4, p)
}
}
return v4, v6
}
// ruleGroupKey builds the selector key for a rule. v6 must reflect the
// rule's source family: mgmt emits one rule per family and mixing them
// would break ICMP-variant selection in uspfilter.
func ruleGroupKey(r *mgmProto.FirewallRule, v6 bool) peerRuleKey {
k := peerRuleKey{
v6: v6,
policyID: string(r.PolicyID),
direction: r.Direction,
action: r.Action,
protocol: r.Protocol,
legacyPort: r.Port,
}
if pi := r.PortInfo; pi != nil {
k.port = uint16(pi.GetPort())
if rng := pi.GetRange(); rng != nil {
k.rangeStart = uint16(rng.GetStart())
k.rangeEnd = uint16(rng.GetEnd())
}
}
return k
}
// extractRuleSources returns all source prefixes the rule applies to.
// New management populates sourcePrefixes; older management sets PeerIP.
func extractRuleSources(r *mgmProto.FirewallRule) ([]netip.Prefix, error) {
if len(r.SourcePrefixes) > 0 {
out := make([]netip.Prefix, 0, len(r.SourcePrefixes))
for _, raw := range r.SourcePrefixes {
addr, err := netiputil.DecodeAddr(raw)
if err != nil {
return nil, fmt.Errorf("decode source prefix: %w", err)
}
out = append(out, netip.PrefixFrom(addr.Unmap(), addr.Unmap().BitLen()))
}
return out, nil
}
//nolint:staticcheck // PeerIP used for backward compatibility with old management
addr, err := netip.ParseAddr(r.PeerIP)
func (d *DefaultManager) protoRuleToFirewallRule(
r *mgmProto.FirewallRule,
ipsetName string,
) (id.RuleID, []firewall.Rule, error) {
ip, err := extractRuleIP(r)
if err != nil {
return nil, fmt.Errorf("invalid IP address, skipping firewall rule")
return "", nil, err
}
addr = addr.Unmap()
return []netip.Prefix{netip.PrefixFrom(addr, addr.BitLen())}, nil
}
func resolveGroupPort(g *peerRuleGroup) (*firewall.Port, error) {
if !portInfoEmpty(g.port) {
return convertPortInfo(g.port), nil
protocol, err := convertToFirewallProtocol(r.Protocol)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
}
if g.legacyPort != "" {
value, err := strconv.Atoi(g.legacyPort)
action, err := convertFirewallAction(r.Action)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
}
var port *firewall.Port
if !portInfoEmpty(r.PortInfo) {
port = convertPortInfo(r.PortInfo)
} else if r.Port != "" {
// old version of management, single port
value, err := strconv.Atoi(r.Port)
if err != nil {
return nil, fmt.Errorf("invalid port: %w", err)
return "", nil, fmt.Errorf("invalid port: %w", err)
}
return &firewall.Port{
port = &firewall.Port{
Values: []uint16{uint16(value)},
}, nil
}
}
// nolint:nilnil // a nil port legitimately means "no port restriction"
return nil, nil
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action)
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
return ruleID, rulesPair, nil
}
var rules []firewall.Rule
switch r.Direction {
case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
case mgmProto.RuleDirection_OUT:
if d.firewall.IsStateful() {
return "", nil, nil
}
// return traffic for outbound connections if firewall is stateless
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
}
if err != nil {
return "", nil, err
}
return ruleID, rules, nil
}
func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
@@ -522,9 +294,85 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
}
}
// ConvertToFirewallProtocol maps a management rule protocol to the
// firewall protocol type.
func ConvertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
func (d *DefaultManager) addInRules(
id []byte,
ip netip.Addr,
protocol firewall.Protocol,
port *firewall.Port,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, nil, port, action, ipsetName)
if err != nil {
return nil, fmt.Errorf("add firewall rule: %w", err)
}
return rule, nil
}
func (d *DefaultManager) addOutRules(
id []byte,
ip netip.Addr,
protocol firewall.Protocol,
port *firewall.Port,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
if shouldSkipInvertedRule(protocol, port) {
return nil, nil
}
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, port, nil, action, ipsetName)
if err != nil {
return nil, fmt.Errorf("add firewall rule: %w", err)
}
return rule, nil
}
// getPeerRuleID returns unique ID for the rule based on its parameters.
func (d *DefaultManager) getPeerRuleID(
ip netip.Addr,
proto firewall.Protocol,
direction int,
port *firewall.Port,
action firewall.Action,
) id.RuleID {
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action))
if port != nil {
idStr += port.String()
}
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
}
// getRuleGroupingSelector takes all rule properties except IP address to build selector
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
}
// extractRuleIP extracts the peer IP from a firewall rule.
// If sourcePrefixes is populated (new management), decode the first entry and use its address.
// Otherwise fall back to the deprecated PeerIP string field (old management).
func extractRuleIP(r *mgmProto.FirewallRule) (netip.Addr, error) {
if len(r.SourcePrefixes) > 0 {
addr, err := netiputil.DecodeAddr(r.SourcePrefixes[0])
if err != nil {
return netip.Addr{}, fmt.Errorf("decode source prefix: %w", err)
}
return addr.Unmap(), nil
}
//nolint:staticcheck // PeerIP used for backward compatibility with old management
addr, err := netip.ParseAddr(r.PeerIP)
if err != nil {
return netip.Addr{}, fmt.Errorf("invalid IP address, skipping firewall rule")
}
return addr.Unmap(), nil
}
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
switch protocol {
case mgmProto.RuleProtocol_TCP:
return firewall.ProtocolTCP, nil

View File

@@ -9,7 +9,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/firewall"
fwmanager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
@@ -77,9 +76,9 @@ func TestDefaultManager(t *testing.T) {
})
t.Run("add extra rules", func(t *testing.T) {
existedPairs := map[fwmanager.RuleID]struct{}{}
existedPairs := map[string]struct{}{}
for id := range acl.peerRulesPairs {
existedPairs[id] = struct{}{}
existedPairs[id.ID()] = struct{}{}
}
// remove first rule
@@ -106,7 +105,7 @@ func TestDefaultManager(t *testing.T) {
// check that old rule was removed
previousCount := 0
for id := range acl.peerRulesPairs {
if _, ok := existedPairs[id]; ok {
if _, ok := existedPairs[id.ID()]; ok {
previousCount++
}
}

View File

@@ -360,13 +360,7 @@ func isRedirectURLPortUsed(redirectURL string, excludedRanges []excludedPortRang
return true
}
// FreeBSD 15 disables connecting to INADDR_ANY (0.0.0.0) as a localhost
// alias by default, ensure explicit ip for localhost.
host := parsedURL.Hostname()
if host == "" {
host = "127.0.0.1"
}
addr := net.JoinHostPort(host, port)
addr := fmt.Sprintf(":%s", port)
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
if err != nil {
return false

View File

@@ -339,7 +339,8 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
case entry.Pattern == ".":
return true
case entry.IsWildcard:
return strings.HasSuffix(qname, "."+entry.Pattern)
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
default:
// For non-wildcard patterns:
// If handler wants subdomain matching, allow suffix match

View File

@@ -164,54 +164,6 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
matchSubdomains: true,
shouldMatch: true,
},
{
name: "wildcard label-boundary mismatch (suffix overlap)",
handlerDomain: "*.b.test.",
queryDomain: "x.ab.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "wildcard label-boundary match",
handlerDomain: "*.b.test.",
queryDomain: "x.b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "wildcard multi-label match",
handlerDomain: "*.b.test.",
queryDomain: "x.y.b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "wildcard no match on multi-label apex",
handlerDomain: "*.b.test.",
queryDomain: "b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "wildcard no match on unrelated suffix containment",
handlerDomain: "*.example.com.",
queryDomain: "notexample.com.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "wildcard accepts pattern registered without trailing dot",
handlerDomain: "*.b.test",
queryDomain: "x.b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
}
for _, tt := range tests {
@@ -321,19 +273,6 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
expectedCalls: 1,
expectedHandler: 2, // highest priority matching handler should be called
},
{
name: "overlapping wildcard suffixes route to correct handler",
handlers: []struct {
pattern string
priority int
}{
{pattern: "*.b.test.", priority: nbdns.PriorityDNSRoute},
{pattern: "*.ab.test.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "app.ab.test.",
expectedCalls: 1,
expectedHandler: 1,
},
{
name: "root zone with specific domain",
handlers: []struct {

View File

@@ -26,19 +26,6 @@ type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
// PeerConnectivity reports whether a tunnel IP belongs to a peer the
// client knows about and whether that peer is currently connected. The
// local resolver uses this to suppress A/AAAA answers whose RDATA points
// at a disconnected peer (typical case: a synthesized private-service
// record pointing at an embedded proxy peer that just went offline).
//
// known=false means the IP isn't in the local peerstore at all — the
// record is left alone (it points at something outside our mesh, e.g.
// a non-peer upstream).
type PeerConnectivity interface {
IsConnectedByIP(ip string) (known, connected bool)
}
type Resolver struct {
mu sync.RWMutex
records map[dns.Question][]dns.RR
@@ -46,11 +33,6 @@ type Resolver struct {
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
zones map[domain.Domain]bool
resolver resolver
// peerConn, when non-nil, is consulted on every A/AAAA answer to
// drop records pointing at disconnected peers. nil disables the
// filter and preserves the legacy "return whatever is registered"
// behaviour for callers that never wire a status source.
peerConn PeerConnectivity
ctx context.Context
cancel context.CancelFunc
@@ -67,15 +49,6 @@ func NewResolver() *Resolver {
}
}
// SetPeerConnectivity wires the per-IP connectivity check used to filter
// out A/AAAA answers pointing at disconnected peers. Pass nil to disable.
// Safe to call multiple times; the latest value wins.
func (d *Resolver) SetPeerConnectivity(p PeerConnectivity) {
d.mu.Lock()
defer d.mu.Unlock()
d.peerConn = p
}
func (d *Resolver) MatchSubdomains() bool {
return true
}
@@ -122,7 +95,6 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
replyMessage.RecursionAvailable = true
result := d.lookupRecords(logger, question)
result.records = d.filterDisconnectedPeerAnswers(logger, question, result.records)
replyMessage.Authoritative = !result.hasExternalData
replyMessage.Answer = result.records
replyMessage.Rcode = d.determineRcode(question, result)
@@ -464,78 +436,6 @@ func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16,
}
}
// filterDisconnectedPeerAnswers drops A/AAAA records whose RDATA matches
// a known but disconnected peer. The synthesized private-service zones
// emit one A record per connected proxy peer in a cluster; when a peer
// goes offline, the server-side refresh removes the record from the
// next netmap, but the client may still hold the previous netmap for a
// short window. This filter is the local belt to that braces — even on
// the stale netmap, the resolver hides the offline target.
//
// Records pointing at unknown IPs (outside the local peerstore, e.g.
// non-mesh upstreams) are never dropped. Non-A/AAAA records pass
// through untouched.
//
// Escape hatch: if filtering would leave the answer empty AND at least
// one record was filtered, the original list is returned. Better to
// hand the client a record that may not respond than NXDOMAIN it
// completely when every proxy peer is offline (the upstream may still
// be reachable some other way, or the peerstore may be stale).
func (d *Resolver) filterDisconnectedPeerAnswers(logger *log.Entry, question dns.Question, records []dns.RR) []dns.RR {
if len(records) == 0 {
return records
}
d.mu.RLock()
checker := d.peerConn
d.mu.RUnlock()
if checker == nil {
return records
}
kept := make([]dns.RR, 0, len(records))
var dropped int
for _, rr := range records {
ip := extractRecordIP(rr)
if ip == "" {
kept = append(kept, rr)
continue
}
known, connected := checker.IsConnectedByIP(ip)
if known && !connected {
dropped++
continue
}
kept = append(kept, rr)
}
if dropped == 0 {
return records
}
if len(kept) == 0 {
logger.Debugf("all %d answers for %s point at disconnected peers; returning the original list", dropped, question.Name)
return records
}
logger.Tracef("dropped %d disconnected-peer answer(s) for %s, returning %d", dropped, question.Name, len(kept))
return kept
}
// extractRecordIP returns the dotted-decimal / colon-hex IP carried by
// an A or AAAA record, or "" for any other record type.
func extractRecordIP(rr dns.RR) string {
switch r := rr.(type) {
case *dns.A:
if r.A == nil {
return ""
}
return r.A.String()
case *dns.AAAA:
if r.AAAA == nil {
return ""
}
return r.AAAA.String()
}
return ""
}
// Update replaces all zones and their records
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
d.mu.Lock()

View File

@@ -30,21 +30,6 @@ func (m *mockResolver) LookupNetIP(ctx context.Context, network, host string) ([
return nil, nil
}
// mockPeerConnectivity returns canned (known, connected) results per IP.
// Used by the disconnected-peer filter tests below. IPs not in the map
// are reported as unknown so the filter leaves them alone.
type mockPeerConnectivity struct {
byIP map[string]struct{ known, connected bool }
}
func (m mockPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) {
v, ok := m.byIP[ip]
if !ok {
return false, false
}
return v.known, v.connected
}
func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.",
@@ -2667,114 +2652,3 @@ func BenchmarkIsInManagedZone_ManyZones(b *testing.B) {
resolver.isInManagedZone(qname)
}
}
// TestLocalResolver_FilterDisconnectedPeerAnswers verifies the
// connectivity-aware filtering layered on top of lookupRecords:
// when an A record's IP belongs to a known peer that's disconnected,
// the record is dropped from the answer. Records for unknown IPs pass
// through. If filtering would empty the answer entirely and at least
// one record was dropped, the original list is restored (escape hatch
// for the "all proxies offline" case).
func TestLocalResolver_FilterDisconnectedPeerAnswers(t *testing.T) {
zone := "svc.cluster.netbird."
connectedRec := nbdns.SimpleRecord{
Name: zone,
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 5,
RData: "100.64.0.10",
}
disconnectedRec := nbdns.SimpleRecord{
Name: zone,
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 5,
RData: "100.64.0.11",
}
unknownRec := nbdns.SimpleRecord{
Name: zone,
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 5,
RData: "203.0.113.5",
}
type ipState struct{ known, connected bool }
tests := []struct {
name string
records []nbdns.SimpleRecord
connByIP map[string]ipState
wantInOrder []string
}{
{
name: "drops disconnected peer, keeps connected",
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
connByIP: map[string]ipState{
"100.64.0.10": {known: true, connected: true},
"100.64.0.11": {known: true, connected: false},
},
wantInOrder: []string{"100.64.0.10"},
},
{
name: "unknown IPs pass through untouched",
records: []nbdns.SimpleRecord{unknownRec, disconnectedRec},
connByIP: map[string]ipState{
"100.64.0.11": {known: true, connected: false},
},
wantInOrder: []string{"203.0.113.5"},
},
{
name: "all disconnected falls back to original list",
records: []nbdns.SimpleRecord{disconnectedRec, connectedRec},
connByIP: map[string]ipState{
"100.64.0.10": {known: true, connected: false},
"100.64.0.11": {known: true, connected: false},
},
wantInOrder: []string{"100.64.0.11", "100.64.0.10"},
},
{
name: "no checker wired returns all records",
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
connByIP: nil,
wantInOrder: []string{"100.64.0.10", "100.64.0.11"},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
resolver := NewResolver()
if tc.connByIP != nil {
cm := mockPeerConnectivity{byIP: make(map[string]struct{ known, connected bool }, len(tc.connByIP))}
for ip, st := range tc.connByIP {
cm.byIP[ip] = struct{ known, connected bool }{st.known, st.connected}
}
resolver.SetPeerConnectivity(cm)
}
resolver.Update([]nbdns.CustomZone{{
Domain: strings.TrimSuffix(zone, "."),
Records: tc.records,
NonAuthoritative: true,
}})
var got *dns.Msg
writer := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
got = m
return nil
},
}
req := new(dns.Msg).SetQuestion(zone, dns.TypeA)
resolver.ServeDNS(writer, req)
require.NotNil(t, got, "resolver must produce a response")
require.Len(t, got.Answer, len(tc.wantInOrder),
"answer count must match expected: %v", tc.wantInOrder)
for i, want := range tc.wantInOrder {
a, ok := got.Answer[i].(*dns.A)
require.True(t, ok, "answer[%d] must be an A record", i)
assert.Equal(t, want, a.A.String(),
"answer[%d] expected %s got %s", i, want, a.A.String())
}
})
}
}

View File

@@ -301,11 +301,6 @@ func newDefaultServer(
warningDelayBase: defaultWarningDelayBase,
healthRefresh: make(chan struct{}, 1),
}
// Wire the local resolver against the peer status recorder so it can
// suppress A/AAAA answers that point at disconnected peers (typical
// case: synthesised private-service records pointing at an embedded
// proxy peer that just went offline).
defaultServer.localResolver.SetPeerConnectivity(localPeerConnectivity{statusRecorder})
// register with root zone, handler chain takes care of the routing
dnsService.RegisterMux(".", handlerChain)
@@ -1391,25 +1386,3 @@ func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error {
}
return nil
}
// localPeerConnectivity adapts *peer.Status to local.PeerConnectivity so
// the local resolver can ask "is this IP a known peer and is it
// connected?" without taking on the peer package as a dependency.
// A nil status recorder always reports known=false so the resolver
// short-circuits to the legacy "return everything" path.
type localPeerConnectivity struct {
status *peer.Status
}
// IsConnectedByIP looks the IP up in the peerstore and surfaces both
// the known and connected bits. Used by Resolver.filterDisconnectedPeerAnswers.
func (l localPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) {
if l.status == nil {
return false, false
}
state, ok := l.status.PeerStateByIP(ip)
if !ok {
return false, false
}
return true, state.ConnStatus == peer.StatusConnected
}

View File

@@ -900,7 +900,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
return nil, err
}
pf, err := uspfilter.Create(wgIface, nil, false, flowLogger, iface.DefaultMTU)
pf, err := uspfilter.Create(wgIface, false, flowLogger, iface.DefaultMTU)
if err != nil {
t.Fatalf("failed to create uspfilter: %v", err)
return nil, err

View File

@@ -3,6 +3,7 @@ package dnsfwd
import (
"context"
"fmt"
"net"
"net/netip"
"os"
"strconv"
@@ -159,13 +160,12 @@ func (m *Manager) allowDNSFirewall() error {
return nil
}
anyV4 := []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
dnsRule, err := m.firewall.AddFilterRule(nil, anyV4, firewall.Network{}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept)
dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
if err != nil {
return fmt.Errorf("add udp firewall rule: %w", err)
}
tcpRule, err := m.firewall.AddFilterRule(nil, anyV4, firewall.Network{}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept)
tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "")
if err != nil {
return fmt.Errorf("add tcp firewall rule: %w", err)
}
@@ -174,12 +174,8 @@ func (m *Manager) allowDNSFirewall() error {
return fmt.Errorf("flush: %w", err)
}
if dnsRule != nil {
m.fwRules = []firewall.Rule{dnsRule}
}
if tcpRule != nil {
m.tcpRules = []firewall.Rule{tcpRule}
}
m.fwRules = dnsRules
m.tcpRules = tcpRules
m.registerNetstackServices()
@@ -213,12 +209,12 @@ func (m *Manager) unregisterNetstackServices() {
func (m *Manager) dropDNSFirewall() error {
var mErr *multierror.Error
for _, rule := range m.fwRules {
if err := m.firewall.DeleteFilterRule(rule); err != nil {
if err := m.firewall.DeletePeerRule(rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
}
}
for _, rule := range m.tcpRules {
if err := m.firewall.DeleteFilterRule(rule); err != nil {
if err := m.firewall.DeletePeerRule(rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
}
}

View File

@@ -640,14 +640,14 @@ func (e *Engine) initFirewall() error {
port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}}
// IPv4-only: rosenpass peers connect via AllowedIps[0] which is always v4.
if _, err := e.firewall.AddFilterRule(
if _, err := e.firewall.AddPeerFiltering(
nil,
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
firewallManager.Network{},
net.IP{0, 0, 0, 0},
firewallManager.ProtocolUDP,
nil,
&port,
firewallManager.ActionAccept,
"",
); err != nil {
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
return nil
@@ -697,7 +697,7 @@ func (e *Engine) blockLanAccess() {
if network.Addr().Is6() {
source = v6
}
if _, err := e.firewall.AddFilterRule(
if _, err := e.firewall.AddRouteFiltering(
nil,
[]netip.Prefix{source},
firewallManager.Network{Prefix: network},
@@ -1967,29 +1967,6 @@ func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
return e.clientMetrics
}
// Performance bundles runtime-adjustable tunnel pool knobs.
// See Engine.SetPerformance. Nil fields are ignored.
type Performance struct {
PreallocatedBuffersPerPool *uint32
}
// SetPerformance applies the given tuning to this engine's live Device.
func (e *Engine) SetPerformance(t Performance) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if e.wgInterface == nil {
return fmt.Errorf("wg interface not initialized")
}
dev := e.wgInterface.GetWGDevice()
if dev == nil {
return fmt.Errorf("wg device not initialized")
}
if t.PreallocatedBuffersPerPool != nil {
dev.SetPreallocatedBuffersPerPool(*t.PreallocatedBuffersPerPool)
}
return nil
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {
@@ -2369,7 +2346,7 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewal
var merr *multierror.Error
forwardingRules := make([]firewallManager.ForwardRule, 0, len(rules))
for _, rule := range rules {
proto, err := acl.ConvertToFirewallProtocol(rule.GetProtocol())
proto, err := convertToFirewallProtocol(rule.GetProtocol())
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("failed to convert protocol '%s': %w", rule.GetProtocol(), err))
continue

View File

@@ -27,7 +27,7 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
@@ -66,8 +66,8 @@ import (
"github.com/netbirdio/netbird/route"
mgmt "github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/shared/netiputil"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
@@ -1641,7 +1641,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
return nil, "", err
}
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)

View File

@@ -24,14 +24,14 @@ type RulePair struct {
type Manager struct {
dnatFirewall DNATFirewall
rules map[firewall.RuleID]RulePair
rules map[string]RulePair // keys is the ID of the ForwardRule
rulesMu sync.Mutex
}
func NewManager(dnatFirewall DNATFirewall) *Manager {
return &Manager{
dnatFirewall: dnatFirewall,
rules: make(map[firewall.RuleID]RulePair),
rules: make(map[string]RulePair),
}
}
@@ -41,7 +41,7 @@ func (h *Manager) Update(forwardRules []firewall.ForwardRule) error {
var mErr *multierror.Error
toDelete := make(map[firewall.RuleID]RulePair, len(h.rules))
toDelete := make(map[string]RulePair, len(h.rules))
for id, r := range h.rules {
toDelete[id] = r
}
@@ -59,10 +59,6 @@ func (h *Manager) Update(forwardRules []firewall.ForwardRule) error {
mErr = multierror.Append(mErr, fmt.Errorf("add forward rule '%s': %v", fwdRule.String(), err))
continue
}
if rule == nil {
mErr = multierror.Append(mErr, fmt.Errorf("add forward rule '%s': backend returned no rule", fwdRule.String()))
continue
}
log.Infof("forward rule has been added '%s'", fwdRule)
h.rules[id] = RulePair{
ForwardRule: fwdRule,
@@ -94,7 +90,7 @@ func (h *Manager) Close() error {
}
}
h.rules = make(map[firewall.RuleID]RulePair)
h.rules = make(map[string]RulePair)
return nberrors.FormatErrorOrNil(mErr)
}

View File

@@ -14,11 +14,11 @@ var (
)
type MocFwRule struct {
id firewall.RuleID
id string
}
func (m *MocFwRule) ID() firewall.RuleID {
return m.id
func (m *MocFwRule) ID() string {
return string(m.id)
}
type MockDNATFirewall struct {

View File

@@ -10,6 +10,21 @@ import (
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewallManager.Protocol, error) {
switch protocol {
case mgmProto.RuleProtocol_TCP:
return firewallManager.ProtocolTCP, nil
case mgmProto.RuleProtocol_UDP:
return firewallManager.ProtocolUDP, nil
case mgmProto.RuleProtocol_ICMP:
return firewallManager.ProtocolICMP, nil
case mgmProto.RuleProtocol_ALL:
return firewallManager.ProtocolALL, nil
default:
return "", fmt.Errorf("invalid protocol type: %s", protocol.String())
}
}
func convertPortInfo(portInfo *mgmProto.PortInfo) (*firewallManager.Port, error) {
if portInfo == nil {
return nil, errors.New("portInfo cannot be nil")

View File

@@ -50,7 +50,7 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
switch msg.Type {
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, flags, err := parseRouteMessage(buf[:n])
route, err := parseRouteMessage(buf[:n])
if err != nil {
log.Debugf("Network monitor: error parsing routing message: %v", err)
continue
@@ -66,10 +66,6 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
}
switch msg.Type {
case unix.RTM_ADD:
if systemops.IgnoreAddedDefaultRoute(flags) {
log.Debugf("Network monitor: ignoring added default route via %s, interface %s, flags %#x", route.Gw, intf, flags)
continue
}
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
return nil
case unix.RTM_DELETE:
@@ -82,26 +78,22 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
}
}
func parseRouteMessage(buf []byte) (*systemops.Route, int, error) {
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
if err != nil {
return nil, 0, fmt.Errorf("parse RIB: %v", err)
return nil, fmt.Errorf("parse RIB: %v", err)
}
if len(msgs) != 1 {
return nil, 0, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
}
msg, ok := msgs[0].(*route.RouteMessage)
if !ok {
return nil, 0, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
}
r, err := systemops.MsgToRoute(msg)
if err != nil {
return nil, 0, err
}
return r, msg.Flags, nil
return systemops.MsgToRoute(msg)
}
// waitReadable blocks until fd has data to read, or ctx is cancelled.

View File

@@ -23,7 +23,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
@@ -900,7 +899,7 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
}
// Fallback to deterministic key if no NetBird PSK is configured
determKey, err := rosenpass.DeterministicSeedKey(conn.config.LocalKey, conn.config.Key)
determKey, err := conn.rosenpassDetermKey()
if err != nil {
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
return nil
@@ -909,6 +908,26 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
return determKey
}
// todo: move this logic into Rosenpass package
func (conn *Conn) rosenpassDetermKey() (*wgtypes.Key, error) {
lk := []byte(conn.config.LocalKey)
rk := []byte(conn.config.Key) // remote key
var keyInput []byte
if string(lk) > string(rk) {
//nolint:gocritic
keyInput = append(lk[:16], rk[:16]...)
} else {
//nolint:gocritic
keyInput = append(rk[:16], lk[:16]...)
}
key, err := wgtypes.NewKey(keyInput)
if err != nil {
return nil, err
}
return &key, nil
}
func isController(config ConnConfig) bool {
return config.LocalKey > config.Key
}

View File

@@ -185,12 +185,9 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
return s.eventsChan
}
// Status holds a state of peers, signal, management connections and relays.
// mux is an RWMutex so hot read paths (notably PeerStateByIP, called for
// every private-service request) don't contend against each other.
// Pure read methods take RLock; anything that mutates state takes Lock.
// Status holds a state of peers, signal, management connections and relays
type Status struct {
mux sync.RWMutex
mux sync.Mutex
peers map[string]State
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
signalState bool
@@ -286,8 +283,8 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string)
// GetPeer adds peer to Daemon status map
func (d *Status) GetPeer(peerPubKey string) (State, error) {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
state, ok := d.peers[peerPubKey]
if !ok {
@@ -297,8 +294,8 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
}
func (d *Status) PeerByIP(ip string) (string, bool) {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
for _, state := range d.peers {
if state.IP == ip {
@@ -308,25 +305,6 @@ func (d *Status) PeerByIP(ip string) (string, bool) {
return "", false
}
// PeerStateByIP returns the full peer State for the given tunnel IP.
// Matches against either the IPv4 (State.IP) or IPv6 (State.IPv6) tunnel
// address so dual-stack peers are reachable on either family. Returns the
// zero State and false when no peer matches or the input is empty.
func (d *Status) PeerStateByIP(ip string) (State, bool) {
if ip == "" {
return State{}, false
}
d.mux.RLock()
defer d.mux.RUnlock()
for _, state := range d.peers {
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
return state, true
}
}
return State{}, false
}
// RemovePeer removes peer from Daemon status map
func (d *Status) RemovePeer(peerPubKey string) error {
d.mux.Lock()
@@ -724,8 +702,8 @@ func (d *Status) UnsubscribePeerStateChanges(subscription *StatusChangeSubscript
// GetLocalPeerState returns the local peer state
func (d *Status) GetLocalPeerState() LocalPeerState {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
return d.localPeer.Clone()
}
@@ -931,8 +909,8 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
}
func (d *Status) GetRosenpassState() RosenpassState {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
return RosenpassState{
d.rosenpassEnabled,
d.rosenpassPermissive,
@@ -940,14 +918,14 @@ func (d *Status) GetRosenpassState() RosenpassState {
}
func (d *Status) GetLazyConnection() bool {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
return d.lazyConnectionEnabled
}
func (d *Status) GetManagementState() ManagementState {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
return ManagementState{
d.mgmAddress,
d.managementState,
@@ -973,8 +951,8 @@ func (d *Status) UpdateLatency(pubKey string, latency time.Duration) error {
// IsLoginRequired determines if a peer's login has expired.
func (d *Status) IsLoginRequired() bool {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
// if peer is connected to the management then login is not expired
if d.managementState {
@@ -989,8 +967,8 @@ func (d *Status) IsLoginRequired() bool {
}
func (d *Status) GetSignalState() SignalState {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
return SignalState{
d.signalAddress,
d.signalState,
@@ -1000,8 +978,8 @@ func (d *Status) GetSignalState() SignalState {
// GetRelayStates returns the stun/turn/permanent relay states
func (d *Status) GetRelayStates() []relay.ProbeResult {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
if d.relayMgr == nil {
return d.relayStates
}
@@ -1030,8 +1008,8 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
}
func (d *Status) ForwardingRules() []firewall.ForwardRule {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
if d.ingressGwMgr == nil {
return nil
}
@@ -1040,16 +1018,16 @@ func (d *Status) ForwardingRules() []firewall.ForwardRule {
}
func (d *Status) GetDNSStates() []NSGroupState {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
// shallow copy is good enough, as slices fields are currently not updated
return slices.Clone(d.nsGroupStates)
}
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
return maps.Clone(d.resolvedDomainsStates)
}
@@ -1065,8 +1043,8 @@ func (d *Status) GetFullStatus() FullStatus {
LazyConnectionEnabled: d.GetLazyConnection(),
}
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
fullStatus.LocalPeerState = d.localPeer
@@ -1241,8 +1219,8 @@ func (d *Status) SetWgIface(wgInterface WGIfaceStatus) {
}
func (d *Status) PeersStatus() (*configurer.Stats, error) {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
if d.wgIface == nil {
return nil, fmt.Errorf("wgInterface is nil, cannot retrieve peers status")
}

View File

@@ -63,33 +63,6 @@ func TestUpdatePeerState(t *testing.T) {
assert.Equal(t, ip, state.IP, "ip should be equal")
}
func TestStatus_PeerStateByIP(t *testing.T) {
status := NewRecorder("https://mgm")
req := require.New(t)
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", ""))
req.NoError(status.AddPeer("pk-2", "peer-2.netbird", "100.64.0.11", ""))
state, ok := status.PeerStateByIP("100.64.0.10")
req.True(ok, "known tunnel IP should resolve to a peer state")
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
req.Equal("peer-1.netbird", state.FQDN, "matching state must carry the right FQDN")
_, ok = status.PeerStateByIP("100.64.0.99")
req.False(ok, "unknown IP must report ok=false")
}
func TestStatus_PeerStateByIP_MatchesIPv6(t *testing.T) {
status := NewRecorder("https://mgm")
req := require.New(t)
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", "fd00::1"))
state, ok := status.PeerStateByIP("fd00::1")
req.True(ok, "IPv6-only match must resolve to the peer state")
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
}
func TestStatus_UpdatePeerFQDN(t *testing.T) {
key := "abc"
fqdn := "peer-a.netbird.local"

View File

@@ -179,10 +179,8 @@ func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) {
}
dst := net.IPv4zero
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
// go-netroute v0.4.0 rejects unspecified destinations client-side on Linux/Android.
// TODO: on android/ios, use platform APIs (ConnectivityManager.getLinkProperties /
// NWPathMonitor) when netlink-based lookup is restricted or unavailable.
if runtime.GOOS == "linux" {
// go-netroute v0.4.0 rejects unspecified destinations client-side on Linux.
dst = net.IPv4(0, 0, 0, 1)
}
_, gateway, localIP, err = router.Route(dst)
@@ -205,7 +203,7 @@ func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) {
}
dst := net.IPv6zero
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if runtime.GOOS == "linux" {
// ::2
dst = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}
}

Some files were not shown because too many files have changed in this diff Show More